diff --git a/.clang-format b/.clang-format index 8bf91de917..6fca7ffffb 100644 --- a/.clang-format +++ b/.clang-format @@ -1,66 +1,5 @@ --- -Language: Cpp -# BasedOnStyle: LLVM -AccessModifierOffset: -2 -AlignAfterOpenBracket: true -AlignEscapedNewlinesLeft: false -AlignOperands: true -AlignTrailingComments: true -AlignConsecutiveAssignments: false -AllowAllParametersOfDeclarationOnNextLine: true -AllowShortBlocksOnASingleLine: false -AllowShortCaseLabelsOnASingleLine: false -AllowShortIfStatementsOnASingleLine: false -AllowShortLoopsOnASingleLine: false -AllowShortFunctionsOnASingleLine: All -AlwaysBreakAfterDefinitionReturnType: false -AlwaysBreakTemplateDeclarations: false -AlwaysBreakBeforeMultilineStrings: false -BreakBeforeBinaryOperators: None -BreakBeforeTernaryOperators: true -BreakConstructorInitializersBeforeComma: false -BinPackParameters: true -BinPackArguments: true -ColumnLimit: 80 -ConstructorInitializerAllOnOneLineOrOnePerLine: false -ConstructorInitializerIndentWidth: 4 -DerivePointerAlignment: false -ExperimentalAutoDetectBinPacking: false -IndentCaseLabels: false -IndentWrappedFunctionNames: false -IndentFunctionDeclarationAfterType: false -MaxEmptyLinesToKeep: 1 -KeepEmptyLinesAtTheStartOfBlocks: true -NamespaceIndentation: None -ObjCBlockIndentWidth: 2 -ObjCSpaceAfterProperty: false -ObjCSpaceBeforeProtocolList: true -PenaltyBreakBeforeFirstCallParameter: 19 -PenaltyBreakComment: 300 -PenaltyBreakString: 1000 -PenaltyBreakFirstLessLess: 120 -PenaltyExcessCharacter: 1000000 -PenaltyReturnTypeOnItsOwnLine: 60 -PointerAlignment: Right -SpacesBeforeTrailingComments: 1 -Cpp11BracedListStyle: true -Standard: Cpp11 -IndentWidth: 2 -TabWidth: 8 -UseTab: Never -BreakBeforeBraces: Attach -SpacesInParentheses: false -SpacesInSquareBrackets: false -SpacesInAngles: false -SpaceInEmptyParentheses: false -SpacesInCStyleCastParentheses: false -SpaceAfterCStyleCast: false -SpacesInContainerLiterals: true -SpaceBeforeAssignmentOperators: true -ContinuationIndentWidth: 4 -CommentPragmas: '^ IWYU pragma:' -ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] -SpaceBeforeParens: ControlStatements -DisableFormat: false +Language: Cpp +BasedOnStyle: LLVM +DisableFormat: false ... - diff --git a/src/examples/algorithms/HelloSP_TP.cpp b/src/examples/algorithms/HelloSP_TP.cpp index 9858a1f1f1..c34936576d 100644 --- a/src/examples/algorithms/HelloSP_TP.cpp +++ b/src/examples/algorithms/HelloSP_TP.cpp @@ -20,31 +20,33 @@ * --------------------------------------------------------------------- */ +#include // std::generate +#include // pow +#include // std::rand, std::srand +#include // std::time #include #include -#include // std::generate -#include // std::time -#include // std::rand, std::srand -#include // pow -#include "nupic/algorithms/SpatialPooler.hpp" #include "nupic/algorithms/Cells4.hpp" +#include "nupic/algorithms/SpatialPooler.hpp" #include "nupic/os/Timer.hpp" using namespace std; using namespace nupic; -using nupic::algorithms::spatial_pooler::SpatialPooler; using nupic::algorithms::Cells4::Cells4; +using nupic::algorithms::spatial_pooler::SpatialPooler; // function generator: -int RandomNumber01 () { return (rand()%2); } // returns random (binary) numbers from {0,1} +int RandomNumber01() { + return (rand() % 2); +} // returns random (binary) numbers from {0,1} -int main(int argc, const char * argv[]) -{ -const UInt DIM = 2048; // number of columns in SP, TP -const UInt DIM_INPUT = 10000; -const UInt TP_CELLS_PER_COL = 10; // cells per column in TP -const UInt EPOCHS = pow(10, 4); // number of iterations (calls to SP/TP compute() ) +int main(int argc, const char *argv[]) { + const UInt DIM = 2048; // number of columns in SP, TP + const UInt DIM_INPUT = 10000; + const UInt TP_CELLS_PER_COL = 10; // cells per column in TP + const UInt EPOCHS = + pow(10, 4); // number of iterations (calls to SP/TP compute() ) vector inputDim = {DIM_INPUT}; vector colDim = {DIM}; @@ -53,18 +55,19 @@ const UInt EPOCHS = pow(10, 4); // number of iterations (calls to SP/TP compute( vector input(DIM_INPUT); vector outSP(DIM); // active array, output of SP/TP const int _CELLS = DIM * TP_CELLS_PER_COL; - vector outTP(_CELLS); + vector outTP(_CELLS); Real rIn[DIM] = {}; // input for TP (must be Reals) Real rOut[_CELLS] = {}; // initialize SP, TP SpatialPooler sp(inputDim, colDim); - Cells4 tp(DIM, TP_CELLS_PER_COL, 12, 8, 15, 5, .5, .8, 1.0, .1, .1, 0.0, false, 42, true, false); + Cells4 tp(DIM, TP_CELLS_PER_COL, 12, 8, 15, 5, .5, .8, 1.0, .1, .1, 0.0, + false, 42, true, false); // Start a stopwatch timer Timer stopwatch(true); - //run + // run for (UInt e = 0; e < EPOCHS; e++) { generate(input.begin(), input.end(), RandomNumber01); fill(outSP.begin(), outSP.end(), 0); @@ -77,12 +80,12 @@ const UInt EPOCHS = pow(10, 4); // number of iterations (calls to SP/TP compute( tp.compute(rIn, rOut, true, true); - for (UInt i=0; i< _CELLS; i++) { + for (UInt i = 0; i < _CELLS; i++) { outTP[i] = (UInt)rOut[i]; } // print - if (e == EPOCHS-1) { + if (e == EPOCHS - 1) { cout << "Epoch = " << e << endl; cout << "SP=" << outSP << endl; cout << "TP=" << outTP << endl; @@ -90,7 +93,8 @@ const UInt EPOCHS = pow(10, 4); // number of iterations (calls to SP/TP compute( } stopwatch.stop(); - cout << "Total elapsed time = " << stopwatch.getElapsed() << " seconds" << endl; + cout << "Total elapsed time = " << stopwatch.getElapsed() << " seconds" + << endl; return 0; } diff --git a/src/examples/prototest.cpp b/src/examples/prototest.cpp index a34fa14af6..ccb1a710a0 100644 --- a/src/examples/prototest.cpp +++ b/src/examples/prototest.cpp @@ -20,8 +20,8 @@ * --------------------------------------------------------------------- */ -#include #include +#include #include #include #include @@ -38,8 +38,7 @@ using namespace std; using namespace nupic; using namespace nupic::algorithms::spatial_pooler; -void testSP() -{ +void testSP() { Random random(10); const UInt inputSize = 500; @@ -53,10 +52,8 @@ void testSP() sp1.initialize(inputDims, colDims); UInt input[inputSize]; - for (UInt i = 0; i < inputSize; ++i) - { - if (i < w) - { + for (UInt i = 0; i < inputSize; ++i) { + if (i < w) { input[i] = 1; } else { input[i] = 0; @@ -64,8 +61,7 @@ void testSP() } UInt output[numColumns]; - for (UInt i = 0; i < 10000; ++i) - { + for (UInt i = 0; i < 10000; ++i) { random.shuffle(input, input + inputSize); sp1.compute(input, true, output); } @@ -73,10 +69,8 @@ void testSP() // Now we reuse the last input to test after serialization vector activeColumnsBefore; - for (UInt i = 0; i < numColumns; ++i) - { - if (output[i] == 1) - { + for (UInt i = 0; i < numColumns; ++i) { + if (output[i] == 1) { activeColumnsBefore.push_back(i); } } @@ -94,8 +88,7 @@ void testSP() Real64 timeA = 0.0, timeC = 0.0; - for (UInt i = 0; i < 100; ++i) - { + for (UInt i = 0; i < 100; ++i) { // Create new input random.shuffle(input, input + inputSize); @@ -128,8 +121,7 @@ void testSP() timeA = timeA + testTimer.getElapsed(); } - for (UInt i = 0; i < numColumns; ++i) - { + for (UInt i = 0; i < numColumns; ++i) { NTA_CHECK(outputBaseline[i] == outputA[i]); } @@ -158,11 +150,9 @@ void testSP() timeC = timeC + testTimer.getElapsed(); } - for (UInt i = 0; i < numColumns; ++i) - { + for (UInt i = 0; i < numColumns; ++i) { NTA_CHECK(outputBaseline[i] == outputC[i]); } - } remove("outA.proto"); @@ -173,15 +163,13 @@ void testSP() cout << "Manual: " << timeC << endl; } -void testRandomIOStream(UInt n) -{ +void testRandomIOStream(UInt n) { Random r1(7); Random r2; nupic::Timer testTimer; testTimer.start(); - for (UInt i = 0; i < n; ++i) - { + for (UInt i = 0; i < n; ++i) { r1.getUInt32(); // Serialize @@ -209,15 +197,13 @@ void testRandomIOStream(UInt n) cout << "Cap'n Proto: " << testTimer.getElapsed() << endl; } -void testRandomManual(UInt n) -{ +void testRandomManual(UInt n) { Random r1(7); Random r2; nupic::Timer testTimer; testTimer.start(); - for (UInt i = 0; i < n; ++i) - { + for (UInt i = 0; i < n; ++i) { r1.getUInt32(); // Serialize @@ -245,8 +231,7 @@ void testRandomManual(UInt n) cout << "Manual: " << testTimer.getElapsed() << endl; } -int main(int argc, const char * argv[]) -{ +int main(int argc, const char *argv[]) { UInt n = 1000; cout << "Timing for Random serialization (smaller is better):" << endl; testRandomIOStream(n); diff --git a/src/examples/regions/HelloRegions.cpp b/src/examples/regions/HelloRegions.cpp index 4c7ba1758c..208803fac9 100644 --- a/src/examples/regions/HelloRegions.cpp +++ b/src/examples/regions/HelloRegions.cpp @@ -25,86 +25,83 @@ #include #include -#include #include +#include #include #include - - using namespace nupic; -int main(int argc, const char * argv[]) -{ - // Create network - Network net = Network(); +int main(int argc, const char *argv[]) { + // Create network + Network net = Network(); - // Add VectorFileSensor region to network - Region* region = net.addRegion("region", "VectorFileSensor", "{activeOutputCount: 1}"); + // Add VectorFileSensor region to network + Region *region = + net.addRegion("region", "VectorFileSensor", "{activeOutputCount: 1}"); - // Set region dimensions - Dimensions dims; - dims.push_back(1); + // Set region dimensions + Dimensions dims; + dims.push_back(1); - std::cout << "Setting region dimensions" << dims.toString() << std::endl; + std::cout << "Setting region dimensions" << dims.toString() << std::endl; - region->setDimensions(dims); + region->setDimensions(dims); - // Load data - std::string path = Path::makeAbsolute("../../../src/examples/regions/Data.csv"); + // Load data + std::string path = + Path::makeAbsolute("../../../src/examples/regions/Data.csv"); - std::cout << "Loading data from " << path << std::endl; + std::cout << "Loading data from " << path << std::endl; - std::vector loadFileArgs; - loadFileArgs.push_back("loadFile"); - loadFileArgs.push_back(path); - loadFileArgs.push_back("2"); + std::vector loadFileArgs; + loadFileArgs.push_back("loadFile"); + loadFileArgs.push_back(path); + loadFileArgs.push_back("2"); - region->executeCommand(loadFileArgs); + region->executeCommand(loadFileArgs); - // Initialize network - std::cout << "Initializing network" << std::endl; + // Initialize network + std::cout << "Initializing network" << std::endl; - net.initialize(); + net.initialize(); - ArrayRef outputArray = region->getOutputData("dataOut"); + ArrayRef outputArray = region->getOutputData("dataOut"); - // Compute - std::cout << "Compute" << std::endl; + // Compute + std::cout << "Compute" << std::endl; - region->compute(); + region->compute(); - // Get output - Real64 *buffer = (Real64*) outputArray.getBuffer(); + // Get output + Real64 *buffer = (Real64 *)outputArray.getBuffer(); - for (size_t i = 0; i < outputArray.getCount(); i++) - { - std::cout << " " << i << " " << buffer[i] << "" << std::endl; - } + for (size_t i = 0; i < outputArray.getCount(); i++) { + std::cout << " " << i << " " << buffer[i] << "" << std::endl; + } - // Serialize - Network net2; - { - std::stringstream ss; - net.write(ss); - net2.read(ss); - } - net2.initialize(); + // Serialize + Network net2; + { + std::stringstream ss; + net.write(ss); + net2.read(ss); + } + net2.initialize(); - Region* region2 = net2.getRegions().getByName("region"); - region2->executeCommand(loadFileArgs); - ArrayRef outputArray2 = region2->getOutputData("dataOut"); - Real64 *buffer2 = (Real64*)outputArray2.getBuffer(); + Region *region2 = net2.getRegions().getByName("region"); + region2->executeCommand(loadFileArgs); + ArrayRef outputArray2 = region2->getOutputData("dataOut"); + Real64 *buffer2 = (Real64 *)outputArray2.getBuffer(); - net.run(1); - net2.run(1); + net.run(1); + net2.run(1); - NTA_ASSERT(outputArray2.getCount() == outputArray.getCount()); - for (size_t i = 0; i < outputArray.getCount(); i++) - { - std::cout << " " << i << " " << buffer[i] << " " << buffer2[i] - << std::endl; - } + NTA_ASSERT(outputArray2.getCount() == outputArray.getCount()); + for (size_t i = 0; i < outputArray.getCount(); i++) { + std::cout << " " << i << " " << buffer[i] << " " << buffer2[i] + << std::endl; + } - return 0; + return 0; } diff --git a/src/nupic/algorithms/Anomaly.cpp b/src/nupic/algorithms/Anomaly.cpp index 5dbfbf3ae8..c972601bf3 100644 --- a/src/nupic/algorithms/Anomaly.cpp +++ b/src/nupic/algorithms/Anomaly.cpp @@ -20,11 +20,11 @@ * --------------------------------------------------------------------- */ -#include -#include #include #include +#include #include +#include #include "nupic/algorithms/Anomaly.hpp" #include "nupic/utils/Log.hpp" @@ -32,22 +32,16 @@ using namespace std; -namespace nupic -{ +namespace nupic { -namespace algorithms -{ +namespace algorithms { -namespace anomaly -{ +namespace anomaly { - -Real32 computeRawAnomalyScore(const vector& active, - const vector& predicted) -{ +Real32 computeRawAnomalyScore(const vector &active, + const vector &predicted) { // Return 0 if no active columns are present - if (active.size() == 0) - { + if (active.size() == 0) { return 0.0f; } @@ -56,23 +50,19 @@ Real32 computeRawAnomalyScore(const vector& active, vector predictedActiveCols; // Calculate and return percent of active columns that were not predicted. - set_intersection(active_.begin(), active_.end(), - predicted_.begin(), predicted_.end(), - back_inserter(predictedActiveCols)); + set_intersection(active_.begin(), active_.end(), predicted_.begin(), + predicted_.end(), back_inserter(predictedActiveCols)); return (active.size() - predictedActiveCols.size()) / Real32(active.size()); } - Anomaly::Anomaly(UInt slidingWindowSize, AnomalyMode mode, Real32 binaryAnomalyThreshold) - : binaryThreshold_(binaryAnomalyThreshold) -{ + : binaryThreshold_(binaryAnomalyThreshold) { NTA_ASSERT(binaryAnomalyThreshold >= 0 && binaryAnomalyThreshold <= 1) << "binaryAnomalyThreshold must be within [0.0,1.0]"; mode_ = mode; - if (slidingWindowSize > 0) - { + if (slidingWindowSize > 0) { movingAverage_.reset(new nupic::util::MovingAverage(slidingWindowSize)); } @@ -81,33 +71,28 @@ Anomaly::Anomaly(UInt slidingWindowSize, AnomalyMode mode, << "C++ Anomaly implemented only for PURE mode!"; } - -Real32 Anomaly::compute( - const vector& active, const vector& predicted, - Real64 inputValue, UInt timestamp) -{ +Real32 Anomaly::compute(const vector &active, + const vector &predicted, Real64 inputValue, + UInt timestamp) { Real32 anomalyScore = computeRawAnomalyScore(active, predicted); Real32 score = anomalyScore; - switch(mode_) - { - case AnomalyMode::PURE: - score = anomalyScore; - break; - case AnomalyMode::LIKELIHOOD: - case AnomalyMode::WEIGHTED: - // Not implemented. Fail - NTA_ASSERT(mode_ == AnomalyMode::PURE) - << "C++ Anomaly implemented only for PURE mode!"; - break; + switch (mode_) { + case AnomalyMode::PURE: + score = anomalyScore; + break; + case AnomalyMode::LIKELIHOOD: + case AnomalyMode::WEIGHTED: + // Not implemented. Fail + NTA_ASSERT(mode_ == AnomalyMode::PURE) + << "C++ Anomaly implemented only for PURE mode!"; + break; } - if (movingAverage_) - { + if (movingAverage_) { score = movingAverage_->compute(score); } - if (binaryThreshold_) - { + if (binaryThreshold_) { score = (score >= binaryThreshold_) ? 1.0 : 0.0; } diff --git a/src/nupic/algorithms/Anomaly.hpp b/src/nupic/algorithms/Anomaly.hpp index 0739c3eeac..2f359f1a35 100644 --- a/src/nupic/algorithms/Anomaly.hpp +++ b/src/nupic/algorithms/Anomaly.hpp @@ -23,103 +23,94 @@ #ifndef NUPIC_ALGORITHMS_ANOMALY_HPP #define NUPIC_ALGORITHMS_ANOMALY_HPP - -#include #include // Needed for smart pointer templates -#include // Needed for for smart pointer templates #include +#include // Needed for for smart pointer templates +#include -namespace nupic -{ - - namespace util - { - class MovingAverage; // Forward declaration - } - - namespace algorithms - { - - namespace anomaly - { - - /** - * Computes the raw anomaly score. - * - * The raw anomaly score is the fraction of active columns not predicted. - * - * @param activeColumns: array of active column indices - * @param prevPredictedColumns: array of columns indices predicted in - * prev step - * @return anomaly score 0..1 (Real32) - */ - Real32 computeRawAnomalyScore(const std::vector& active, - const std::vector& predicted); - - - enum class AnomalyMode { PURE, LIKELIHOOD, WEIGHTED }; - +namespace nupic { - class Anomaly - { - public: - /** - * Utility class for generating anomaly scores in different ways. - * - * Supported modes: - * PURE - the raw anomaly score as computed by computeRawAnomalyScore - * LIKELIHOOD - uses the AnomalyLikelihood class on top of the raw - * anomaly scores (not implemented in C++) - * WEIGHTED - multiplies the likelihood result with the raw anomaly - * score that was used to generate the likelihood (not - * implemented in C++) - * - * @param slidingWindowSize (optional) - how many elements are - * summed up; enables moving average on final anomaly score; - * int >= 0 - * @param mode (optional) - (enum) how to compute anomaly; - * possible values are AnomalyMode:: - * - PURE - the default, how much anomal the value is; - * Real32 0..1 where 1=totally unexpected - * - LIKELIHOOD - uses the anomaly_likelihood code; - * models probability of receiving this value and - * anomalyScore - * - WEIGHTED - "pure" anomaly weighted by "likelihood" - * (anomaly * likelihood) - * @param binaryAnomalyThreshold (optional) - if set [0,1] anomaly - * score will be discretized to 1/0 - * (1 iff >= binaryAnomalyThreshold). The transformation is - * applied after moving average is computed. - */ - Anomaly(UInt slidingWindowSize=0, AnomalyMode mode=AnomalyMode::PURE, - Real32 binaryAnomalyThreshold=0); +namespace util { +class MovingAverage; // Forward declaration +} - /** - * Compute the anomaly score as the percent of active columns not - * predicted. - * - * @param active: array of active column indices - * @param predicted: array of columns indices predicted in this step - * (used for anomaly in step T+1) - * @param inputValue: (optional) value of current input to encoders - * (eg "cat" for category encoder) - * (used in anomaly-likelihood) - * @param timestamp: (optional) date timestamp when the sample occured - * (used in anomaly-likelihood) - * @return the computed anomaly score; Real32 0..1 - */ - Real32 compute(const std::vector& active, - const std::vector& predicted, - Real64 inputValue=0, UInt timestamp=0); +namespace algorithms { - private: - AnomalyMode mode_; - Real32 binaryThreshold_; - std::unique_ptr movingAverage_; +namespace anomaly { - }; - } // namespace anomaly - } // namespace algorithms +/** + * Computes the raw anomaly score. + * + * The raw anomaly score is the fraction of active columns not predicted. + * + * @param activeColumns: array of active column indices + * @param prevPredictedColumns: array of columns indices predicted in + * prev step + * @return anomaly score 0..1 (Real32) + */ +Real32 computeRawAnomalyScore(const std::vector &active, + const std::vector &predicted); + +enum class AnomalyMode { PURE, LIKELIHOOD, WEIGHTED }; + +class Anomaly { +public: + /** + * Utility class for generating anomaly scores in different ways. + * + * Supported modes: + * PURE - the raw anomaly score as computed by computeRawAnomalyScore + * LIKELIHOOD - uses the AnomalyLikelihood class on top of the raw + * anomaly scores (not implemented in C++) + * WEIGHTED - multiplies the likelihood result with the raw anomaly + * score that was used to generate the likelihood (not + * implemented in C++) + * + * @param slidingWindowSize (optional) - how many elements are + * summed up; enables moving average on final anomaly score; + * int >= 0 + * @param mode (optional) - (enum) how to compute anomaly; + * possible values are AnomalyMode:: + * - PURE - the default, how much anomal the value is; + * Real32 0..1 where 1=totally unexpected + * - LIKELIHOOD - uses the anomaly_likelihood code; + * models probability of receiving this value and + * anomalyScore + * - WEIGHTED - "pure" anomaly weighted by "likelihood" + * (anomaly * likelihood) + * @param binaryAnomalyThreshold (optional) - if set [0,1] anomaly + * score will be discretized to 1/0 + * (1 iff >= binaryAnomalyThreshold). The transformation is + * applied after moving average is computed. + */ + Anomaly(UInt slidingWindowSize = 0, AnomalyMode mode = AnomalyMode::PURE, + Real32 binaryAnomalyThreshold = 0); + + /** + * Compute the anomaly score as the percent of active columns not + * predicted. + * + * @param active: array of active column indices + * @param predicted: array of columns indices predicted in this step + * (used for anomaly in step T+1) + * @param inputValue: (optional) value of current input to encoders + * (eg "cat" for category encoder) + * (used in anomaly-likelihood) + * @param timestamp: (optional) date timestamp when the sample occured + * (used in anomaly-likelihood) + * @return the computed anomaly score; Real32 0..1 + */ + Real32 compute(const std::vector &active, + const std::vector &predicted, Real64 inputValue = 0, + UInt timestamp = 0); + +private: + AnomalyMode mode_; + Real32 binaryThreshold_; + std::unique_ptr movingAverage_; +}; +} // namespace anomaly +} // namespace algorithms } // namespace nupic #endif // NUPIC_ALGORITHMS_ANOMALY_HPP diff --git a/src/nupic/algorithms/ArrayBuffer.hpp b/src/nupic/algorithms/ArrayBuffer.hpp index 0d3239ad23..528d7640cc 100644 --- a/src/nupic/algorithms/ArrayBuffer.hpp +++ b/src/nupic/algorithms/ArrayBuffer.hpp @@ -20,47 +20,46 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * This header file defines the data structures used for * facilitating the passing of numpy arrays from python * code to C code. - */ + */ #ifndef NTA_ARRAY_BUFFER_HPP #define NTA_ARRAY_BUFFER_HPP #ifdef __cplusplus extern "C" { -#endif // __cplusplus - +#endif // __cplusplus -// Structure that wraps the essential elements of +// Structure that wraps the essential elements of // a numpy array object. typedef struct _NUMPY_ARRAY { int nNumDims; - const int * pnDimensions; - const int * pnStrides; - const char * pData; -} NUMPY_ARRAY; + const int *pnDimensions; + const int *pnStrides; + const char *pData; +} NUMPY_ARRAY; // Bounding box typedef struct _BBOX { - int nLeft; - int nRight; - int nTop; - int nBottom; + int nLeft; + int nRight; + int nTop; + int nBottom; } BBOX; // Macros for clipping boxes #ifndef MIN -#define MIN(x, y) ((x) <= (y) ? (x) : (y)) +#define MIN(x, y) ((x) <= (y) ? (x) : (y)) #endif // MIN #ifndef MAX -#define MAX(x, y) ((x) <= (y) ? (y) : (x)) +#define MAX(x, y) ((x) <= (y) ? (y) : (x)) #endif // MAX #ifdef __cplusplus } -#endif // __cplusplus +#endif // __cplusplus #endif // NTA_ARRAY_BUFFER_HPP diff --git a/src/nupic/algorithms/BitHistory.cpp b/src/nupic/algorithms/BitHistory.cpp index 2d4f621fbf..f29e83b7a2 100644 --- a/src/nupic/algorithms/BitHistory.cpp +++ b/src/nupic/algorithms/BitHistory.cpp @@ -32,215 +32,177 @@ #include #include -namespace nupic -{ - namespace algorithms - { - namespace cla_classifier - { - - const Real64 DUTY_CYCLE_UPDATE_INTERVAL = pow(3.2, 32); - - BitHistory::BitHistory(UInt bitNum, int nSteps, Real64 alpha, - UInt verbosity) : - lastTotalUpdate_(-1), learnIteration_(0), alpha_(alpha), - verbosity_(verbosity) - { - stringstream ss; - ss << bitNum << "[" << nSteps << "]"; - id_ = ss.str(); - } - - void BitHistory::store(int iteration, int bucketIdx) - { - if (lastTotalUpdate_ == -1) - { - lastTotalUpdate_ = iteration; - } - - // Get the previous duty cycle, or 0.0 for new buckets. - map::const_iterator it = stats_.find(bucketIdx); - Real64 dc = 0.0; - if (it != stats_.end()) - { - dc = it->second; - } - - // Compute the new duty cycle, dcNew, at the iteration that the duty - // cycles are currently at. - Real64 denom = pow(1.0 - alpha_, iteration - lastTotalUpdate_); - Real64 dcNew = -1.0; - if (denom > 0.0) - { - dcNew = dc + (alpha_ / denom); - } - - if (denom < 0.00001 || dcNew > DUTY_CYCLE_UPDATE_INTERVAL) - { - // Update all duty cycles to the current iteration. - Real64 exp = pow(1.0 - alpha_, iteration - lastTotalUpdate_); - for (map::const_iterator i = stats_.begin(); - i != stats_.end(); ++i) - { - stats_[i->first] = i->second * exp; - } - - lastTotalUpdate_ = iteration; - - dc = stats_[bucketIdx] + alpha_; - } else { - dc = dcNew; - } - - // Set the new duty cycle for the specified bucket. - stats_[bucketIdx] = dc; - } - - void BitHistory::infer(int iteration, vector* votes) - { - Real64 total = 0.0; - // Set the vote for each bucket to the duty cycle value. - for (map::const_iterator it = stats_.begin(); - it != stats_.end(); ++it) - { - if (it->second > 0.0) - { - (*votes)[it->first] = it->second; - total += it->second; - } - } - - // Normalize the duty cycles. - if (total > 0.0) - { - for (auto & vote : *votes) - { - vote = vote / total; - } - } - } - - void BitHistory::save(ostream& outStream) const - { - // Write out a starting marker. - outStream << "BitHistory" << endl; - - // Save the simple variables. - outStream << id_ << " " - << lastTotalUpdate_ << " " - << learnIteration_ << " " - << alpha_ << " " - << verbosity_ << " " - << endl; - - // Save the bucket duty cycles. - outStream << stats_.size() << " "; - for (const auto & elem : stats_) - { - outStream << elem.first << " " << elem.second << " "; - } - outStream << endl; - - // Write out a termination marker. - outStream << "~BitHistory" << endl; - } - - void BitHistory::load(istream& inStream) - { - // Check the starting marker. - string marker; - inStream >> marker; - NTA_CHECK(marker == "BitHistory"); - - // Load the simple variables. - inStream >> id_ - >> lastTotalUpdate_ - >> learnIteration_ - >> alpha_ - >> verbosity_; - - // Load the bucket duty cycles. - UInt numBuckets; - int bucketIdx; - Real64 dutyCycle; - inStream >> numBuckets; - for (UInt i = 0; i < numBuckets; ++i) - { - inStream >> bucketIdx >> dutyCycle; - stats_.insert(pair(bucketIdx, dutyCycle)); - } - - // Check the termination marker. - inStream >> marker; - NTA_CHECK(marker == "~BitHistory"); - } - void BitHistory::write(BitHistoryProto::Builder& proto) const - { - proto.setId(id_.c_str()); - - auto statsList = proto.initStats(stats_.size()); - UInt i = 0; - for (const auto & elem : stats_) - { - auto stat = statsList[i]; - stat.setIndex(elem.first); - stat.setDutyCycle(elem.second); - i++; - } - - proto.setLastTotalUpdate(lastTotalUpdate_); - proto.setLearnIteration(learnIteration_); - proto.setAlpha(alpha_); - proto.setVerbosity(verbosity_); - } - - void BitHistory::read(BitHistoryProto::Reader& proto) - { - id_ = proto.getId().cStr(); - - stats_.clear(); - for (auto stat : proto.getStats()) - { - stats_[stat.getIndex()] = stat.getDutyCycle(); - } - - lastTotalUpdate_ = proto.getLastTotalUpdate(); - learnIteration_ = proto.getLearnIteration(); - alpha_ = proto.getAlpha(); - verbosity_ = proto.getVerbosity(); - } - - bool BitHistory::operator==(const BitHistory& other) const - { - if (id_ != other.id_ || - lastTotalUpdate_ != other.lastTotalUpdate_ || - learnIteration_ != other.learnIteration_ || - fabs(alpha_ - other.alpha_) > 0.000001 || - verbosity_ != other.verbosity_) - { - return false; - } - - if (stats_.size() != other.stats_.size()) - { - return false; - } - for (auto it = stats_.begin(); it != stats_.end(); it++) - { - if (fabs(it->second - other.stats_.at(it->first)) > 0.000001) - { - return false; - } - } - - return true; - } - - bool BitHistory::operator!=(const BitHistory& other) const - { - return !operator==(other); - } - - } // end namespace cla_classifier - } // end namespace algorithms +namespace nupic { +namespace algorithms { +namespace cla_classifier { + +const Real64 DUTY_CYCLE_UPDATE_INTERVAL = pow(3.2, 32); + +BitHistory::BitHistory(UInt bitNum, int nSteps, Real64 alpha, UInt verbosity) + : lastTotalUpdate_(-1), learnIteration_(0), alpha_(alpha), + verbosity_(verbosity) { + stringstream ss; + ss << bitNum << "[" << nSteps << "]"; + id_ = ss.str(); +} + +void BitHistory::store(int iteration, int bucketIdx) { + if (lastTotalUpdate_ == -1) { + lastTotalUpdate_ = iteration; + } + + // Get the previous duty cycle, or 0.0 for new buckets. + map::const_iterator it = stats_.find(bucketIdx); + Real64 dc = 0.0; + if (it != stats_.end()) { + dc = it->second; + } + + // Compute the new duty cycle, dcNew, at the iteration that the duty + // cycles are currently at. + Real64 denom = pow(1.0 - alpha_, iteration - lastTotalUpdate_); + Real64 dcNew = -1.0; + if (denom > 0.0) { + dcNew = dc + (alpha_ / denom); + } + + if (denom < 0.00001 || dcNew > DUTY_CYCLE_UPDATE_INTERVAL) { + // Update all duty cycles to the current iteration. + Real64 exp = pow(1.0 - alpha_, iteration - lastTotalUpdate_); + for (map::const_iterator i = stats_.begin(); i != stats_.end(); + ++i) { + stats_[i->first] = i->second * exp; + } + + lastTotalUpdate_ = iteration; + + dc = stats_[bucketIdx] + alpha_; + } else { + dc = dcNew; + } + + // Set the new duty cycle for the specified bucket. + stats_[bucketIdx] = dc; +} + +void BitHistory::infer(int iteration, vector *votes) { + Real64 total = 0.0; + // Set the vote for each bucket to the duty cycle value. + for (map::const_iterator it = stats_.begin(); it != stats_.end(); + ++it) { + if (it->second > 0.0) { + (*votes)[it->first] = it->second; + total += it->second; + } + } + + // Normalize the duty cycles. + if (total > 0.0) { + for (auto &vote : *votes) { + vote = vote / total; + } + } +} + +void BitHistory::save(ostream &outStream) const { + // Write out a starting marker. + outStream << "BitHistory" << endl; + + // Save the simple variables. + outStream << id_ << " " << lastTotalUpdate_ << " " << learnIteration_ << " " + << alpha_ << " " << verbosity_ << " " << endl; + + // Save the bucket duty cycles. + outStream << stats_.size() << " "; + for (const auto &elem : stats_) { + outStream << elem.first << " " << elem.second << " "; + } + outStream << endl; + + // Write out a termination marker. + outStream << "~BitHistory" << endl; +} + +void BitHistory::load(istream &inStream) { + // Check the starting marker. + string marker; + inStream >> marker; + NTA_CHECK(marker == "BitHistory"); + + // Load the simple variables. + inStream >> id_ >> lastTotalUpdate_ >> learnIteration_ >> alpha_ >> + verbosity_; + + // Load the bucket duty cycles. + UInt numBuckets; + int bucketIdx; + Real64 dutyCycle; + inStream >> numBuckets; + for (UInt i = 0; i < numBuckets; ++i) { + inStream >> bucketIdx >> dutyCycle; + stats_.insert(pair(bucketIdx, dutyCycle)); + } + + // Check the termination marker. + inStream >> marker; + NTA_CHECK(marker == "~BitHistory"); +} +void BitHistory::write(BitHistoryProto::Builder &proto) const { + proto.setId(id_.c_str()); + + auto statsList = proto.initStats(stats_.size()); + UInt i = 0; + for (const auto &elem : stats_) { + auto stat = statsList[i]; + stat.setIndex(elem.first); + stat.setDutyCycle(elem.second); + i++; + } + + proto.setLastTotalUpdate(lastTotalUpdate_); + proto.setLearnIteration(learnIteration_); + proto.setAlpha(alpha_); + proto.setVerbosity(verbosity_); +} + +void BitHistory::read(BitHistoryProto::Reader &proto) { + id_ = proto.getId().cStr(); + + stats_.clear(); + for (auto stat : proto.getStats()) { + stats_[stat.getIndex()] = stat.getDutyCycle(); + } + + lastTotalUpdate_ = proto.getLastTotalUpdate(); + learnIteration_ = proto.getLearnIteration(); + alpha_ = proto.getAlpha(); + verbosity_ = proto.getVerbosity(); +} + +bool BitHistory::operator==(const BitHistory &other) const { + if (id_ != other.id_ || lastTotalUpdate_ != other.lastTotalUpdate_ || + learnIteration_ != other.learnIteration_ || + fabs(alpha_ - other.alpha_) > 0.000001 || + verbosity_ != other.verbosity_) { + return false; + } + + if (stats_.size() != other.stats_.size()) { + return false; + } + for (auto it = stats_.begin(); it != stats_.end(); it++) { + if (fabs(it->second - other.stats_.at(it->first)) > 0.000001) { + return false; + } + } + + return true; +} + +bool BitHistory::operator!=(const BitHistory &other) const { + return !operator==(other); +} + +} // end namespace cla_classifier +} // end namespace algorithms } // end namespace nupic diff --git a/src/nupic/algorithms/BitHistory.hpp b/src/nupic/algorithms/BitHistory.hpp index eb34747729..d458b6e380 100644 --- a/src/nupic/algorithms/BitHistory.hpp +++ b/src/nupic/algorithms/BitHistory.hpp @@ -37,119 +37,114 @@ using namespace std; -namespace nupic -{ - namespace algorithms - { - namespace cla_classifier - { - - /** Class to store duty cycles for buckets for a single input bit. - * - * @b Responsibility - * The BitHistory is responsible for updating and relaying the duty - * cycles for the different buckets. - * - * TODO: Support serialization and deserialization. - * - */ - class BitHistory : public Serializable - { - public: - /** - * Constructor. - */ - BitHistory() {} - - /** - * Constructor. - * - * @param bitNum The input bit index that this BitHistory stores data - * for. - * @param nSteps The number of steps this BitHistory is storing duty - * cycles for. - * @param alpha The alpha to use when decaying the duty cycles. - * @param verbosity The logging verbosity to use. - * - */ - BitHistory(UInt bitNum, int nSteps, Real64 alpha, UInt verbosity); - - virtual ~BitHistory() {}; - - /** - * Update the duty cycle for the specified bucket index. - * - * @param iteration The current iteration. The difference between - * consecutive calls is used to determine how much to - * decay the previous duty cycle value. - * @param bucketIdx The bucket index to update. - * - */ - void store(int iteration, int bucketIdx); - - /** - * Sets the votes for each bucket when this cell is active. - * - * @param iteration The current iteration. - * @param votes A vector to populate with the votes for each bucket. - * - */ - void infer(int iteration, vector* votes); - - /** - * Save the state to the ostream. - */ - void save(ostream& outStream) const; - - /** - * Load state from istream. - */ - void load(istream& inStream); - - /** - * Save the state to the builder. - */ - using Serializable::write; - void write(BitHistoryProto::Builder& builder) const; - - /** - * Load state from reader. - */ - using Serializable::read; - void read(BitHistoryProto::Reader& proto); - - /** - * Check if the other instance matches this one. - * - * @param other an instance to compare to - * @returns true iff the other instance matches this one - */ - bool operator==(const BitHistory& other) const; - - /** - * Check if the other instance doesn't match this one. - * - * @param other an instance to compare to - * @returns true iff the other instance matches doesn't match this one - */ - bool operator!=(const BitHistory& other) const; - - private: - - string id_; - // Mapping from bucket index to the duty cycle values. - map stats_; - // Last iteration at which the duty cycles were updated to the present - // value. This is not done every iteration for efficiency reasons. - int lastTotalUpdate_; - int learnIteration_; - // The alpha to use when decaying the duty cycles. - Real64 alpha_; - UInt verbosity_; - }; // end class BitHistory - - } // end namespace cla_classifier - } // end namespace algorithms +namespace nupic { +namespace algorithms { +namespace cla_classifier { + +/** Class to store duty cycles for buckets for a single input bit. + * + * @b Responsibility + * The BitHistory is responsible for updating and relaying the duty + * cycles for the different buckets. + * + * TODO: Support serialization and deserialization. + * + */ +class BitHistory : public Serializable { +public: + /** + * Constructor. + */ + BitHistory() {} + + /** + * Constructor. + * + * @param bitNum The input bit index that this BitHistory stores data + * for. + * @param nSteps The number of steps this BitHistory is storing duty + * cycles for. + * @param alpha The alpha to use when decaying the duty cycles. + * @param verbosity The logging verbosity to use. + * + */ + BitHistory(UInt bitNum, int nSteps, Real64 alpha, UInt verbosity); + + virtual ~BitHistory(){}; + + /** + * Update the duty cycle for the specified bucket index. + * + * @param iteration The current iteration. The difference between + * consecutive calls is used to determine how much to + * decay the previous duty cycle value. + * @param bucketIdx The bucket index to update. + * + */ + void store(int iteration, int bucketIdx); + + /** + * Sets the votes for each bucket when this cell is active. + * + * @param iteration The current iteration. + * @param votes A vector to populate with the votes for each bucket. + * + */ + void infer(int iteration, vector *votes); + + /** + * Save the state to the ostream. + */ + void save(ostream &outStream) const; + + /** + * Load state from istream. + */ + void load(istream &inStream); + + /** + * Save the state to the builder. + */ + using Serializable::write; + void write(BitHistoryProto::Builder &builder) const; + + /** + * Load state from reader. + */ + using Serializable::read; + void read(BitHistoryProto::Reader &proto); + + /** + * Check if the other instance matches this one. + * + * @param other an instance to compare to + * @returns true iff the other instance matches this one + */ + bool operator==(const BitHistory &other) const; + + /** + * Check if the other instance doesn't match this one. + * + * @param other an instance to compare to + * @returns true iff the other instance matches doesn't match this one + */ + bool operator!=(const BitHistory &other) const; + +private: + string id_; + // Mapping from bucket index to the duty cycle values. + map stats_; + // Last iteration at which the duty cycles were updated to the present + // value. This is not done every iteration for efficiency reasons. + int lastTotalUpdate_; + int learnIteration_; + // The alpha to use when decaying the duty cycles. + Real64 alpha_; + UInt verbosity_; +}; // end class BitHistory + +} // end namespace cla_classifier +} // end namespace algorithms } // end namespace nupic #endif // NTA_fast_cla_classifier_HPP diff --git a/src/nupic/algorithms/Cell.cpp b/src/nupic/algorithms/Cell.cpp index cda1e1a0bb..24e890d68e 100644 --- a/src/nupic/algorithms/Cell.cpp +++ b/src/nupic/algorithms/Cell.cpp @@ -20,18 +20,13 @@ * --------------------------------------------------------------------- */ -#include #include +#include using namespace nupic::algorithms::Cells4; using namespace nupic; - -Cell::Cell() -: _segments(0), -_freeSegments(0) -{ -} +Cell::Cell() : _segments(0), _freeSegments(0) {} //------------------------------------------------------------------------------ /** @@ -47,10 +42,8 @@ _freeSegments(0) */ static bool cellMatchPythonSegOrder = false; -void Cell::setSegmentOrder(bool matchPythonOrder) -{ - if (matchPythonOrder) - { +void Cell::setSegmentOrder(bool matchPythonOrder) { + if (matchPythonOrder) { std::cout << "*** Python segment match turned on for Cells4\n"; } cellMatchPythonSegOrder = matchPythonOrder; @@ -62,13 +55,10 @@ void Cell::setSegmentOrder(bool matchPythonOrder) * allocated ones that have been previously "freed" (but we kept * the memory allocated), or by allocating a new one. */ -UInt Cell::getFreeSegment(const Segment::InSynapses& synapses, - Real initFrequency, - bool sequenceSegmentFlag, - Real permConnected, - UInt iteration) -{ - NTA_ASSERT(! synapses.empty()); +UInt Cell::getFreeSegment(const Segment::InSynapses &synapses, + Real initFrequency, bool sequenceSegmentFlag, + Real permConnected, UInt iteration) { + NTA_ASSERT(!synapses.empty()); UInt segIdx = 0; @@ -88,7 +78,7 @@ UInt Cell::getFreeSegment(const Segment::InSynapses& synapses, if (_freeSegments.empty()) { segIdx = _segments.size(); - //TODO: Should we grow by larger amounts here? + // TODO: Should we grow by larger amounts here? _segments.resize(_segments.size() + 1); } else { segIdx = _freeSegments.back(); @@ -100,9 +90,8 @@ UInt Cell::getFreeSegment(const Segment::InSynapses& synapses, NTA_ASSERT(not_in(segIdx, _freeSegments)); NTA_ASSERT(_segments[segIdx].empty()); // important in case we push_back - _segments[segIdx] = - Segment(synapses, initFrequency, sequenceSegmentFlag, - permConnected, iteration); + _segments[segIdx] = Segment(synapses, initFrequency, sequenceSegmentFlag, + permConnected, iteration); return segIdx; } @@ -111,48 +100,39 @@ UInt Cell::getFreeSegment(const Segment::InSynapses& synapses, /** * Update the duty cycle of each segment in this cell */ -void Cell::updateDutyCycle(UInt iterations) -{ - for (UInt i = 0; i != _segments.size(); ++i) - { - if (!_segments[i].empty()) - { +void Cell::updateDutyCycle(UInt iterations) { + for (UInt i = 0; i != _segments.size(); ++i) { + if (!_segments[i].empty()) { _segments[i].dutyCycle(iterations, false, false); } } } //----------------------------------------------------------------------------- -void Cell::write(CellProto::Builder& proto) const -{ +void Cell::write(CellProto::Builder &proto) const { auto segmentsProto = proto.initSegments(_segments.size()); - for (UInt i = 0; i < _segments.size(); ++i) - { + for (UInt i = 0; i < _segments.size(); ++i) { auto segProto = segmentsProto[i]; _segments[i].write(segProto); } } //----------------------------------------------------------------------------- -void Cell::read(CellProto::Reader& proto) -{ +void Cell::read(CellProto::Reader &proto) { auto segmentsProto = proto.getSegments(); _segments.resize(segmentsProto.size()); _freeSegments.resize(0); - for (UInt i = 0; i < segmentsProto.size(); ++i) - { + for (UInt i = 0; i < segmentsProto.size(); ++i) { auto segProto = segmentsProto[i]; _segments[i].read(segProto); - if (_segments[i].empty()) - { + if (_segments[i].empty()) { _freeSegments.push_back(i); } } } //----------------------------------------------------------------------------- -void Cell::save(std::ostream& outStream) const -{ +void Cell::save(std::ostream &outStream) const { outStream << _segments.size() << ' '; // TODO: save only non-empty segments for (UInt i = 0; i != _segments.size(); ++i) { @@ -162,8 +142,7 @@ void Cell::save(std::ostream& outStream) const } //---------------------------------------------------------------------------- -void Cell::load(std::istream& inStream) -{ +void Cell::load(std::istream &inStream) { UInt n = 0; inStream >> n; @@ -171,7 +150,7 @@ void Cell::load(std::istream& inStream) _segments.resize(n); _freeSegments.resize(0); - for (UInt i = 0; i != (UInt) n; ++i) { + for (UInt i = 0; i != (UInt)n; ++i) { _segments[i].load(inStream); if (_segments[i].empty()) _freeSegments.push_back(i); diff --git a/src/nupic/algorithms/Cell.hpp b/src/nupic/algorithms/Cell.hpp index 299ad40ef3..d17175c25d 100644 --- a/src/nupic/algorithms/Cell.hpp +++ b/src/nupic/algorithms/Cell.hpp @@ -23,256 +23,237 @@ #ifndef NTA_CELL_HPP #define NTA_CELL_HPP -#include +#include #include #include #include -#include +#include namespace nupic { - namespace algorithms { - namespace Cells4 { - - class Cells4; - - //-------------------------------------------------------------------------------- - //-------------------------------------------------------------------------------- - /** - * A Cell is a container for Segments. It maintains a list of active segments and - * a list of segments that have been "inactivated" because all their synapses were - * removed. The slots of inactivated segments are re-used, in contrast to the - * Python TP, which keeps its segments in a dynamic list and always allocates new - * segments at the end of this dynamic list. This difference is a source of - * mismatches in unit testing when comparing the Python TP to the C++ down to the - * segment level. - */ - class Cell : Serializable - { - private: - std::vector< Segment > _segments; // both 'active' and 'inactive' segments - std::vector _freeSegments; // slots of the 'inactive' segments - - public: - //-------------------------------------------------------------------------------- - Cell(); - - //-------------------------------------------------------------------------------- - bool empty() const { return _segments.size() == _freeSegments.size(); } - - //-------------------------------------------------------------------------------- - UInt nSynapses() const - { - UInt n = 0; - for (UInt i = 0; i != _segments.size(); ++i) - n += _segments[i].size(); - return n; - } - - //-------------------------------------------------------------------------------- - /** - * Returns size of _segments (see nSegments below). If using this to iterate, - * indices less than size() might contain indices of empty segments. - */ - UInt size() const { return _segments.size(); } - - //-------------------------------------------------------------------------------- - /** - * Returns number of segments that are not in the free list currently, i.e. that - * have at leat 1 synapse. - */ - UInt nSegments() const - { - NTA_ASSERT(_freeSegments.size() <= _segments.size()); - return _segments.size() - _freeSegments.size(); - } - - //-------------------------------------------------------------------------------- - /** - * Returns list of segments that are not empty. - */ - std::vector getNonEmptySegList() const - { - std::vector non_empties; - for (UInt i = 0; i != _segments.size(); ++i) - if (!_segments[i].empty()) - non_empties.push_back(i); - NTA_ASSERT(non_empties.size() == nSegments()); - return non_empties; - } - - //-------------------------------------------------------------------------------- - Segment& operator[](UInt segIdx) - { - NTA_ASSERT(segIdx < _segments.size()); - return _segments[segIdx]; - } - - //-------------------------------------------------------------------------------- - const Segment& operator[](UInt segIdx) const - { - NTA_ASSERT(segIdx < _segments.size()); - return _segments[segIdx]; - } - - //-------------------------------------------------------------------------------- - Segment& getSegment(UInt segIdx) - { - NTA_ASSERT(segIdx < _segments.size()); - return _segments[segIdx]; - } - - //-------------------------------------------------------------------------------- - /** - * Returns an empty segment to use, either from list of already - * allocated ones that have been previously "freed" (but we kept - * the memory allocated), or by allocating a new one. - */ - // TODO: rename method to "addToFreeSegment" ?? - UInt getFreeSegment(const Segment::InSynapses& synapses, - Real initFrequency, - bool sequenceSegmentFlag, - Real permConnected, - UInt iteration); - - //-------------------------------------------------------------------------------- - /** - * Whether we want to match python's segment ordering - */ - static void setSegmentOrder(bool matchPythonOrder); - - - //-------------------------------------------------------------------------------- - /** - * Update the duty cycle of each segment in this cell - */ - void updateDutyCycle(UInt iterations); - - //-------------------------------------------------------------------------------- - /** - * Rebalance the segment list. The segment list is compacted and all - * free segments are removed. The most frequent segment is placed at - * the head of the list. - * - * Note: outSynapses must be updated after a call to this. - */ - void rebalanceSegments() - { - //const std::vector &non_empties = getNonEmptySegList(); - UInt bestOne = getMostActiveSegment(); - - // Swap the best one with the 0'th one - if (bestOne != 0) { - Segment seg = _segments[0]; - _segments[0] = _segments[bestOne]; - _segments[bestOne] = seg; - } - - // Sort segments according to activation frequency - // TODO Comment not backed by code. Investigate whether the - // above-mentioned sort is needed. - - // Redo free segments list - _freeSegments.clear(); - for (UInt segIdx = 0; segIdx != _segments.size(); ++segIdx) { - if ( _segments[segIdx].empty() ) - releaseSegment(segIdx); - } - } - - //-------------------------------------------------------------------------------- - /** - * Returns index of segment with highest activation frequency. - * 0 means - */ - UInt getMostActiveSegment() - { - UInt bestIdx = 0; // Segment with highest totalActivations - UInt maxActivity = 0; // Value of highest totalActivations - - for (UInt i = 0; i < _segments.size(); ++i) { - if ( !_segments[i].empty() && - (_segments[i].getTotalActivations() > maxActivity) ){ - maxActivity = _segments[i].getTotalActivations(); - bestIdx = i; - } - } - - return bestIdx; - } - //-------------------------------------------------------------------------------- - /** - * Release a segment by putting it on the list of "freed" segments. We keep the - * memory instead of deallocating it each time, so that's it's fast to "allocate" - * a new segment next time. - * - * Assumes outSynapses has already been updated. - * TODO: a call to releaseSegment should delete any pending - * update for that segment in the update list. The - * cheapest way to do this is to maintain segment updates on a - * per cell basis. Currently there is a check in - * Cells4::adaptSegment for this case but that may be insufficient. - */ - void releaseSegment(UInt segIdx) - { - NTA_ASSERT(segIdx < _segments.size()); - - // TODO: check this - if (is_in(segIdx, _freeSegments)) { - return; - } - - // TODO: check this - NTA_ASSERT(not_in(segIdx, _freeSegments)); - - _segments[segIdx].clear(); // important in case we push_back later - _freeSegments.push_back(segIdx); - _segments[segIdx]._totalActivations = 0; - _segments[segIdx]._positiveActivations = 0; - - NTA_ASSERT(_segments[segIdx].empty()); - NTA_ASSERT(is_in(segIdx, _freeSegments)); - } - - // The comment below is so awesome, I had to leave it in! - - //---------------------------------------------------------------------- - /** - * TODO: write - */ - bool invariants(Cells4* =nullptr) const - { - return true; - } - - //---------------------------------------------------------------------- - // PERSISTENCE - //---------------------------------------------------------------------- - UInt persistentSize() const - { - std::stringstream buff; - this->save(buff); - return buff.str().size(); - } - - //---------------------------------------------------------------------- - using Serializable::write; - virtual void write(CellProto::Builder& proto) const override; - - //---------------------------------------------------------------------- - using Serializable::read; - virtual void read(CellProto::Reader& proto) override; - - //---------------------------------------------------------------------- - void save(std::ostream& outStream) const; - - //---------------------------------------------------------------------- - void load(std::istream& inStream); - - }; - - // end namespace +namespace algorithms { +namespace Cells4 { + +class Cells4; + +//-------------------------------------------------------------------------------- +//-------------------------------------------------------------------------------- +/** + * A Cell is a container for Segments. It maintains a list of active segments + * and a list of segments that have been "inactivated" because all their + * synapses were removed. The slots of inactivated segments are re-used, in + * contrast to the Python TP, which keeps its segments in a dynamic list and + * always allocates new segments at the end of this dynamic list. This + * difference is a source of mismatches in unit testing when comparing the + * Python TP to the C++ down to the segment level. + */ +class Cell : Serializable { +private: + std::vector _segments; // both 'active' and 'inactive' segments + std::vector _freeSegments; // slots of the 'inactive' segments + +public: + //-------------------------------------------------------------------------------- + Cell(); + + //-------------------------------------------------------------------------------- + bool empty() const { return _segments.size() == _freeSegments.size(); } + + //-------------------------------------------------------------------------------- + UInt nSynapses() const { + UInt n = 0; + for (UInt i = 0; i != _segments.size(); ++i) + n += _segments[i].size(); + return n; + } + + //-------------------------------------------------------------------------------- + /** + * Returns size of _segments (see nSegments below). If using this to iterate, + * indices less than size() might contain indices of empty segments. + */ + UInt size() const { return _segments.size(); } + + //-------------------------------------------------------------------------------- + /** + * Returns number of segments that are not in the free list currently, i.e. + * that have at leat 1 synapse. + */ + UInt nSegments() const { + NTA_ASSERT(_freeSegments.size() <= _segments.size()); + return _segments.size() - _freeSegments.size(); + } + + //-------------------------------------------------------------------------------- + /** + * Returns list of segments that are not empty. + */ + std::vector getNonEmptySegList() const { + std::vector non_empties; + for (UInt i = 0; i != _segments.size(); ++i) + if (!_segments[i].empty()) + non_empties.push_back(i); + NTA_ASSERT(non_empties.size() == nSegments()); + return non_empties; + } + + //-------------------------------------------------------------------------------- + Segment &operator[](UInt segIdx) { + NTA_ASSERT(segIdx < _segments.size()); + return _segments[segIdx]; + } + + //-------------------------------------------------------------------------------- + const Segment &operator[](UInt segIdx) const { + NTA_ASSERT(segIdx < _segments.size()); + return _segments[segIdx]; + } + + //-------------------------------------------------------------------------------- + Segment &getSegment(UInt segIdx) { + NTA_ASSERT(segIdx < _segments.size()); + return _segments[segIdx]; + } + + //-------------------------------------------------------------------------------- + /** + * Returns an empty segment to use, either from list of already + * allocated ones that have been previously "freed" (but we kept + * the memory allocated), or by allocating a new one. + */ + // TODO: rename method to "addToFreeSegment" ?? + UInt getFreeSegment(const Segment::InSynapses &synapses, Real initFrequency, + bool sequenceSegmentFlag, Real permConnected, + UInt iteration); + + //-------------------------------------------------------------------------------- + /** + * Whether we want to match python's segment ordering + */ + static void setSegmentOrder(bool matchPythonOrder); + + //-------------------------------------------------------------------------------- + /** + * Update the duty cycle of each segment in this cell + */ + void updateDutyCycle(UInt iterations); + + //-------------------------------------------------------------------------------- + /** + * Rebalance the segment list. The segment list is compacted and all + * free segments are removed. The most frequent segment is placed at + * the head of the list. + * + * Note: outSynapses must be updated after a call to this. + */ + void rebalanceSegments() { + // const std::vector &non_empties = getNonEmptySegList(); + UInt bestOne = getMostActiveSegment(); + + // Swap the best one with the 0'th one + if (bestOne != 0) { + Segment seg = _segments[0]; + _segments[0] = _segments[bestOne]; + _segments[bestOne] = seg; + } + + // Sort segments according to activation frequency + // TODO Comment not backed by code. Investigate whether the + // above-mentioned sort is needed. + + // Redo free segments list + _freeSegments.clear(); + for (UInt segIdx = 0; segIdx != _segments.size(); ++segIdx) { + if (_segments[segIdx].empty()) + releaseSegment(segIdx); } } -} -#endif // NTA_CELL_HPP + //-------------------------------------------------------------------------------- + /** + * Returns index of segment with highest activation frequency. + * 0 means + */ + UInt getMostActiveSegment() { + UInt bestIdx = 0; // Segment with highest totalActivations + UInt maxActivity = 0; // Value of highest totalActivations + + for (UInt i = 0; i < _segments.size(); ++i) { + if (!_segments[i].empty() && + (_segments[i].getTotalActivations() > maxActivity)) { + maxActivity = _segments[i].getTotalActivations(); + bestIdx = i; + } + } + + return bestIdx; + } + //-------------------------------------------------------------------------------- + /** + * Release a segment by putting it on the list of "freed" segments. We keep + * the memory instead of deallocating it each time, so that's it's fast to + * "allocate" a new segment next time. + * + * Assumes outSynapses has already been updated. + * TODO: a call to releaseSegment should delete any pending + * update for that segment in the update list. The + * cheapest way to do this is to maintain segment updates on a + * per cell basis. Currently there is a check in + * Cells4::adaptSegment for this case but that may be insufficient. + */ + void releaseSegment(UInt segIdx) { + NTA_ASSERT(segIdx < _segments.size()); + + // TODO: check this + if (is_in(segIdx, _freeSegments)) { + return; + } + + // TODO: check this + NTA_ASSERT(not_in(segIdx, _freeSegments)); + + _segments[segIdx].clear(); // important in case we push_back later + _freeSegments.push_back(segIdx); + _segments[segIdx]._totalActivations = 0; + _segments[segIdx]._positiveActivations = 0; + + NTA_ASSERT(_segments[segIdx].empty()); + NTA_ASSERT(is_in(segIdx, _freeSegments)); + } + + // The comment below is so awesome, I had to leave it in! + + //---------------------------------------------------------------------- + /** + * TODO: write + */ + bool invariants(Cells4 * = nullptr) const { return true; } + + //---------------------------------------------------------------------- + // PERSISTENCE + //---------------------------------------------------------------------- + UInt persistentSize() const { + std::stringstream buff; + this->save(buff); + return buff.str().size(); + } + + //---------------------------------------------------------------------- + using Serializable::write; + virtual void write(CellProto::Builder &proto) const override; + + //---------------------------------------------------------------------- + using Serializable::read; + virtual void read(CellProto::Reader &proto) override; + + //---------------------------------------------------------------------- + void save(std::ostream &outStream) const; + + //---------------------------------------------------------------------- + void load(std::istream &inStream); +}; + +// end namespace +} // namespace Cells4 +} // namespace algorithms +} // namespace nupic +#endif // NTA_CELL_HPP diff --git a/src/nupic/algorithms/Cells4.cpp b/src/nupic/algorithms/Cells4.cpp index 06211b9a68..128ce3c9ab 100644 --- a/src/nupic/algorithms/Cells4.cpp +++ b/src/nupic/algorithms/Cells4.cpp @@ -24,14 +24,13 @@ #include #include -#include #include -#include // numeric_limits +#include // numeric_limits #include #include +#include #include "cycle_counter.hpp" -#include #include #include #include @@ -39,10 +38,11 @@ #include #include #include // is_in -#include // binary_save +#include // binary_save #include #include #include +#include using namespace nupic::algorithms::Cells4; @@ -67,84 +67,59 @@ static nupic::Timer chooseCellsTimer; #endif -Cells4::Cells4(UInt nColumns, UInt nCellsPerCol, - UInt activationThreshold, - UInt minThreshold, - UInt newSynapseCount, - UInt segUpdateValidDuration, - Real permInitial, - Real permConnected, - Real permMax, - Real permDec, - Real permInc, - Real globalDecay, - bool doPooling, - int seed, - bool initFromCpp, - bool checkSynapseConsistency) - : _rng(seed < 0 ? rand() : seed) -{ +Cells4::Cells4(UInt nColumns, UInt nCellsPerCol, UInt activationThreshold, + UInt minThreshold, UInt newSynapseCount, + UInt segUpdateValidDuration, Real permInitial, + Real permConnected, Real permMax, Real permDec, Real permInc, + Real globalDecay, bool doPooling, int seed, bool initFromCpp, + bool checkSynapseConsistency) + : _rng(seed < 0 ? rand() : seed) { _version = VERSION; - initialize(nColumns, - nCellsPerCol, - activationThreshold, - minThreshold, - newSynapseCount, - segUpdateValidDuration, - permInitial, - permConnected, - permMax, - permDec, - permInc, - globalDecay, - doPooling, - initFromCpp, - checkSynapseConsistency); -} - -Cells4::~Cells4() -{ + initialize(nColumns, nCellsPerCol, activationThreshold, minThreshold, + newSynapseCount, segUpdateValidDuration, permInitial, + permConnected, permMax, permDec, permInc, globalDecay, doPooling, + initFromCpp, checkSynapseConsistency); +} + +Cells4::~Cells4() { if (_ownsMemory) { - delete [] _cellConfidenceT; - delete [] _cellConfidenceT1; - delete [] _colConfidenceT; - delete [] _colConfidenceT1; + delete[] _cellConfidenceT; + delete[] _cellConfidenceT1; + delete[] _colConfidenceT; + delete[] _colConfidenceT1; } - delete [] _cellConfidenceCandidate; - delete [] _colConfidenceCandidate; - delete [] _tmpInputBuffer; + delete[] _cellConfidenceCandidate; + delete[] _colConfidenceCandidate; + delete[] _tmpInputBuffer; } //-------------------------------------------------------------------------------- -// Utility routines used in this file to print list of active columns and cell indices +// Utility routines used in this file to print list of active columns and cell +// indices //-------------------------------------------------------------------------------- -static void printActiveColumns(std::ostream& out, const std::vector & activeColumns) -{ +static void printActiveColumns(std::ostream &out, + const std::vector &activeColumns) { out << "["; - for (auto & activeColumn : activeColumns) { + for (auto &activeColumn : activeColumns) { out << " " << activeColumn; } out << "]"; } -static void printCell(UInt srcCellIdx, UInt nCellsPerCol) -{ - UInt col = (UInt) (srcCellIdx / nCellsPerCol); - UInt cell = srcCellIdx - col*nCellsPerCol; +static void printCell(UInt srcCellIdx, UInt nCellsPerCol) { + UInt col = (UInt)(srcCellIdx / nCellsPerCol); + UInt cell = srcCellIdx - col * nCellsPerCol; std::cout << "[" << col << "," << cell << "] "; } - - //-------------------------------------------------------------------------------- -bool Cells4::isActive(UInt cellIdx, UInt segIdx, const CState& state) const -{ +bool Cells4::isActive(UInt cellIdx, UInt segIdx, const CState &state) const { { NTA_ASSERT(cellIdx < nCells()); NTA_ASSERT(segIdx < _cells[cellIdx].size()); } - const Segment& seg = _cells[cellIdx][segIdx]; + const Segment &seg = _cells[cellIdx][segIdx]; if (seg.size() < _activationThreshold) return false; @@ -155,8 +130,8 @@ bool Cells4::isActive(UInt cellIdx, UInt segIdx, const CState& state) const //-------------------------------------------------------------------------------- /** * Push a segmentUpdate data structure containing a list of proposed changes - * to segment s into our SegmentUpdate list. Return false if no update was actually - * pushed (this can happen if we didn't find any new synapses). + * to segment s into our SegmentUpdate list. Return false if no update was + * actually pushed (this can happen if we didn't find any new synapses). * * Let activeSynapses be the list of active synapses where the * originating cells have their activeState output = 1 at time step t. @@ -169,28 +144,27 @@ bool Cells4::isActive(UInt cellIdx, UInt segIdx, const CState& state) const * NOTE: called getSegmentActiveSynapses in Python * */ -bool Cells4::computeUpdate(UInt cellIdx, UInt segIdx, CStateIndexed& activeState, - bool sequenceSegmentFlag, bool newSynapsesFlag) -{ +bool Cells4::computeUpdate(UInt cellIdx, UInt segIdx, + CStateIndexed &activeState, bool sequenceSegmentFlag, + bool newSynapsesFlag) { { NTA_ASSERT(cellIdx < nCells()); - NTA_ASSERT(segIdx == (UInt) - 1 || segIdx < _cells[cellIdx].size()); + NTA_ASSERT(segIdx == (UInt)-1 || segIdx < _cells[cellIdx].size()); } static std::vector newSynapses; - newSynapses.clear(); // purge residual data + newSynapses.clear(); // purge residual data - if (segIdx != (UInt) -1) { // not a new segment + if (segIdx != (UInt)-1) { // not a new segment - Segment& segment = _cells[cellIdx][segIdx]; + Segment &segment = _cells[cellIdx][segIdx]; static UInt highWaterSize = 0; if (highWaterSize < segment.size()) { highWaterSize = segment.size(); newSynapses.reserve(highWaterSize); } - for (UInt i = 0; i < segment.size(); ++i) - { + for (UInt i = 0; i < segment.size(); ++i) { if (activeState.isSet(segment[i].srcCellIdx())) { newSynapses.push_back(segment[i].srcCellIdx()); } @@ -198,9 +172,10 @@ bool Cells4::computeUpdate(UInt cellIdx, UInt segIdx, CStateIndexed& activeState } if (newSynapsesFlag) { - int nSynToAdd = (int) _newSynapseCount - (int) newSynapses.size(); + int nSynToAdd = (int)_newSynapseCount - (int)newSynapses.size(); if (nSynToAdd > 0) { - chooseCellsToLearnFrom(cellIdx, segIdx, nSynToAdd, activeState, newSynapses); + chooseCellsToLearnFrom(cellIdx, segIdx, nSynToAdd, activeState, + newSynapses); } } @@ -209,8 +184,8 @@ bool Cells4::computeUpdate(UInt cellIdx, UInt segIdx, CStateIndexed& activeState if (newSynapses.empty()) return false; - SegmentUpdate update(cellIdx, segIdx, sequenceSegmentFlag, - _nLrnIterations, newSynapses); // TODO: Add this for invariants check + SegmentUpdate update(cellIdx, segIdx, sequenceSegmentFlag, _nLrnIterations, + newSynapses); // TODO: Add this for invariants check _segmentUpdates.push_back(update); return true; @@ -229,10 +204,8 @@ bool Cells4::computeUpdate(UInt cellIdx, UInt segIdx, CStateIndexed& activeState */ template -void Cells4::addOutSynapses(UInt dstCellIdx, UInt dstSegIdx, - It newSynapse, - It newSynapsesEnd) -{ +void Cells4::addOutSynapses(UInt dstCellIdx, UInt dstSegIdx, It newSynapse, + It newSynapsesEnd) { NTA_ASSERT(dstCellIdx < nCells()); NTA_ASSERT(dstSegIdx < _cells[dstCellIdx].size()); @@ -242,37 +215,34 @@ void Cells4::addOutSynapses(UInt dstCellIdx, UInt dstSegIdx, NTA_ASSERT(not_in(newOutSyn, _outSynapses[srcCellIdx])); _outSynapses[srcCellIdx].push_back(newOutSyn); } - } // explicit instantiations for the method above namespace nupic { - namespace algorithms { - namespace Cells4 { - template void Cells4::addOutSynapses(nupic::UInt, nupic::UInt, - std::set::const_iterator, - std::set::const_iterator); - template void Cells4::addOutSynapses(nupic::UInt, nupic::UInt, - std::vector::const_iterator, - std::vector::const_iterator); - - } - } -} - +namespace algorithms { +namespace Cells4 { +template void Cells4::addOutSynapses(nupic::UInt, nupic::UInt, + std::set::const_iterator, + std::set::const_iterator); +template void Cells4::addOutSynapses(nupic::UInt, nupic::UInt, + std::vector::const_iterator, + std::vector::const_iterator); + +} // namespace Cells4 +} // namespace algorithms +} // namespace nupic //-------------------------------------------------------------------------------- /** * Erases an OutSynapses. See addOutSynapses just above. */ void Cells4::eraseOutSynapses(UInt dstCellIdx, UInt dstSegIdx, - const std::vector& srcCells) -{ + const std::vector &srcCells) { NTA_ASSERT(dstCellIdx < nCells()); NTA_ASSERT(dstSegIdx < _cells[dstCellIdx].size()); - for (auto & srcCellIdx : srcCells) { - OutSynapses& outSyns = _outSynapses[srcCellIdx]; + for (auto &srcCellIdx : srcCells) { + OutSynapses &outSyns = _outSynapses[srcCellIdx]; // TODO: binary search or faster for (UInt j = 0; j != outSyns.size(); ++j) if (outSyns[j].goesTo(dstCellIdx, dstSegIdx)) { @@ -289,13 +259,13 @@ void Cells4::eraseOutSynapses(UInt dstCellIdx, UInt dstSegIdx, * onto the current set of inputs by assuming the sequence started N * steps ago on start cells. */ -void Cells4::inferBacktrack(const std::vector & activeColumns) -{ +void Cells4::inferBacktrack(const std::vector &activeColumns) { //--------------------------------------------------------------------------- // How much input history have we accumulated? Is it enough to backtrack? // The current input is always at the end of self._prevInfPatterns, but // it is also evaluated as a potential starting point - if (_prevInfPatterns.empty()) return; + if (_prevInfPatterns.empty()) + return; TIMER(infBacktrackTimer.start()); @@ -314,7 +284,7 @@ void Cells4::inferBacktrack(const std::vector & activeColumns) // input history queue so that we don't waste time evaluating them again at // a later time step. static std::vector badPatterns; - badPatterns.clear(); // purge residual data + badPatterns.clear(); // purge residual data //--------------------------------------------------------------------------- // Let's go back in time and replay the recent inputs from start cells and @@ -328,8 +298,7 @@ void Cells4::inferBacktrack(const std::vector & activeColumns) // If we have a candidate already in the past, don't bother falling back // to start cells on the current input. - if ( (startOffset == currentTimeStepsOffset) && - (candConfidence != -1) ) + if ((startOffset == currentTimeStepsOffset) && (candConfidence != -1)) break; if (_verbosity >= 3) { @@ -342,60 +311,61 @@ void Cells4::inferBacktrack(const std::vector & activeColumns) // Play through starting from time t-startOffset inSequence = false; Real totalConfidence = 0; - for (UInt offset = startOffset; offset < _prevInfPatterns.size(); offset++) - { + for (UInt offset = startOffset; offset < _prevInfPatterns.size(); + offset++) { // If we are about to set the active columns for the current time step // based on what we predicted, capture and save the total confidence of // predicting the current input if (offset == currentTimeStepsOffset) { totalConfidence = 0; - for (auto & activeColumn : activeColumns) { + for (auto &activeColumn : activeColumns) { totalConfidence += _colConfidenceT[activeColumn]; } } // Compute activeState[t] given bottom-up and predictedState[t-1] _infPredictedStateT1 = _infPredictedStateT; - inSequence = inferPhase1(_prevInfPatterns[offset], - (offset == startOffset)); - if (!inSequence) break; + inSequence = + inferPhase1(_prevInfPatterns[offset], (offset == startOffset)); + if (!inSequence) + break; // Compute predictedState['t'] given activeState['t'] if (_verbosity >= 3) { std::cout << " backtrack: computing predictions from "; printActiveColumns(std::cout, _prevInfPatterns[offset]); std::cout << "\n"; - } inSequence = inferPhase2(); - if (!inSequence) break; + if (!inSequence) + break; } // If starting from startOffset got lost along the way, mark it as an // invalid start point. if (!inSequence) { badPatterns.push_back(startOffset); - } - else { + } else { candStartOffset = startOffset; // If we got to here, startOffset is a candidate starting point. - if (_verbosity >= 3 && (startOffset != currentTimeStepsOffset) ) { + if (_verbosity >= 3 && (startOffset != currentTimeStepsOffset)) { std::cout << "# Prediction confidence of current input after starting " - << _prevInfPatterns.size() - 1 - startOffset - << " steps ago: " << totalConfidence << "\n"; + << _prevInfPatterns.size() - 1 - startOffset + << " steps ago: " << totalConfidence << "\n"; } - if (candStartOffset == (Int) currentTimeStepsOffset) + if (candStartOffset == (Int)currentTimeStepsOffset) break; _infActiveStateCandidate = _infActiveStateT; _infPredictedStateCandidate = _infPredictedStateT; - memcpy(_cellConfidenceCandidate, _cellConfidenceT, _nCells * sizeof(_cellConfidenceT[0])); - memcpy(_colConfidenceCandidate, _colConfidenceT, _nColumns * sizeof(_colConfidenceT[0])); + memcpy(_cellConfidenceCandidate, _cellConfidenceT, + _nCells * sizeof(_cellConfidenceT[0])); + memcpy(_colConfidenceCandidate, _colConfidenceT, + _nColumns * sizeof(_colConfidenceT[0])); break; } - } //--------------------------------------------------------------------------- @@ -411,15 +381,17 @@ void Cells4::inferBacktrack(const std::vector & activeColumns) } else { if (_verbosity >= 3) { std::cout << "Locked on to current input by using start cells from " - << _prevInfPatterns.size() - 1 - candStartOffset << - " steps ago.\n"; + << _prevInfPatterns.size() - 1 - candStartOffset + << " steps ago.\n"; } // Install the candidate state, if it wasn't the last one we evaluated. - if (candStartOffset != (Int) currentTimeStepsOffset) { + if (candStartOffset != (Int)currentTimeStepsOffset) { _infActiveStateT = _infActiveStateCandidate; _infPredictedStateT = _infPredictedStateCandidate; - memcpy(_cellConfidenceT, _cellConfidenceCandidate, _nCells * sizeof(_cellConfidenceCandidate[0])); - memcpy(_colConfidenceT, _colConfidenceCandidate, _nColumns * sizeof(_colConfidenceCandidate[0])); + memcpy(_cellConfidenceT, _cellConfidenceCandidate, + _nCells * sizeof(_cellConfidenceCandidate[0])); + memcpy(_colConfidenceT, _colConfidenceCandidate, + _nColumns * sizeof(_colConfidenceCandidate[0])); } } @@ -427,21 +399,18 @@ void Cells4::inferBacktrack(const std::vector & activeColumns) // Remove any useless patterns at the head of the previous input pattern // queue. UInt numPrevPatterns = _prevInfPatterns.size(); - for (UInt i = 0; i < numPrevPatterns; i++) - { + for (UInt i = 0; i < numPrevPatterns; i++) { std::vector::iterator result; result = find(badPatterns.begin(), badPatterns.end(), i); - if ( result != badPatterns.end() || - ( (candStartOffset != -1) && ((Int)i <= candStartOffset) ) ) - { + if (result != badPatterns.end() || + ((candStartOffset != -1) && ((Int)i <= candStartOffset))) { if (_verbosity >= 3) { std::cout << "Removing useless pattern from history "; printActiveColumns(std::cout, _prevInfPatterns[0]); std::cout << "\n"; } _prevInfPatterns.pop_front(); - } - else + } else break; } @@ -457,8 +426,7 @@ void Cells4::inferBacktrack(const std::vector & activeColumns) * A utility method called from learnBacktrack. This will backtrack * starting from the given startOffset in our prevLrnPatterns queue. */ -bool Cells4::learnBacktrackFrom(UInt startOffset, bool readOnly) -{ +bool Cells4::learnBacktrackFrom(UInt startOffset, bool readOnly) { // How much input history have we accumulated? // The current input is always at the end of self._prevInfPatterns (at // index -1), but it is also evaluated as a potential starting point by @@ -491,8 +459,7 @@ bool Cells4::learnBacktrackFrom(UInt startOffset, bool readOnly) //--------------------------------------------------------------------------- // Play through up to the current time step bool inSequence = true; - for (UInt offset = startOffset; offset < numPrevPatterns; offset++) - { + for (UInt offset = startOffset; offset < numPrevPatterns; offset++) { //-------------------------------------------------------------------------- // Copy predicted and active states into t-1 @@ -502,7 +469,7 @@ bool Cells4::learnBacktrackFrom(UInt startOffset, bool readOnly) // Apply segment updates from the last set of predictions if (!readOnly) { memset(_tmpInputBuffer, 0, _nColumns * sizeof(_tmpInputBuffer[0])); - for (auto & elem : _prevLrnPatterns[offset]) { + for (auto &elem : _prevLrnPatterns[offset]) { _tmpInputBuffer[elem] = 1; } processSegmentUpdates(_tmpInputBuffer, _learnPredictedStateT); @@ -512,8 +479,8 @@ bool Cells4::learnBacktrackFrom(UInt startOffset, bool readOnly) // Compute activeState[t] given bottom-up and predictedState[t-1] if (offset == startOffset) { _learnActiveStateT.resetAll(); - for (auto & elem : _prevLrnPatterns[offset]) { - UInt cellIdx = elem*_nCellsPerCol; + for (auto &elem : _prevLrnPatterns[offset]) { + UInt cellIdx = elem * _nCellsPerCol; _learnActiveStateT.set(cellIdx); inSequence = true; } @@ -523,7 +490,7 @@ bool Cells4::learnBacktrackFrom(UInt startOffset, bool readOnly) // Break out immediately if we fell out of sequence or reached the current // time step - if (!inSequence || (offset == currentTimeStepsOffset) ) + if (!inSequence || (offset == currentTimeStepsOffset)) break; //-------------------------------------------------------------------------- @@ -551,12 +518,12 @@ bool Cells4::learnBacktrackFrom(UInt startOffset, bool readOnly) * onto the current set of inputs by assuming the sequence started * up to N steps ago on start cells. */ -UInt Cells4::learnBacktrack() -{ +UInt Cells4::learnBacktrack() { // How much input history have we accumulated? // The current input is always at the end of self._prevInfPatterns (at // index -1), and is not a valid startingOffset to evaluate. - UInt numPrevPatterns = (_prevLrnPatterns.size() == 0 ? 0 : _prevLrnPatterns.size() - 1); + UInt numPrevPatterns = + (_prevLrnPatterns.size() == 0 ? 0 : _prevLrnPatterns.size() - 1); if (numPrevPatterns <= 0) { if (_verbosity >= 3) { std::cout << "lrnBacktrack: No available history to backtrack from\n"; @@ -569,7 +536,7 @@ UInt Cells4::learnBacktrack() // input history queue so that we don't waste time evaluating them again at // a later time step. static std::vector badPatterns; - badPatterns.clear(); // purge residual data + badPatterns.clear(); // purge residual data //--------------------------------------------------------------------------- // Let's go back in time and replay the recent inputs from start cells and @@ -617,25 +584,21 @@ UInt Cells4::learnBacktrack() // Remove any useless patterns at the head of the input pattern history // queue - for (UInt i = 0; i < (UInt) numPrevPatterns; i++) - { + for (UInt i = 0; i < (UInt)numPrevPatterns; i++) { std::vector::iterator result; result = find(badPatterns.begin(), badPatterns.end(), i); - if ( result != badPatterns.end() || (i <= startOffset) ) - { + if (result != badPatterns.end() || (i <= startOffset)) { if (_verbosity >= 3) { std::cout << "Removing useless pattern from history "; printActiveColumns(std::cout, _prevLrnPatterns[0]); std::cout << "\n"; } _prevLrnPatterns.pop_front(); - } - else + } else break; } return numPrevPatterns - startOffset; - } //---------------------------------------------------------------------- @@ -643,8 +606,7 @@ UInt Cells4::learnBacktrack() * Return the index of a cell in this column which is a good candidate * for adding a new segment. */ -UInt Cells4::getCellForNewSegment(UInt colIdx) -{ +UInt Cells4::getCellForNewSegment(UInt colIdx) { TIMER(getNewCellTimer.start()); UInt candidateCellIdx = 0; @@ -652,13 +614,12 @@ UInt Cells4::getCellForNewSegment(UInt colIdx) if (_maxSegmentsPerCell < 0) { if (_nCellsPerCol > 1) { // Don't ever choose the start cell (cell # 0) in each column - candidateCellIdx = _rng.getUInt32(_nCellsPerCol-1) + 1; - } - else { + candidateCellIdx = _rng.getUInt32(_nCellsPerCol - 1) + 1; + } else { candidateCellIdx = 0; } TIMER(getNewCellTimer.stop()); - return getCellIdx(colIdx,candidateCellIdx); + return getCellIdx(colIdx, candidateCellIdx); } // --------------------------------------------------------------------- @@ -671,28 +632,26 @@ UInt Cells4::getCellForNewSegment(UInt colIdx) // cell indices we choose in each column of a pattern will advance in // lockstep (i.e. we pick cell indices of 1, then cell indices of 2, etc.). static std::vector candidateCellIdxs; - candidateCellIdxs.clear(); // purge residual data - UInt minIdx = getCellIdx(colIdx,0), maxIdx = getCellIdx(colIdx,0); + candidateCellIdxs.clear(); // purge residual data + UInt minIdx = getCellIdx(colIdx, 0), maxIdx = getCellIdx(colIdx, 0); if (_nCellsPerCol > 0) { - minIdx = getCellIdx(colIdx,1); // Don't include startCell in the mix - maxIdx = getCellIdx(colIdx,_nCellsPerCol - 1); + minIdx = getCellIdx(colIdx, 1); // Don't include startCell in the mix + maxIdx = getCellIdx(colIdx, _nCellsPerCol - 1); } - for (UInt i = minIdx; i <= maxIdx; i++) - { - Int numSegs = (Int) _cells[i].size(); + for (UInt i = minIdx; i <= maxIdx; i++) { + Int numSegs = (Int)_cells[i].size(); if (numSegs < _maxSegmentsPerCell) { candidateCellIdxs.push_back(i); } } // If we found one, return with it - if (!candidateCellIdxs.empty()) - { + if (!candidateCellIdxs.empty()) { candidateCellIdx = - candidateCellIdxs[_rng.getUInt32(candidateCellIdxs.size())]; + candidateCellIdxs[_rng.getUInt32(candidateCellIdxs.size())]; if (_verbosity >= 5) { - std::cout << "Cell [" << colIdx - << "," << candidateCellIdx - getCellIdx(colIdx,0) + std::cout << "Cell [" << colIdx << "," + << candidateCellIdx - getCellIdx(colIdx, 0) << "] chosen for new segment, # of segs is " << _cells[candidateCellIdx].size() << "\n"; } @@ -704,14 +663,12 @@ UInt Cells4::getCellForNewSegment(UInt colIdx) // All cells in the column are full, find a segment with lowest duty // cycle to free up - UInt candidateSegmentIdx = (UInt) -1; + UInt candidateSegmentIdx = (UInt)-1; Real candidateSegmentDC = 1.0; // For each cell in this column - for (UInt i = minIdx; i <= maxIdx; i++) - { + for (UInt i = minIdx; i <= maxIdx; i++) { // For each non-empty segment in this cell - for (UInt segIdx= 0; segIdx < _cells[i].size(); segIdx++) - { + for (UInt segIdx = 0; segIdx < _cells[i].size(); segIdx++) { if (!_cells[i][segIdx].empty()) { Real dc = _cells[i][segIdx].dutyCycle(_nLrnIterations, false, false); if (dc < candidateSegmentDC) { @@ -726,18 +683,17 @@ UInt Cells4::getCellForNewSegment(UInt colIdx) // Free up the least used segment if (_verbosity >= 5) { std::cout << "Deleting segment #" << candidateSegmentIdx << " for cell[" - << colIdx << "," << candidateCellIdx - getCellIdx(colIdx,0) + << colIdx << "," << candidateCellIdx - getCellIdx(colIdx, 0) << "] to make room for new segment "; _cells[candidateCellIdx][candidateSegmentIdx].print(std::cout, - _nCellsPerCol); + _nCellsPerCol); std::cout << "\n"; - } // Remove this segment from cell and remove any pending updates to this // segment. Update outSynapses structure. std::vector synsToRemove; - synsToRemove.clear(); // purge residual data + synsToRemove.clear(); // purge residual data _cells[candidateCellIdx][candidateSegmentIdx].getSrcCellIndices(synsToRemove); eraseOutSynapses(candidateCellIdx, candidateSegmentIdx, synsToRemove); cleanUpdatesList(candidateCellIdx, candidateSegmentIdx); @@ -753,21 +709,21 @@ UInt Cells4::getCellForNewSegment(UInt colIdx) * Compute the learning active state given the predicted state and * the bottom-up input. */ -bool Cells4::learnPhase1(const std::vector & activeColumns, bool readOnly) -{ +bool Cells4::learnPhase1(const std::vector &activeColumns, + bool readOnly) { TIMER(learnPhase1Timer.start()); // Save previous active state (where?) and start out on a clean slate _learnActiveStateT.resetAll(); UInt numUnpredictedColumns = 0; - for (auto & activeColumn : activeColumns) { - UInt cell0 = activeColumn*_nCellsPerCol; + for (auto &activeColumn : activeColumns) { + UInt cell0 = activeColumn * _nCellsPerCol; // Find any predicting cell in this column (there is at most one) UInt numPredictedCells = 0, predictingCell = _nCellsPerCol; - for (UInt j= 0; j < _nCellsPerCol; j++) { - if (_learnPredictedStateT1.isSet(j+cell0)) { + for (UInt j = 0; j < _nCellsPerCol; j++) { + if (_learnPredictedStateT1.isSet(j + cell0)) { numPredictedCells++; predictingCell = j; } @@ -779,36 +735,34 @@ bool Cells4::learnPhase1(const std::vector & activeColumns, bool readOnly) if (numPredictedCells == 1) { NTA_ASSERT(predictingCell < _nCellsPerCol); _learnActiveStateT.set(cell0 + predictingCell); - } - else { + } else { //---------------------------------------------------------------------- // If no predicted cell, pick the closest matching one to reinforce, or // if none exists, create a new segment on a cell in that column numUnpredictedColumns++; - if (! readOnly) - { + if (!readOnly) { std::pair p; - p = getBestMatchingCellT1(activeColumn, _learnActiveStateT1, _minThreshold); + p = getBestMatchingCellT1(activeColumn, _learnActiveStateT1, + _minThreshold); UInt cellIdx = p.first, segIdx = p.second; // If we found a sequence segment, reinforce it - if (segIdx != (UInt) -1 && - _cells[cellIdx][segIdx].isSequenceSegment()) { + if (segIdx != (UInt)-1 && _cells[cellIdx][segIdx].isSequenceSegment()) { if (_verbosity >= 4) { std::cout << "Learn branch 0, found segment match: "; std::cout << " learning on col=" << activeColumn - << ", cellIdx=" << cellIdx << "\n"; + << ", cellIdx=" << cellIdx << "\n"; } _learnActiveStateT.set(cellIdx); - bool newUpdate = computeUpdate(cellIdx, segIdx, - _learnActiveStateT1, true, true); + bool newUpdate = + computeUpdate(cellIdx, segIdx, _learnActiveStateT1, true, true); _cells[cellIdx][segIdx]._totalActivations++; if (newUpdate) { // This will update the permanences, posActivationsCount, and the // lastActiveIteration (age). - const SegmentUpdate& update = _segmentUpdates.back(); + const SegmentUpdate &update = _segmentUpdates.back(); adaptSegment(update); _segmentUpdates.pop_back(); } @@ -822,17 +776,16 @@ bool Cells4::learnPhase1(const std::vector & activeColumns, bool readOnly) std::cout << "Learn branch 1, no match: "; std::cout << " learning on col=" << activeColumn << ", newCellIdxInCol=" - << newCellIdx - getCellIdx(activeColumn, 0) - << "\n"; + << newCellIdx - getCellIdx(activeColumn, 0) << "\n"; } _learnActiveStateT.set(newCellIdx); - bool newUpdate = computeUpdate(newCellIdx, (UInt) -1, + bool newUpdate = computeUpdate(newCellIdx, (UInt)-1, _learnActiveStateT1, true, true); // This will update the permanences, posActivationsCount, and the // lastActiveIteration (age). if (newUpdate) { - const SegmentUpdate& update = _segmentUpdates.back(); + const SegmentUpdate &update = _segmentUpdates.back(); adaptSegment(update); _segmentUpdates.pop_back(); } @@ -847,15 +800,14 @@ bool Cells4::learnPhase1(const std::vector & activeColumns, bool readOnly) //---------------------------------------------------------------------- // Determine if we are out of sequence or not and reset our PAM counter // if we are in sequence - return numUnpredictedColumns < activeColumns.size()/2; + return numUnpredictedColumns < activeColumns.size() / 2; } //-------------------------------------------------------------------------------- /** * Compute the predicted segments given the current set of active cells. */ -void Cells4::learnPhase2(bool readOnly) -{ +void Cells4::learnPhase2(bool readOnly) { // Compute number of active synapses per segment based on forward propagation TIMER(forwardLearnPropTimer.start()); computeForwardPropagation(_learnActiveStateT); @@ -872,15 +824,16 @@ void Cells4::learnPhase2(bool readOnly) std::pair p; p = getBestMatchingCellT(colIdx, _learnActiveStateT, _activationThreshold); UInt cellIdx = p.first, segIdx = p.second; - if (segIdx != (UInt) -1) { + if (segIdx != (UInt)-1) { // Turn on the predicted state for the best matching cell and queue // the pertinent segment up for an update, which will get processed if // the cell receives bottom up in the future. _learnPredictedStateT.set(cellIdx); if (!readOnly) { if (_verbosity >= 4) { - std::cout << "learnPhase2, learning on col=" << colIdx << ", cellIdx=" - << cellIdx << ", seg ID: " << segIdx << ", segment: "; + std::cout << "learnPhase2, learning on col=" << colIdx + << ", cellIdx=" << cellIdx << ", seg ID: " << segIdx + << ", segment: "; _cells[cellIdx][segIdx].print(std::cout, _nCellsPerCol); std::cout << "\n"; } @@ -899,9 +852,8 @@ void Cells4::learnPhase2(bool readOnly) /** * Update the learning state. Called from compute() */ -void Cells4::updateLearningState(const std::vector & activeColumns, - Real* input) -{ +void Cells4::updateLearningState(const std::vector &activeColumns, + Real *input) { // ========================================================================= // Copy over learning states to t-1 and reset state at t to 0 _learnActiveStateT1 = _learnActiveStateT; @@ -938,7 +890,7 @@ void Cells4::updateLearningState(const std::vector & activeColumns, //--------------------------------------------------------------------------- // For each column, turn on the predicted cell. At all times at most // 1 cell is active per column in the learn predicted state. - if (! _resetCalled) { + if (!_resetCalled) { bool inSequence = learnPhase1(activeColumns, false); if (inSequence) { _pamCounter = _pamLength; @@ -947,8 +899,8 @@ void Cells4::updateLearningState(const std::vector & activeColumns, // Print status of PAM counter, learned sequence length if (_verbosity >= 3) { - std::cout << "pamCounter = " << _pamCounter << ", learnedSeqLength = " - << _learnedSeqLength << "\n"; + std::cout << "pamCounter = " << _pamCounter + << ", learnedSeqLength = " << _learnedSeqLength << "\n"; } //--------------------------------------------------------------------------- @@ -956,30 +908,29 @@ void Cells4::updateLearningState(const std::vector & activeColumns, // 1.) A reset was just called // 2.) We have been too long out of sequence (the pamCounter has expired) // 3.) We have reached maximum allowed sequence length. - if ( _resetCalled || - ( _pamCounter==0) || - ( (_maxSeqLength != 0) && (_learnedSeqLength >= _maxSeqLength) ) - ) - { + if (_resetCalled || (_pamCounter == 0) || + ((_maxSeqLength != 0) && (_learnedSeqLength >= _maxSeqLength))) { if (_verbosity >= 3) { std::cout << "Starting over:"; printActiveColumns(std::cout, activeColumns); - if (_resetCalled) std::cout << "(reset was called)\n"; - else if (_pamCounter == 0) std::cout << "(PAM counter expired)\n"; - else std::cout << "(reached maxSeqLength)\n"; + if (_resetCalled) + std::cout << "(reset was called)\n"; + else if (_pamCounter == 0) + std::cout << "(PAM counter expired)\n"; + else + std::cout << "(reached maxSeqLength)\n"; } // Update average learned sequence length - this is a diagnostic statistic - UInt seqLength = ( _pamCounter == 0 ? - _learnedSeqLength - _pamLength : - _learnedSeqLength ); + UInt seqLength = + (_pamCounter == 0 ? _learnedSeqLength - _pamLength : _learnedSeqLength); if (_verbosity >= 3) std::cout << " learned sequence length was: " << seqLength << "\n"; _updateAvgLearnedSeqLength(seqLength); // Backtrack to an earlier starting point, if we find one UInt backsteps = 0; - if (! _resetCalled ) { + if (!_resetCalled) { TIMER(learnBacktrackTimer.start()); backsteps = learnBacktrack(); TIMER(learnBacktrackTimer.stop()); @@ -987,10 +938,10 @@ void Cells4::updateLearningState(const std::vector & activeColumns, // Start over in the current time step if reset was called, or we couldn't // backtrack - if (_resetCalled || backsteps==0) { + if (_resetCalled || backsteps == 0) { _learnActiveStateT.resetAll(); - for (auto & activeColumn : activeColumns) { - UInt cell0 = activeColumn*_nCellsPerCol; + for (auto &activeColumn : activeColumns) { + UInt cell0 = activeColumn * _nCellsPerCol; _learnActiveStateT.set(cell0); } @@ -1004,12 +955,10 @@ void Cells4::updateLearningState(const std::vector & activeColumns, // Clear out any old segment updates from prior sequences _segmentUpdates.clear(); - } // Done computing active state - // ========================================================================= // Phase 2 - Compute new predicted state. When computing predictions for // phase 2, we predict at most one cell per column (the one with the best @@ -1018,24 +967,23 @@ void Cells4::updateLearningState(const std::vector & activeColumns, learnPhase2(false); } - //-------------------------------------------------------------------------------- /** * Update the inference state. Called from compute() on every iteration */ -void Cells4::updateInferenceState(const std::vector & activeColumns) -{ +void Cells4::updateInferenceState(const std::vector &activeColumns) { //--------------------------------------------------------------------------- // Copy over inference related states to t-1 // We need to do a copy here in case the buffers are numpy allocated // A possible optimization here is to do a swap if Cells4 owns its memory. _infActiveStateT1 = _infActiveStateT; _infPredictedStateT1 = _infPredictedStateT; - memcpy(_cellConfidenceT1, _cellConfidenceT, _nCells * sizeof(_cellConfidenceT[0])); + memcpy(_cellConfidenceT1, _cellConfidenceT, + _nCells * sizeof(_cellConfidenceT[0])); // Copy over previous column confidences - memcpy(_colConfidenceT1, _colConfidenceT, _nColumns * sizeof(_colConfidenceT[0])); - + memcpy(_colConfidenceT1, _colConfidenceT, + _nColumns * sizeof(_colConfidenceT[0])); //--------------------------------------------------------------------------- // Update our inference input history @@ -1084,9 +1032,8 @@ void Cells4::updateInferenceState(const std::vector & activeColumns) * Update the inference active state from the last set of predictions * and the current bottom-up. */ -bool Cells4::inferPhase1(const std::vector & activeColumns, - bool useStartCells) -{ +bool Cells4::inferPhase1(const std::vector &activeColumns, + bool useStartCells) { TIMER(infPhase1Timer.start()); //--------------------------------------------------------------------------- // Initialize current active state to 0 to start @@ -1098,9 +1045,8 @@ bool Cells4::inferPhase1(const std::vector & activeColumns, // If we are following a reset, activate only the start cell in each // column that has bottom-up UInt numPredictedColumns = 0; - if (useStartCells) - { - for (auto & activeColumn : activeColumns) { + if (useStartCells) { + for (auto &activeColumn : activeColumns) { UInt cellIdx = activeColumn * _nCellsPerCol; _infActiveStateT.set(cellIdx); } @@ -1108,14 +1054,12 @@ bool Cells4::inferPhase1(const std::vector & activeColumns, } // else, for each column turn on any predicted cells. If there are none, then // turn on all cells (burst the column) - else - { - for (auto & activeColumn : activeColumns) { + else { + for (auto &activeColumn : activeColumns) { UInt cellIdx = activeColumn * _nCellsPerCol; UInt numPredictingCells = 0; - for (UInt ci = cellIdx; ci < cellIdx + _nCellsPerCol; ci++) - { + for (UInt ci = cellIdx; ci < cellIdx + _nCellsPerCol; ci++) { if (_infPredictedStateT1.isSet(ci)) { numPredictingCells++; _infActiveStateT.set(ci); @@ -1124,12 +1068,10 @@ bool Cells4::inferPhase1(const std::vector & activeColumns, if (numPredictingCells > 0) { numPredictedColumns += 1; - } - else { - //std::cout << "inferPhase1 bursting col=" << activeColumns[i] << "\n"; - for (UInt ci = cellIdx; ci < cellIdx + _nCellsPerCol; ci++) - { - _infActiveStateT.set(ci); // whole column bursts + } else { + // std::cout << "inferPhase1 bursting col=" << activeColumns[i] << "\n"; + for (UInt ci = cellIdx; ci < cellIdx + _nCellsPerCol; ci++) { + _infActiveStateT.set(ci); // whole column bursts } } } @@ -1137,7 +1079,8 @@ bool Cells4::inferPhase1(const std::vector & activeColumns, TIMER(infPhase1Timer.stop()); // Did we predict this input well enough? - return (useStartCells || (numPredictedColumns >= 0.50 * activeColumns.size()) ); + return (useStartCells || + (numPredictedColumns >= 0.50 * activeColumns.size())); } //------------------------------------------------------------------------------ @@ -1146,8 +1089,7 @@ bool Cells4::inferPhase1(const std::vector & activeColumns, * then checks to insure that the predicted state is not over-saturated, * i.e. look too close like a burst. */ -bool Cells4::inferPhase2() -{ +bool Cells4::inferPhase2() { // Compute number of active synapses per segment based on forward propagation TIMER(forwardInfPropTimer.start()); computeForwardPropagation(_infActiveStateT); @@ -1178,10 +1120,10 @@ bool Cells4::inferPhase2() // Run sanity check to ensure forward prop matches activity // calcuations (turned on in some tests) if (_checkSynapseConsistency) { - const Segment& seg = _cells[cellIdx][j]; - UInt numActiveSyns = seg.computeActivity( - _infActiveStateT, _permConnected, false); - NTA_CHECK( numActiveSyns == _inferActivity.get(cellIdx, j) ); + const Segment &seg = _cells[cellIdx][j]; + UInt numActiveSyns = + seg.computeActivity(_infActiveStateT, _permConnected, false); + NTA_CHECK(numActiveSyns == _inferActivity.get(cellIdx, j)); } // See if segment has a min number of active synapses @@ -1189,7 +1131,8 @@ bool Cells4::inferPhase2() // Incorporate the confidence into the owner cell and column // Use segment::getLastPosDutyCycle() here - Real dc = _cells[cellIdx][j].dutyCycle(_nLrnIterations, false, false); + Real dc = + _cells[cellIdx][j].dutyCycle(_nLrnIterations, false, false); _cellConfidenceT[cellIdx] += dc; _colConfidenceT[c] += dc; @@ -1213,30 +1156,31 @@ bool Cells4::inferPhase2() //--------------------------------------------------------------------------- // Normalize column confidences if (sumColConfidence > 0) { - for (UInt c = 0; c < _nColumns; c++) _colConfidenceT[c] /= sumColConfidence; - for (UInt i = 0; i < _nCells; i++) _cellConfidenceT[i] /= sumColConfidence; + for (UInt c = 0; c < _nColumns; c++) + _colConfidenceT[c] /= sumColConfidence; + for (UInt i = 0; i < _nCells; i++) + _cellConfidenceT[i] /= sumColConfidence; } - // Turn off timer before we return TIMER(infPhase2Timer.stop()); //--------------------------------------------------------------------------- // Are we predicting the required minimum number of columns? - return (numPredictedCols >= (0.5*_avgInputDensity)); + return (numPredictedCols >= (0.5 * _avgInputDensity)); } - //------------------------------------------------------------------------------ /** * Main compute routine, called for both learning and inference. */ -void Cells4::compute(Real* input, Real* output, bool doInference, bool doLearning) -{ +void Cells4::compute(Real *input, Real *output, bool doInference, + bool doLearning) { TIMER(computeTimer.start()); NTA_CHECK(doInference || doLearning); - if (doLearning) _nLrnIterations++; + if (doLearning) + _nLrnIterations++; ++_nIterations; #ifdef CELLS4_TIMING @@ -1247,15 +1191,17 @@ void Cells4::compute(Real* input, Real* output, bool doInference, bool doLearnin } if (_verbosity >= 3) { - std::cout << "\n==== CPP Iteration: " << _nIterations << " =====" << std::endl; + std::cout << "\n==== CPP Iteration: " << _nIterations + << " =====" << std::endl; } #endif // Create array of active bottom up column indices for later use static std::vector activeColumns; - activeColumns.clear(); // purge residual data + activeColumns.clear(); // purge residual data for (UInt i = 0; i != _nColumns; ++i) { - if (input[i]) activeColumns.push_back(i); + if (input[i]) + activeColumns.push_back(i); } // Print active columns @@ -1265,12 +1211,11 @@ void Cells4::compute(Real* input, Real* output, bool doInference, bool doLearnin std::cout << "\n"; } - //--------------------------------------------------------------------------- // Update segment duty cycles if we are crossing a "tier" if (doLearning && Segment::atDutyCycleTier(_nLrnIterations)) { - for (auto & cell : _cells) { + for (auto &cell : _cells) { cell.updateDutyCycle(_nLrnIterations); } } @@ -1278,11 +1223,12 @@ void Cells4::compute(Real* input, Real* output, bool doInference, bool doLearnin //--------------------------------------------------------------------------- // Update average input density if (_avgInputDensity == 0.0) { - _avgInputDensity = (Real) activeColumns.size(); + _avgInputDensity = (Real)activeColumns.size(); } else { // TODO remove magic constants. should this be swarmed-over? const auto COOL_DOWN = (Real)0.99; - _avgInputDensity = COOL_DOWN*_avgInputDensity + (1-COOL_DOWN)*(Real)activeColumns.size(); + _avgInputDensity = COOL_DOWN * _avgInputDensity + + (1 - COOL_DOWN) * (Real)activeColumns.size(); } //--------------------------------------------------------------------------- @@ -1328,22 +1274,30 @@ void Cells4::compute(Real* input, Real* output, bool doInference, bool doLearnin memset(output, 0, _nCells * sizeof(output[0])); // most output is zero #if SOME_STATES_NOT_INDEXED #if defined(NTA_ARCH_32) - const UInt multipleOf4 = 4 * (_nCells/4); + const UInt multipleOf4 = 4 * (_nCells / 4); UInt i; for (i = 0; i < multipleOf4; i += 4) { - UInt32 fourStates = * (UInt32 *)(_infPredictedStateT.arrayPtr() + i); + UInt32 fourStates = *(UInt32 *)(_infPredictedStateT.arrayPtr() + i); if (fourStates != 0) { - if ((fourStates & 0x000000ff) != 0) output[i + 0] = 1.0; - if ((fourStates & 0x0000ff00) != 0) output[i + 1] = 1.0; - if ((fourStates & 0x00ff0000) != 0) output[i + 2] = 1.0; - if ((fourStates & 0xff000000) != 0) output[i + 3] = 1.0; - } - fourStates = * (UInt32 *)(_infActiveStateT.arrayPtr() + i); + if ((fourStates & 0x000000ff) != 0) + output[i + 0] = 1.0; + if ((fourStates & 0x0000ff00) != 0) + output[i + 1] = 1.0; + if ((fourStates & 0x00ff0000) != 0) + output[i + 2] = 1.0; + if ((fourStates & 0xff000000) != 0) + output[i + 3] = 1.0; + } + fourStates = *(UInt32 *)(_infActiveStateT.arrayPtr() + i); if (fourStates != 0) { - if ((fourStates & 0x000000ff) != 0) output[i + 0] = 1.0; - if ((fourStates & 0x0000ff00) != 0) output[i + 1] = 1.0; - if ((fourStates & 0x00ff0000) != 0) output[i + 2] = 1.0; - if ((fourStates & 0xff000000) != 0) output[i + 3] = 1.0; + if ((fourStates & 0x000000ff) != 0) + output[i + 0] = 1.0; + if ((fourStates & 0x0000ff00) != 0) + output[i + 1] = 1.0; + if ((fourStates & 0x00ff0000) != 0) + output[i + 2] = 1.0; + if ((fourStates & 0xff000000) != 0) + output[i + 3] = 1.0; } } @@ -1351,36 +1305,51 @@ void Cells4::compute(Real* input, Real* output, bool doInference, bool doLearnin for (i = multipleOf4; i < _nCells; i++) { if (_infPredictedStateT.isSet(i)) { output[i] = 1.0; - } - else if (_infActiveStateT.isSet(i)) { + } else if (_infActiveStateT.isSet(i)) { output[i] = 1.0; } } #else - const UInt multipleOf8 = 8 * (_nCells/8); + const UInt multipleOf8 = 8 * (_nCells / 8); UInt i; for (i = 0; i < multipleOf8; i += 8) { - UInt64 eightStates = * (UInt64 *)(_infPredictedStateT.arrayPtr() + i); + UInt64 eightStates = *(UInt64 *)(_infPredictedStateT.arrayPtr() + i); if (eightStates != 0) { - if ((eightStates & 0x00000000000000ff) != 0) output[i + 0] = 1.0; - if ((eightStates & 0x000000000000ff00) != 0) output[i + 1] = 1.0; - if ((eightStates & 0x0000000000ff0000) != 0) output[i + 2] = 1.0; - if ((eightStates & 0x00000000ff000000) != 0) output[i + 3] = 1.0; - if ((eightStates & 0x000000ff00000000) != 0) output[i + 4] = 1.0; - if ((eightStates & 0x0000ff0000000000) != 0) output[i + 5] = 1.0; - if ((eightStates & 0x00ff000000000000) != 0) output[i + 6] = 1.0; - if ((eightStates & 0xff00000000000000) != 0) output[i + 7] = 1.0; - } - eightStates = * (UInt64 *)(_infActiveStateT.arrayPtr() + i); + if ((eightStates & 0x00000000000000ff) != 0) + output[i + 0] = 1.0; + if ((eightStates & 0x000000000000ff00) != 0) + output[i + 1] = 1.0; + if ((eightStates & 0x0000000000ff0000) != 0) + output[i + 2] = 1.0; + if ((eightStates & 0x00000000ff000000) != 0) + output[i + 3] = 1.0; + if ((eightStates & 0x000000ff00000000) != 0) + output[i + 4] = 1.0; + if ((eightStates & 0x0000ff0000000000) != 0) + output[i + 5] = 1.0; + if ((eightStates & 0x00ff000000000000) != 0) + output[i + 6] = 1.0; + if ((eightStates & 0xff00000000000000) != 0) + output[i + 7] = 1.0; + } + eightStates = *(UInt64 *)(_infActiveStateT.arrayPtr() + i); if (eightStates != 0) { - if ((eightStates & 0x00000000000000ff) != 0) output[i + 0] = 1.0; - if ((eightStates & 0x000000000000ff00) != 0) output[i + 1] = 1.0; - if ((eightStates & 0x0000000000ff0000) != 0) output[i + 2] = 1.0; - if ((eightStates & 0x00000000ff000000) != 0) output[i + 3] = 1.0; - if ((eightStates & 0x000000ff00000000) != 0) output[i + 4] = 1.0; - if ((eightStates & 0x0000ff0000000000) != 0) output[i + 5] = 1.0; - if ((eightStates & 0x00ff000000000000) != 0) output[i + 6] = 1.0; - if ((eightStates & 0xff00000000000000) != 0) output[i + 7] = 1.0; + if ((eightStates & 0x00000000000000ff) != 0) + output[i + 0] = 1.0; + if ((eightStates & 0x000000000000ff00) != 0) + output[i + 1] = 1.0; + if ((eightStates & 0x0000000000ff0000) != 0) + output[i + 2] = 1.0; + if ((eightStates & 0x00000000ff000000) != 0) + output[i + 3] = 1.0; + if ((eightStates & 0x000000ff00000000) != 0) + output[i + 4] = 1.0; + if ((eightStates & 0x0000ff0000000000) != 0) + output[i + 5] = 1.0; + if ((eightStates & 0x00ff000000000000) != 0) + output[i + 6] = 1.0; + if ((eightStates & 0xff00000000000000) != 0) + output[i + 7] = 1.0; } } @@ -1388,8 +1357,7 @@ void Cells4::compute(Real* input, Real* output, bool doInference, bool doLearnin for (i = multipleOf8; i < _nCells; i++) { if (_infPredictedStateT.isSet(i)) { output[i] = 1.0; - } - else if (_infActiveStateT.isSet(i)) { + } else if (_infActiveStateT.isSet(i)) { output[i] = 1.0; } } @@ -1403,10 +1371,9 @@ void Cells4::compute(Real* input, Real* output, bool doInference, bool doLearnin cellsOn = _infActiveStateT.cellsOn(); for (iterOn = cellsOn.begin(); iterOn != cellsOn.end(); ++iterOn) output[*iterOn] = 1.0; -#endif // SOME_STATES_NOT_INDEXED +#endif // SOME_STATES_NOT_INDEXED - if (_checkSynapseConsistency) - { + if (_checkSynapseConsistency) { NTA_CHECK(invariants(true)); } TIMER(computeTimer.stop()); @@ -1416,20 +1383,18 @@ void Cells4::compute(Real* input, Real* output, bool doInference, bool doLearnin /** * Update our moving average of learned sequence length. */ -void Cells4::_updateAvgLearnedSeqLength(UInt prevSeqLength) -{ +void Cells4::_updateAvgLearnedSeqLength(UInt prevSeqLength) { Real alpha = 0.1; - if (_nLrnIterations < 100) alpha = 0.5; + if (_nLrnIterations < 100) + alpha = 0.5; if (_verbosity >= 5) { - std::cout << "_updateAvgLearnedSeqLength before = " - << _avgLearnedSeqLength << " prevSeqLength = " - << prevSeqLength << "\n"; + std::cout << "_updateAvgLearnedSeqLength before = " << _avgLearnedSeqLength + << " prevSeqLength = " << prevSeqLength << "\n"; } - _avgLearnedSeqLength = (1.0 - alpha)*_avgLearnedSeqLength + - alpha * (Real) prevSeqLength; + _avgLearnedSeqLength = + (1.0 - alpha) * _avgLearnedSeqLength + alpha * (Real)prevSeqLength; if (_verbosity >= 5) { - std::cout << " after = " - << _avgLearnedSeqLength << "\n"; + std::cout << " after = " << _avgLearnedSeqLength << "\n"; } } @@ -1437,14 +1402,13 @@ void Cells4::_updateAvgLearnedSeqLength(UInt prevSeqLength) /** * Go through the list of accumulated segment updates and process them. */ -void Cells4::processSegmentUpdates(Real* input, const CState& predictedState) -{ +void Cells4::processSegmentUpdates(Real *input, const CState &predictedState) { static std::vector delUpdates; - delUpdates.clear(); // purge residual data + delUpdates.clear(); // purge residual data for (UInt i = 0; i != _segmentUpdates.size(); ++i) { - const SegmentUpdate& update = _segmentUpdates[i]; + const SegmentUpdate &update = _segmentUpdates[i]; if (_verbosity >= 4) { std::cout << "\n_nLrnIterations: " << _nLrnIterations @@ -1456,27 +1420,30 @@ void Cells4::processSegmentUpdates(Real* input, const CState& predictedState) // Decide whether to apply the update now. If update has expired, then // mark this update for deletion if (_nLrnIterations - update.timeStamp() > _segUpdateValidDuration) { - if (_verbosity >= 4) std::cout << " Expired, deleting now.\n"; + if (_verbosity >= 4) + std::cout << " Expired, deleting now.\n"; delUpdates.push_back(i); } // Update has not expired else { UInt cellIdx = update.cellIdx(); - UInt colIdx = (UInt) (cellIdx / _nCellsPerCol); + UInt colIdx = (UInt)(cellIdx / _nCellsPerCol); // If we received bottom up input, then adapt this segment and schedule // update for removal - if ( input[colIdx] == 1 ) { + if (input[colIdx] == 1) { - if (_verbosity >= 4) std::cout << " Applying update now.\n"; + if (_verbosity >= 4) + std::cout << " Applying update now.\n"; adaptSegment(update); delUpdates.push_back(i); } else { // We didn't receive bottom up input. If we are not (pooling and still // predicting) then delete this update - if ( ! (_doPooling && predictedState.isSet(cellIdx)) ) { - if (_verbosity >= 4) std::cout << " Deleting update now.\n"; + if (!(_doPooling && predictedState.isSet(cellIdx))) { + if (_verbosity >= 4) + std::cout << " Deleting update now.\n"; delUpdates.push_back(i); } } @@ -1486,7 +1453,6 @@ void Cells4::processSegmentUpdates(Real* input, const CState& predictedState) } // Loop over updates remove_at(delUpdates, _segmentUpdates); - } //---------------------------------------------------------------------- @@ -1494,27 +1460,25 @@ void Cells4::processSegmentUpdates(Real* input, const CState& predictedState) * Removes any updates that would be applied to the given col, * cellIdx, segIdx. */ -void Cells4::cleanUpdatesList(UInt cellIdx, UInt segIdx) -{ +void Cells4::cleanUpdatesList(UInt cellIdx, UInt segIdx) { static std::vector delUpdates; - delUpdates.clear(); // purge residual data + delUpdates.clear(); // purge residual data for (UInt i = 0; i != _segmentUpdates.size(); ++i) { // Get the cell and column associated with this update - const SegmentUpdate& update = _segmentUpdates[i]; + const SegmentUpdate &update = _segmentUpdates[i]; if (_verbosity >= 4) { std::cout << "\nIn cleanUpdatesList. _nLrnIterations: " << _nLrnIterations - << " checking segment: "; + << " checking segment: "; update.print(std::cout, true, _nCellsPerCol); std::cout << std::endl; } // Decide whether to remove update. Note: we can't remove update from // vector while we are iterating over it. - if ( (update.cellIdx() == cellIdx) && (segIdx == update.segIdx()) ) - { + if ((update.cellIdx() == cellIdx) && (segIdx == update.segIdx())) { if (_verbosity >= 4) { std::cout << " Removing it\n"; } @@ -1533,20 +1497,19 @@ void Cells4::cleanUpdatesList(UInt cellIdx, UInt segIdx) * Apply age-based global decay logic and remove segments/synapses * as appropriate. */ -void Cells4::applyGlobalDecay() -{ +void Cells4::applyGlobalDecay() { UInt nSegmentsDecayed = 0, nSynapsesRemoved = 0; - if (_globalDecay != 0 && (_maxAge>0) && (_nLrnIterations % _maxAge == 0) ) { + if (_globalDecay != 0 && (_maxAge > 0) && (_nLrnIterations % _maxAge == 0)) { for (UInt cellIdx = 0; cellIdx != _nCells; ++cellIdx) { for (UInt segIdx = 0; segIdx != _cells[cellIdx].size(); ++segIdx) { - Segment& seg = segment(cellIdx, segIdx); + Segment &seg = segment(cellIdx, segIdx); UInt age = _nLrnIterations - seg._lastActiveIteration; - if ( age > _maxAge ) { + if (age > _maxAge) { static std::vector removedSynapses; - removedSynapses.clear(); // purge residual data + removedSynapses.clear(); // purge residual data nSegmentsDecayed++; seg.decaySynapses2(_globalDecay, removedSynapses, _permConnected); @@ -1558,21 +1521,20 @@ void Cells4::applyGlobalDecay() if (seg.empty()) { _cells[cellIdx].releaseSegment(segIdx); } - } } } if (_verbosity >= 3) { std::cout << "CPP Global decay decremented " << nSegmentsDecayed - << " segments and removed " - << nSynapsesRemoved << " synapses\n"; - std::cout << "_nLrnIterations = " << _nLrnIterations << ", _maxAge = " - << _maxAge << ", globalDecay = " << _globalDecay << "\n"; + << " segments and removed " << nSynapsesRemoved + << " synapses\n"; + std::cout << "_nLrnIterations = " << _nLrnIterations + << ", _maxAge = " << _maxAge + << ", globalDecay = " << _globalDecay << "\n"; } } // (_globalDecay) } - //-------------------------------------------------------------------------------- /** * Helper function for Cells4::adaptSegment. Generates lists of synapses to @@ -1583,13 +1545,11 @@ void Cells4::applyGlobalDecay() * */ void Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment( - const Segment& segment, - std::set& synapsesSet, - std::vector& inactiveSrcCellIdxs, - std::vector& inactiveSynapseIdxs, - std::vector& activeSrcCellIdxs, - std::vector& activeSynapseIdxs) -{ + const Segment &segment, std::set &synapsesSet, + std::vector &inactiveSrcCellIdxs, + std::vector &inactiveSynapseIdxs, + std::vector &activeSrcCellIdxs, + std::vector &activeSynapseIdxs) { // Purge residual data inactiveSrcCellIdxs.clear(); inactiveSynapseIdxs.clear(); @@ -1601,8 +1561,7 @@ void Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment( if (not_in(srcCellIdx, synapsesSet)) { inactiveSrcCellIdxs.push_back(srcCellIdx); inactiveSynapseIdxs.push_back(i); - } - else { + } else { activeSrcCellIdxs.push_back(srcCellIdx); synapsesSet.erase(srcCellIdx); // those that remain will be created activeSynapseIdxs.push_back(i); @@ -1610,7 +1569,6 @@ void Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment( } } - //-------------------------------------------------------------------------------- /** * Applies segment update information to a segment in a cell. @@ -1620,8 +1578,7 @@ void Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment( * TODO: This whole method is really ugly and could do with a serious cleanup. * */ -void Cells4::adaptSegment(const SegmentUpdate& update) -{ +void Cells4::adaptSegment(const SegmentUpdate &update) { TIMER(adaptSegmentTimer.start()); { // consistency checks: @@ -1632,7 +1589,7 @@ void Cells4::adaptSegment(const SegmentUpdate& update) UInt cellIdx = update.cellIdx(); UInt segIdx = update.segIdx(); - if (! update.isNewSegment()) { // modify an existing segment + if (!update.isNewSegment()) { // modify an existing segment // Sometimes you can have a pending update after a segment has already // been released. It's cheaper to deal with it here rather than do @@ -1642,13 +1599,13 @@ void Cells4::adaptSegment(const SegmentUpdate& update) return; } - Segment& segment = _cells[cellIdx][segIdx]; + Segment &segment = _cells[cellIdx][segIdx]; if (_verbosity >= 4) { - UInt col = (UInt) (cellIdx / _nCellsPerCol); - UInt cell = cellIdx - col*_nCellsPerCol; - std::cout << "Reinforcing segment " << segIdx << " for cell[" - << col<< "," << cell << "]\n before: "; + UInt col = (UInt)(cellIdx / _nCellsPerCol); + UInt cell = cellIdx - col * _nCellsPerCol; + std::cout << "Reinforcing segment " << segIdx << " for cell[" << col + << "," << cell << "]\n before: "; segment.print(std::cout, _nCellsPerCol); std::cout << std::endl; } @@ -1675,7 +1632,8 @@ void Cells4::adaptSegment(const SegmentUpdate& update) // NOTE: the following variables were declared static as a performance // optimization in the legacy code in order to reduce memory allocations. // The side effect is that this code is not thread-safe. At the time, - // processes were the mechanism for parllelizing execution of the algorithms. + // processes were the mechanism for parllelizing execution of the + // algorithms. // Tracks source cell indexes corresponding to synapses in // the given segment that have been removed during execution of this method @@ -1687,20 +1645,19 @@ void Cells4::adaptSegment(const SegmentUpdate& update) // Indexes of synapses within the current segment corresponding to synapses // that are inactive/active in ascending order; these variables correlate // with synToDec and synToInc. - static std::vector inactiveSegmentIndices, - activeSegmentIndices; + static std::vector inactiveSegmentIndices, activeSegmentIndices; // Purge residual data from static variable; the others will be purged by // _generateListsOfSynapsesToAdjustForAdaptSegment removed.clear(); _generateListsOfSynapsesToAdjustForAdaptSegment( - segment, synapsesSet, - synToDec, inactiveSegmentIndices, - synToInc, activeSegmentIndices); + segment, synapsesSet, synToDec, inactiveSegmentIndices, synToInc, + activeSegmentIndices); // Decrement permanences of inactive synapses - segment.updateSynapses(synToDec, - _permDec, _permMax, _permConnected, removed); + segment.updateSynapses(synToDec, -_permDec, _permMax, _permConnected, + removed); // If any synapses were removed as the result of permanence decrements, // regenerate affected parameters @@ -1709,28 +1666,27 @@ void Cells4::adaptSegment(const SegmentUpdate& update) synapsesSet.insert(update.begin(), update.end()); _generateListsOfSynapsesToAdjustForAdaptSegment( - segment, synapsesSet, - synToDec, inactiveSegmentIndices, - synToInc, activeSegmentIndices); + segment, synapsesSet, synToDec, inactiveSegmentIndices, synToInc, + activeSegmentIndices); } // Increment permanences of active synapses const UInt numRemovedBeforePermInc = removed.size(); - segment.updateSynapses(synToInc, _permInc, _permMax, _permConnected, removed); + segment.updateSynapses(synToInc, _permInc, _permMax, _permConnected, + removed); // Incrementing of permanences shouldn't remove synapses NTA_CHECK(removed.size() == numRemovedBeforePermInc); // If we have fixed resources, get rid of some old synapses, if necessary - if ( (_maxSynapsesPerSegment > 0) && - (synapsesSet.size() + segment.size() > (UInt) _maxSynapsesPerSegment) ) { + if ((_maxSynapsesPerSegment > 0) && + (synapsesSet.size() + segment.size() > (UInt)_maxSynapsesPerSegment)) { // TODO: What's preventing numToFree from exceeding segment.size()? If you // know, add a comment explaining it. If it exceeds, it will cause memory // corruption in Segment::freeNSynapses. - UInt numToFree = synapsesSet.size() + segment.size() - _maxSynapsesPerSegment; - segment.freeNSynapses(numToFree, - synToDec, inactiveSegmentIndices, - synToInc, activeSegmentIndices, - removed, _verbosity, + UInt numToFree = + synapsesSet.size() + segment.size() - _maxSynapsesPerSegment; + segment.freeNSynapses(numToFree, synToDec, inactiveSegmentIndices, + synToInc, activeSegmentIndices, removed, _verbosity, _nCellsPerCol, _permMax); } @@ -1763,10 +1719,9 @@ void Cells4::adaptSegment(const SegmentUpdate& update) for (UInt i = 0; i != update.size(); ++i) { synapses.push_back(InSynapse(update[i], _permInitial)); } - UInt segIdx = - _cells[cellIdx].getFreeSegment(synapses, _initSegFreq, - update.isSequenceSegment(), _permConnected, - _nLrnIterations); + UInt segIdx = _cells[cellIdx].getFreeSegment( + synapses, _initSegFreq, update.isSequenceSegment(), _permConnected, + _nLrnIterations); // Initialize the new segment's last active iteration and frequency related // counts @@ -1782,9 +1737,7 @@ void Cells4::adaptSegment(const SegmentUpdate& update) std::cout << std::endl; } - addOutSynapses(cellIdx, segIdx, update.begin(), update.end()); - } if (_checkSynapseConsistency) { @@ -1794,15 +1747,12 @@ void Cells4::adaptSegment(const SegmentUpdate& update) TIMER(adaptSegmentTimer.stop()); } - - // Rebalance segment lists for each cell -void Cells4::_rebalance() -{ +void Cells4::_rebalance() { std::cout << "Rebalancing\n"; _nIterationsSinceRebalance = _nLrnIterations; - for (auto & cell : _cells) { + for (auto &cell : _cells) { if (!cell.empty()) { cell.rebalanceSegments(); } @@ -1810,7 +1760,6 @@ void Cells4::_rebalance() // After rebalancing we need to redo the OutSynapses rebuildOutSynapses(); - } //-------------------------------------------------------------------------------- @@ -1818,18 +1767,17 @@ void Cells4::_rebalance() * Removes any old segment that has not been touched for maxAge iterations and * where the number of connected synapses is less than activation threshold. */ -void Cells4::trimOldSegments(UInt maxAge) -{ +void Cells4::trimOldSegments(UInt maxAge) { UInt nSegsRemoved = 0; for (UInt cellIdx = 0; cellIdx != _nCells; ++cellIdx) { for (UInt segIdx = 0; segIdx != _cells[cellIdx].size(); ++segIdx) { - Segment& seg = segment(cellIdx, segIdx); + Segment &seg = segment(cellIdx, segIdx); UInt age = _nLrnIterations - seg._lastActiveIteration; - if ( (age > maxAge) && (seg.nConnected() < _activationThreshold) ) { + if ((age > maxAge) && (seg.nConnected() < _activationThreshold)) { static std::vector removedSynapses; - removedSynapses.clear(); // purge residual data + removedSynapses.clear(); // purge residual data for (UInt i = 0; i != seg.size(); ++i) removedSynapses.push_back(seg[i].srcCellIdx()); @@ -1837,20 +1785,17 @@ void Cells4::trimOldSegments(UInt maxAge) eraseOutSynapses(cellIdx, segIdx, removedSynapses); _cells[cellIdx].releaseSegment(segIdx); nSegsRemoved++; - } } } std::cout << "In trimOldSegments. Removed " << nSegsRemoved << " segments\n"; NTA_CHECK(invariants()); - } // Clear out and rebuild the entire _outSynapses data structure // This is useful if segments have changed. -void Cells4::rebuildOutSynapses() -{ +void Cells4::rebuildOutSynapses() { // TODO: Is this logic sufficient? _outSynapses.resize(_nCells); @@ -1863,7 +1808,7 @@ void Cells4::rebuildOutSynapses() // data structure for (UInt dstCellIdx = 0; dstCellIdx != _nCells; ++dstCellIdx) { for (UInt segIdx = 0; segIdx != _cells[dstCellIdx].size(); ++segIdx) { - const Segment& seg = _cells[dstCellIdx][segIdx]; + const Segment &seg = _cells[dstCellIdx][segIdx]; for (UInt synIdx = 0; synIdx != seg.size(); ++synIdx) { UInt srcCellIdx = seg.getSrcCellIdx(synIdx); OutSynapse newOutSyn(dstCellIdx, segIdx); @@ -1885,7 +1830,8 @@ void Cells4::rebuildOutSynapses() UInt destCol = (UInt) (syn.dstCellIdx() / _nCellsPerCol); UInt destCell = syn.dstCellIdx() - destCol*_nCellsPerCol; - std::cout << "\n [" << syn.dstCellIdx() << " : " << destCol << "," << destCell + std::cout << "\n [" << syn.dstCellIdx() << " : " << destCol << "," + << destCell << "] segment: " << syn.dstSegIdx() << ","; } @@ -1893,15 +1839,12 @@ void Cells4::rebuildOutSynapses() } std::cout << "\n"; */ - } - //-------------------------------------------------------------------------------- /** */ -void Cells4::reset() -{ +void Cells4::reset() { if (_verbosity >= 3) { std::cout << "\n==== RESET =====\n"; } @@ -1926,15 +1869,13 @@ void Cells4::reset() // Clear out input history _prevInfPatterns.clear(); _prevLrnPatterns.clear(); - //if (_nLrnIterations - _nIterationsSinceRebalance > 1000) { + // if (_nLrnIterations - _nIterationsSinceRebalance > 1000) { // //_rebalance(); //} } - //-------------------------------------------------------------------------------- -void Cells4::write(Cells4Proto::Builder& proto) const -{ +void Cells4::write(Cells4Proto::Builder &proto) const { proto.setVersion(version()); proto.setOwnsMemory(_ownsMemory); auto randomProto = proto.initRng(); @@ -1980,40 +1921,29 @@ void Cells4::write(Cells4Proto::Builder& proto) const _learnPredictedStateT1.write(learnPredictedStateT1Proto); auto cellListProto = proto.initCells(_nCells); - for (UInt i = 0; i < _nCells; ++i) - { + for (UInt i = 0; i < _nCells; ++i) { auto cellProto = cellListProto[i]; _cells[i].write(cellProto); } - auto segmentUpdatesListProto = proto.initSegmentUpdates( - _segmentUpdates.size()); - for (UInt i = 0; i < _segmentUpdates.size(); ++i) - { + auto segmentUpdatesListProto = + proto.initSegmentUpdates(_segmentUpdates.size()); + for (UInt i = 0; i < _segmentUpdates.size(); ++i) { auto segmentUpdateProto = segmentUpdatesListProto[i]; _segmentUpdates[i].write(segmentUpdateProto); } } - //-------------------------------------------------------------------------------- -void Cells4::read(Cells4Proto::Reader& proto) -{ +void Cells4::read(Cells4Proto::Reader &proto) { NTA_CHECK(proto.getVersion() == 2); - initialize(proto.getNColumns(), - proto.getNCellsPerCol(), - proto.getActivationThreshold(), - proto.getMinThreshold(), - proto.getNewSynapseCount(), - proto.getSegUpdateValidDuration(), - proto.getPermInitial(), - proto.getPermConnected(), - proto.getPermMax(), - proto.getPermDec(), - proto.getPermInc(), - proto.getGlobalDecay(), - proto.getDoPooling(), + initialize(proto.getNColumns(), proto.getNCellsPerCol(), + proto.getActivationThreshold(), proto.getMinThreshold(), + proto.getNewSynapseCount(), proto.getSegUpdateValidDuration(), + proto.getPermInitial(), proto.getPermConnected(), + proto.getPermMax(), proto.getPermDec(), proto.getPermInc(), + proto.getGlobalDecay(), proto.getDoPooling(), proto.getOwnsMemory()); auto randomProto = proto.getRng(); _rng.read(randomProto); @@ -2049,8 +1979,7 @@ void Cells4::read(Cells4Proto::Reader& proto) auto cellListProto = proto.getCells(); _nCells = cellListProto.size(); _cells.resize(_nCells); - for (UInt i = 0; i < cellListProto.size(); ++i) - { + for (UInt i = 0; i < cellListProto.size(); ++i) { auto cellProto = cellListProto[i]; _cells[i].read(cellProto); } @@ -2058,76 +1987,49 @@ void Cells4::read(Cells4Proto::Reader& proto) auto segmentUpdatesListProto = proto.getSegmentUpdates(); _segmentUpdates.clear(); _segmentUpdates.resize(segmentUpdatesListProto.size()); - for (UInt i = 0; i < segmentUpdatesListProto.size(); ++i) - { + for (UInt i = 0; i < segmentUpdatesListProto.size(); ++i) { auto segmentUpdateProto = segmentUpdatesListProto[i]; _segmentUpdates[i].read(segmentUpdateProto); } rebuildOutSynapses(); - if (_checkSynapseConsistency || (_nCells * _maxSegmentsPerCell < 100000)) - { + if (_checkSynapseConsistency || (_nCells * _maxSegmentsPerCell < 100000)) { NTA_CHECK(invariants(true)); } _version = VERSION; } - //-------------------------------------------------------------------------------- -void Cells4::save(std::ostream& outStream) const -{ +void Cells4::save(std::ostream &outStream) const { // Check invariants for smaller networks or if explicitly requested - if (_checkSynapseConsistency || (_nCells * _maxSegmentsPerCell < 100000) ) - { + if (_checkSynapseConsistency || (_nCells * _maxSegmentsPerCell < 100000)) { NTA_CHECK(invariants(true)); } - outStream << version() << " " - << _ownsMemory << " " - << _rng << " " - << _nColumns << " " - << _nCellsPerCol << " " - << _activationThreshold << " " - << _minThreshold << " " - << _newSynapseCount << " " - << _nIterations << " " - << _segUpdateValidDuration << " " - << _initSegFreq << " " - << _permInitial << " " - << _permConnected << " " - << _permMax << " " - << _permDec << " " - << _permInc << " " - << _globalDecay << " " - << _doPooling << " " - << _maxInfBacktrack << " " - << _maxLrnBacktrack << " " - << _pamLength << " " - << _maxAge << " " - << _avgInputDensity << " " - << _pamCounter << " " - << _maxSeqLength << " " - << _avgLearnedSeqLength << " " - << _nLrnIterations << " " - << _maxSegmentsPerCell << " " - << _maxSynapsesPerSegment << " " - << std::endl; + outStream << version() << " " << _ownsMemory << " " << _rng << " " + << _nColumns << " " << _nCellsPerCol << " " << _activationThreshold + << " " << _minThreshold << " " << _newSynapseCount << " " + << _nIterations << " " << _segUpdateValidDuration << " " + << _initSegFreq << " " << _permInitial << " " << _permConnected + << " " << _permMax << " " << _permDec << " " << _permInc << " " + << _globalDecay << " " << _doPooling << " " << _maxInfBacktrack + << " " << _maxLrnBacktrack << " " << _pamLength << " " << _maxAge + << " " << _avgInputDensity << " " << _pamCounter << " " + << _maxSeqLength << " " << _avgLearnedSeqLength << " " + << _nLrnIterations << " " << _maxSegmentsPerCell << " " + << _maxSynapsesPerSegment << " " << std::endl; // Additions in version 1. - outStream << _learnedSeqLength << " " - << _verbosity << " " - << _checkSynapseConsistency << " " - << _resetCalled << std::endl; - outStream << _learnActiveStateT << " " - << _learnActiveStateT1 << " " - << _learnPredictedStateT << " " - << _learnPredictedStateT1 << std::endl; + outStream << _learnedSeqLength << " " << _verbosity << " " + << _checkSynapseConsistency << " " << _resetCalled << std::endl; + outStream << _learnActiveStateT << " " << _learnActiveStateT1 << " " + << _learnPredictedStateT << " " << _learnPredictedStateT1 + << std::endl; // Additions in version 2. outStream << _segmentUpdates.size() << " "; - for (auto & elem : _segmentUpdates) - { + for (auto &elem : _segmentUpdates) { elem.save(outStream); } @@ -2144,9 +2046,9 @@ void Cells4::save(std::ostream& outStream) const /** * Save the state to the given file */ -void Cells4::saveToFile(std::string filePath) const -{ - OFStream outStream(filePath.c_str(), std::ios_base::out | std::ios_base::binary); +void Cells4::saveToFile(std::string filePath) const { + OFStream outStream(filePath.c_str(), + std::ios_base::out | std::ios_base::binary); // Request std::ios_base::failure exception upon logical or physical i/o error outStream.exceptions(std::ifstream::failbit | std::ifstream::badbit); @@ -2161,29 +2063,25 @@ void Cells4::saveToFile(std::string filePath) const /** * Load the state from the given file */ -void Cells4::loadFromFile(std::string filePath) -{ - IFStream outStream(filePath.c_str(), std::ios_base::in | std::ios_base::binary); +void Cells4::loadFromFile(std::string filePath) { + IFStream outStream(filePath.c_str(), + std::ios_base::in | std::ios_base::binary); load(outStream); } - - //------------------------------------------------------------------------------ /** * Need to load and re-propagate activities so that we can really persist * at any point, load back and resume inference at exactly the same point. */ -void Cells4::load(std::istream& inStream) -{ +void Cells4::load(std::istream &inStream) { std::string tag = ""; inStream >> tag; // If the checkpoint starts with "cellsV4" then it is the original, // otherwise the version is a UInt. UInt v = 0; std::stringstream ss; - if (tag != "cellsV4") - { + if (tag != "cellsV4") { ss << tag; ss >> v; } @@ -2196,68 +2094,38 @@ void Cells4::load(std::istream& inStream) inStream >> nColumns >> nCellsPerCol; - inStream >> _activationThreshold - >> _minThreshold - >> _newSynapseCount - >> nIterations - >> _segUpdateValidDuration - >> _initSegFreq - >> _permInitial - >> _permConnected - >> _permMax - >> _permDec - >> _permInc - >> _globalDecay - >> _doPooling; + inStream >> _activationThreshold >> _minThreshold >> _newSynapseCount >> + nIterations >> _segUpdateValidDuration >> _initSegFreq >> _permInitial >> + _permConnected >> _permMax >> _permDec >> _permInc >> _globalDecay >> + _doPooling; // TODO: clean up constructor/initialization and _segActivity below - initialize(nColumns, nCellsPerCol, - _activationThreshold, - _minThreshold, - _newSynapseCount, - _segUpdateValidDuration, - _permInitial, - _permConnected, - _permMax, - _permDec, - _permInc, - _globalDecay, - _doPooling, - _ownsMemory); + initialize(nColumns, nCellsPerCol, _activationThreshold, _minThreshold, + _newSynapseCount, _segUpdateValidDuration, _permInitial, + _permConnected, _permMax, _permDec, _permInc, _globalDecay, + _doPooling, _ownsMemory); _nIterations = nIterations; - inStream >> _maxInfBacktrack - >> _maxLrnBacktrack - >> _pamLength - >> _maxAge - >> _avgInputDensity - >> _pamCounter - >> _maxSeqLength - >> _avgLearnedSeqLength - >> _nLrnIterations - >> _maxSegmentsPerCell - >> _maxSynapsesPerSegment; - - if (v >= 1) - { - inStream >> _learnedSeqLength - >> _verbosity - >> _checkSynapseConsistency - >> _resetCalled; + inStream >> _maxInfBacktrack >> _maxLrnBacktrack >> _pamLength >> _maxAge >> + _avgInputDensity >> _pamCounter >> _maxSeqLength >> + _avgLearnedSeqLength >> _nLrnIterations >> _maxSegmentsPerCell >> + _maxSynapsesPerSegment; + + if (v >= 1) { + inStream >> _learnedSeqLength >> _verbosity >> _checkSynapseConsistency >> + _resetCalled; _learnActiveStateT.load(inStream); _learnActiveStateT1.load(inStream); _learnPredictedStateT.load(inStream); _learnPredictedStateT1.load(inStream); } - if (v >= 2) - { + if (v >= 2) { UInt n; _segmentUpdates.clear(); inStream >> n; - for (UInt i = 0; i < n; ++i) - { + for (UInt i = 0; i < n; ++i) { _segmentUpdates.push_back(SegmentUpdate()); _segmentUpdates[i].load(inStream); } @@ -2275,8 +2143,7 @@ void Cells4::load(std::istream& inStream) rebuildOutSynapses(); // Check invariants for smaller networks or if explicitly requested - if (_checkSynapseConsistency || (_nCells * _maxSegmentsPerCell < 100000) ) - { + if (_checkSynapseConsistency || (_nCells * _maxSegmentsPerCell < 100000)) { NTA_CHECK(invariants(true)); } @@ -2292,8 +2159,7 @@ void Cells4::load(std::istream& inStream) * each time synapses/segments are created/deleted. This test takes some time * but it's indispensable in development. */ -bool Cells4::invariants(bool verbose) const -{ +bool Cells4::invariants(bool verbose) const { using namespace std; set back_map; @@ -2319,7 +2185,7 @@ bool Cells4::invariants(bool verbose) const // Analyze InSynapses for (UInt j = 0; j != _cells[i].size(); ++j) { - const Segment& seg = _cells[i][j]; + const Segment &seg = _cells[i][j]; for (UInt k = 0; k != seg.size(); ++k) { @@ -2340,7 +2206,7 @@ bool Cells4::invariants(bool verbose) const // Analyze OutSynapses for (UInt j = 0; j != _outSynapses[i].size(); ++j) { - const OutSynapse& syn = _outSynapses[i][j]; + const OutSynapse &syn = _outSynapses[i][j]; stringstream buf; buf << syn.dstCellIdx() << '.' << syn.dstSegIdx() << '.' << i; @@ -2357,54 +2223,48 @@ bool Cells4::invariants(bool verbose) const consistent &= back_map == forward_map; if (!consistent) { - std::cout << "synapses inconsistent forward_map size=" - << forward_map.size() << " back_map size=" - << back_map.size() << std::endl; - //std::cout << "\nBack/forward maps: " + std::cout << "synapses inconsistent forward_map size=" << forward_map.size() + << " back_map size=" << back_map.size() << std::endl; + // std::cout << "\nBack/forward maps: " // << back_map.size() << " " << forward_map.size() << std::endl; - //set::iterator it1 = back_map.begin(); - //set::iterator it2 = forward_map.begin(); - //while (it1 != back_map.end() && it2 != forward_map.end()) + // set::iterator it1 = back_map.begin(); + // set::iterator it2 = forward_map.begin(); + // while (it1 != back_map.end() && it2 != forward_map.end()) // std::cout << *it1++ << " " << *it2++ << std::endl; - //while (it1 != back_map.end()) + // while (it1 != back_map.end()) // std::cout << *it1++ << std::endl; - //while (it2 != forward_map.end()) + // while (it2 != forward_map.end()) // std::cout << *it2++ << std::endl; } else { - //std::cout << "synapses consistent" << std::endl; + // std::cout << "synapses consistent" << std::endl; } return consistent; } - - -void -Cells4::addNewSegment(UInt colIdx, UInt cellIdxInCol, - bool sequenceSegmentFlag, - const std::vector >& extSynapses) -{ +void Cells4::addNewSegment( + UInt colIdx, UInt cellIdxInCol, bool sequenceSegmentFlag, + const std::vector> &extSynapses) { NTA_ASSERT(colIdx < nColumns()); NTA_ASSERT(cellIdxInCol < nCellsPerCol()); UInt cellIdx = colIdx * _nCellsPerCol + cellIdxInCol; static std::vector synapses; - synapses.resize(extSynapses.size()); // how many slots we need + synapses.resize(extSynapses.size()); // how many slots we need for (UInt i = 0; i != extSynapses.size(); ++i) synapses[i] = extSynapses[i].first * _nCellsPerCol + extSynapses[i].second; - SegmentUpdate update(cellIdx, (UInt) -1, sequenceSegmentFlag, - _nLrnIterations, synapses); // TODO: Add this for invariants check + SegmentUpdate update(cellIdx, (UInt)-1, sequenceSegmentFlag, _nLrnIterations, + synapses); // TODO: Add this for invariants check _segmentUpdates.push_back(update); } //-------------------------------------------------------------------------------- -void -Cells4::updateSegment(UInt colIdx, UInt cellIdxInCol, UInt segIdx, - const std::vector >& extSynapses) -{ +void Cells4::updateSegment( + UInt colIdx, UInt cellIdxInCol, UInt segIdx, + const std::vector> &extSynapses) { NTA_ASSERT(colIdx < nColumns()); NTA_ASSERT(cellIdxInCol < nCellsPerCol()); @@ -2412,7 +2272,7 @@ Cells4::updateSegment(UInt colIdx, UInt cellIdxInCol, UInt segIdx, bool sequenceSegmentFlag = segment(cellIdx, segIdx).isSequenceSegment(); static std::vector synapses; - synapses.resize(extSynapses.size()); // how many slots we need + synapses.resize(extSynapses.size()); // how many slots we need for (UInt i = 0; i != extSynapses.size(); ++i) synapses[i] = extSynapses[i].first * _nCellsPerCol + extSynapses[i].second; @@ -2426,47 +2286,35 @@ Cells4::updateSegment(UInt colIdx, UInt cellIdxInCol, UInt segIdx, /** * Simple helper function for allocating our numerous state variables */ -template void allocateState(It *&state, const UInt numElmts) -{ - state = new It [numElmts]; +template void allocateState(It *&state, const UInt numElmts) { + state = new It[numElmts]; memset(state, 0, numElmts * sizeof(It)); } -void Cells4::setCellSegmentOrder(bool matchPythonOrder) -{ +void Cells4::setCellSegmentOrder(bool matchPythonOrder) { Cell::setSegmentOrder(matchPythonOrder); } -void -Cells4::initialize(UInt nColumns, - UInt nCellsPerCol, - UInt activationThreshold, - UInt minThreshold, - UInt newSynapseCount, - UInt segUpdateValidDuration, - Real permInitial, - Real permConnected, - Real permMax, - Real permDec, - Real permInc, - Real globalDecay, - bool doPooling, - bool initFromCpp, - bool checkSynapseConsistency) -{ - _nColumns = nColumns; - _nCellsPerCol = nCellsPerCol; - _nCells = nColumns * nCellsPerCol; +void Cells4::initialize(UInt nColumns, UInt nCellsPerCol, + UInt activationThreshold, UInt minThreshold, + UInt newSynapseCount, UInt segUpdateValidDuration, + Real permInitial, Real permConnected, Real permMax, + Real permDec, Real permInc, Real globalDecay, + bool doPooling, bool initFromCpp, + bool checkSynapseConsistency) { + _nColumns = nColumns; + _nCellsPerCol = nCellsPerCol; + _nCells = nColumns * nCellsPerCol; NTA_CHECK(_nCells <= _MAX_CELLS); - _activationThreshold = activationThreshold; - _minThreshold = minThreshold; - _newSynapseCount = newSynapseCount; - _segUpdateValidDuration = segUpdateValidDuration; + _activationThreshold = activationThreshold; + _minThreshold = minThreshold; + _newSynapseCount = newSynapseCount; + _segUpdateValidDuration = segUpdateValidDuration; - _initSegFreq = 0.5; - _permInitial = permInitial; - _permConnected = permConnected; + _initSegFreq = 0.5; + _permInitial = permInitial; + _permConnected = permConnected; _permMax = permMax; _permDec = permDec; _permInc = permInc; @@ -2478,7 +2326,7 @@ Cells4::initialize(UInt nColumns, _nIterations = 0; _nLrnIterations = 0; - _pamCounter = _pamLength+1; + _pamCounter = _pamLength + 1; _maxInfBacktrack = 10; _maxLrnBacktrack = 5; _maxSeqLength = 0; @@ -2504,12 +2352,11 @@ Cells4::initialize(UInt nColumns, _infActiveStateT1.initialize(_nCells); _infPredictedStateT.initialize(_nCells); _infPredictedStateT1.initialize(_nCells); - allocateState(_cellConfidenceT, _nCells); - allocateState(_cellConfidenceT1, _nCells); - allocateState(_colConfidenceT, _nColumns); - allocateState(_colConfidenceT1, _nColumns); - } - else { + allocateState(_cellConfidenceT, _nCells); + allocateState(_cellConfidenceT1, _nCells); + allocateState(_colConfidenceT, _nColumns); + allocateState(_colConfidenceT1, _nColumns); + } else { _ownsMemory = false; } @@ -2522,9 +2369,9 @@ Cells4::initialize(UInt nColumns, _infPredictedBackup.initialize(_nCells); _infActiveStateCandidate.initialize(_nCells); _infPredictedStateCandidate.initialize(_nCells); - allocateState(_cellConfidenceCandidate, _nCells); - allocateState(_colConfidenceCandidate, _nColumns); - allocateState(_tmpInputBuffer, _nColumns); + allocateState(_cellConfidenceCandidate, _nCells); + allocateState(_colConfidenceCandidate, _nColumns); + allocateState(_tmpInputBuffer, _nColumns); // Internal timings and states used for optimization _nIterationsSinceRebalance = 0; @@ -2532,67 +2379,55 @@ Cells4::initialize(UInt nColumns, _checkSynapseConsistency = checkSynapseConsistency; if (_checkSynapseConsistency) std::cout << "*** Synapse consistency checking turned on for Cells4 ***\n"; - } -UInt Cells4::nSegments() const -{ +UInt Cells4::nSegments() const { UInt n = 0; for (UInt i = 0; i != _nCells; ++i) n += _cells[i].nSegments(); return n; } -UInt Cells4::__nSegmentsOnCell(UInt cellIdx) const -{ +UInt Cells4::__nSegmentsOnCell(UInt cellIdx) const { NTA_ASSERT(cellIdx < _nCells); return _cells[cellIdx].size(); } -UInt Cells4::nSegmentsOnCell(UInt colIdx, UInt cellIdxInCol) const -{ +UInt Cells4::nSegmentsOnCell(UInt colIdx, UInt cellIdxInCol) const { NTA_ASSERT(colIdx < nColumns()); NTA_ASSERT(cellIdxInCol < nCellsPerCol()); return _cells[colIdx * nCellsPerCol() + cellIdxInCol].nSegments(); } - -UInt Cells4::nSynapses() const -{ +UInt Cells4::nSynapses() const { UInt n = 0; for (UInt i = 0; i != _nCells; ++i) n += _cells[i].nSynapses(); return n; } - -UInt Cells4::nSynapsesInCell(UInt cellIdx) const -{ +UInt Cells4::nSynapsesInCell(UInt cellIdx) const { NTA_ASSERT(cellIdx < nCells()); return _cells[cellIdx].nSynapses(); } -Cell* Cells4::getCell(UInt colIdx, UInt cellIdxInCol) -{ +Cell *Cells4::getCell(UInt colIdx, UInt cellIdxInCol) { NTA_ASSERT(colIdx < nColumns()); NTA_ASSERT(cellIdxInCol < nCellsPerCol()); - return & _cells[colIdx * _nCellsPerCol + cellIdxInCol]; + return &_cells[colIdx * _nCellsPerCol + cellIdxInCol]; } -UInt Cells4::getCellIdx(UInt colIdx, UInt cellIdxInCol) -{ +UInt Cells4::getCellIdx(UInt colIdx, UInt cellIdxInCol) { NTA_ASSERT(colIdx < nColumns()); NTA_ASSERT(cellIdxInCol < nCellsPerCol()); return colIdx * _nCellsPerCol + cellIdxInCol; } -Segment* -Cells4::getSegment(UInt colIdx, UInt cellIdxInCol, UInt segIdx) -{ +Segment *Cells4::getSegment(UInt colIdx, UInt cellIdxInCol, UInt segIdx) { NTA_ASSERT(colIdx < nColumns()); NTA_ASSERT(cellIdxInCol < nCellsPerCol()); @@ -2600,22 +2435,17 @@ Cells4::getSegment(UInt colIdx, UInt cellIdxInCol, UInt segIdx) NTA_ASSERT(segIdx < _cells[cellIdx].size()); - return & segment(cellIdx, segIdx); + return &segment(cellIdx, segIdx); } - -Segment& Cells4::segment(UInt cellIdx, UInt segIdx) -{ +Segment &Cells4::segment(UInt cellIdx, UInt segIdx) { NTA_ASSERT(cellIdx < nCells()); NTA_ASSERT(segIdx < _cells[cellIdx].size()); return _cells[cellIdx][segIdx]; } - -std::vector -Cells4::getNonEmptySegList(UInt colIdx, UInt cellIdxInCol) -{ +std::vector Cells4::getNonEmptySegList(UInt colIdx, UInt cellIdxInCol) { NTA_ASSERT(colIdx < nColumns()); NTA_ASSERT(cellIdxInCol < nCellsPerCol()); @@ -2628,83 +2458,82 @@ Cells4::getNonEmptySegList(UInt colIdx, UInt cellIdxInCol) /** * Find weakly activated cell in column. */ -std::pair Cells4::getBestMatchingCellT(UInt colIdx, const CState& state, UInt minThreshold) -{ - { - NTA_ASSERT(colIdx < nColumns()); - } +std::pair Cells4::getBestMatchingCellT(UInt colIdx, + const CState &state, + UInt minThreshold) { + { NTA_ASSERT(colIdx < nColumns()); } - int start = colIdx * _nCellsPerCol, - end = start + _nCellsPerCol; + int start = colIdx * _nCellsPerCol, end = start + _nCellsPerCol; UInt best_cell = UInt(-1); - UInt best_seg = UInt(-1); + UInt best_seg = UInt(-1); UInt best_activity = minThreshold > 0 ? minThreshold - 1 : 0; // For each cell in the column - for (int ii = end - 1; ii >= start; --ii) { // reverse segment order to match Python logic + for (int ii = end - 1; ii >= start; + --ii) { // reverse segment order to match Python logic UInt i = UInt(ii); // Check synapse consistency for each segment if requested if (_checkSynapseConsistency) { for (UInt j = 0; j != _cells[i].size(); ++j) { - NTA_CHECK( segment(i,j).computeActivity(state, _permConnected, false) == _learnActivity.get(i, j) ); + NTA_CHECK(segment(i, j).computeActivity(state, _permConnected, false) == + _learnActivity.get(i, j)); } } - if (_learnActivity.get(i) > best_activity) { // if this cell may have a worthy segment + if (_learnActivity.get(i) > + best_activity) { // if this cell may have a worthy segment for (UInt j = 0; j != _cells[i].size(); ++j) { // Open: Does _cells[i].size() vary? UInt activity = _learnActivity.get(i, j); - if (best_activity < activity) { // if a new maximum - best_activity = activity; // set the new maximum - best_cell = i; // remember the cell - best_seg = j; // remember the segment + if (best_activity < activity) { // if a new maximum + best_activity = activity; // set the new maximum + best_cell = i; // remember the cell + best_seg = j; // remember the segment } if (_verbosity >= 6 && activity >= minThreshold) { std::cout << "getBestMatchingCell, learning on col=" << colIdx << ", segment: "; _cells[i][j].print(std::cout, _nCellsPerCol); std::cout << "\n"; - std::cout << "activity = " << activity << ", maxSegActivity = " - << best_activity << "\n"; + std::cout << "activity = " << activity + << ", maxSegActivity = " << best_activity << "\n"; } } // for each segment in cell } // if this cell may have a segment with a new maximum - } // for each cell in the column + } // for each cell in the column - return std::make_pair(best_cell, best_seg); // could be (-1,-1) + return std::make_pair(best_cell, best_seg); // could be (-1,-1) } //---------------------------------------------------------------------- /** * Find weakly activated cell in column. */ -std::pair Cells4::getBestMatchingCellT1(UInt colIdx, const CState& state, UInt minThreshold) -{ - { - NTA_ASSERT(colIdx < nColumns()); - } +std::pair Cells4::getBestMatchingCellT1(UInt colIdx, + const CState &state, + UInt minThreshold) { + { NTA_ASSERT(colIdx < nColumns()); } UInt start = colIdx * _nCellsPerCol, end = start + _nCellsPerCol; - UInt best_cell = (UInt) -1; - std::pair best((UInt) -1, minThreshold); + UInt best_cell = (UInt)-1; + std::pair best((UInt)-1, minThreshold); // For each cell in column for (UInt i = start; i != end; ++i) { UInt maxSegActivity = 0, maxSegIdx = 0, activity = 0; - for (UInt j = 0; j != _cells[i].size(); ++j) { - if (segment(i,j).empty()) + if (segment(i, j).empty()) continue; - activity = segment(i,j).computeActivity(state, _permConnected, false); + activity = segment(i, j).computeActivity(state, _permConnected, false); if (activity > maxSegActivity) { maxSegActivity = activity; @@ -2715,8 +2544,8 @@ std::pair Cells4::getBestMatchingCellT1(UInt colIdx, const CState& s << ", segment: "; _cells[i][j].print(std::cout, _nCellsPerCol); std::cout << "\n"; - std::cout << "activity = " << activity << ", maxSegActivity = " - << maxSegActivity << "\n"; + std::cout << "activity = " << activity + << ", maxSegActivity = " << maxSegActivity << "\n"; } } // for each segment in cell @@ -2728,84 +2557,80 @@ std::pair Cells4::getBestMatchingCellT1(UInt colIdx, const CState& s } // for each cell in column - if (best_cell != (UInt) -1) + if (best_cell != (UInt)-1) return std::make_pair(best_cell, best.first); else - return std::make_pair((UInt) -1, (UInt) -1); + return std::make_pair((UInt)-1, (UInt)-1); } -void -Cells4::chooseCellsToLearnFrom(UInt cellIdx, UInt segIdx, - UInt nSynToAdd, CStateIndexed& state, - std::vector& srcCells) -{ +void Cells4::chooseCellsToLearnFrom(UInt cellIdx, UInt segIdx, UInt nSynToAdd, + CStateIndexed &state, + std::vector &srcCells) { // bail out if no cells requested if (nSynToAdd == 0) return; TIMER(chooseCellsTimer.start()); - // start with a sorted vector of all the cells that are on in the current state + // start with a sorted vector of all the cells that are on in the current + // state static std::vector vecCellBuffer; vecCellBuffer = state.cellsOn(true); // remove any cells already in this segment static std::vector vecPruned; - if (segIdx != (UInt) -1) { + if (segIdx != (UInt)-1) { // collect the sorted list of source cell indices Segment segThis = _cells[cellIdx][segIdx]; static std::vector vecAlreadyHave; if (vecAlreadyHave.capacity() < segThis.size()) vecAlreadyHave.reserve(segThis.size()); - vecAlreadyHave.clear(); // purge residual data + vecAlreadyHave.clear(); // purge residual data for (UInt i = 0; i != segThis.size(); ++i) vecAlreadyHave.push_back(segThis[i].srcCellIdx()); // remove any of these found in vecCellBuffer - if (vecPruned.size() < vecCellBuffer.size()) // ensure there is enough room for the results + if (vecPruned.size() < + vecCellBuffer.size()) // ensure there is enough room for the results vecPruned.resize(vecCellBuffer.size()); std::vector::iterator iterPruned; - iterPruned = std::set_difference(vecCellBuffer.begin(), - vecCellBuffer.end(), + iterPruned = std::set_difference(vecCellBuffer.begin(), vecCellBuffer.end(), vecAlreadyHave.begin(), - vecAlreadyHave.end(), - vecPruned.begin()); + vecAlreadyHave.end(), vecPruned.begin()); vecPruned.resize(iterPruned - vecPruned.begin()); - } - else { + } else { vecPruned = vecCellBuffer; } const UInt nbrCells = vecPruned.size(); // bail out if there are no cells left to process if (nbrCells == 0) { - TIMER(chooseCellsTimer.stop()); // turn off timer + TIMER(chooseCellsTimer.stop()); // turn off timer return; } // if we found fewer cells than requested, return all of them // The new ones are sorted, but we need to sort again if there were // any old ones. - bool fSortNeeded = !srcCells.empty(); // may be overridden below + bool fSortNeeded = !srcCells.empty(); // may be overridden below if (nbrCells <= nSynToAdd) { // since we use all of vecPruned, we don't need a random number srcCells.reserve(nbrCells + srcCells.size()); std::vector::iterator iterCellBuffer; - for (iterCellBuffer = vecPruned.begin(); iterCellBuffer != vecPruned.end(); ++iterCellBuffer) { + for (iterCellBuffer = vecPruned.begin(); iterCellBuffer != vecPruned.end(); + ++iterCellBuffer) { srcCells.push_back(*iterCellBuffer); } - } - else if (nSynToAdd == 1) { + } else if (nSynToAdd == 1) { // if just one cell requested, choose one at random srcCells.push_back(vecPruned[_rng.getUInt32(nbrCells)]); - } - else { + } else { // choose a random subset of the cells found, and append them to the // caller's array UInt start = srcCells.size(); srcCells.resize(srcCells.size() + nSynToAdd); - _rng.sample(&vecPruned.front(), vecPruned.size(), - &srcCells[start], nSynToAdd); + _rng.sample(&vecPruned.front(), vecPruned.size(), &srcCells[start], + nSynToAdd); fSortNeeded = true; } @@ -2818,10 +2643,8 @@ Cells4::chooseCellsToLearnFrom(UInt cellIdx, UInt segIdx, TIMER(chooseCellsTimer.stop()); } - std::pair Cells4::trimSegments(Real minPermanence, - UInt minNumSyns) -{ + UInt minNumSyns) { UInt nSegsRemoved = 0, nSynsRemoved = 0; // Fill in defaults @@ -2834,13 +2657,11 @@ std::pair Cells4::trimSegments(Real minPermanence, for (UInt segIdx = 0; segIdx != _cells[cellIdx].size(); ++segIdx) { static std::vector removedSynapses; - removedSynapses.clear(); // purge residual data + removedSynapses.clear(); // purge residual data - Segment& seg = segment(cellIdx, segIdx); + Segment &seg = segment(cellIdx, segIdx); - seg.decaySynapses(minPermanence, - removedSynapses, - minPermanence, false); + seg.decaySynapses(minPermanence, removedSynapses, minPermanence, false); if (seg.size() < minNumSyns) { @@ -2850,7 +2671,7 @@ std::pair Cells4::trimSegments(Real minPermanence, eraseOutSynapses(cellIdx, segIdx, removedSynapses); _cells[cellIdx].releaseSegment(segIdx); - ++ nSegsRemoved; + ++nSegsRemoved; } else { eraseOutSynapses(cellIdx, segIdx, removedSynapses); @@ -2871,8 +2692,7 @@ std::pair Cells4::trimSegments(Real minPermanence, // Debugging helpers //---------------------------------------------------------------------- -void Cells4::printState(UInt *state) -{ +void Cells4::printState(UInt *state) { for (UInt i = 0; i != nCellsPerCol(); ++i) { for (UInt c = 0; c != nColumns(); ++c) { if (c > 0 && c % 10 == 0) @@ -2884,8 +2704,7 @@ void Cells4::printState(UInt *state) } } -void Cells4::printStates() -{ +void Cells4::printStates() { // Print out active state for debugging if (true) { std::cout << "TP10X: Active T-1 \t T\n"; @@ -2947,8 +2766,7 @@ void Cells4::printStates() } } -void Cells4::dumpSegmentUpdates() -{ +void Cells4::dumpSegmentUpdates() { std::cout << _segmentUpdates.size() << " updates" << std::endl; for (UInt i = 0; i != _segmentUpdates.size(); ++i) { _segmentUpdates[i].print(std::cout, true); @@ -2957,21 +2775,18 @@ void Cells4::dumpSegmentUpdates() } // Print input pattern queue -void Cells4::dumpPrevPatterns(std::deque > &patterns) -{ +void Cells4::dumpPrevPatterns(std::deque> &patterns) { for (UInt p = 0; p < patterns.size(); p++) { std::cout << "Pattern " << p << ": "; - for (auto & elem : patterns[p]) { + for (auto &elem : patterns[p]) { std::cout << elem << " "; } std::cout << std::endl; } std::cout << std::endl; - } -void Cells4::print(std::ostream& outStream) const -{ +void Cells4::print(std::ostream &outStream) const { for (UInt i = 0; i != _nCells; ++i) { std::cout << "Cell #" << i << " "; for (UInt j = 0; j != _cells[i].size(); ++j) { @@ -2981,8 +2796,7 @@ void Cells4::print(std::ostream& outStream) const } } -std::ostream& operator<<(std::ostream& outStream, const Cells4& cells) -{ +std::ostream &operator<<(std::ostream &outStream, const Cells4 &cells) { cells.print(outStream); return outStream; } @@ -2992,8 +2806,7 @@ std::ostream& operator<<(std::ostream& outStream, const Cells4& cells) * Compute cell and segment activities using forward propagation * and the given state variable. */ -void Cells4::computeForwardPropagation(CStateIndexed& state) -{ +void Cells4::computeForwardPropagation(CStateIndexed &state) { // Zero out previous values // Using memset is quite a bit faster on laptops, but has almost no effect // on Neo15! @@ -3004,11 +2817,12 @@ void Cells4::computeForwardPropagation(CStateIndexed& state) // activity coming into a cell. // process all cells that are on in the current state - static std::vector vecCellBuffer ; + static std::vector vecCellBuffer; vecCellBuffer = state.cellsOn(); std::vector::iterator iterCellBuffer; - for (iterCellBuffer = vecCellBuffer.begin(); iterCellBuffer != vecCellBuffer.end(); ++iterCellBuffer) { - std::vector< OutSynapse >& os = _outSynapses[*iterCellBuffer]; + for (iterCellBuffer = vecCellBuffer.begin(); + iterCellBuffer != vecCellBuffer.end(); ++iterCellBuffer) { + std::vector &os = _outSynapses[*iterCellBuffer]; for (UInt j = 0; j != os.size(); ++j) { UInt dstCellIdx = os[j].dstCellIdx(); UInt dstSegIdx = os[j].dstSegIdx(); @@ -3029,8 +2843,7 @@ void Cells4::computeForwardPropagation(CStateIndexed& state) * to move all state array modifications from Python to C++. One * known offender is TP.py. */ -void Cells4::computeForwardPropagation(CState& state) -{ +void Cells4::computeForwardPropagation(CState &state) { // Zero out previous values // Using memset is quite a bit faster on laptops, but has almost no effect // on Neo15! @@ -3040,13 +2853,13 @@ void Cells4::computeForwardPropagation(CState& state) // links from each source cell. _cellActivity will be set to the total // activity coming into a cell. #ifdef NTA_ARCH_64 - const UInt multipleOf8 = 8 * (_nCells/8); + const UInt multipleOf8 = 8 * (_nCells / 8); UInt i; for (i = 0; i < multipleOf8; i += 8) { - UInt64 eightStates = * (UInt64 *)(state.arrayPtr() + i); - for (int k = 0; eightStates != 0 && k < 8; eightStates >>= 8, k++) { + UInt64 eightStates = *(UInt64 *)(state.arrayPtr() + i); + for (int k = 0; eightStates != 0 && k < 8; eightStates >>= 8, k++) { if ((eightStates & 0xff) != 0) { - std::vector< OutSynapse >& os = _outSynapses[i + k]; + std::vector &os = _outSynapses[i + k]; for (UInt j = 0; j != os.size(); ++j) { UInt dstCellIdx = os[j].dstCellIdx(); UInt dstSegIdx = os[j].dstSegIdx(); @@ -3059,7 +2872,7 @@ void Cells4::computeForwardPropagation(CState& state) // process the tail if (_nCells % 8) != 0 for (i = multipleOf8; i < _nCells; i++) { if (state.isSet(i)) { - std::vector< OutSynapse >& os = _outSynapses[i]; + std::vector &os = _outSynapses[i]; for (UInt j = 0; j != os.size(); ++j) { UInt dstCellIdx = os[j].dstCellIdx(); UInt dstSegIdx = os[j].dstSegIdx(); @@ -3068,13 +2881,13 @@ void Cells4::computeForwardPropagation(CState& state) } } #else - const UInt multipleOf4 = 4 * (_nCells/4); + const UInt multipleOf4 = 4 * (_nCells / 4); UInt i; for (i = 0; i < multipleOf4; i += 4) { - UInt32 fourStates = * (UInt32 *)(state.arrayPtr() + i); - for (int k = 0; fourStates != 0 && k < 4; fourStates >>= 8, k++) { + UInt32 fourStates = *(UInt32 *)(state.arrayPtr() + i); + for (int k = 0; fourStates != 0 && k < 4; fourStates >>= 8, k++) { if ((fourStates & 0xff) != 0) { - std::vector< OutSynapse >& os = _outSynapses[i + k]; + std::vector &os = _outSynapses[i + k]; for (UInt j = 0; j != os.size(); ++j) { UInt dstCellIdx = os[j].dstCellIdx(); UInt dstSegIdx = os[j].dstSegIdx(); @@ -3087,7 +2900,7 @@ void Cells4::computeForwardPropagation(CState& state) // process the tail if (_nCells % 4) != 0 for (i = multipleOf4; i < _nCells; i++) { if (state.isSet(i)) { - std::vector< OutSynapse >& os = _outSynapses[i]; + std::vector &os = _outSynapses[i]; for (UInt j = 0; j != os.size(); ++j) { UInt dstCellIdx = os[j].dstCellIdx(); UInt dstSegIdx = os[j].dstSegIdx(); @@ -3097,20 +2910,18 @@ void Cells4::computeForwardPropagation(CState& state) } #endif // NTA_ARCH_32/64 } -#endif // SOME_STATES_NOT_INDEXED - +#endif // SOME_STATES_NOT_INDEXED //-------------------------------------------------------------------------------- // Dump detailed Cells4 timing report to stdout //-------------------------------------------------------------------------------- -void Cells4::dumpTiming() -{ +void Cells4::dumpTiming() { #ifdef CELLS4_TIMING Real64 learnTime = learningTimer.getElapsed(), inferenceTime = inferenceTimer.getElapsed(); - std::cout << "Total time in compute: " << computeTimer.toString() << "\n"; - std::cout << "Total time in learning: " << learningTimer.toString() << "\n"; + std::cout << "Total time in compute: " << computeTimer.toString() << "\n"; + std::cout << "Total time in learning: " << learningTimer.toString() << "\n"; std::cout << "Total time in inference: " << inferenceTimer.toString() << "\n"; std::cout << "\n\nLearning breakdown:" << std::endl; @@ -3133,9 +2944,8 @@ void Cells4::dumpTiming() << std::setprecision(3) << 100.0 * chooseCellsTimer.getElapsed() / learnTime << "%\n"; std::cout << "adaptSegment: " << adaptSegmentTimer.toString() << " " - << std::setprecision(3) - << 100.0 * adaptSegmentTimer.getElapsed() / learnTime - << "%\n"; + << std::setprecision(3) + << 100.0 * adaptSegmentTimer.getElapsed() / learnTime << "%\n"; std::cout << "Note: % is percentage of learning time\n"; std::cout << "\n\nInference breakdown:" << std::endl; @@ -3150,7 +2960,8 @@ void Cells4::dumpTiming() << 100.0 * infBacktrackTimer.getElapsed() / inferenceTime << "%\n"; std::cout << "Forward prop: " << forwardInfPropTimer.toString() << " " << std::setprecision(3) - << 100.0 * forwardInfPropTimer.getElapsed() / inferenceTime << "%\n"; + << 100.0 * forwardInfPropTimer.getElapsed() / inferenceTime + << "%\n"; std::cout << "Note: % is percentage of inference time\n"; #endif @@ -3159,8 +2970,7 @@ void Cells4::dumpTiming() //-------------------------------------------------------------------------------- // Reset timers and counters to 0 //-------------------------------------------------------------------------------- -void Cells4::resetTimers() -{ +void Cells4::resetTimers() { #ifdef CELLS4_TIMING computeTimer.reset(); inferenceTimer.reset(); diff --git a/src/nupic/algorithms/Cells4.hpp b/src/nupic/algorithms/Cells4.hpp index a8df99d30b..fa9f78f168 100644 --- a/src/nupic/algorithms/Cells4.hpp +++ b/src/nupic/algorithms/Cells4.hpp @@ -23,17 +23,16 @@ #ifndef NTA_Cells4_HPP #define NTA_Cells4_HPP -#include -#include +#include #include -#include #include +#include #include #include #include +#include #include -#include - +#include //----------------------------------------------------------------------- /** @@ -68,1142 +67,1079 @@ */ namespace nupic { - namespace algorithms { - namespace Cells4 { - - class Cell; - class Cells4; - class SegmentUpdate; - - /** - * Class CBasicActivity: - * Manage activity counters - * - * This class is used by CCellSegActivity. The counters stay well - * below 255, allowing us to use UChar elements. The biggest we - * have seen is 33. More important than the raw memory utilization - * is the reduced pressure on L2 cache. To see the difference, - * benchmark this version, then try again after changing - * - * CCellSegActivity _learnActivity; - * CCellSegActivity _inferActivity; - * - * to - * - * CCellSegActivity _learnActivity; - * CCellSegActivity _inferActivity; - * - * We leave this class and CCellSegActivity templated to simplify - * such testing. - * - * While we typically test on just one core, our production - * configuration may run one engine on each core, thereby increasing - * the pressure on L2. - * - * Counts are collected in one function, following a reset, and - * used in another. - * - * Collected in Used in - * _learnActivity computeForwardPropagation(CStateIndexed& state) getBestMatchingCellT - * _inferActivity computeForwardPropagation(CState& state) inferPhase2 - * - * The _segment counts are the ones that matter. The _cell counts - * are an optimization technique. They track the maximum count - * for all segments in that cell. Since segment counts are - * interesting only if they exceed a threshold, we can skip all of - * a cell's segments when the maximum is too small. - * - * Repeatedly resetting all the counters in large sparse arrays - * can be costly, and much of the work is unnecessary when most - * counters are already zero. To address this, we track which - * array elements are nonzero, and at reset time zero only those. - * If an array is not so sparse, this selective zeroing may be - * slower than a full memset(). We arbitrarily choose a threshold - * of 6.25%, past which we use memset() instead of selective - * zeroing. - */ - const UInt _MAX_CELLS = 1 << 18; // power of 2 allows efficient array indexing - const UInt _MAX_SEGS = 1 << 7; // power of 2 allows efficient array indexing - typedef unsigned char UChar; // custom type, since NTA_Byte = Byte is signed - - template - class CBasicActivity - { - public: - CBasicActivity() - { - _counter = nullptr; - _nonzero = nullptr; - _size = 0; - _dimension = 0; - } - ~CBasicActivity() - { - if (_counter != nullptr) - delete [] _counter; - if (_nonzero != nullptr) - delete [] _nonzero; - } - void initialize(UInt n) - { - if (_counter != nullptr) - delete [] _counter; - if (_nonzero != nullptr) - delete [] _nonzero; - _counter = new It[n]; // use typename here - memset(_counter, 0, n * sizeof(_counter[0])); - _nonzero = new UInt[n]; - _size = 0; - _dimension = n; - } - UInt get(UInt cellIdx) - { - return _counter[cellIdx]; - } - void add(UInt cellIdx, UInt incr) - { - // currently unused, but may need to resurrect - if (_counter[cellIdx] == 0) - _nonzero[_size++] = cellIdx; - _counter[cellIdx] += incr; - } - It increment(UInt cellIdx) // use typename here - { - // In the learning phase, the activity count appears never to - // reach 255. Is this a safe assumption? - if (_counter[cellIdx] != 0) - return ++_counter[cellIdx]; - _counter[cellIdx] = 1; // without this, the inefficient compiler reloads the value from memory, increments it and stores it back - _nonzero[_size++] = cellIdx; - return 1; - } - void max(UInt cellIdx, It val) // use typename here - { - const It curr = _counter[cellIdx]; // use typename here - if (val > curr) { - _counter[cellIdx] = val; - if (curr == 0) - _nonzero[_size++] = cellIdx; - } - } - void reset() - { +namespace algorithms { +namespace Cells4 { + +class Cell; +class Cells4; +class SegmentUpdate; + +/** + * Class CBasicActivity: + * Manage activity counters + * + * This class is used by CCellSegActivity. The counters stay well + * below 255, allowing us to use UChar elements. The biggest we + * have seen is 33. More important than the raw memory utilization + * is the reduced pressure on L2 cache. To see the difference, + * benchmark this version, then try again after changing + * + * CCellSegActivity _learnActivity; + * CCellSegActivity _inferActivity; + * + * to + * + * CCellSegActivity _learnActivity; + * CCellSegActivity _inferActivity; + * + * We leave this class and CCellSegActivity templated to simplify + * such testing. + * + * While we typically test on just one core, our production + * configuration may run one engine on each core, thereby increasing + * the pressure on L2. + * + * Counts are collected in one function, following a reset, and + * used in another. + * + * Collected in Used in + * _learnActivity computeForwardPropagation(CStateIndexed& state) + * getBestMatchingCellT _inferActivity computeForwardPropagation(CState& + * state) inferPhase2 + * + * The _segment counts are the ones that matter. The _cell counts + * are an optimization technique. They track the maximum count + * for all segments in that cell. Since segment counts are + * interesting only if they exceed a threshold, we can skip all of + * a cell's segments when the maximum is too small. + * + * Repeatedly resetting all the counters in large sparse arrays + * can be costly, and much of the work is unnecessary when most + * counters are already zero. To address this, we track which + * array elements are nonzero, and at reset time zero only those. + * If an array is not so sparse, this selective zeroing may be + * slower than a full memset(). We arbitrarily choose a threshold + * of 6.25%, past which we use memset() instead of selective + * zeroing. + */ +const UInt _MAX_CELLS = 1 << 18; // power of 2 allows efficient array indexing +const UInt _MAX_SEGS = 1 << 7; // power of 2 allows efficient array indexing +typedef unsigned char UChar; // custom type, since NTA_Byte = Byte is signed + +template class CBasicActivity { +public: + CBasicActivity() { + _counter = nullptr; + _nonzero = nullptr; + _size = 0; + _dimension = 0; + } + ~CBasicActivity() { + if (_counter != nullptr) + delete[] _counter; + if (_nonzero != nullptr) + delete[] _nonzero; + } + void initialize(UInt n) { + if (_counter != nullptr) + delete[] _counter; + if (_nonzero != nullptr) + delete[] _nonzero; + _counter = new It[n]; // use typename here + memset(_counter, 0, n * sizeof(_counter[0])); + _nonzero = new UInt[n]; + _size = 0; + _dimension = n; + } + UInt get(UInt cellIdx) { return _counter[cellIdx]; } + void add(UInt cellIdx, UInt incr) { + // currently unused, but may need to resurrect + if (_counter[cellIdx] == 0) + _nonzero[_size++] = cellIdx; + _counter[cellIdx] += incr; + } + It increment(UInt cellIdx) // use typename here + { + // In the learning phase, the activity count appears never to + // reach 255. Is this a safe assumption? + if (_counter[cellIdx] != 0) + return ++_counter[cellIdx]; + _counter[cellIdx] = + 1; // without this, the inefficient compiler reloads the value from + // memory, increments it and stores it back + _nonzero[_size++] = cellIdx; + return 1; + } + void max(UInt cellIdx, It val) // use typename here + { + const It curr = _counter[cellIdx]; // use typename here + if (val > curr) { + _counter[cellIdx] = val; + if (curr == 0) + _nonzero[_size++] = cellIdx; + } + } + void reset() { #define REPORT_ACTIVITY_STATISTICS 0 #if REPORT_ACTIVITY_STATISTICS - // report the statistics for this table - // Without a high water counter, we can't tell for sure if a - // UChar counter overflowed, but it's likely there was no - // overflow if all the other counters are below, say, 200. - if (_size == 0) { - std::cout << "Reset width=" << sizeof(It) << " all zeroes" << std::endl; - } - else { - static std::vector vectStat; - vectStat.clear(); - UInt ndxStat; - for (ndxStat = 0; ndxStat < _size; ndxStat++) - vectStat.push_back(_counter[_nonzero[ndxStat]]); - std::sort(vectStat.begin(), vectStat.end()); - std::cout << "Reset width=" << sizeof(It) - << " size=" << _dimension - << " nonzero=" << _size - << " min=" << UInt(vectStat.front()) - << " max=" << UInt(vectStat.back()) - << " med=" << UInt(vectStat[_size/2]) - << std::endl; - } + // report the statistics for this table + // Without a high water counter, we can't tell for sure if a + // UChar counter overflowed, but it's likely there was no + // overflow if all the other counters are below, say, 200. + if (_size == 0) { + std::cout << "Reset width=" << sizeof(It) << " all zeroes" << std::endl; + } else { + static std::vector vectStat; + vectStat.clear(); + UInt ndxStat; + for (ndxStat = 0; ndxStat < _size; ndxStat++) + vectStat.push_back(_counter[_nonzero[ndxStat]]); + std::sort(vectStat.begin(), vectStat.end()); + std::cout << "Reset width=" << sizeof(It) << " size=" << _dimension + << " nonzero=" << _size << " min=" << UInt(vectStat.front()) + << " max=" << UInt(vectStat.back()) + << " med=" << UInt(vectStat[_size / 2]) << std::endl; + } #endif - // zero all the nonzero slots - if (_size < _dimension / 16) { // if fewer than 6.25% are nonzero - UInt ndx; // zero selectively - for (ndx = 0; ndx < _size; ndx++) - _counter[_nonzero[ndx]] = 0; - } - else { - memset(_counter, 0, _dimension * sizeof(_counter[0])); - } - - // no more nonzero slots - _size = 0; - } - private: - It * _counter; // use typename here - UInt * _nonzero; - UInt _size; - UInt _dimension; - }; - - template - class CCellSegActivity - { - public: - CCellSegActivity() - { - _cell.initialize(_MAX_CELLS); - _seg.initialize(_MAX_CELLS * _MAX_SEGS); - } - UInt get(UInt cellIdx) - { - return _cell.get(cellIdx); - } - UInt get(UInt cellIdx, UInt segIdx) - { - return _seg.get(cellIdx * _MAX_SEGS + segIdx); - } - void increment(UInt cellIdx, UInt segIdx) - { - _cell.max(cellIdx, _seg.increment(cellIdx * _MAX_SEGS + segIdx)); - } - void reset() - { - _cell.reset(); - _seg.reset(); - } - private: - CBasicActivity _cell; - CBasicActivity _seg; - }; - - class Cells4 : public Serializable - { - public: - - typedef Segment::InSynapses InSynapses; - typedef std::vector OutSynapses; - typedef std::vector SegmentUpdates; - static const UInt VERSION = 2; - - private: - nupic::Random _rng; - - //----------------------------------------------------------------------- - /** - * Temporal pooler parameters, typically set by the user. - * See TP.py for explanations. - */ - UInt _nColumns; - UInt _nCellsPerCol; - UInt _nCells; - UInt _activationThreshold; - UInt _minThreshold; - UInt _newSynapseCount; - UInt _nIterations; - UInt _nLrnIterations; - UInt _segUpdateValidDuration; - Real _initSegFreq; // TODO: Can we remove this? Used anywhere? - Real _permInitial; - Real _permConnected; - Real _permMax; - Real _permDec; - Real _permInc; - Real _globalDecay; - bool _doPooling; - UInt _pamLength; - UInt _maxInfBacktrack; - UInt _maxLrnBacktrack; - UInt _maxSeqLength; - UInt _learnedSeqLength; - Real _avgLearnedSeqLength; - UInt _maxAge; - UInt _verbosity; - Int _maxSegmentsPerCell; - Int _maxSynapsesPerSegment; - bool _checkSynapseConsistency; // If true, will perform time - // consuming invariance checks. - - //----------------------------------------------------------------------- - /** - * Internal variables. - */ - bool _resetCalled; // True if reset() was called since the - // last call to compute. - Real _avgInputDensity; // Average no. of non-zero inputs - UInt _pamCounter; // pamCounter gets reset to pamLength - // whenever we detect that the learning - // state is making good predictions - UInt _version; - - //----------------------------------------------------------------------- - /** - * The various inference and learning states. See TP.py documentation - * - * Note: 'T1' means 't-1' - * TODO: change to more compact data type (later) 2011-07-23 partly done. - */ + // zero all the nonzero slots + if (_size < _dimension / 16) { // if fewer than 6.25% are nonzero + UInt ndx; // zero selectively + for (ndx = 0; ndx < _size; ndx++) + _counter[_nonzero[ndx]] = 0; + } else { + memset(_counter, 0, _dimension * sizeof(_counter[0])); + } + + // no more nonzero slots + _size = 0; + } + +private: + It *_counter; // use typename here + UInt *_nonzero; + UInt _size; + UInt _dimension; +}; + +template class CCellSegActivity { +public: + CCellSegActivity() { + _cell.initialize(_MAX_CELLS); + _seg.initialize(_MAX_CELLS * _MAX_SEGS); + } + UInt get(UInt cellIdx) { return _cell.get(cellIdx); } + UInt get(UInt cellIdx, UInt segIdx) { + return _seg.get(cellIdx * _MAX_SEGS + segIdx); + } + void increment(UInt cellIdx, UInt segIdx) { + _cell.max(cellIdx, _seg.increment(cellIdx * _MAX_SEGS + segIdx)); + } + void reset() { + _cell.reset(); + _seg.reset(); + } + +private: + CBasicActivity _cell; + CBasicActivity _seg; +}; + +class Cells4 : public Serializable { +public: + typedef Segment::InSynapses InSynapses; + typedef std::vector OutSynapses; + typedef std::vector SegmentUpdates; + static const UInt VERSION = 2; + +private: + nupic::Random _rng; + + //----------------------------------------------------------------------- + /** + * Temporal pooler parameters, typically set by the user. + * See TP.py for explanations. + */ + UInt _nColumns; + UInt _nCellsPerCol; + UInt _nCells; + UInt _activationThreshold; + UInt _minThreshold; + UInt _newSynapseCount; + UInt _nIterations; + UInt _nLrnIterations; + UInt _segUpdateValidDuration; + Real _initSegFreq; // TODO: Can we remove this? Used anywhere? + Real _permInitial; + Real _permConnected; + Real _permMax; + Real _permDec; + Real _permInc; + Real _globalDecay; + bool _doPooling; + UInt _pamLength; + UInt _maxInfBacktrack; + UInt _maxLrnBacktrack; + UInt _maxSeqLength; + UInt _learnedSeqLength; + Real _avgLearnedSeqLength; + UInt _maxAge; + UInt _verbosity; + Int _maxSegmentsPerCell; + Int _maxSynapsesPerSegment; + bool _checkSynapseConsistency; // If true, will perform time + // consuming invariance checks. + + //----------------------------------------------------------------------- + /** + * Internal variables. + */ + bool _resetCalled; // True if reset() was called since the + // last call to compute. + Real _avgInputDensity; // Average no. of non-zero inputs + UInt _pamCounter; // pamCounter gets reset to pamLength + // whenever we detect that the learning + // state is making good predictions + UInt _version; + + //----------------------------------------------------------------------- + /** + * The various inference and learning states. See TP.py documentation + * + * Note: 'T1' means 't-1' + * TODO: change to more compact data type (later) 2011-07-23 partly done. + */ #define SOME_STATES_NOT_INDEXED 1 #if SOME_STATES_NOT_INDEXED - CState _infActiveStateT; - CState _infActiveStateT1; - CState _infPredictedStateT; - CState _infPredictedStateT1; + CState _infActiveStateT; + CState _infActiveStateT1; + CState _infPredictedStateT; + CState _infPredictedStateT1; #else - CStateIndexed _infActiveStateT; - CStateIndexed _infActiveStateT1; - CStateIndexed _infPredictedStateT; - CStateIndexed _infPredictedStateT1; + CStateIndexed _infActiveStateT; + CStateIndexed _infActiveStateT1; + CStateIndexed _infPredictedStateT; + CStateIndexed _infPredictedStateT1; #endif - Real* _cellConfidenceT; - Real* _cellConfidenceT1; - Real* _colConfidenceT; - Real* _colConfidenceT1; - bool _ownsMemory; // If true, this class is responsible - // for managing memory of above - // eight arrays. - - - CStateIndexed _learnActiveStateT; - CStateIndexed _learnActiveStateT1; - CStateIndexed _learnPredictedStateT; - CStateIndexed _learnPredictedStateT1; - - Real* _cellConfidenceCandidate; - Real* _colConfidenceCandidate; - Real* _tmpInputBuffer; + Real *_cellConfidenceT; + Real *_cellConfidenceT1; + Real *_colConfidenceT; + Real *_colConfidenceT1; + bool _ownsMemory; // If true, this class is responsible + // for managing memory of above + // eight arrays. + + CStateIndexed _learnActiveStateT; + CStateIndexed _learnActiveStateT1; + CStateIndexed _learnPredictedStateT; + CStateIndexed _learnPredictedStateT1; + + Real *_cellConfidenceCandidate; + Real *_colConfidenceCandidate; + Real *_tmpInputBuffer; #if SOME_STATES_NOT_INDEXED - CState _infActiveStateCandidate; - CState _infPredictedStateCandidate; - CState _infActiveBackup; - CState _infPredictedBackup; + CState _infActiveStateCandidate; + CState _infPredictedStateCandidate; + CState _infActiveBackup; + CState _infPredictedBackup; #else - CStateIndexed _infActiveStateCandidate; - CStateIndexed _infPredictedStateCandidate; - CStateIndexed _infActiveBackup; - CStateIndexed _infPredictedBackup; + CStateIndexed _infActiveStateCandidate; + CStateIndexed _infPredictedStateCandidate; + CStateIndexed _infActiveBackup; + CStateIndexed _infPredictedBackup; #endif - //----------------------------------------------------------------------- - /** - * Internal data structures. - */ - std::vector< Cell > _cells; - std::deque > _prevInfPatterns; - std::deque > _prevLrnPatterns; - SegmentUpdates _segmentUpdates; - - //----------------------------------------------------------------------- - /** - * Internal data structures used for speed optimization. - */ - std::vector _outSynapses; - UInt _nIterationsSinceRebalance; - CCellSegActivity _learnActivity; - // _inferActivity and _learnActivity use identical data - // structures, and their use does not overlap - #define _inferActivity _learnActivity - - public: - //----------------------------------------------------------------------- - /** - * Default constructor needed when lifting from persistence. - */ - Cells4(UInt nColumns =0, UInt nCellsPerCol =0, - UInt activationThreshold =1, - UInt minThreshold =1, - UInt newSynapseCount =1, - UInt segUpdateValidDuration =1, - Real permInitial =.5, - Real permConnected =.8, - Real permMax =1, - Real permDec =.1, - Real permInc =.1, - Real globalDecay =0, - bool doPooling =false, - int seed =-1, - bool initFromCpp =false, - bool checkSynapseConsistency =false); - - - //---------------------------------------------------------------------- - /** - * This also called when lifting from persistence. - */ - void - initialize(UInt nColumns =0, UInt nCellsPerCol =0, - UInt activationThreshold =1, - UInt minThreshold =1, - UInt newSynapseCount =1, - UInt segUpdateValidDuration =1, - Real permInitial =.5, - Real permConnected =.8, - Real permMax =1, - Real permDec =.1, - Real permInc =.1, - Real globalDecay =.1, - bool doPooling =false, - bool initFromCpp =false, - bool checkSynapseConsistency =false); - - //---------------------------------------------------------------------- - ~Cells4(); - - //---------------------------------------------------------------------- - UInt version() const - { - return _version; - } - - //---------------------------------------------------------------------- - /** - * Call this when allocating numpy arrays, to have pointers use those - * arrays. - */ - void setStatePointers(Byte* infActiveT, Byte* infActiveT1, - Byte* infPredT, Byte* infPredT1, - Real* colConfidenceT, Real* colConfidenceT1, - Real* cellConfidenceT, Real* cellConfidenceT1) - { - if (_ownsMemory) { - delete [] _cellConfidenceT; - delete [] _cellConfidenceT1; - delete [] _colConfidenceT; - delete [] _colConfidenceT1; - } - - _ownsMemory = false; - - _infActiveStateT.usePythonMemory(infActiveT, _nCells); - _infActiveStateT1.usePythonMemory(infActiveT1, _nCells); - _infPredictedStateT.usePythonMemory(infPredT, _nCells); - _infPredictedStateT1.usePythonMemory(infPredT1, _nCells); - _cellConfidenceT = cellConfidenceT; - _cellConfidenceT1 = cellConfidenceT1; - _colConfidenceT = colConfidenceT; - _colConfidenceT1 = colConfidenceT1; - } - - //----------------------------------------------------------------------- - /** - * Use this when C++ allocates memory for the arrays, and Python needs to look - * at them. - */ - void getStatePointers(Byte*& activeT, Byte*& activeT1, - Byte*& predT, Byte*& predT1, - Real*& colConfidenceT, Real*& colConfidenceT1, - Real*& confidenceT, Real*& confidenceT1) const - { - NTA_ASSERT(_ownsMemory); - - activeT = _infActiveStateT.arrayPtr(); - activeT1 = _infActiveStateT1.arrayPtr(); - predT = _infPredictedStateT.arrayPtr(); - predT1 = _infPredictedStateT1.arrayPtr(); - confidenceT = _cellConfidenceT; - confidenceT1 = _cellConfidenceT1; - colConfidenceT = _colConfidenceT; - colConfidenceT1 = _colConfidenceT1; - } - - //----------------------------------------------------------------------- - /** - * Use this when Python needs to look up the learn states. - */ - void getLearnStatePointers(Byte*& activeT, Byte*& activeT1, - Byte*& predT, Byte*& predT1) const - { - activeT = _learnActiveStateT.arrayPtr(); - activeT1 = _learnActiveStateT1.arrayPtr(); - predT = _learnPredictedStateT.arrayPtr(); - predT1 = _learnPredictedStateT1.arrayPtr(); - } - - //---------------------------------------------------------------------- - /** - * Accessors for getting various member variables - */ - UInt nSegments() const; - UInt nCells() const { return _nCells; } - UInt nColumns() const { return _nColumns; } - UInt nCellsPerCol() const { return _nCellsPerCol; } - UInt getMinThreshold() const { return _minThreshold; } - Real getPermConnected() const { return _permConnected; } - UInt getVerbosity() const { return _verbosity; } - UInt getMaxAge() const { return _maxAge; } - UInt getPamLength() const { return _pamLength; } - UInt getMaxInfBacktrack() const { return _maxInfBacktrack;} - UInt getMaxLrnBacktrack() const { return _maxLrnBacktrack;} - UInt getPamCounter() const { return _pamCounter;} - UInt getMaxSeqLength() const { return _maxSeqLength;} - Real getAvgLearnedSeqLength() const { return _avgLearnedSeqLength;} - UInt getNLrnIterations() const { return _nLrnIterations;} - Int getMaxSegmentsPerCell() const { return _maxSegmentsPerCell;} - Int getMaxSynapsesPerSegment() const { return _maxSynapsesPerSegment;} - bool getCheckSynapseConsistency() const { return _checkSynapseConsistency;} - - - //---------------------------------------------------------------------- - /** - * Accessors for setting various member variables - */ - void setMaxInfBacktrack(UInt t) {_maxInfBacktrack = t;} - void setMaxLrnBacktrack(UInt t) {_maxLrnBacktrack = t;} - void setVerbosity(UInt v) {_verbosity = v; } - void setMaxAge(UInt a) {_maxAge = a; } - void setMaxSeqLength(UInt v) {_maxSeqLength = v;} - void setCheckSynapseConsistency(bool val) - { _checkSynapseConsistency = val;} - - void setMaxSegmentsPerCell(int maxSegs) { - if (maxSegs != -1) { - NTA_CHECK(maxSegs > 0); - NTA_CHECK(_globalDecay == 0.0); - NTA_CHECK(_maxAge == 0); - } - _maxSegmentsPerCell = maxSegs; - } - - void setMaxSynapsesPerCell(int maxSyns) { - if (maxSyns != -1) { - NTA_CHECK(maxSyns > 0); - NTA_CHECK(_globalDecay == 0.0); - NTA_CHECK(_maxAge == 0); - } - _maxSynapsesPerSegment = maxSyns; - } - - void setPamLength(UInt pl) - { - NTA_CHECK(pl > 0); - _pamLength = pl; - _pamCounter = _pamLength; - } - - - //----------------------------------------------------------------------- - /** - * Returns the number of segments currently in use on the given cell. - */ - UInt nSegmentsOnCell(UInt colIdx, UInt cellIdxInCol) const; - - //----------------------------------------------------------------------- - UInt nSynapses() const; - - //----------------------------------------------------------------------- - /** - * WRONG ONE if you want the current number of segments with actual synapses - * on the cell!!!! - * This one counts the total number of segments ever allocated on a cell, which - * includes empty segments that have been previously freed. - */ - UInt __nSegmentsOnCell(UInt cellIdx) const; - - //----------------------------------------------------------------------- - /** - * Total number of synapses in a given cell (at at given point, changes all the - * time). - */ - UInt nSynapsesInCell(UInt cellIdx) const; - - - //----------------------------------------------------------------------- - Cell* getCell(UInt colIdx, UInt cellIdxInCol); - - //----------------------------------------------------------------------- - UInt getCellIdx(UInt colIdx, UInt cellIdxInCol); - - //----------------------------------------------------------------------- - /** - * Can return a previously freed segment (segment size == 0) if called with a segIdx - * which is in the "free" list of the cell. - */ - Segment* - getSegment(UInt colIdx, UInt cellIdxInCol, UInt segIdx); - - //----------------------------------------------------------------------- - /** - * Can return a previously freed segment (segment size == 0) if called with a segIdx - * which is in the "free" list of the cell. - */ - Segment& segment(UInt cellIdx, UInt segIdx); - - //---------------------------------------------------------------------- - //---------------------------------------------------------------------- - // - // ROUTINES USED IN PERFORMING INFERENCE AND LEARNING - // - //---------------------------------------------------------------------- - //---------------------------------------------------------------------- - - //----------------------------------------------------------------------- - /** - * Main compute routine, called for both learning and inference. - * - * Parameters: - * =========== - * - * input: array representing bottom up input - * output: array representing inference output - * doInference: if true, inference output will be computed - * doLearning: if true, learning will occur - */ - void compute(Real* input, Real* output, bool doInference, bool doLearning); - - //----------------------------------------------------------------------- - /** - */ - void reset(); - - //---------------------------------------------------------------------- - bool isActive(UInt cellIdx, UInt segIdx, const CState& state) const; - - //---------------------------------------------------------------------- - /** - * Find weakly activated cell in column. - * - * Parameters: - * ========== - * colIdx: index of column in which to search - * state: the array of cell activities - * minThreshold: only consider segments with activity >= minThreshold - * useSegActivity: if true, use forward prop segment activity values - * - * Return value: index and segment of most activated segment whose - * activity is >= minThreshold. The index returned for the cell - * is between 0 and _nCells, *not* a cell index inside the column. - * If no cells are found, return ((UInt) -1, (UInt) -1). - */ - std::pair getBestMatchingCellT(UInt colIdx, const CState& state, UInt minThreshold); - std::pair getBestMatchingCellT1(UInt colIdx, const CState& state, UInt minThreshold); - - //---------------------------------------------------------------------- - /** - * Compute cell and segment activities using forward propagation - * and the given state variable. - * - * 2011-08-11: We will remove the CState& function if we can - * convert _infActiveStateT from a CState object to CStateIndexed - * without degrading performance. Conversion will also require us - * to move all state array modifications from Python to C++. One - * known offender is TP.py. - */ - void computeForwardPropagation(CStateIndexed& state); + //----------------------------------------------------------------------- + /** + * Internal data structures. + */ + std::vector _cells; + std::deque> _prevInfPatterns; + std::deque> _prevLrnPatterns; + SegmentUpdates _segmentUpdates; + + //----------------------------------------------------------------------- + /** + * Internal data structures used for speed optimization. + */ + std::vector _outSynapses; + UInt _nIterationsSinceRebalance; + CCellSegActivity _learnActivity; +// _inferActivity and _learnActivity use identical data +// structures, and their use does not overlap +#define _inferActivity _learnActivity + +public: + //----------------------------------------------------------------------- + /** + * Default constructor needed when lifting from persistence. + */ + Cells4(UInt nColumns = 0, UInt nCellsPerCol = 0, UInt activationThreshold = 1, + UInt minThreshold = 1, UInt newSynapseCount = 1, + UInt segUpdateValidDuration = 1, Real permInitial = .5, + Real permConnected = .8, Real permMax = 1, Real permDec = .1, + Real permInc = .1, Real globalDecay = 0, bool doPooling = false, + int seed = -1, bool initFromCpp = false, + bool checkSynapseConsistency = false); + + //---------------------------------------------------------------------- + /** + * This also called when lifting from persistence. + */ + void initialize(UInt nColumns = 0, UInt nCellsPerCol = 0, + UInt activationThreshold = 1, UInt minThreshold = 1, + UInt newSynapseCount = 1, UInt segUpdateValidDuration = 1, + Real permInitial = .5, Real permConnected = .8, + Real permMax = 1, Real permDec = .1, Real permInc = .1, + Real globalDecay = .1, bool doPooling = false, + bool initFromCpp = false, + bool checkSynapseConsistency = false); + + //---------------------------------------------------------------------- + ~Cells4(); + + //---------------------------------------------------------------------- + UInt version() const { return _version; } + + //---------------------------------------------------------------------- + /** + * Call this when allocating numpy arrays, to have pointers use those + * arrays. + */ + void setStatePointers(Byte *infActiveT, Byte *infActiveT1, Byte *infPredT, + Byte *infPredT1, Real *colConfidenceT, + Real *colConfidenceT1, Real *cellConfidenceT, + Real *cellConfidenceT1) { + if (_ownsMemory) { + delete[] _cellConfidenceT; + delete[] _cellConfidenceT1; + delete[] _colConfidenceT; + delete[] _colConfidenceT1; + } + + _ownsMemory = false; + + _infActiveStateT.usePythonMemory(infActiveT, _nCells); + _infActiveStateT1.usePythonMemory(infActiveT1, _nCells); + _infPredictedStateT.usePythonMemory(infPredT, _nCells); + _infPredictedStateT1.usePythonMemory(infPredT1, _nCells); + _cellConfidenceT = cellConfidenceT; + _cellConfidenceT1 = cellConfidenceT1; + _colConfidenceT = colConfidenceT; + _colConfidenceT1 = colConfidenceT1; + } + + //----------------------------------------------------------------------- + /** + * Use this when C++ allocates memory for the arrays, and Python needs to look + * at them. + */ + void getStatePointers(Byte *&activeT, Byte *&activeT1, Byte *&predT, + Byte *&predT1, Real *&colConfidenceT, + Real *&colConfidenceT1, Real *&confidenceT, + Real *&confidenceT1) const { + NTA_ASSERT(_ownsMemory); + + activeT = _infActiveStateT.arrayPtr(); + activeT1 = _infActiveStateT1.arrayPtr(); + predT = _infPredictedStateT.arrayPtr(); + predT1 = _infPredictedStateT1.arrayPtr(); + confidenceT = _cellConfidenceT; + confidenceT1 = _cellConfidenceT1; + colConfidenceT = _colConfidenceT; + colConfidenceT1 = _colConfidenceT1; + } + + //----------------------------------------------------------------------- + /** + * Use this when Python needs to look up the learn states. + */ + void getLearnStatePointers(Byte *&activeT, Byte *&activeT1, Byte *&predT, + Byte *&predT1) const { + activeT = _learnActiveStateT.arrayPtr(); + activeT1 = _learnActiveStateT1.arrayPtr(); + predT = _learnPredictedStateT.arrayPtr(); + predT1 = _learnPredictedStateT1.arrayPtr(); + } + + //---------------------------------------------------------------------- + /** + * Accessors for getting various member variables + */ + UInt nSegments() const; + UInt nCells() const { return _nCells; } + UInt nColumns() const { return _nColumns; } + UInt nCellsPerCol() const { return _nCellsPerCol; } + UInt getMinThreshold() const { return _minThreshold; } + Real getPermConnected() const { return _permConnected; } + UInt getVerbosity() const { return _verbosity; } + UInt getMaxAge() const { return _maxAge; } + UInt getPamLength() const { return _pamLength; } + UInt getMaxInfBacktrack() const { return _maxInfBacktrack; } + UInt getMaxLrnBacktrack() const { return _maxLrnBacktrack; } + UInt getPamCounter() const { return _pamCounter; } + UInt getMaxSeqLength() const { return _maxSeqLength; } + Real getAvgLearnedSeqLength() const { return _avgLearnedSeqLength; } + UInt getNLrnIterations() const { return _nLrnIterations; } + Int getMaxSegmentsPerCell() const { return _maxSegmentsPerCell; } + Int getMaxSynapsesPerSegment() const { return _maxSynapsesPerSegment; } + bool getCheckSynapseConsistency() const { return _checkSynapseConsistency; } + + //---------------------------------------------------------------------- + /** + * Accessors for setting various member variables + */ + void setMaxInfBacktrack(UInt t) { _maxInfBacktrack = t; } + void setMaxLrnBacktrack(UInt t) { _maxLrnBacktrack = t; } + void setVerbosity(UInt v) { _verbosity = v; } + void setMaxAge(UInt a) { _maxAge = a; } + void setMaxSeqLength(UInt v) { _maxSeqLength = v; } + void setCheckSynapseConsistency(bool val) { _checkSynapseConsistency = val; } + + void setMaxSegmentsPerCell(int maxSegs) { + if (maxSegs != -1) { + NTA_CHECK(maxSegs > 0); + NTA_CHECK(_globalDecay == 0.0); + NTA_CHECK(_maxAge == 0); + } + _maxSegmentsPerCell = maxSegs; + } + + void setMaxSynapsesPerCell(int maxSyns) { + if (maxSyns != -1) { + NTA_CHECK(maxSyns > 0); + NTA_CHECK(_globalDecay == 0.0); + NTA_CHECK(_maxAge == 0); + } + _maxSynapsesPerSegment = maxSyns; + } + + void setPamLength(UInt pl) { + NTA_CHECK(pl > 0); + _pamLength = pl; + _pamCounter = _pamLength; + } + + //----------------------------------------------------------------------- + /** + * Returns the number of segments currently in use on the given cell. + */ + UInt nSegmentsOnCell(UInt colIdx, UInt cellIdxInCol) const; + + //----------------------------------------------------------------------- + UInt nSynapses() const; + + //----------------------------------------------------------------------- + /** + * WRONG ONE if you want the current number of segments with actual synapses + * on the cell!!!! + * This one counts the total number of segments ever allocated on a cell, + * which includes empty segments that have been previously freed. + */ + UInt __nSegmentsOnCell(UInt cellIdx) const; + + //----------------------------------------------------------------------- + /** + * Total number of synapses in a given cell (at at given point, changes all + * the time). + */ + UInt nSynapsesInCell(UInt cellIdx) const; + + //----------------------------------------------------------------------- + Cell *getCell(UInt colIdx, UInt cellIdxInCol); + + //----------------------------------------------------------------------- + UInt getCellIdx(UInt colIdx, UInt cellIdxInCol); + + //----------------------------------------------------------------------- + /** + * Can return a previously freed segment (segment size == 0) if called with a + * segIdx which is in the "free" list of the cell. + */ + Segment *getSegment(UInt colIdx, UInt cellIdxInCol, UInt segIdx); + + //----------------------------------------------------------------------- + /** + * Can return a previously freed segment (segment size == 0) if called with a + * segIdx which is in the "free" list of the cell. + */ + Segment &segment(UInt cellIdx, UInt segIdx); + + //---------------------------------------------------------------------- + //---------------------------------------------------------------------- + // + // ROUTINES USED IN PERFORMING INFERENCE AND LEARNING + // + //---------------------------------------------------------------------- + //---------------------------------------------------------------------- + + //----------------------------------------------------------------------- + /** + * Main compute routine, called for both learning and inference. + * + * Parameters: + * =========== + * + * input: array representing bottom up input + * output: array representing inference output + * doInference: if true, inference output will be computed + * doLearning: if true, learning will occur + */ + void compute(Real *input, Real *output, bool doInference, bool doLearning); + + //----------------------------------------------------------------------- + /** + */ + void reset(); + + //---------------------------------------------------------------------- + bool isActive(UInt cellIdx, UInt segIdx, const CState &state) const; + + //---------------------------------------------------------------------- + /** + * Find weakly activated cell in column. + * + * Parameters: + * ========== + * colIdx: index of column in which to search + * state: the array of cell activities + * minThreshold: only consider segments with activity >= minThreshold + * useSegActivity: if true, use forward prop segment activity values + * + * Return value: index and segment of most activated segment whose + * activity is >= minThreshold. The index returned for the cell + * is between 0 and _nCells, *not* a cell index inside the column. + * If no cells are found, return ((UInt) -1, (UInt) -1). + */ + std::pair getBestMatchingCellT(UInt colIdx, const CState &state, + UInt minThreshold); + std::pair getBestMatchingCellT1(UInt colIdx, const CState &state, + UInt minThreshold); + + //---------------------------------------------------------------------- + /** + * Compute cell and segment activities using forward propagation + * and the given state variable. + * + * 2011-08-11: We will remove the CState& function if we can + * convert _infActiveStateT from a CState object to CStateIndexed + * without degrading performance. Conversion will also require us + * to move all state array modifications from Python to C++. One + * known offender is TP.py. + */ + void computeForwardPropagation(CStateIndexed &state); #if SOME_STATES_NOT_INDEXED - void computeForwardPropagation(CState& state); + void computeForwardPropagation(CState &state); #endif - //---------------------------------------------------------------------- - //---------------------------------------------------------------------- - // - // ROUTINES FOR PERFORMING INFERENCE - // - //---------------------------------------------------------------------- - //---------------------------------------------------------------------- - - //---------------------------------------------------------------------- - /** - * Update the inference state. Called from compute() on every iteration - * - * Parameters: - * =========== - * - * activeColumns: Indices of active columns - */ - void updateInferenceState(const std::vector & activeColumns); - - //---------------------------------------------------------------------- - /** - * Update the inference active state from the last set of predictions - * and the current bottom-up. - * - * Parameters: - * =========== - * - * activeColumns: Indices of active columns - * useStartCells: If true, ignore previous predictions and simply - * turn on the start cells in the active columns - * - * Return value: whether or not we are in a sequence. - * 'true' if the current input was sufficiently - * predicted, OR if we started over on startCells. - * 'false' indicates that the current input was NOT - * predicted, and we are now bursting on most columns. - * - */ - bool inferPhase1(const std::vector & activeColumns, bool useStartCells); - - //----------------------------------------------------------------------- - /** - * Phase 2 for the inference state. The computes the predicted state, - * then checks to insure that the predicted state is not over-saturated, - * i.e. look too close like a burst. This indicates that there were so - * many separate paths learned from the current input columns to the - * predicted input columns that bursting on the current input columns - * is most likely generated mix and match errors on cells in the - * predicted columns. If we detect this situation, we instead turn on - * only the start cells in the current active columns and re-generate - * the predicted state from those. - * - * Return value: 'true' if we have at least some guess as to the - * next input. 'false' indicates that we have reached - * the end of a learned sequence. - * - */ - bool inferPhase2(); - - //----------------------------------------------------------------------- - /** - * This "backtracks" our inference state, trying to see if we can lock - * onto the current set of inputs by assuming the sequence started N - * steps ago on start cells. For details please see documentation in - * TP.py - * - * Parameters: - * =========== - * - * activeColumns: Indices of active columns - */ - void inferBacktrack(const std::vector & activeColumns); - - //---------------------------------------------------------------------- - //---------------------------------------------------------------------- - // - // ROUTINES FOR PERFORMING LEARNING - // - //---------------------------------------------------------------------- - //---------------------------------------------------------------------- - - //---------------------------------------------------------------------- - /** - * Update the learning state. Called from compute() - * - * Parameters: - * =========== - * - * activeColumns: Indices of active columns - */ - void updateLearningState(const std::vector & activeColumns, - Real* input); - - //----------------------------------------------------------------------- - /** - * Compute the learning active state given the predicted state and - * the bottom-up input. - * - * - * Parameters: - * =========== - * - * activeColumns: Indices of active columns - * readOnly: True if being called from backtracking logic. - * This tells us not to increment any segment - * duty cycles or queue up any updates. - * - * Return value: 'true' if the current input was sufficiently - * predicted, OR if we started over on startCells. - * 'false' indicates that the current input was NOT - * predicted well enough to be considered inSequence - * - */ - bool learnPhase1(const std::vector & activeColumns, bool readOnly); - - //----------------------------------------------------------------------- - /** - * Compute the predicted segments given the current set of active cells. - * - * This computes the lrnPredictedState['t'] and queues up any segments - * that became active (and the list of active synapses for each - * segment) into the segmentUpdates queue - * - * Parameters: - * =========== - * - * readOnly: True if being called from backtracking logic. - * This tells us not to increment any segment - * duty cycles or queue up any updates. - * - */ - void learnPhase2(bool readOnly); - - //----------------------------------------------------------------------- - /** - * This "backtracks" our learning state, trying to see if we can lock - * onto the current set of inputs by assuming the sequence started - * up to N steps ago on start cells. - * - */ - UInt learnBacktrack(); - - //----------------------------------------------------------------------- - /** - * A utility method called from learnBacktrack. This will backtrack - * starting from the given startOffset in our prevLrnPatterns queue. - * - * It returns True if the backtrack was successful and we managed to get - * predictions all the way up to the current time step. - * - * If readOnly, then no segments are updated or modified, otherwise, all - * segment updates that belong to the given path are applied. - * - */ - bool learnBacktrackFrom(UInt startOffset, bool readOnly); - - //----------------------------------------------------------------------- - /** - * Update our moving average of learned sequence length. - */ - void _updateAvgLearnedSeqLength(UInt prevSeqLength); - - //---------------------------------------------------------------------- - /** - * Choose n random cells to learn from, using cells with activity in - * the state array. The passed in srcCells are excluded. - * - * Parameters: - * - cellIdx: the destination cell to pick sources for - * - segIdx: the destination segment to pick sources for - * - nSynToAdd: the numbers of synapses to add - * - * Return: - * - srcCells: contains the chosen source cell indices upon return - * - * NOTE: don't forget to keep cell indices sorted!!! - * TODO: make sure we don't pick a cell that's already a src for that seg - */ - void - chooseCellsToLearnFrom(UInt cellIdx, UInt segIdx, - UInt nSynToAdd, CStateIndexed& state, std::vector& srcCells); - - //---------------------------------------------------------------------- - /** - * Return the index of a cell in this column which is a good candidate - * for adding a new segment. - * - * When we have fixed size resources in effect, we insure that we pick a - * cell which does not already have the max number of allowed segments. - * If none exists, we choose the least used segment in the column to - * re-allocate. Note that this routine should never return the start - * cell (cellIdx 0) if we have more than one cell per column. - * - * Parameters: - * - colIdx: which column to look at - * - * Return: - * - cellIdx: index of the chosen cell - * - */ - UInt getCellForNewSegment(UInt colIdx); - - - //---------------------------------------------------------------------- - /** - * Insert a segmentUpdate data structure containing a list of proposed changes - * to segment segIdx. If newSynapses - * is true, then newSynapseCount - len(activeSynapses) synapses are added to - * activeSynapses. These synapses are randomly chosen from the set of cells - * that have learnState = 1 at timeStep. - * - * Return: true if a new segmentUpdate data structure was pushed onto - * the list. - * - * NOTE: called getSegmentActiveSynapses in Python - * - */ - bool computeUpdate(UInt cellIdx, UInt segIdx, CStateIndexed& activeState, - bool sequenceSegmentFlag, bool newSynapsesFlag); - - //---------------------------------------------------------------------- - /** - * Adds OutSynapses to the internal data structure that maintains OutSynapses - * for each InSynapses. This enables us to propagation activation forward, which - * is faster since activation is sparse. - * - * This is a templated method because sometimes we are called with - * std::set::const_iterator and sometimes with - * std::vector::const_iterator - */ - - template - void addOutSynapses(UInt dstCellIdx, UInt dstSegIdx, - It newSynapse, - It newSynapsesEnd); - - //---------------------------------------------------------------------- - /** - * Erases an OutSynapses. See addOutSynapses just above. - */ - void eraseOutSynapses(UInt dstCellIdx, UInt dstSegIdx, - const std::vector& srcCells); - - //---------------------------------------------------------------------- - /** - * Go through the list of accumulated segment updates and process them - * as follows: - * - * if the segment update is too old, remove the update - * - * elseif the cell received bottom-up input (activeColumns==1) update - * its permanences then positively adapt this segment - * - * elseif the cell is still being predicted, and pooling is on then leave it - * in the queue - * - * else remove it from the queue. - * - * Parameters: - * =========== - * - * activeColumns: array of _nColumns columns which are currently active - * predictedState: array of _nCells states representing predictions for each - * cell - * - */ - void processSegmentUpdates(Real* input, const CState& predictedState); - - //---------------------------------------------------------------------- - /** - * Removes any updates that would be applied to the given col, - * cellIdx, segIdx. - */ - void cleanUpdatesList(UInt cellIdx, UInt segIdx); - - //---------------------------------------------------------------------- - /** - * Apply age-based global decay logic and remove segments/synapses - * as appropriate. - */ - void applyGlobalDecay(); - - - //----------------------------------------------------------------------- - /** - * Private Helper function for Cells4::adaptSegment. Generates lists of - * synapses to decrement, increment, add, and remove. - * - * We break it out into a separate function to facilitate unit testing. - * - * On Entry, purges resudual data from inactiveSrcCellIdxs, - * inactiveSynapseIdxs, activeSrcCellIdxs, and activeSynapseIdxs. - * - * segment: The segment being adapted. - * - * synapsesSet: IN/OUT On entry, the union of source cell indexes - * corresponding to existing active synapses in the - * segment as well as new synapses to be created. On - * return, it's former self sans elements returned - * in activeSrcCellIdxs. The remaining elements - * correspond to new synapses to be created within - * the segment. - * - * inactiveSrcCellIdxs: OUT Source cell indexes corresponding to - * inactive synapses in the segment. Ordered by - * relative position of the corresponding InSynapses - * in the segment. The elements here correlate to - * elements in inactiveSynapseIdxs. - * - * inactiveSynapseIdxs: OUT Synapse indexes corresponding to inactive - * synapses in the segment. Sorted in - * ascending order. The elements here correlate to - * elements in inactiveSrcCellIdxs. - * - * activeSrcCellIdxs: OUT Source cell indexes corresponding to - * active synapses in the segment. Ordered by - * relative position of the corresponding InSynapses - * in the segment. The elements correlate to - * elements in activeSynapseIdxs. - * - * activeSynapseIdxs: OUT Synapse indexes corresponding to active - * synapses in the segment. In ascending order. The - * elements correlate to elements in - * activeSrcCellIdxs. - * - */ - static void _generateListsOfSynapsesToAdjustForAdaptSegment( - const Segment& segment, - std::set& synapsesSet, - std::vector& inactiveSrcCellIdxs, - std::vector& inactiveSynapseIdxs, - std::vector& activeSrcCellIdxs, - std::vector& activeSynapseIdxs); - - //----------------------------------------------------------------------- - /** - * Applies segment update information to a segment in a cell as follows: - * - * If the segment exists, synapses on the active list get their - * permanence counts incremented by permanenceInc. All other synapses - * get their permanence counts decremented by permanenceDec. If - * a synapse's permanence drops to zero, it is removed from the segment. - * If a segment does not have synapses anymore, it is removed from the - * Cell. We also increment the positiveActivations count of the segment. - * - * If the segment does not exist, it is created using the synapses in - * update. - * - * Parameters: - * =========== - * - * update: segmentUpdate instance - */ - void adaptSegment(const SegmentUpdate& update); - - //----------------------------------------------------------------------- - /** - * This method deletes all synapses where permanence value is strictly - * less than minPermanence. It also deletes all segments where the - * number of connected synapses is strictly less than minNumSyns+1. - * Returns the number of segments and synapses removed. - * - * Parameters: - * - * minPermanence: Any syn whose permamence is 0 or < minPermanence - * will be deleted. If 0 is passed in, then - * _permConnected is used. - * minNumSyns: Any segment with less than minNumSyns synapses - * remaining in it will be deleted. If 0 is passed - * in, then _activationThreshold is used. - * - */ - std::pair trimSegments(Real minPermanence, UInt minNumSyns); - - - //---------------------------------------------------------------------- - //---------------------------------------------------------------------- - // - // ROUTINES FOR PERSISTENCE - // - //---------------------------------------------------------------------- - //---------------------------------------------------------------------- - - /** - * TODO: compute, rather than writing to a buffer. - * TODO: move persistence to binary, faster and easier to compute expecte size. - */ - UInt persistentSize() const - { - // TODO: this won't scale! - std::stringstream tmp; - this->save(tmp); - return tmp.str().size(); - } - - //---------------------------------------------------------------------- - /** - * Write the state to a proto or file - */ - using Serializable::write; - virtual void write(Cells4Proto::Builder& proto) const override; - - //---------------------------------------------------------------------- - /** - * Read the state into a proto or file - */ - using Serializable::read; - virtual void read(Cells4Proto::Reader& proto) override; - - //---------------------------------------------------------------------- - /** - * Save the state to the given file - */ - void saveToFile(std::string filePath) const; - - //---------------------------------------------------------------------- - /** - * Load the state from the given file - */ - void loadFromFile(std::string filePath); - - //---------------------------------------------------------------------- - void save(std::ostream& outStream) const; - - //----------------------------------------------------------------------- - /** - * Need to load and re-propagate activities so that we can really persist - * at any point, load back and resume inference at exactly the same point. - */ - void load(std::istream& inStream); - - //----------------------------------------------------------------------- - void print(std::ostream& outStream) const; - - //---------------------------------------------------------------------- - //---------------------------------------------------------------------- - // - // MISC SUPPORT AND DEBUGGING ROUTINES - // - //---------------------------------------------------------------------- - //---------------------------------------------------------------------- - - // Set the Cell class segment order - void setCellSegmentOrder(bool matchPythonOrder); - - //---------------------------------------------------------------------- - /** - * Used in unit tests and debugging. - */ - void - addNewSegment(UInt colIdx, UInt cellIdxInCol, - bool sequenceSegmentFlag, - const std::vector >& extSynapses); - - void - updateSegment(UInt colIdx, UInt cellIdxInCol, UInt segIdx, - const std::vector >& extSynapses); - - //----------------------------------------------------------------------- - /** - * Rebalances and rebuilds internal structures for faster computing - * - */ - void _rebalance(); - void rebuildOutSynapses(); - void trimOldSegments(UInt age); - - //---------------------------------------------------------------------- - /** - * Various debugging helpers - */ - void printStates(); - void printState(UInt *state); - void dumpPrevPatterns(std::deque > &patterns); - void dumpSegmentUpdates(); - - //----------------------------------------------------------------------- - /** - * Returns list of indices of segments that are *not* empty in the free list. - */ - std::vector - getNonEmptySegList(UInt colIdx, UInt cellIdxInCol); - - - //----------------------------------------------------------------------- - /** - * Dump timing results to stdout - */ - void dumpTiming(); - - //----------------------------------------------------------------------- - // Reset all timers to 0 - //----------------------------------------------------------------------- - void resetTimers(); - - //----------------------------------------------------------------------- - // Invariants - //----------------------------------------------------------------------- - /** - * Performs a number of consistency checks. The test takes some time - * but is very helpful in development. The test is run during load/save. - * It is also run on every compute if _checkSynapseConsistency is true - */ - bool invariants(bool verbose = false) const; - - //----------------------------------------------------------------------- - // Statistics - //----------------------------------------------------------------------- - void stats() const - { - return; - } - - }; - - //----------------------------------------------------------------------- + //---------------------------------------------------------------------- + //---------------------------------------------------------------------- + // + // ROUTINES FOR PERFORMING INFERENCE + // + //---------------------------------------------------------------------- + //---------------------------------------------------------------------- + + //---------------------------------------------------------------------- + /** + * Update the inference state. Called from compute() on every iteration + * + * Parameters: + * =========== + * + * activeColumns: Indices of active columns + */ + void updateInferenceState(const std::vector &activeColumns); + + //---------------------------------------------------------------------- + /** + * Update the inference active state from the last set of predictions + * and the current bottom-up. + * + * Parameters: + * =========== + * + * activeColumns: Indices of active columns + * useStartCells: If true, ignore previous predictions and simply + * turn on the start cells in the active columns + * + * Return value: whether or not we are in a sequence. + * 'true' if the current input was sufficiently + * predicted, OR if we started over on startCells. + * 'false' indicates that the current input was NOT + * predicted, and we are now bursting on most columns. + * + */ + bool inferPhase1(const std::vector &activeColumns, bool useStartCells); + + //----------------------------------------------------------------------- + /** + * Phase 2 for the inference state. The computes the predicted state, + * then checks to insure that the predicted state is not over-saturated, + * i.e. look too close like a burst. This indicates that there were so + * many separate paths learned from the current input columns to the + * predicted input columns that bursting on the current input columns + * is most likely generated mix and match errors on cells in the + * predicted columns. If we detect this situation, we instead turn on + * only the start cells in the current active columns and re-generate + * the predicted state from those. + * + * Return value: 'true' if we have at least some guess as to the + * next input. 'false' indicates that we have reached + * the end of a learned sequence. + * + */ + bool inferPhase2(); + + //----------------------------------------------------------------------- + /** + * This "backtracks" our inference state, trying to see if we can lock + * onto the current set of inputs by assuming the sequence started N + * steps ago on start cells. For details please see documentation in + * TP.py + * + * Parameters: + * =========== + * + * activeColumns: Indices of active columns + */ + void inferBacktrack(const std::vector &activeColumns); + + //---------------------------------------------------------------------- + //---------------------------------------------------------------------- + // + // ROUTINES FOR PERFORMING LEARNING + // + //---------------------------------------------------------------------- + //---------------------------------------------------------------------- + + //---------------------------------------------------------------------- + /** + * Update the learning state. Called from compute() + * + * Parameters: + * =========== + * + * activeColumns: Indices of active columns + */ + void updateLearningState(const std::vector &activeColumns, Real *input); + + //----------------------------------------------------------------------- + /** + * Compute the learning active state given the predicted state and + * the bottom-up input. + * + * + * Parameters: + * =========== + * + * activeColumns: Indices of active columns + * readOnly: True if being called from backtracking logic. + * This tells us not to increment any segment + * duty cycles or queue up any updates. + * + * Return value: 'true' if the current input was sufficiently + * predicted, OR if we started over on startCells. + * 'false' indicates that the current input was NOT + * predicted well enough to be considered inSequence + * + */ + bool learnPhase1(const std::vector &activeColumns, bool readOnly); + + //----------------------------------------------------------------------- + /** + * Compute the predicted segments given the current set of active cells. + * + * This computes the lrnPredictedState['t'] and queues up any segments + * that became active (and the list of active synapses for each + * segment) into the segmentUpdates queue + * + * Parameters: + * =========== + * + * readOnly: True if being called from backtracking logic. + * This tells us not to increment any segment + * duty cycles or queue up any updates. + * + */ + void learnPhase2(bool readOnly); + + //----------------------------------------------------------------------- + /** + * This "backtracks" our learning state, trying to see if we can lock + * onto the current set of inputs by assuming the sequence started + * up to N steps ago on start cells. + * + */ + UInt learnBacktrack(); + + //----------------------------------------------------------------------- + /** + * A utility method called from learnBacktrack. This will backtrack + * starting from the given startOffset in our prevLrnPatterns queue. + * + * It returns True if the backtrack was successful and we managed to get + * predictions all the way up to the current time step. + * + * If readOnly, then no segments are updated or modified, otherwise, all + * segment updates that belong to the given path are applied. + * + */ + bool learnBacktrackFrom(UInt startOffset, bool readOnly); + + //----------------------------------------------------------------------- + /** + * Update our moving average of learned sequence length. + */ + void _updateAvgLearnedSeqLength(UInt prevSeqLength); + + //---------------------------------------------------------------------- + /** + * Choose n random cells to learn from, using cells with activity in + * the state array. The passed in srcCells are excluded. + * + * Parameters: + * - cellIdx: the destination cell to pick sources for + * - segIdx: the destination segment to pick sources for + * - nSynToAdd: the numbers of synapses to add + * + * Return: + * - srcCells: contains the chosen source cell indices upon return + * + * NOTE: don't forget to keep cell indices sorted!!! + * TODO: make sure we don't pick a cell that's already a src for that seg + */ + void chooseCellsToLearnFrom(UInt cellIdx, UInt segIdx, UInt nSynToAdd, + CStateIndexed &state, + std::vector &srcCells); + + //---------------------------------------------------------------------- + /** + * Return the index of a cell in this column which is a good candidate + * for adding a new segment. + * + * When we have fixed size resources in effect, we insure that we pick a + * cell which does not already have the max number of allowed segments. + * If none exists, we choose the least used segment in the column to + * re-allocate. Note that this routine should never return the start + * cell (cellIdx 0) if we have more than one cell per column. + * + * Parameters: + * - colIdx: which column to look at + * + * Return: + * - cellIdx: index of the chosen cell + * + */ + UInt getCellForNewSegment(UInt colIdx); + + //---------------------------------------------------------------------- + /** + * Insert a segmentUpdate data structure containing a list of proposed changes + * to segment segIdx. If newSynapses + * is true, then newSynapseCount - len(activeSynapses) synapses are added to + * activeSynapses. These synapses are randomly chosen from the set of cells + * that have learnState = 1 at timeStep. + * + * Return: true if a new segmentUpdate data structure was pushed onto + * the list. + * + * NOTE: called getSegmentActiveSynapses in Python + * + */ + bool computeUpdate(UInt cellIdx, UInt segIdx, CStateIndexed &activeState, + bool sequenceSegmentFlag, bool newSynapsesFlag); + + //---------------------------------------------------------------------- + /** + * Adds OutSynapses to the internal data structure that maintains OutSynapses + * for each InSynapses. This enables us to propagation activation forward, + * which is faster since activation is sparse. + * + * This is a templated method because sometimes we are called with + * std::set::const_iterator and sometimes with + * std::vector::const_iterator + */ + + template + void addOutSynapses(UInt dstCellIdx, UInt dstSegIdx, It newSynapse, + It newSynapsesEnd); + + //---------------------------------------------------------------------- + /** + * Erases an OutSynapses. See addOutSynapses just above. + */ + void eraseOutSynapses(UInt dstCellIdx, UInt dstSegIdx, + const std::vector &srcCells); + + //---------------------------------------------------------------------- + /** + * Go through the list of accumulated segment updates and process them + * as follows: + * + * if the segment update is too old, remove the update + * + * elseif the cell received bottom-up input (activeColumns==1) update + * its permanences then positively adapt this segment + * + * elseif the cell is still being predicted, and pooling is on then leave it + * in the queue + * + * else remove it from the queue. + * + * Parameters: + * =========== + * + * activeColumns: array of _nColumns columns which are currently active + * predictedState: array of _nCells states representing predictions for each + * cell + * + */ + void processSegmentUpdates(Real *input, const CState &predictedState); + + //---------------------------------------------------------------------- + /** + * Removes any updates that would be applied to the given col, + * cellIdx, segIdx. + */ + void cleanUpdatesList(UInt cellIdx, UInt segIdx); + + //---------------------------------------------------------------------- + /** + * Apply age-based global decay logic and remove segments/synapses + * as appropriate. + */ + void applyGlobalDecay(); + + //----------------------------------------------------------------------- + /** + * Private Helper function for Cells4::adaptSegment. Generates lists of + * synapses to decrement, increment, add, and remove. + * + * We break it out into a separate function to facilitate unit testing. + * + * On Entry, purges resudual data from inactiveSrcCellIdxs, + * inactiveSynapseIdxs, activeSrcCellIdxs, and activeSynapseIdxs. + * + * segment: The segment being adapted. + * + * synapsesSet: IN/OUT On entry, the union of source cell indexes + * corresponding to existing active synapses in the + * segment as well as new synapses to be created. On + * return, it's former self sans elements returned + * in activeSrcCellIdxs. The remaining elements + * correspond to new synapses to be created within + * the segment. + * + * inactiveSrcCellIdxs: OUT Source cell indexes corresponding to + * inactive synapses in the segment. Ordered by + * relative position of the corresponding InSynapses + * in the segment. The elements here correlate to + * elements in inactiveSynapseIdxs. + * + * inactiveSynapseIdxs: OUT Synapse indexes corresponding to inactive + * synapses in the segment. Sorted in + * ascending order. The elements here correlate to + * elements in inactiveSrcCellIdxs. + * + * activeSrcCellIdxs: OUT Source cell indexes corresponding to + * active synapses in the segment. Ordered by + * relative position of the corresponding InSynapses + * in the segment. The elements correlate to + * elements in activeSynapseIdxs. + * + * activeSynapseIdxs: OUT Synapse indexes corresponding to active + * synapses in the segment. In ascending order. The + * elements correlate to elements in + * activeSrcCellIdxs. + * + */ + static void _generateListsOfSynapsesToAdjustForAdaptSegment( + const Segment &segment, std::set &synapsesSet, + std::vector &inactiveSrcCellIdxs, + std::vector &inactiveSynapseIdxs, + std::vector &activeSrcCellIdxs, + std::vector &activeSynapseIdxs); + + //----------------------------------------------------------------------- + /** + * Applies segment update information to a segment in a cell as follows: + * + * If the segment exists, synapses on the active list get their + * permanence counts incremented by permanenceInc. All other synapses + * get their permanence counts decremented by permanenceDec. If + * a synapse's permanence drops to zero, it is removed from the segment. + * If a segment does not have synapses anymore, it is removed from the + * Cell. We also increment the positiveActivations count of the segment. + * + * If the segment does not exist, it is created using the synapses in + * update. + * + * Parameters: + * =========== + * + * update: segmentUpdate instance + */ + void adaptSegment(const SegmentUpdate &update); + + //----------------------------------------------------------------------- + /** + * This method deletes all synapses where permanence value is strictly + * less than minPermanence. It also deletes all segments where the + * number of connected synapses is strictly less than minNumSyns+1. + * Returns the number of segments and synapses removed. + * + * Parameters: + * + * minPermanence: Any syn whose permamence is 0 or < minPermanence + * will be deleted. If 0 is passed in, then + * _permConnected is used. + * minNumSyns: Any segment with less than minNumSyns synapses + * remaining in it will be deleted. If 0 is passed + * in, then _activationThreshold is used. + * + */ + std::pair trimSegments(Real minPermanence, UInt minNumSyns); + + //---------------------------------------------------------------------- + //---------------------------------------------------------------------- + // + // ROUTINES FOR PERSISTENCE + // + //---------------------------------------------------------------------- + //---------------------------------------------------------------------- + + /** + * TODO: compute, rather than writing to a buffer. + * TODO: move persistence to binary, faster and easier to compute expecte + * size. + */ + UInt persistentSize() const { + // TODO: this won't scale! + std::stringstream tmp; + this->save(tmp); + return tmp.str().size(); + } + + //---------------------------------------------------------------------- + /** + * Write the state to a proto or file + */ + using Serializable::write; + virtual void write(Cells4Proto::Builder &proto) const override; + + //---------------------------------------------------------------------- + /** + * Read the state into a proto or file + */ + using Serializable::read; + virtual void read(Cells4Proto::Reader &proto) override; + + //---------------------------------------------------------------------- + /** + * Save the state to the given file + */ + void saveToFile(std::string filePath) const; + + //---------------------------------------------------------------------- + /** + * Load the state from the given file + */ + void loadFromFile(std::string filePath); + + //---------------------------------------------------------------------- + void save(std::ostream &outStream) const; + + //----------------------------------------------------------------------- + /** + * Need to load and re-propagate activities so that we can really persist + * at any point, load back and resume inference at exactly the same point. + */ + void load(std::istream &inStream); + + //----------------------------------------------------------------------- + void print(std::ostream &outStream) const; + + //---------------------------------------------------------------------- + //---------------------------------------------------------------------- + // + // MISC SUPPORT AND DEBUGGING ROUTINES + // + //---------------------------------------------------------------------- + //---------------------------------------------------------------------- + + // Set the Cell class segment order + void setCellSegmentOrder(bool matchPythonOrder); + + //---------------------------------------------------------------------- + /** + * Used in unit tests and debugging. + */ + void addNewSegment(UInt colIdx, UInt cellIdxInCol, bool sequenceSegmentFlag, + const std::vector> &extSynapses); + + void updateSegment(UInt colIdx, UInt cellIdxInCol, UInt segIdx, + const std::vector> &extSynapses); + + //----------------------------------------------------------------------- + /** + * Rebalances and rebuilds internal structures for faster computing + * + */ + void _rebalance(); + void rebuildOutSynapses(); + void trimOldSegments(UInt age); + + //---------------------------------------------------------------------- + /** + * Various debugging helpers + */ + void printStates(); + void printState(UInt *state); + void dumpPrevPatterns(std::deque> &patterns); + void dumpSegmentUpdates(); + + //----------------------------------------------------------------------- + /** + * Returns list of indices of segments that are *not* empty in the free list. + */ + std::vector getNonEmptySegList(UInt colIdx, UInt cellIdxInCol); + + //----------------------------------------------------------------------- + /** + * Dump timing results to stdout + */ + void dumpTiming(); + + //----------------------------------------------------------------------- + // Reset all timers to 0 + //----------------------------------------------------------------------- + void resetTimers(); + + //----------------------------------------------------------------------- + // Invariants + //----------------------------------------------------------------------- + /** + * Performs a number of consistency checks. The test takes some time + * but is very helpful in development. The test is run during load/save. + * It is also run on every compute if _checkSynapseConsistency is true + */ + bool invariants(bool verbose = false) const; + + //----------------------------------------------------------------------- + // Statistics + //----------------------------------------------------------------------- + void stats() const { return; } +}; + +//----------------------------------------------------------------------- #ifndef SWIG - std::ostream& operator<<(std::ostream& outStream, const Cells4& cells); +std::ostream &operator<<(std::ostream &outStream, const Cells4 &cells); #endif - //----------------------------------------------------------------------- - } // end namespace Cells4 - } // end namespace algorithms +//----------------------------------------------------------------------- +} // end namespace Cells4 +} // end namespace algorithms } // end namespace nupic - //----------------------------------------------------------------------- +//----------------------------------------------------------------------- #endif // NTA_Cells4_HPP diff --git a/src/nupic/algorithms/ClassifierResult.cpp b/src/nupic/algorithms/ClassifierResult.cpp index 8c24d6be04..1d5ac67299 100644 --- a/src/nupic/algorithms/ClassifierResult.cpp +++ b/src/nupic/algorithms/ClassifierResult.cpp @@ -30,53 +30,42 @@ using namespace std; -namespace nupic -{ - namespace algorithms - { - namespace cla_classifier - { +namespace nupic { +namespace algorithms { +namespace cla_classifier { - ClassifierResult::~ClassifierResult() - { - for (map*>::const_iterator it = result_.begin(); - it != result_.end(); ++it) - { - delete it->second; - } - } +ClassifierResult::~ClassifierResult() { + for (map *>::const_iterator it = result_.begin(); + it != result_.end(); ++it) { + delete it->second; + } +} - vector* ClassifierResult::createVector(Int step, UInt size, - Real64 value) - { - NTA_CHECK(result_.count(step) == 0) - << "The ClassifierResult cannot be reused!"; - vector* v = new vector(size, value); - result_.insert(pair*>(step, v)); - return v; - } +vector *ClassifierResult::createVector(Int step, UInt size, + Real64 value) { + NTA_CHECK(result_.count(step) == 0) + << "The ClassifierResult cannot be reused!"; + vector *v = new vector(size, value); + result_.insert(pair *>(step, v)); + return v; +} - bool ClassifierResult::operator==(const ClassifierResult& other) const - { - for (auto it = result_.begin(); it != result_.end(); it++) - { - auto thisVec = it->second; - auto otherVec = other.result_.at(it->first); - if (otherVec == nullptr || thisVec->size() != otherVec->size()) - { - return false; - } - for (UInt i = 0; i < thisVec->size(); i++) - { - if (fabs(thisVec->at(i) - otherVec->at(i)) > 0.000001) - { - return false; - } - } - } - return true; +bool ClassifierResult::operator==(const ClassifierResult &other) const { + for (auto it = result_.begin(); it != result_.end(); it++) { + auto thisVec = it->second; + auto otherVec = other.result_.at(it->first); + if (otherVec == nullptr || thisVec->size() != otherVec->size()) { + return false; + } + for (UInt i = 0; i < thisVec->size(); i++) { + if (fabs(thisVec->at(i) - otherVec->at(i)) > 0.000001) { + return false; } + } + } + return true; +} - } // end namespace cla_classifier - } // end namespace algorithms +} // end namespace cla_classifier +} // end namespace algorithms } // end namespace nupic diff --git a/src/nupic/algorithms/ClassifierResult.hpp b/src/nupic/algorithms/ClassifierResult.hpp index 28a60e6a32..880403caf4 100644 --- a/src/nupic/algorithms/ClassifierResult.hpp +++ b/src/nupic/algorithms/ClassifierResult.hpp @@ -30,83 +30,75 @@ using namespace std; -namespace nupic -{ - namespace algorithms - { - namespace cla_classifier - { +namespace nupic { +namespace algorithms { +namespace cla_classifier { - /** CLA classifier result class. - * - * @b Responsibility - * The ClassifierResult is responsible for storing result data and - * cleaning up the data when deleted. - * - */ - class ClassifierResult - { - public: - - /** - * Constructor. - */ - ClassifierResult() {} - - /** - * Destructor - frees memory allocated during lifespan. - */ - virtual ~ClassifierResult(); - - /** - * Creates and returns a vector for a given step. - * - * The vectors created are stored and can be accessed with the - * iterator methods. The vectors are owned by this class and are - * deleted in the destructor. - * - * @param step The prediction step to create a vector for. If -1, then - * a vector for the actual values to use for each bucket - * is returned. - * @param size The size of the desired vector. - * @param value The value to populate the vector with. - * - * @returns The specified vector. - */ - virtual vector* createVector(Int step, UInt size, Real64 value); - - /** - * Checks if the other instance has the exact same values. - * - * @param other The other instance to compare to. - * @returns True iff the other instance has the same values. - */ - virtual bool operator==(const ClassifierResult& other) const; - - /** - * Iterator method begin. - */ - virtual map*>::const_iterator begin() - { - return result_.begin(); - } - - /** - * Iterator method end. - */ - virtual map*>::const_iterator end() - { - return result_.end(); - } - - private: - - map*> result_; - - }; // end class ClassifierResult - - } // end namespace cla_classifier - } // end namespace algorithms +/** CLA classifier result class. + * + * @b Responsibility + * The ClassifierResult is responsible for storing result data and + * cleaning up the data when deleted. + * + */ +class ClassifierResult { +public: + /** + * Constructor. + */ + ClassifierResult() {} + + /** + * Destructor - frees memory allocated during lifespan. + */ + virtual ~ClassifierResult(); + + /** + * Creates and returns a vector for a given step. + * + * The vectors created are stored and can be accessed with the + * iterator methods. The vectors are owned by this class and are + * deleted in the destructor. + * + * @param step The prediction step to create a vector for. If -1, then + * a vector for the actual values to use for each bucket + * is returned. + * @param size The size of the desired vector. + * @param value The value to populate the vector with. + * + * @returns The specified vector. + */ + virtual vector *createVector(Int step, UInt size, Real64 value); + + /** + * Checks if the other instance has the exact same values. + * + * @param other The other instance to compare to. + * @returns True iff the other instance has the same values. + */ + virtual bool operator==(const ClassifierResult &other) const; + + /** + * Iterator method begin. + */ + virtual map *>::const_iterator begin() { + return result_.begin(); + } + + /** + * Iterator method end. + */ + virtual map *>::const_iterator end() { + return result_.end(); + } + +private: + map *> result_; + +}; // end class ClassifierResult + +} // end namespace cla_classifier +} // end namespace algorithms } // end namespace nupic #endif // NTA_classifier_result_HPP diff --git a/src/nupic/algorithms/CondProbTable.cpp b/src/nupic/algorithms/CondProbTable.cpp index 05c77bcaec..027e24c29f 100644 --- a/src/nupic/algorithms/CondProbTable.cpp +++ b/src/nupic/algorithms/CondProbTable.cpp @@ -20,335 +20,314 @@ * --------------------------------------------------------------------- */ -/** @file - * - */ +/** @file + * + */ -#include "nupic/utils/Log.hpp" #include "nupic/algorithms/CondProbTable.hpp" +#include "nupic/utils/Log.hpp" using namespace std; namespace nupic { - - //////////////////////////////////////////////////////////////////////////// - // Constructor - ////////////////////////////////////////////////////////////////////////////// - CondProbTable::CondProbTable(const UInt hintNumCols, const UInt hintNumRows) - : hintNumCols_(hintNumCols), - hintNumRows_(hintNumRows), - tableP_(nullptr), - cleanTableP_(nullptr), - cleanTableValid_(false), - rowSums_(), - colSums_() - { - } - - //////////////////////////////////////////////////////////////////////////// - // Destructor - ////////////////////////////////////////////////////////////////////////////// - CondProbTable::~CondProbTable() - { - delete tableP_; - delete cleanTableP_; - } - - //////////////////////////////////////////////////////////////////////////// - // Get a row of the table - ////////////////////////////////////////////////////////////////////////////// - void CondProbTable::getRow(const UInt& row, vector& contents) - { - // Overwrite the contents - contents.resize(tableP_->nCols()); - tableP_->getRowToDense(row, contents.begin()); - } - //////////////////////////////////////////////////////////////////////////// - // Grow the # of rows - ////////////////////////////////////////////////////////////////////////////// - void CondProbTable::grow(const UInt& rows, const UInt& cols) - { - const char* errPrefix = "CondProbTable::grow() - "; - - // Allocate the matrix now if we haven't already - if (!tableP_) { - NTA_ASSERT(cols != 0) << errPrefix << "Must have non-zero columns"; - - if (hintNumRows_ != 0) - tableP_ = new SparseMatrix(hintNumRows_, cols); - else - tableP_ = new SparseMatrix(0,0); - - // Setup our column sums - colSums_.resize(cols, (Real)0); - } +//////////////////////////////////////////////////////////////////////////// +// Constructor +////////////////////////////////////////////////////////////////////////////// +CondProbTable::CondProbTable(const UInt hintNumCols, const UInt hintNumRows) + : hintNumCols_(hintNumCols), hintNumRows_(hintNumRows), tableP_(nullptr), + cleanTableP_(nullptr), cleanTableValid_(false), rowSums_(), colSums_() {} - UInt curRows = tableP_->nRows(); - UInt curCols = tableP_->nCols(); - UInt nextRows = max(rows, curRows); - UInt nextCols = max(cols, curCols); +//////////////////////////////////////////////////////////////////////////// +// Destructor +////////////////////////////////////////////////////////////////////////////// +CondProbTable::~CondProbTable() { + delete tableP_; + delete cleanTableP_; +} - if ((curRows < nextRows) || (curCols < nextCols)) - { - cleanTableValid_ = false; - tableP_->resize(nextRows, nextCols); +//////////////////////////////////////////////////////////////////////////// +// Get a row of the table +////////////////////////////////////////////////////////////////////////////// +void CondProbTable::getRow(const UInt &row, vector &contents) { + // Overwrite the contents + contents.resize(tableP_->nCols()); + tableP_->getRowToDense(row, contents.begin()); +} - rowSums_.resize(nextRows); - colSums_.resize(nextCols); - } - } - - //////////////////////////////////////////////////////////////////////////// - // Update a row - //////////////////////////////////////////////////////////////////////////// - void CondProbTable::updateRow(const UInt& row, const vector& distribution) - { - //const char* errPrefix = "CondProbTable::updateRow() - "; - - // Grow the matrix if necessary - UInt cols = UInt(distribution.size()); - if (cols < hintNumCols_) - cols = hintNumCols_; - grow(row+1, cols); - - // Update the row +//////////////////////////////////////////////////////////////////////////// +// Grow the # of rows +////////////////////////////////////////////////////////////////////////////// +void CondProbTable::grow(const UInt &rows, const UInt &cols) { + const char *errPrefix = "CondProbTable::grow() - "; + + // Allocate the matrix now if we haven't already + if (!tableP_) { + NTA_ASSERT(cols != 0) << errPrefix << "Must have non-zero columns"; + + if (hintNumRows_ != 0) + tableP_ = new SparseMatrix(hintNumRows_, cols); + else + tableP_ = new SparseMatrix(0, 0); + + // Setup our column sums + colSums_.resize(cols, (Real)0); + } + + UInt curRows = tableP_->nRows(); + UInt curCols = tableP_->nCols(); + UInt nextRows = max(rows, curRows); + UInt nextCols = max(cols, curCols); + + if ((curRows < nextRows) || (curCols < nextCols)) { cleanTableValid_ = false; - tableP_->elementRowApply(row, std::plus(), distribution.begin()); - - // Update the row sums and column sums - Real rowSum = 0; - auto colSumsIter = colSums_.begin(); - CONST_LOOP(vector, iter, distribution) { - rowSum = rowSum + *iter; - *colSumsIter = *colSumsIter + *iter; - colSumsIter++; - } - rowSums_[row] += rowSum; + tableP_->resize(nextRows, nextCols); + + rowSums_.resize(nextRows); + colSums_.resize(nextCols); } +} - //////////////////////////////////////////////////////////////////////////// - // Infer, given vectors as inputs - ////////////////////////////////////////////////////////////////////////////// - void CondProbTable::inferRow(const vector& distribution, - vector& outScores, inferType infer) - { - const char* errPrefix = "CondProbTable::inferRow() - "; - - // Make sure they gave us the right source size - NTA_ASSERT(distribution.size() == tableP_->nCols()) - << errPrefix - << "input distribution vector should be " - << tableP_->nCols() << " wide"; - - // And the right output size - NTA_ASSERT(outScores.size() >= tableP_->nRows()) - << errPrefix - << "Output vector not large enough to hold all " - << tableP_->nRows() << " rows."; - - // Call the iterator version - inferRow(distribution.begin(), outScores.begin(), infer); +//////////////////////////////////////////////////////////////////////////// +// Update a row +//////////////////////////////////////////////////////////////////////////// +void CondProbTable::updateRow(const UInt &row, + const vector &distribution) { + // const char* errPrefix = "CondProbTable::updateRow() - "; + + // Grow the matrix if necessary + UInt cols = UInt(distribution.size()); + if (cols < hintNumCols_) + cols = hintNumCols_; + grow(row + 1, cols); + + // Update the row + cleanTableValid_ = false; + tableP_->elementRowApply(row, std::plus(), distribution.begin()); + + // Update the row sums and column sums + Real rowSum = 0; + auto colSumsIter = colSums_.begin(); + CONST_LOOP(vector, iter, distribution) { + rowSum = rowSum + *iter; + *colSumsIter = *colSumsIter + *iter; + colSumsIter++; } + rowSums_[row] += rowSum; +} + +//////////////////////////////////////////////////////////////////////////// +// Infer, given vectors as inputs +////////////////////////////////////////////////////////////////////////////// +void CondProbTable::inferRow(const vector &distribution, + vector &outScores, inferType infer) { + const char *errPrefix = "CondProbTable::inferRow() - "; + + // Make sure they gave us the right source size + NTA_ASSERT(distribution.size() == tableP_->nCols()) + << errPrefix << "input distribution vector should be " << tableP_->nCols() + << " wide"; + + // And the right output size + NTA_ASSERT(outScores.size() >= tableP_->nRows()) + << errPrefix << "Output vector not large enough to hold all " + << tableP_->nRows() << " rows."; + + // Call the iterator version + inferRow(distribution.begin(), outScores.begin(), infer); +} - //////////////////////////////////////////////////////////////////////////// - // Infer, given iterators as inputs - ////////////////////////////////////////////////////////////////////////////// - void CondProbTable::inferRow(vector::const_iterator distIter, - vector::iterator outIter, inferType infer) - { - const char* errPrefix = "CondProbTable::inferRow() - "; - - // Make sure we have a table - NTA_ASSERT(tableP_ != nullptr) +//////////////////////////////////////////////////////////////////////////// +// Infer, given iterators as inputs +////////////////////////////////////////////////////////////////////////////// +void CondProbTable::inferRow(vector::const_iterator distIter, + vector::iterator outIter, inferType infer) { + const char *errPrefix = "CondProbTable::inferRow() - "; + + // Make sure we have a table + NTA_ASSERT(tableP_ != nullptr) << errPrefix << "Must call updateRow at least once before doing inference"; - - // ---------------------------------------------------------------- - // Marginal inference - // ---------------------------------------------------------------- - if (infer == inferMarginal) { - - // Normalize by the column sums first - vector normDist; - LOOP(vector, iter, colSums_) { - normDist.push_back(*distIter / *iter); - ++distIter; - } - - tableP_->rightVecProd(normDist.begin(), outIter); - } - - // ---------------------------------------------------------------- - // Row evidence - // ---------------------------------------------------------------- - else if (infer == inferRowEvidence) { - - tableP_->rightVecProd(distIter, outIter); - - // Normalize by the row sums - LOOP(vector, iter, rowSums_) { - *outIter = *outIter / *iter; - ++outIter; - } - } - - // ---------------------------------------------------------------- - // Max product per row - // ---------------------------------------------------------------- - else if (infer == inferMaxProd) { - tableP_->vecMaxProd(distIter, outIter); + + // ---------------------------------------------------------------- + // Marginal inference + // ---------------------------------------------------------------- + if (infer == inferMarginal) { + + // Normalize by the column sums first + vector normDist; + LOOP(vector, iter, colSums_) { + normDist.push_back(*distIter / *iter); + ++distIter; } - - // ---------------------------------------------------------------- - // Viterbi, Use a "clean" CPD - // ---------------------------------------------------------------- - else if (infer == inferViterbi) { - - if (!cleanTableValid_) - makeCleanCPT(); - - // Do max product per row with clean CPD - cleanTableP_->vecMaxProd(distIter, outIter); - } - - // ---------------------------------------------------------------- - // Unknown inference method - // ---------------------------------------------------------------- - else - NTA_THROW << errPrefix << "Unknown inference type " << infer; + + tableP_->rightVecProd(normDist.begin(), outIter); } - //////////////////////////////////////////////////////////////////////////// - // make clean CPT - ////////////////////////////////////////////////////////////////////////////// - void CondProbTable::makeCleanCPT() - { - delete cleanTableP_; - - UInt nrows = tableP_->nRows(), ncols = tableP_->nCols(); - vector > col_max(ncols, make_pair(0, Real(0))); - - tableP_->colMax(col_max.begin()); - - cleanTableP_ = new SparseMatrix01(ncols, 1); - - for (UInt row = 0; row < nrows; ++row) { - vector nz; - for (UInt col = 0; col < ncols; ++col) - if (col_max[col].first == row) - nz.push_back(col); - cleanTableP_->addRow(UInt(nz.size()), nz.begin()); + // ---------------------------------------------------------------- + // Row evidence + // ---------------------------------------------------------------- + else if (infer == inferRowEvidence) { + + tableP_->rightVecProd(distIter, outIter); + + // Normalize by the row sums + LOOP(vector, iter, rowSums_) { + *outIter = *outIter / *iter; + ++outIter; } + } + + // ---------------------------------------------------------------- + // Max product per row + // ---------------------------------------------------------------- + else if (infer == inferMaxProd) { + tableP_->vecMaxProd(distIter, outIter); + } + + // ---------------------------------------------------------------- + // Viterbi, Use a "clean" CPD + // ---------------------------------------------------------------- + else if (infer == inferViterbi) { + + if (!cleanTableValid_) + makeCleanCPT(); + + // Do max product per row with clean CPD + cleanTableP_->vecMaxProd(distIter, outIter); + } + + // ---------------------------------------------------------------- + // Unknown inference method + // ---------------------------------------------------------------- + else + NTA_THROW << errPrefix << "Unknown inference type " << infer; +} + +//////////////////////////////////////////////////////////////////////////// +// make clean CPT +////////////////////////////////////////////////////////////////////////////// +void CondProbTable::makeCleanCPT() { + delete cleanTableP_; + + UInt nrows = tableP_->nRows(), ncols = tableP_->nCols(); + vector> col_max(ncols, make_pair(0, Real(0))); + + tableP_->colMax(col_max.begin()); + + cleanTableP_ = new SparseMatrix01(ncols, 1); + + for (UInt row = 0; row < nrows; ++row) { + vector nz; + for (UInt col = 0; col < ncols; ++col) + if (col_max[col].first == row) + nz.push_back(col); + cleanTableP_->addRow(UInt(nz.size()), nz.begin()); + } + + cleanTableValid_ = true; +} + +//////////////////////////////////////////////////////////////////////////// +// save state +////////////////////////////////////////////////////////////////////////////// +void CondProbTable::saveState(ostream &state) const { + const char *errPrefix = "CondProbTable::saveState() - "; + + NTA_CHECK(state.good()) << errPrefix << "- Bad stream"; - cleanTableValid_ = true; + state << "CondProbTable.V1 "; + + // Do we have a table yet? + if (tableP_) { + state << "1 "; + state << tableP_->nCols() << " "; + tableP_->toCSR(state); + } else { + state << "0 "; + state << hintNumCols_ << " " << hintNumRows_; + } + + state << " "; +} + +//////////////////////////////////////////////////////////////////////////// +// read state +////////////////////////////////////////////////////////////////////////////// +void CondProbTable::readState(istream &state) { + const char *errPrefix = "CondProbTable::readState() - "; + ios::iostate excMask; + + NTA_CHECK(state.good()) << errPrefix << "- Bad stream"; + + // Turn on exceptions on the stream so we can watch for errors + excMask = state.exceptions(); + state.exceptions(ios_base::failbit | ios_base::badbit); + + // ----------------------------------------------------------------- + // Verify signature on the stream + // ----------------------------------------------------------------- + string str; + state >> str; + if (str != string("CondProbTable.V1")) { + NTA_THROW << errPrefix << "Invalid state specified"; + return; + } + + // Delete the old table + if (tableP_) { + delete tableP_; + tableP_ = nullptr; } - //////////////////////////////////////////////////////////////////////////// - // save state - ////////////////////////////////////////////////////////////////////////////// - void CondProbTable::saveState(ostream& state) const - { - const char* errPrefix = "CondProbTable::saveState() - "; - - NTA_CHECK(state.good()) << errPrefix << "- Bad stream"; - - state << "CondProbTable.V1 "; - - // Do we have a table yet? - if (tableP_) { - state << "1 "; - state << tableP_->nCols() << " "; - tableP_->toCSR(state); + cleanTableValid_ = false; + + // ----------------------------------------------------------------- + // Get # of columns then read in the old matrix + // ----------------------------------------------------------------- + try { + bool hasTable; + state >> hasTable; + if (hasTable) { + state >> hintNumCols_; + tableP_ = new SparseMatrix(0, hintNumCols_); + tableP_->fromCSR(state); } else { - state << "0 "; - state << hintNumCols_ << " " << hintNumRows_; + state >> hintNumCols_ >> hintNumRows_; } - state << " "; + } catch (exception &e) { + NTA_THROW << errPrefix << "Error reading from stream: " << e.what(); } - //////////////////////////////////////////////////////////////////////////// - // read state - ////////////////////////////////////////////////////////////////////////////// - void CondProbTable::readState(istream& state) - { - const char* errPrefix = "CondProbTable::readState() - "; - ios::iostate excMask; - - NTA_CHECK(state.good()) << errPrefix << "- Bad stream"; - - // Turn on exceptions on the stream so we can watch for errors - excMask = state.exceptions(); - state.exceptions(ios_base::failbit | ios_base::badbit); - - // ----------------------------------------------------------------- - // Verify signature on the stream - // ----------------------------------------------------------------- - string str; - state >> str; - if (str != string("CondProbTable.V1")) { - NTA_THROW << errPrefix << "Invalid state specified"; - return; - } - - // Delete the old table - if (tableP_) { - delete tableP_; - tableP_ = nullptr; - } + // ----------------------------------------------------------------- + // Init other vars if we have a table + // ----------------------------------------------------------------- + if (tableP_) { + // Update the row sums and column sums + rowSums_.resize(tableP_->nRows()); + colSums_.resize(tableP_->nCols()); - cleanTableValid_ = false; + auto rowIter = rowSums_.begin(); + vector row; + for (UInt r = 0; r < tableP_->nRows(); ++r, ++rowIter) { + getRow(r, row); - // ----------------------------------------------------------------- - // Get # of columns then read in the old matrix - // ----------------------------------------------------------------- - try { - bool hasTable; - state >> hasTable; - if (hasTable) { - state >> hintNumCols_; - tableP_ = new SparseMatrix (0, hintNumCols_); - tableP_->fromCSR(state); - } else { - state >> hintNumCols_ >> hintNumRows_; - } - - } catch (exception& e) { - NTA_THROW << errPrefix - << "Error reading from stream: " << e.what(); - } + // Get the row sum + Real rowSum = 0; + CONST_LOOP(vector, iter, row) { rowSum += *iter; } + *rowIter = rowSum; - // ----------------------------------------------------------------- - // Init other vars if we have a table - // ----------------------------------------------------------------- - if (tableP_) { - // Update the row sums and column sums - rowSums_.resize (tableP_->nRows()); - colSums_.resize (tableP_->nCols()); - - auto rowIter = rowSums_.begin(); - vector row; - for (UInt r=0; rnRows(); ++r, ++rowIter) { - getRow (r, row); - - // Get the row sum - Real rowSum = 0; - CONST_LOOP(vector, iter, row) { - rowSum += *iter; - } - *rowIter = rowSum; - - // Add to column sums - vector::const_iterator srcIter = row.begin(); - LOOP(vector, colIter, colSums_) { - *colIter = *colIter + *srcIter; - ++srcIter; - } + // Add to column sums + vector::const_iterator srcIter = row.begin(); + LOOP(vector, colIter, colSums_) { + *colIter = *colIter + *srcIter; + ++srcIter; } } - - // Restore exceptions mask - state.exceptions(excMask); } + + // Restore exceptions mask + state.exceptions(excMask); +} } // namespace nupic diff --git a/src/nupic/algorithms/CondProbTable.hpp b/src/nupic/algorithms/CondProbTable.hpp index d586d7925d..d1b5bfbb02 100644 --- a/src/nupic/algorithms/CondProbTable.hpp +++ b/src/nupic/algorithms/CondProbTable.hpp @@ -29,185 +29,206 @@ #include namespace nupic { - - //////////////////////////////////////////////////////////////////////////// - /// Conditional Probablity Table - /// - /// @b Responsibility - /// - Holds frequencies in a 2D grid of bins. - /// - /// @b Resources/Ownerships: - /// - none - /// - /// @b Notes: - /// Binning is not performed automatically by this class. Bin updates msut be done - /// one row at a time. This class uses nupic::SparseMatrix which is a compressed sparse row - /// matrix. Also maintains the row and column sumProp distributions. - /// - ////////////////////////////////////////////////////////////////////////////// - class CondProbTable - { - public: - typedef enum {inferViterbi, inferMarginal, inferMaxProd, inferRowEvidence} inferType; - - static inferType convertInferType(const std::string &name) - { - if(name == "0") return inferViterbi; - else if(name == "1") return inferMarginal; - else if(name == "maxProp") return inferViterbi; - else if(name == "sumProp") return inferMarginal; - else { - throw std::invalid_argument("'" + name + "' is not a valid " + +//////////////////////////////////////////////////////////////////////////// +/// Conditional Probablity Table +/// +/// @b Responsibility +/// - Holds frequencies in a 2D grid of bins. +/// +/// @b Resources/Ownerships: +/// - none +/// +/// @b Notes: +/// Binning is not performed automatically by this class. Bin updates msut be +/// done one row at a time. This class uses nupic::SparseMatrix which is a +/// compressed sparse row matrix. Also maintains the row and column sumProp +/// distributions. +/// +////////////////////////////////////////////////////////////////////////////// +class CondProbTable { +public: + typedef enum { + inferViterbi, + inferMarginal, + inferMaxProd, + inferRowEvidence + } inferType; + + static inferType convertInferType(const std::string &name) { + if (name == "0") + return inferViterbi; + else if (name == "1") + return inferMarginal; + else if (name == "maxProp") + return inferViterbi; + else if (name == "sumProp") + return inferMarginal; + else { + throw std::invalid_argument( + "'" + name + + "' is not a valid " "conditional probability table inference type."); - return inferViterbi; // Unused. - } + return inferViterbi; // Unused. } + } - ///////////////////////////////////////////////////////////////////////////////////// - /// Constructor - /// - /// Both the number of columns and the number of rows can be increased after - /// construction by calling updateRow(). - /// - /// @param hintNumCols Number of columns in the table. This can be increased later - /// via updateRow() but never decreased. - /// @param hintNumRows Number of rows in the table. This can be increased later - /// via updateRow() but never decreased. - /// - /////////////////////////////////////////////////////////////////////////////////// - CondProbTable(const UInt hintNumCols=0, const UInt hintNumRows=0); - - ///////////////////////////////////////////////////////////////////////////////////// - /// Destructor - /// - /////////////////////////////////////////////////////////////////////////////////// - virtual ~CondProbTable(); - - ///////////////////////////////////////////////////////////////////////////////////// - /// Return the number of rows in the table - /// - /// @retval number of rows - /////////////////////////////////////////////////////////////////////////////////// - UInt numRows (void) { - if (tableP_) - return UInt(tableP_->nRows()); - else - return hintNumRows_; - } - - ///////////////////////////////////////////////////////////////////////////////////// - /// Return the number of columns in the table. - /// - /// @retval number of rows - /////////////////////////////////////////////////////////////////////////////////// - UInt numColumns (void) { - if (tableP_) - return tableP_->nCols(); - else - return hintNumCols_; - } - - ///////////////////////////////////////////////////////////////////////////////////// - /// Update a row with the given distribution. - /// - /// @param row which row to update - /// @param distribution the distribution to update the row with - /////////////////////////////////////////////////////////////////////////////////// - void updateRow (const UInt& row, const std::vector& distribution); - - ///////////////////////////////////////////////////////////////////////////////////// - /// Return the probablity of the given distribution belonging to each row. - /// - /// Computes the sumProp probablity of each row given the input probability of - /// each column. - /// - /// The semantics are as follows: If the distribution is P(col|e) where e is - /// the evidence is col is the column, and the CPD represents P(row|col), then - /// this calculates sum(P(col|e) P(row|col)) = P(row|e). - /// - /// The available inference methods are: - /// inferMarginal - Normalizes the distribution over the columns - /// inferRowEvidence - Normalize the distribution over the rows. - /// inferMaxProd - Computes the max product between each element of distribution - /// and corresponding element of row. - /// inferViterbi - works on a "clean" probability table, produced by finding the - /// max element of each column, setting it to 1, and putting 0 in - /// all other elements of the column. - /// - /// @param distribution the distribution to test - length equal to # of columns - /// @param outScores the return probablity of distribution belonging to each row - - /// length equal to # of rows - /// @param method the method to use, one of either inferMarginal, inferMaxProd, - /// inferRowEvidence, or inferViterbi - /////////////////////////////////////////////////////////////////////////////////// - void inferRow (const std::vector& distribution, std::vector& outScores, - inferType infer=inferMarginal); - - ///////////////////////////////////////////////////////////////////////////////////// - /// Form of inferRow that takes iterators as input - /// - /// @param distribution the distribution to test - length equal to # of columns - /// @param outScores the return probablity of distribution belonging to each row - /// length equal to # of rows - /// @param method the method to use, one of either inferMarginal, inferMaxProd, - /// inferRowEvidence, or inferViterbi - /////////////////////////////////////////////////////////////////////////////////// - void inferRow (std::vector::const_iterator distribution, - std::vector::iterator outScores, inferType infer=inferMarginal); - - ///////////////////////////////////////////////////////////////////////////////////// - /// Get a row of the table out. - /// - /// @param row which row to get - /// @param contents the row contents are written here - /////////////////////////////////////////////////////////////////////////////////// - void getRow (const UInt& row, std::vector& contents); - - ///////////////////////////////////////////////////////////////////////////////////// - /// Get the entire table out as a sparse matrix - /// - /// @retval pointer to the table - /////////////////////////////////////////////////////////////////////////////////// - const SparseMatrix* getTable (void) const {return tableP_;} - - ///////////////////////////////////////////////////////////////////////////////////// - /// Save state to a stream - /// - /// @param state the stream to write to - /////////////////////////////////////////////////////////////////////////////////// - void saveState(std::ostream& state) const; - - ///////////////////////////////////////////////////////////////////////////////////// - /// Read state from a stream - /// - /// @param state the stream to read from - /////////////////////////////////////////////////////////////////////////////////// - void readState(std::istream& state); - - private: - ///////////////////////////////////////////////////////////////////////////////////// - /// Grow the matrix to have the given # of rows - /// - /// @rows number of rows to grow to - /// @cols number of columns to grow to - /////////////////////////////////////////////////////////////////////////////////// - void grow (const UInt& rows, const UInt& cols); - - ///////////////////////////////////////////////////////////////////////////////////// - /// Make a "clean CPT". This is a copy of the CPT table with only the max element - /// in each column kept and all others set to 0. - /// - /////////////////////////////////////////////////////////////////////////////////// - void makeCleanCPT (void); - - UInt hintNumCols_; - UInt hintNumRows_; - SparseMatrix* tableP_; - SparseMatrix01* cleanTableP_; // for inferViterbi - bool cleanTableValid_; - std::vector rowSums_; - std::vector colSums_; - }; + ///////////////////////////////////////////////////////////////////////////////////// + /// Constructor + /// + /// Both the number of columns and the number of rows can be increased after + /// construction by calling updateRow(). + /// + /// @param hintNumCols Number of columns in the table. This can be increased + /// later + /// via updateRow() but never decreased. + /// @param hintNumRows Number of rows in the table. This can be increased + /// later + /// via updateRow() but never decreased. + /// + /////////////////////////////////////////////////////////////////////////////////// + CondProbTable(const UInt hintNumCols = 0, const UInt hintNumRows = 0); + + ///////////////////////////////////////////////////////////////////////////////////// + /// Destructor + /// + /////////////////////////////////////////////////////////////////////////////////// + virtual ~CondProbTable(); + + ///////////////////////////////////////////////////////////////////////////////////// + /// Return the number of rows in the table + /// + /// @retval number of rows + /////////////////////////////////////////////////////////////////////////////////// + UInt numRows(void) { + if (tableP_) + return UInt(tableP_->nRows()); + else + return hintNumRows_; + } + + ///////////////////////////////////////////////////////////////////////////////////// + /// Return the number of columns in the table. + /// + /// @retval number of rows + /////////////////////////////////////////////////////////////////////////////////// + UInt numColumns(void) { + if (tableP_) + return tableP_->nCols(); + else + return hintNumCols_; + } + + ///////////////////////////////////////////////////////////////////////////////////// + /// Update a row with the given distribution. + /// + /// @param row which row to update + /// @param distribution the distribution to update the row with + /////////////////////////////////////////////////////////////////////////////////// + void updateRow(const UInt &row, const std::vector &distribution); + + ///////////////////////////////////////////////////////////////////////////////////// + /// Return the probablity of the given distribution belonging to each row. + /// + /// Computes the sumProp probablity of each row given the input probability of + /// each column. + /// + /// The semantics are as follows: If the distribution is P(col|e) where e is + /// the evidence is col is the column, and the CPD represents P(row|col), then + /// this calculates sum(P(col|e) P(row|col)) = P(row|e). + /// + /// The available inference methods are: + /// inferMarginal - Normalizes the distribution over the columns + /// inferRowEvidence - Normalize the distribution over the rows. + /// inferMaxProd - Computes the max product between each element of + /// distribution + /// and corresponding element of row. + /// inferViterbi - works on a "clean" probability table, produced by finding + /// the + /// max element of each column, setting it to 1, and putting 0 + /// in all other elements of the column. + /// + /// @param distribution the distribution to test - length equal to # of + /// columns + /// @param outScores the return probablity of distribution belonging to + /// each row - + /// length equal to # of rows + /// @param method the method to use, one of either inferMarginal, + /// inferMaxProd, + /// inferRowEvidence, or inferViterbi + /////////////////////////////////////////////////////////////////////////////////// + void inferRow(const std::vector &distribution, + std::vector &outScores, inferType infer = inferMarginal); + + ///////////////////////////////////////////////////////////////////////////////////// + /// Form of inferRow that takes iterators as input + /// + /// @param distribution the distribution to test - length equal to # of + /// columns + /// @param outScores the return probablity of distribution belonging to + /// each row + /// length equal to # of rows + /// @param method the method to use, one of either inferMarginal, + /// inferMaxProd, + /// inferRowEvidence, or inferViterbi + /////////////////////////////////////////////////////////////////////////////////// + void inferRow(std::vector::const_iterator distribution, + std::vector::iterator outScores, + inferType infer = inferMarginal); + + ///////////////////////////////////////////////////////////////////////////////////// + /// Get a row of the table out. + /// + /// @param row which row to get + /// @param contents the row contents are written here + /////////////////////////////////////////////////////////////////////////////////// + void getRow(const UInt &row, std::vector &contents); + + ///////////////////////////////////////////////////////////////////////////////////// + /// Get the entire table out as a sparse matrix + /// + /// @retval pointer to the table + /////////////////////////////////////////////////////////////////////////////////// + const SparseMatrix *getTable(void) const { return tableP_; } + + ///////////////////////////////////////////////////////////////////////////////////// + /// Save state to a stream + /// + /// @param state the stream to write to + /////////////////////////////////////////////////////////////////////////////////// + void saveState(std::ostream &state) const; + + ///////////////////////////////////////////////////////////////////////////////////// + /// Read state from a stream + /// + /// @param state the stream to read from + /////////////////////////////////////////////////////////////////////////////////// + void readState(std::istream &state); + +private: + ///////////////////////////////////////////////////////////////////////////////////// + /// Grow the matrix to have the given # of rows + /// + /// @rows number of rows to grow to + /// @cols number of columns to grow to + /////////////////////////////////////////////////////////////////////////////////// + void grow(const UInt &rows, const UInt &cols); + + ///////////////////////////////////////////////////////////////////////////////////// + /// Make a "clean CPT". This is a copy of the CPT table with only the max + /// element in each column kept and all others set to 0. + /// + /////////////////////////////////////////////////////////////////////////////////// + void makeCleanCPT(void); + + UInt hintNumCols_; + UInt hintNumRows_; + SparseMatrix *tableP_; + SparseMatrix01 *cleanTableP_; // for inferViterbi + bool cleanTableValid_; + std::vector rowSums_; + std::vector colSums_; +}; } // namespace nupic diff --git a/src/nupic/algorithms/Connections.cpp b/src/nupic/algorithms/Connections.cpp index ff4cfb15f8..e14c60c28d 100644 --- a/src/nupic/algorithms/Connections.cpp +++ b/src/nupic/algorithms/Connections.cpp @@ -34,22 +34,17 @@ #include - -using std::vector; -using std::string; using std::endl; +using std::string; +using std::vector; using namespace nupic; using namespace nupic::algorithms::connections; static const Permanence EPSILON = 0.00001; -Connections::Connections(CellIdx numCells) -{ - initialize(numCells); -} +Connections::Connections(CellIdx numCells) { initialize(numCells); } -void Connections::initialize(CellIdx numCells) -{ +void Connections::initialize(CellIdx numCells) { cells_ = vector(numCells); // Every time a segment or synapse is created, we assign it an ordinal and @@ -61,146 +56,125 @@ void Connections::initialize(CellIdx numCells) nextEventToken_ = 0; } -UInt32 Connections::subscribe(ConnectionsEventHandler* handler) -{ +UInt32 Connections::subscribe(ConnectionsEventHandler *handler) { UInt32 token = nextEventToken_++; eventHandlers_[token] = handler; return token; } -void Connections::unsubscribe(UInt32 token) -{ +void Connections::unsubscribe(UInt32 token) { delete eventHandlers_.at(token); eventHandlers_.erase(token); } -Segment Connections::createSegment(CellIdx cell) -{ +Segment Connections::createSegment(CellIdx cell) { Segment segment; - if (destroyedSegments_.size() > 0) - { + if (destroyedSegments_.size() > 0) { segment = destroyedSegments_.back(); destroyedSegments_.pop_back(); - } - else - { + } else { segment = segments_.size(); segments_.push_back(SegmentData()); segmentOrdinals_.push_back(0); } - SegmentData& segmentData = segments_[segment]; + SegmentData &segmentData = segments_[segment]; segmentData.cell = cell; - CellData& cellData = cells_[cell]; + CellData &cellData = cells_[cell]; segmentOrdinals_[segment] = nextSegmentOrdinal_++; cellData.segments.push_back(segment); - for (auto h : eventHandlers_) - { + for (auto h : eventHandlers_) { h.second->onCreateSegment(segment); } return segment; } -Synapse Connections::createSynapse(Segment segment, - CellIdx presynapticCell, - Permanence permanence) -{ +Synapse Connections::createSynapse(Segment segment, CellIdx presynapticCell, + Permanence permanence) { NTA_CHECK(permanence > 0); Synapse synapse; - if (destroyedSynapses_.size() > 0) - { + if (destroyedSynapses_.size() > 0) { synapse = destroyedSynapses_.back(); destroyedSynapses_.pop_back(); - } - else - { + } else { synapse.flatIdx = synapses_.size(); synapses_.push_back(SynapseData()); synapseOrdinals_.push_back(0); } - SynapseData& synapseData = synapses_[synapse]; + SynapseData &synapseData = synapses_[synapse]; synapseData.segment = segment; synapseData.presynapticCell = presynapticCell; synapseData.permanence = permanence; - SegmentData& segmentData = segments_[segment]; + SegmentData &segmentData = segments_[segment]; synapseOrdinals_[synapse] = nextSynapseOrdinal_++; segmentData.synapses.push_back(synapse); synapsesForPresynapticCell_[presynapticCell].push_back(synapse); - for (auto h : eventHandlers_) - { + for (auto h : eventHandlers_) { h.second->onCreateSynapse(synapse); } return synapse; } -bool Connections::segmentExists_(Segment segment) const -{ - const SegmentData& segmentData = segments_[segment]; - const vector& segmentsOnCell = cells_[segmentData.cell].segments; - return (std::find(segmentsOnCell.begin(), segmentsOnCell.end(), segment) - != segmentsOnCell.end()); +bool Connections::segmentExists_(Segment segment) const { + const SegmentData &segmentData = segments_[segment]; + const vector &segmentsOnCell = cells_[segmentData.cell].segments; + return (std::find(segmentsOnCell.begin(), segmentsOnCell.end(), segment) != + segmentsOnCell.end()); } -bool Connections::synapseExists_(Synapse synapse) const -{ - const SynapseData& synapseData = synapses_[synapse]; - const vector& synapsesOnSegment = segments_[synapseData.segment].synapses; - return (std::find(synapsesOnSegment.begin(), synapsesOnSegment.end(), synapse) - != synapsesOnSegment.end()); +bool Connections::synapseExists_(Synapse synapse) const { + const SynapseData &synapseData = synapses_[synapse]; + const vector &synapsesOnSegment = + segments_[synapseData.segment].synapses; + return (std::find(synapsesOnSegment.begin(), synapsesOnSegment.end(), + synapse) != synapsesOnSegment.end()); } -void Connections::removeSynapseFromPresynapticMap_(Synapse synapse) -{ - const SynapseData& synapseData = synapses_[synapse]; - vector& presynapticSynapses = - synapsesForPresynapticCell_.at(synapseData.presynapticCell); +void Connections::removeSynapseFromPresynapticMap_(Synapse synapse) { + const SynapseData &synapseData = synapses_[synapse]; + vector &presynapticSynapses = + synapsesForPresynapticCell_.at(synapseData.presynapticCell); auto it = std::find(presynapticSynapses.begin(), presynapticSynapses.end(), synapse); NTA_ASSERT(it != presynapticSynapses.end()); presynapticSynapses.erase(it); - if (presynapticSynapses.size() == 0) - { + if (presynapticSynapses.size() == 0) { synapsesForPresynapticCell_.erase(synapseData.presynapticCell); } } -void Connections::destroySegment(Segment segment) -{ +void Connections::destroySegment(Segment segment) { NTA_ASSERT(segmentExists_(segment)); - for (auto h : eventHandlers_) - { + for (auto h : eventHandlers_) { h.second->onDestroySegment(segment); } - SegmentData& segmentData = segments_[segment]; - for (Synapse synapse : segmentData.synapses) - { + SegmentData &segmentData = segments_[segment]; + for (Synapse synapse : segmentData.synapses) { // Don't call destroySynapse, since it's unnecessary to do index-shifting. removeSynapseFromPresynapticMap_(synapse); destroyedSynapses_.push_back(synapse); } segmentData.synapses.clear(); - CellData& cellData = cells_[segmentData.cell]; + CellData &cellData = cells_[segmentData.cell]; const auto segmentOnCell = - std::lower_bound(cellData.segments.begin(), cellData.segments.end(), - segment, - [&](Segment a, Segment b) - { - return segmentOrdinals_[a] < segmentOrdinals_[b]; - }); + std::lower_bound(cellData.segments.begin(), cellData.segments.end(), + segment, [&](Segment a, Segment b) { + return segmentOrdinals_[a] < segmentOrdinals_[b]; + }); NTA_ASSERT(segmentOnCell != cellData.segments.end()); NTA_ASSERT(*segmentOnCell == segment); @@ -210,24 +184,20 @@ void Connections::destroySegment(Segment segment) destroyedSegments_.push_back(segment); } -void Connections::destroySynapse(Synapse synapse) -{ +void Connections::destroySynapse(Synapse synapse) { NTA_ASSERT(synapseExists_(synapse)); - for (auto h : eventHandlers_) - { + for (auto h : eventHandlers_) { h.second->onDestroySynapse(synapse); } removeSynapseFromPresynapticMap_(synapse); - SegmentData& segmentData = segments_[synapses_[synapse].segment]; + SegmentData &segmentData = segments_[synapses_[synapse].segment]; const auto synapseOnSegment = - std::lower_bound(segmentData.synapses.begin(), segmentData.synapses.end(), - synapse, - [&](Synapse a, Synapse b) - { - return synapseOrdinals_[a] < synapseOrdinals_[b]; - }); + std::lower_bound(segmentData.synapses.begin(), segmentData.synapses.end(), + synapse, [&](Synapse a, Synapse b) { + return synapseOrdinals_[a] < synapseOrdinals_[b]; + }); NTA_ASSERT(synapseOnSegment != segmentData.synapses.end()); NTA_ASSERT(*synapseOnSegment == synapse); @@ -238,100 +208,77 @@ void Connections::destroySynapse(Synapse synapse) } void Connections::updateSynapsePermanence(Synapse synapse, - Permanence permanence) -{ - for (auto h : eventHandlers_) - { + Permanence permanence) { + for (auto h : eventHandlers_) { h.second->onUpdateSynapsePermanence(synapse, permanence); } synapses_[synapse].permanence = permanence; } -const vector& Connections::segmentsForCell(CellIdx cell) const -{ +const vector &Connections::segmentsForCell(CellIdx cell) const { return cells_[cell].segments; } -Segment Connections::getSegment(CellIdx cell, SegmentIdx idx) const -{ +Segment Connections::getSegment(CellIdx cell, SegmentIdx idx) const { return cells_[cell].segments[idx]; } -const vector& Connections::synapsesForSegment(Segment segment) const -{ +const vector &Connections::synapsesForSegment(Segment segment) const { return segments_[segment].synapses; } -CellIdx Connections::cellForSegment(Segment segment) const -{ +CellIdx Connections::cellForSegment(Segment segment) const { return segments_[segment].cell; } -SegmentIdx Connections::idxOnCellForSegment(Segment segment) const -{ - const vector& segments = segmentsForCell(cellForSegment(segment)); +SegmentIdx Connections::idxOnCellForSegment(Segment segment) const { + const vector &segments = segmentsForCell(cellForSegment(segment)); const auto it = std::find(segments.begin(), segments.end(), segment); NTA_ASSERT(it != segments.end()); return std::distance(segments.begin(), it); } -void Connections::mapSegmentsToCells( - const Segment* segments_begin, const Segment* segments_end, - CellIdx* cells_begin) const -{ - CellIdx* out = cells_begin; +void Connections::mapSegmentsToCells(const Segment *segments_begin, + const Segment *segments_end, + CellIdx *cells_begin) const { + CellIdx *out = cells_begin; - for (auto segment = segments_begin; - segment != segments_end; - ++segment, ++out) - { + for (auto segment = segments_begin; segment != segments_end; + ++segment, ++out) { NTA_ASSERT(segmentExists_(*segment)); *out = segments_[*segment].cell; } } -Segment Connections::segmentForSynapse(Synapse synapse) const -{ +Segment Connections::segmentForSynapse(Synapse synapse) const { return synapses_[synapse].segment; } -const SegmentData& Connections::dataForSegment(Segment segment) const -{ +const SegmentData &Connections::dataForSegment(Segment segment) const { return segments_[segment]; } -const SynapseData& Connections::dataForSynapse(Synapse synapse) const -{ +const SynapseData &Connections::dataForSynapse(Synapse synapse) const { return synapses_[synapse]; } -UInt32 Connections::segmentFlatListLength() const -{ - return segments_.size(); -} +UInt32 Connections::segmentFlatListLength() const { return segments_.size(); } -bool Connections::compareSegments(Segment a, Segment b) const -{ - const SegmentData& aData = segments_[a]; - const SegmentData& bData = segments_[b]; - if (aData.cell < bData.cell) - { +bool Connections::compareSegments(Segment a, Segment b) const { + const SegmentData &aData = segments_[a]; + const SegmentData &bData = segments_[b]; + if (aData.cell < bData.cell) { return true; - } - else if (bData.cell < aData.cell) - { + } else if (bData.cell < aData.cell) { return false; - } - else - { + } else { return segmentOrdinals_[a] < segmentOrdinals_[b]; } } -vector Connections::synapsesForPresynapticCell( - CellIdx presynapticCell) const -{ +vector +Connections::synapsesForPresynapticCell(CellIdx presynapticCell) const { if (synapsesForPresynapticCell_.find(presynapticCell) == synapsesForPresynapticCell_.end()) return vector{}; @@ -339,8 +286,7 @@ vector Connections::synapsesForPresynapticCell( return synapsesForPresynapticCell_.at(presynapticCell); } -Synapse Connections::minPermanenceSynapse_(Segment segment) const -{ +Synapse Connections::minPermanenceSynapse_(Segment segment) const { // Use special EPSILON logic to compensate for floating point differences // between C++ and other environments. @@ -348,10 +294,8 @@ Synapse Connections::minPermanenceSynapse_(Segment segment) const Permanence minPermanence = std::numeric_limits::max(); Synapse minSynapse; - for (Synapse synapse : segments_[segment].synapses) - { - if (synapses_[synapse].permanence < minPermanence - EPSILON) - { + for (Synapse synapse : segments_[segment].synapses) { + if (synapses_[synapse].permanence < minPermanence - EPSILON) { minSynapse = synapse; minPermanence = synapses_[synapse].permanence; found = true; @@ -364,25 +308,20 @@ Synapse Connections::minPermanenceSynapse_(Segment segment) const } void Connections::computeActivity( - vector& numActiveConnectedSynapsesForSegment, - vector& numActivePotentialSynapsesForSegment, - CellIdx activePresynapticCell, - Permanence connectedPermanence) const -{ + vector &numActiveConnectedSynapsesForSegment, + vector &numActivePotentialSynapsesForSegment, + CellIdx activePresynapticCell, Permanence connectedPermanence) const { NTA_ASSERT(numActiveConnectedSynapsesForSegment.size() == segments_.size()); NTA_ASSERT(numActivePotentialSynapsesForSegment.size() == segments_.size()); - if (synapsesForPresynapticCell_.count(activePresynapticCell)) - { + if (synapsesForPresynapticCell_.count(activePresynapticCell)) { for (Synapse synapse : - synapsesForPresynapticCell_.at(activePresynapticCell)) - { - const SynapseData& synapseData = synapses_[synapse]; + synapsesForPresynapticCell_.at(activePresynapticCell)) { + const SynapseData &synapseData = synapses_[synapse]; ++numActivePotentialSynapsesForSegment[synapseData.segment]; NTA_ASSERT(synapseData.permanence > 0); - if (synapseData.permanence >= connectedPermanence - EPSILON) - { + if (synapseData.permanence >= connectedPermanence - EPSILON) { ++numActiveConnectedSynapsesForSegment[synapseData.segment]; } } @@ -390,26 +329,21 @@ void Connections::computeActivity( } void Connections::computeActivity( - vector& numActiveConnectedSynapsesForSegment, - vector& numActivePotentialSynapsesForSegment, - const vector& activePresynapticCells, - Permanence connectedPermanence) const -{ + vector &numActiveConnectedSynapsesForSegment, + vector &numActivePotentialSynapsesForSegment, + const vector &activePresynapticCells, + Permanence connectedPermanence) const { NTA_ASSERT(numActiveConnectedSynapsesForSegment.size() == segments_.size()); NTA_ASSERT(numActivePotentialSynapsesForSegment.size() == segments_.size()); - for (CellIdx cell : activePresynapticCells) - { - if (synapsesForPresynapticCell_.count(cell)) - { - for (Synapse synapse : synapsesForPresynapticCell_.at(cell)) - { - const SynapseData& synapseData = synapses_[synapse]; + for (CellIdx cell : activePresynapticCells) { + if (synapsesForPresynapticCell_.count(cell)) { + for (Synapse synapse : synapsesForPresynapticCell_.at(cell)) { + const SynapseData &synapseData = synapses_[synapse]; ++numActivePotentialSynapsesForSegment[synapseData.segment]; NTA_ASSERT(synapseData.permanence > 0); - if (synapseData.permanence >= connectedPermanence - EPSILON) - { + if (synapseData.permanence >= connectedPermanence - EPSILON) { ++numActiveConnectedSynapsesForSegment[synapseData.segment]; } } @@ -417,38 +351,31 @@ void Connections::computeActivity( } } -template -static void saveFloat_(std::ostream& outStream, FloatType v) -{ +template +static void saveFloat_(std::ostream &outStream, FloatType v) { outStream << std::setprecision(std::numeric_limits::max_digits10) - << v - << " "; + << v << " "; } -void Connections::save(std::ostream& outStream) const -{ +void Connections::save(std::ostream &outStream) const { // Write a starting marker. outStream << "Connections" << endl; outStream << Connections::VERSION << endl; - outStream << cells_.size() << " " - << endl; + outStream << cells_.size() << " " << endl; - for (CellData cellData : cells_) - { - const vector& segments = cellData.segments; + for (CellData cellData : cells_) { + const vector &segments = cellData.segments; outStream << segments.size() << " "; - for (Segment segment : segments) - { - const SegmentData& segmentData = segments_[segment]; + for (Segment segment : segments) { + const SegmentData &segmentData = segments_[segment]; - const vector& synapses = segmentData.synapses; + const vector &synapses = segmentData.synapses; outStream << synapses.size() << " "; - for (Synapse synapse : synapses) - { - const SynapseData& synapseData = synapses_[synapse]; + for (Synapse synapse : synapses) { + const SynapseData &synapseData = synapses_[synapse]; outStream << synapseData.presynapticCell << " "; saveFloat_(outStream, synapseData.permanence); } @@ -461,27 +388,23 @@ void Connections::save(std::ostream& outStream) const outStream << "~Connections" << endl; } -void Connections::write(ConnectionsProto::Builder& proto) const -{ +void Connections::write(ConnectionsProto::Builder &proto) const { proto.setVersion(Connections::VERSION); auto protoCells = proto.initCells(cells_.size()); - for (CellIdx i = 0; i < cells_.size(); ++i) - { - const vector& segments = cells_[i].segments; + for (CellIdx i = 0; i < cells_.size(); ++i) { + const vector &segments = cells_[i].segments; auto protoSegments = protoCells[i].initSegments(segments.size()); - for (SegmentIdx j = 0; j < (SegmentIdx)segments.size(); ++j) - { - const SegmentData& segmentData = segments_[segments[j]]; - const vector& synapses = segmentData.synapses; + for (SegmentIdx j = 0; j < (SegmentIdx)segments.size(); ++j) { + const SegmentData &segmentData = segments_[segments[j]]; + const vector &synapses = segmentData.synapses; auto protoSynapses = protoSegments[j].initSynapses(synapses.size()); - for (SynapseIdx k = 0; k < synapses.size(); ++k) - { - const SynapseData& synapseData = synapses_[synapses[k]]; + for (SynapseIdx k = 0; k < synapses.size(); ++k) { + const SynapseData &synapseData = synapses_[synapses[k]]; protoSynapses[k].setPresynapticCell(synapseData.presynapticCell); protoSynapses[k].setPermanence(synapseData.permanence); @@ -490,8 +413,7 @@ void Connections::write(ConnectionsProto::Builder& proto) const } } -void Connections::load(std::istream& inStream) -{ +void Connections::load(std::istream &inStream) { // Check the marker string marker; inStream >> marker; @@ -511,18 +433,15 @@ void Connections::load(std::istream& inStream) // This logic is complicated by the fact that old versions of the Connections // serialized "destroyed" segments and synapses, which we now ignore. cells_.resize(numCells); - for (UInt cell = 0; cell < numCells; cell++) - { - CellData& cellData = cells_[cell]; + for (UInt cell = 0; cell < numCells; cell++) { + CellData &cellData = cells_[cell]; UInt numSegments; inStream >> numSegments; - for (SegmentIdx j = 0; j < numSegments; j++) - { + for (SegmentIdx j = 0; j < numSegments; j++) { bool destroyedSegment = false; - if (version < 2) - { + if (version < 2) { inStream >> destroyedSegment; } @@ -531,8 +450,7 @@ void Connections::load(std::istream& inStream) SegmentData segmentData = {}; segmentData.cell = cell; - if (!destroyedSegment) - { + if (!destroyedSegment) { segment = segments_.size(); cellData.segments.push_back(segment); segments_.push_back(segmentData); @@ -543,23 +461,20 @@ void Connections::load(std::istream& inStream) UInt numSynapses; inStream >> numSynapses; - for (SynapseIdx k = 0; k < numSynapses; k++) - { + for (SynapseIdx k = 0; k < numSynapses; k++) { SynapseData synapseData = {}; inStream >> synapseData.presynapticCell; inStream >> synapseData.permanence; bool destroyedSynapse = false; - if (version < 2) - { + if (version < 2) { inStream >> destroyedSynapse; } - if (!destroyedSegment && !destroyedSynapse) - { + if (!destroyedSegment && !destroyedSynapse) { synapseData.segment = segment; - SegmentData& segmentData = segments_[segment]; + SegmentData &segmentData = segments_[segment]; Synapse synapse = {(UInt32)synapses_.size()}; segmentData.synapses.push_back(synapse); @@ -567,7 +482,7 @@ void Connections::load(std::istream& inStream) synapseOrdinals_.push_back(nextSynapseOrdinal_++); synapsesForPresynapticCell_[synapseData.presynapticCell].push_back( - synapse); + synapse); } } } @@ -577,8 +492,7 @@ void Connections::load(std::istream& inStream) NTA_CHECK(marker == "~Connections"); } -void Connections::read(ConnectionsProto::Reader& proto) -{ +void Connections::read(ConnectionsProto::Reader &proto) { // Check the saved version. UInt version = proto.getVersion(); NTA_CHECK(version <= Connections::VERSION); @@ -587,34 +501,29 @@ void Connections::read(ConnectionsProto::Reader& proto) initialize(protoCells.size()); - for (CellIdx cell = 0; cell < protoCells.size(); ++cell) - { - CellData& cellData = cells_[cell]; + for (CellIdx cell = 0; cell < protoCells.size(); ++cell) { + CellData &cellData = cells_[cell]; auto protoSegments = protoCells[cell].getSegments(); - for (SegmentIdx j = 0; j < (SegmentIdx)protoSegments.size(); ++j) - { + for (SegmentIdx j = 0; j < (SegmentIdx)protoSegments.size(); ++j) { Segment segment; { - const SegmentData segmentData = {vector(), - cell}; + const SegmentData segmentData = {vector(), cell}; segment = segments_.size(); cellData.segments.push_back(segment); segments_.push_back(segmentData); segmentOrdinals_.push_back(nextSegmentOrdinal_++); } - SegmentData& segmentData = segments_[segment]; + SegmentData &segmentData = segments_[segment]; auto protoSynapses = protoSegments[j].getSynapses(); - for (SynapseIdx k = 0; k < protoSynapses.size(); ++k) - { + for (SynapseIdx k = 0; k < protoSynapses.size(); ++k) { CellIdx presynapticCell = protoSynapses[k].getPresynapticCell(); SynapseData synapseData = {presynapticCell, - protoSynapses[k].getPermanence(), - segment}; + protoSynapses[k].getPermanence(), segment}; Synapse synapse = {(UInt32)synapses_.size()}; synapses_.push_back(synapseData); synapseOrdinals_.push_back(nextSynapseOrdinal_++); @@ -626,68 +535,55 @@ void Connections::read(ConnectionsProto::Reader& proto) } } -CellIdx Connections::numCells() const -{ - return cells_.size(); -} +CellIdx Connections::numCells() const { return cells_.size(); } -UInt Connections::numSegments() const -{ +UInt Connections::numSegments() const { return segments_.size() - destroyedSegments_.size(); } -UInt Connections::numSegments(CellIdx cell) const -{ +UInt Connections::numSegments(CellIdx cell) const { return cells_[cell].segments.size(); } -UInt Connections::numSynapses() const -{ +UInt Connections::numSynapses() const { return synapses_.size() - destroyedSynapses_.size(); } -UInt Connections::numSynapses(Segment segment) const -{ +UInt Connections::numSynapses(Segment segment) const { return segments_[segment].synapses.size(); } -bool Connections::operator==(const Connections &other) const -{ - if (cells_.size() != other.cells_.size()) return false; +bool Connections::operator==(const Connections &other) const { + if (cells_.size() != other.cells_.size()) + return false; - for (CellIdx i = 0; i < cells_.size(); ++i) - { - const CellData& cellData = cells_[i]; - const CellData& otherCellData = other.cells_[i]; + for (CellIdx i = 0; i < cells_.size(); ++i) { + const CellData &cellData = cells_[i]; + const CellData &otherCellData = other.cells_[i]; - if (cellData.segments.size() != otherCellData.segments.size()) - { + if (cellData.segments.size() != otherCellData.segments.size()) { return false; } - for (SegmentIdx j = 0; j < (SegmentIdx)cellData.segments.size(); ++j) - { + for (SegmentIdx j = 0; j < (SegmentIdx)cellData.segments.size(); ++j) { Segment segment = cellData.segments[j]; - const SegmentData& segmentData = segments_[segment]; + const SegmentData &segmentData = segments_[segment]; Segment otherSegment = otherCellData.segments[j]; - const SegmentData& otherSegmentData = other.segments_[otherSegment]; + const SegmentData &otherSegmentData = other.segments_[otherSegment]; if (segmentData.synapses.size() != otherSegmentData.synapses.size() || - segmentData.cell != otherSegmentData.cell) - { + segmentData.cell != otherSegmentData.cell) { return false; } - for (SynapseIdx k = 0; k < (SynapseIdx)segmentData.synapses.size(); ++k) - { + for (SynapseIdx k = 0; k < (SynapseIdx)segmentData.synapses.size(); ++k) { Synapse synapse = segmentData.synapses[k]; - const SynapseData& synapseData = synapses_[synapse]; + const SynapseData &synapseData = synapses_[synapse]; Synapse otherSynapse = otherSegmentData.synapses[k]; - const SynapseData& otherSynapseData = other.synapses_[otherSynapse]; + const SynapseData &otherSynapseData = other.synapses_[otherSynapse]; if (synapseData.presynapticCell != otherSynapseData.presynapticCell || - synapseData.permanence != otherSynapseData.permanence) - { + synapseData.permanence != otherSynapseData.permanence) { return false; } @@ -699,29 +595,28 @@ bool Connections::operator==(const Connections &other) const } if (synapsesForPresynapticCell_.size() != - other.synapsesForPresynapticCell_.size()) return false; + other.synapsesForPresynapticCell_.size()) + return false; for (auto i = synapsesForPresynapticCell_.begin(); - i != synapsesForPresynapticCell_.end(); ++i) - { - const vector& synapses = i->second; - const vector& otherSynapses = - other.synapsesForPresynapticCell_.at(i->first); + i != synapsesForPresynapticCell_.end(); ++i) { + const vector &synapses = i->second; + const vector &otherSynapses = + other.synapsesForPresynapticCell_.at(i->first); - if (synapses.size() != otherSynapses.size()) return false; + if (synapses.size() != otherSynapses.size()) + return false; - for (SynapseIdx j = 0; j < synapses.size(); ++j) - { + for (SynapseIdx j = 0; j < synapses.size(); ++j) { Synapse synapse = synapses[j]; - const SynapseData& synapseData = synapses_[synapse]; - const SegmentData& segmentData = segments_[synapseData.segment]; + const SynapseData &synapseData = synapses_[synapse]; + const SegmentData &segmentData = segments_[synapseData.segment]; Synapse otherSynapse = otherSynapses[j]; - const SynapseData& otherSynapseData = other.synapses_[otherSynapse]; - const SegmentData& otherSegmentData = - other.segments_[otherSynapseData.segment]; + const SynapseData &otherSynapseData = other.synapses_[otherSynapse]; + const SegmentData &otherSegmentData = + other.segments_[otherSynapseData.segment]; - if (segmentData.cell != otherSegmentData.cell) - { + if (segmentData.cell != otherSegmentData.cell) { return false; } } @@ -730,7 +625,6 @@ bool Connections::operator==(const Connections &other) const return true; } -bool Connections::operator!=(const Connections &other) const -{ +bool Connections::operator!=(const Connections &other) const { return !(*this == other); } diff --git a/src/nupic/algorithms/Connections.hpp b/src/nupic/algorithms/Connections.hpp index 7e7c8d854a..0512b4ee9a 100644 --- a/src/nupic/algorithms/Connections.hpp +++ b/src/nupic/algorithms/Connections.hpp @@ -31,561 +31,549 @@ #include #include -#include -#include #include #include +#include +#include + +namespace nupic { -namespace nupic -{ - - namespace algorithms - { - - namespace connections - { - typedef UInt32 CellIdx; - typedef UInt16 SegmentIdx; - typedef UInt16 SynapseIdx; - typedef Real32 Permanence; - typedef UInt32 Segment; - - /** - * Synapse struct used by Connections consumers. - * - * The Synapse struct is used to refer to a synapse. It contains a path to - * a SynapseData. - * - * @param flatIdx This synapse's index in flattened lists of all synapses. - */ - struct Synapse - { - UInt32 flatIdx; - - // Use Synapses as vector indices. - operator unsigned long() const { return flatIdx; }; - - private: - // The flatIdx ordering is not meaningful. - bool operator<=(const Synapse &other) const; - bool operator<(const Synapse &other) const; - bool operator>=(const Synapse &other) const; - bool operator>(const Synapse &other) const; - }; - - /** - * SynapseData class used in Connections. - * - * @b Description - * The SynapseData contains the underlying data for a synapse. - * - * @param presynapticCellIdx - * Cell that this synapse gets input from. - * - * @param permanence - * Permanence of synapse. - */ - struct SynapseData - { - CellIdx presynapticCell; - Permanence permanence; - Segment segment; - }; - - /** - * SegmentData class used in Connections. - * - * @b Description - * The SegmentData contains the underlying data for a Segment. - * - * @param synapses - * Synapses on this segment. - * - * @param cell - * The cell that this segment is on. - */ - struct SegmentData - { - std::vector synapses; - CellIdx cell; - }; - - /** - * CellData class used in Connections. - * - * @b Description - * The CellData contains the underlying data for a Cell. - * - * @param segments - * Segments on this cell. - * - */ - struct CellData - { - std::vector segments; - }; - - /** - * A base class for Connections event handlers. - * - * @b Description - * This acts as a plug-in point for logging / visualizations. - */ - class ConnectionsEventHandler - { - public: - virtual ~ConnectionsEventHandler() {} - - /** - * Called after a segment is created. - */ - virtual void onCreateSegment(Segment segment) {} - - /** - * Called before a segment is destroyed. - */ - virtual void onDestroySegment(Segment segment) {} - - /** - * Called after a synapse is created. - */ - virtual void onCreateSynapse(Synapse synapse) {} - - /** - * Called before a synapse is destroyed. - */ - virtual void onDestroySynapse(Synapse synapse) {} - - /** - * Called before a synapse's permanence is changed. - */ - virtual void onUpdateSynapsePermanence(Synapse synapse, - Permanence permanence) {} - }; - - /** - * Connections implementation in C++. - * - * @b Description - * The Connections class is a data structure that represents the - * connections of a collection of cells. It is used in the HTM - * learning algorithms to store and access data related to the - * connectivity of cells. - * - * Its main utility is to provide a common, optimized data structure - * that all HTM learning algorithms can use. It is flexible enough to - * support any learning algorithm that operates on a collection of cells. - * - * Each type of connection (proximal, distal basal, apical) should be - * represented by a different instantiation of this class. This class - * will help compute the activity along those connections due to active - * input cells. The responsibility for what effect that activity has on - * the cells and connections lies in the user of this class. - * - * This class is optimized to store connections between cells, and - * compute the activity of cells due to input over the connections. - * - * This class assigns each segment a unique "flatIdx" so that it's - * possible to use a simple vector to associate segments with values. - * Create a vector of length `connections.segmentFlatListLength()`, - * iterate over segments and update the vector at index `segment`. - * - */ - class Connections : public Serializable - { - public: - static const UInt16 VERSION = 2; - - /** - * Connections empty constructor. - * (Does not call `initialize`.) - */ - Connections() {}; - - /** - * Connections constructor. - * - * @param numCells Number of cells. - */ - Connections(CellIdx numCells); - - virtual ~Connections() {} - - /** - * Initialize connections. - * - * @param numCells Number of cells. - */ - void initialize(CellIdx numCells); - - /** - * Creates a segment on the specified cell. - * - * @param cell Cell to create segment on. - * - * @retval Created segment. - */ - Segment createSegment(CellIdx cell); - - /** - * Creates a synapse on the specified segment. - * - * @param segment Segment to create synapse on. - * @param presynapticCell Cell to synapse on. - * @param permanence Initial permanence of new synapse. - * - * @reval Created synapse. - */ - Synapse createSynapse(Segment segment, - CellIdx presynapticCell, - Permanence permanence); - - /** - * Destroys segment. - * - * @param segment Segment to destroy. - */ - void destroySegment(Segment segment); - - /** - * Destroys synapse. - * - * @param synapse Synapse to destroy. - */ - void destroySynapse(Synapse synapse); - - /** - * Updates a synapse's permanence. - * - * @param synapse Synapse to update. - * @param permanence New permanence. - */ - void updateSynapsePermanence(Synapse synapse, - Permanence permanence); - - /** - * Gets the segments for a cell. - * - * @param cell Cell to get segments for. - * - * @retval Segments on cell. - */ - const std::vector& segmentsForCell(CellIdx cell) const; - - /** - * Gets the synapses for a segment. - * - * @param segment Segment to get synapses for. - * - * @retval Synapses on segment. - */ - const std::vector& synapsesForSegment(Segment segment) const; - - /** - * Gets the cell that this segment is on. - * - * @param segment Segment to get the cell for. - * - * @retval Cell that this segment is on. - */ - CellIdx cellForSegment(Segment segment) const; - - /** - * Gets the index of this segment on its respective cell. - * - * @param segment Segment to get the idx for. - * - * @retval Index of the segment. - */ - SegmentIdx idxOnCellForSegment(Segment segment) const; - - /** - * Get the cell for each provided segment. - * - * @param segments - * The segments to query - * - * @param cells - * Output array with the same length as 'segments' - */ - void mapSegmentsToCells( - const Segment* segments_begin, const Segment* segments_end, - CellIdx* cells_begin) const; - - /** - * Gets the segment that this synapse is on. - * - * @param synapse Synapse to get Segment for. - * - * @retval Segment that this synapse is on. - */ - Segment segmentForSynapse(Synapse synapse) const; - - /** - * Gets the data for a segment. - * - * @param segment Segment to get data for. - * - * @retval Segment data. - */ - const SegmentData& dataForSegment(Segment segment) const; - - /** - * Gets the data for a synapse. - * - * @param synapse Synapse to get data for. - * - * @retval Synapse data. - */ - const SynapseData& dataForSynapse(Synapse synapse) const; - - /** - * Get the segment at the specified cell and offset. - * - * @param cell The cell that the segment is on. - * @param idx The index of the segment on the cell. - * - * @retval Segment - */ - Segment getSegment(CellIdx cell, SegmentIdx idx) const; - - /** - * Get the vector length needed to use segments as indices. - * - * @retval A vector length - */ - UInt32 segmentFlatListLength() const; - - /** - * Compare two segments. Returns true if a < b. - * - * Segments are ordered first by cell, then by their order on the cell. - * - * @param a Left segment to compare - * @param b Right segment to compare - * - * @retval true if a < b, false otherwise. - */ - bool compareSegments(Segment a, Segment b) const; - - /** - * Returns the synapses for the source cell that they synapse on. - * - * @param presynapticCell(int) Source cell index - * - * @return Synapse indices - */ - std::vector synapsesForPresynapticCell(CellIdx presynapticCell) - const; - - /** - * Compute the segment excitations for a vector of active presynaptic - * cells. - * - * The output vectors aren't grown or cleared. They must be - * preinitialized with the length returned by - * getSegmentFlatVectorLength(). - * - * @param numActiveConnectedSynapsesForSegment - * An output vector for active connected synapse counts per segment. - * - * @param numActivePotentialSynapsesForSegment - * An output vector for active potential synapse counts per segment. - * - * @param activePresynapticCells - * Active cells in the input. - * - * @param connectedPermanence - * Minimum permanence for a synapse to be "connected". - */ - void computeActivity( - std::vector& numActiveConnectedSynapsesForSegment, - std::vector& numActivePotentialSynapsesForSegment, - const std::vector& activePresynapticCells, - Permanence connectedPermanence) const; - - /** - * Compute the segment excitations for a single active presynaptic cell. - * - * The output vectors aren't grown or cleared. They must be - * preinitialized with the length returned by - * getSegmentFlatVectorLength(). - * - * @param numActiveConnectedSynapsesForSegment - * An output vector for active connected synapse counts per segment. - * - * @param numActivePotentialSynapsesForSegment - * An output vector for active potential synapse counts per segment. - * - * @param activePresynapticCells - * Active cells in the input. - * - * @param connectedPermanence - * Minimum permanence for a synapse to be "connected". - */ - void computeActivity( - std::vector& numActiveConnectedSynapsesForSegment, - std::vector& numActivePotentialSynapsesForSegment, - CellIdx activePresynapticCell, - Permanence connectedPermanence) const; - - // Serialization - - /** - * Saves serialized data to output stream. - */ - virtual void save(std::ostream& outStream) const; - - /** - * Writes serialized data to output stream. - */ - using Serializable::write; - - /** - * Writes serialized data to proto object. - */ - virtual void write(ConnectionsProto::Builder& proto) const override; - - /** - * Loads serialized data from input stream. - */ - virtual void load(std::istream& inStream); - - /** - * Reads serialized data from input stream. - */ - using Serializable::read; - - /** - * Reads serialized data from proto object. - */ - virtual void read(ConnectionsProto::Reader& proto) override; - - // Debugging - - /** - * Gets the number of cells. - * - * @retval Number of cells. - */ - CellIdx numCells() const; - - /** - * Gets the number of segments. - * - * @retval Number of segments. - */ - UInt numSegments() const; - - /** - * Gets the number of segments on a cell. - * - * @retval Number of segments. - */ - UInt numSegments(CellIdx cell) const; - - /** - * Gets the number of synapses. - * - * @retval Number of synapses. - */ - UInt numSynapses() const; - - /** - * Gets the number of synapses on a segment. - * - * @retval Number of synapses. - */ - UInt numSynapses(Segment segment) const; - - /** - * Comparison operator. - */ - bool operator==(const Connections &other) const; - bool operator!=(const Connections &other) const; - - /** - * Add a connections events handler. - * - * The Connections instance takes ownership of the eventHandlers - * object. Don't delete it. When calling from Python, call - * eventHandlers.__disown__() to avoid garbage-collecting the object - * while this instance is still using it. It will be deleted on - * `unsubscribe`. - * - * @param handler - * An object implementing the ConnectionsEventHandler interface - * - * @retval Unsubscribe token - */ - UInt32 subscribe(ConnectionsEventHandler* handler); - - /** - * Remove an event handler. - * - * @param token - * The return value of `subscribe`. - */ - void unsubscribe(UInt32 token); - - protected: - - /** - * Gets the synapse with the lowest permanence on the segment. - * - * @param segment Segment whose synapses to consider. - * - * @retval Synapse with the lowest permanence. - */ - Synapse minPermanenceSynapse_(Segment segment) const; - - /** - * Check whether this segment still exists on its cell. - * - * @param Segment - * - * @retval True if it's still in its cell's segment list. - */ - bool segmentExists_(Segment segment) const; - - /** - * Check whether this synapse still exists on its segment. - * - * @param Synapse - * - * @retval True if it's still in its segment's synapse list. - */ - bool synapseExists_(Synapse synapse) const; - - /** - * Remove a synapse from synapsesForPresynapticCell_. - * - * @param Synapse - */ - void removeSynapseFromPresynapticMap_(Synapse synapse); - - private: - std::vector cells_; - std::vector segments_; - std::vector destroyedSegments_; - std::vector synapses_; - std::vector destroyedSynapses_; - - // Extra bookkeeping for faster computing of segment activity. - std::map > synapsesForPresynapticCell_; - - std::vector segmentOrdinals_; - std::vector synapseOrdinals_; - UInt64 nextSegmentOrdinal_; - UInt64 nextSynapseOrdinal_; - - UInt32 nextEventToken_; - std::map eventHandlers_; - }; // end class Connections - - } // end namespace connections - - } // end namespace algorithms +namespace algorithms { + +namespace connections { +typedef UInt32 CellIdx; +typedef UInt16 SegmentIdx; +typedef UInt16 SynapseIdx; +typedef Real32 Permanence; +typedef UInt32 Segment; + +/** + * Synapse struct used by Connections consumers. + * + * The Synapse struct is used to refer to a synapse. It contains a path to + * a SynapseData. + * + * @param flatIdx This synapse's index in flattened lists of all synapses. + */ +struct Synapse { + UInt32 flatIdx; + + // Use Synapses as vector indices. + operator unsigned long() const { return flatIdx; }; + +private: + // The flatIdx ordering is not meaningful. + bool operator<=(const Synapse &other) const; + bool operator<(const Synapse &other) const; + bool operator>=(const Synapse &other) const; + bool operator>(const Synapse &other) const; +}; + +/** + * SynapseData class used in Connections. + * + * @b Description + * The SynapseData contains the underlying data for a synapse. + * + * @param presynapticCellIdx + * Cell that this synapse gets input from. + * + * @param permanence + * Permanence of synapse. + */ +struct SynapseData { + CellIdx presynapticCell; + Permanence permanence; + Segment segment; +}; + +/** + * SegmentData class used in Connections. + * + * @b Description + * The SegmentData contains the underlying data for a Segment. + * + * @param synapses + * Synapses on this segment. + * + * @param cell + * The cell that this segment is on. + */ +struct SegmentData { + std::vector synapses; + CellIdx cell; +}; + +/** + * CellData class used in Connections. + * + * @b Description + * The CellData contains the underlying data for a Cell. + * + * @param segments + * Segments on this cell. + * + */ +struct CellData { + std::vector segments; +}; + +/** + * A base class for Connections event handlers. + * + * @b Description + * This acts as a plug-in point for logging / visualizations. + */ +class ConnectionsEventHandler { +public: + virtual ~ConnectionsEventHandler() {} + + /** + * Called after a segment is created. + */ + virtual void onCreateSegment(Segment segment) {} + + /** + * Called before a segment is destroyed. + */ + virtual void onDestroySegment(Segment segment) {} + + /** + * Called after a synapse is created. + */ + virtual void onCreateSynapse(Synapse synapse) {} + + /** + * Called before a synapse is destroyed. + */ + virtual void onDestroySynapse(Synapse synapse) {} + + /** + * Called before a synapse's permanence is changed. + */ + virtual void onUpdateSynapsePermanence(Synapse synapse, + Permanence permanence) {} +}; + +/** + * Connections implementation in C++. + * + * @b Description + * The Connections class is a data structure that represents the + * connections of a collection of cells. It is used in the HTM + * learning algorithms to store and access data related to the + * connectivity of cells. + * + * Its main utility is to provide a common, optimized data structure + * that all HTM learning algorithms can use. It is flexible enough to + * support any learning algorithm that operates on a collection of cells. + * + * Each type of connection (proximal, distal basal, apical) should be + * represented by a different instantiation of this class. This class + * will help compute the activity along those connections due to active + * input cells. The responsibility for what effect that activity has on + * the cells and connections lies in the user of this class. + * + * This class is optimized to store connections between cells, and + * compute the activity of cells due to input over the connections. + * + * This class assigns each segment a unique "flatIdx" so that it's + * possible to use a simple vector to associate segments with values. + * Create a vector of length `connections.segmentFlatListLength()`, + * iterate over segments and update the vector at index `segment`. + * + */ +class Connections : public Serializable { +public: + static const UInt16 VERSION = 2; + + /** + * Connections empty constructor. + * (Does not call `initialize`.) + */ + Connections(){}; + + /** + * Connections constructor. + * + * @param numCells Number of cells. + */ + Connections(CellIdx numCells); + + virtual ~Connections() {} + + /** + * Initialize connections. + * + * @param numCells Number of cells. + */ + void initialize(CellIdx numCells); + + /** + * Creates a segment on the specified cell. + * + * @param cell Cell to create segment on. + * + * @retval Created segment. + */ + Segment createSegment(CellIdx cell); + + /** + * Creates a synapse on the specified segment. + * + * @param segment Segment to create synapse on. + * @param presynapticCell Cell to synapse on. + * @param permanence Initial permanence of new synapse. + * + * @reval Created synapse. + */ + Synapse createSynapse(Segment segment, CellIdx presynapticCell, + Permanence permanence); + + /** + * Destroys segment. + * + * @param segment Segment to destroy. + */ + void destroySegment(Segment segment); + + /** + * Destroys synapse. + * + * @param synapse Synapse to destroy. + */ + void destroySynapse(Synapse synapse); + + /** + * Updates a synapse's permanence. + * + * @param synapse Synapse to update. + * @param permanence New permanence. + */ + void updateSynapsePermanence(Synapse synapse, Permanence permanence); + + /** + * Gets the segments for a cell. + * + * @param cell Cell to get segments for. + * + * @retval Segments on cell. + */ + const std::vector &segmentsForCell(CellIdx cell) const; + + /** + * Gets the synapses for a segment. + * + * @param segment Segment to get synapses for. + * + * @retval Synapses on segment. + */ + const std::vector &synapsesForSegment(Segment segment) const; + + /** + * Gets the cell that this segment is on. + * + * @param segment Segment to get the cell for. + * + * @retval Cell that this segment is on. + */ + CellIdx cellForSegment(Segment segment) const; + + /** + * Gets the index of this segment on its respective cell. + * + * @param segment Segment to get the idx for. + * + * @retval Index of the segment. + */ + SegmentIdx idxOnCellForSegment(Segment segment) const; + + /** + * Get the cell for each provided segment. + * + * @param segments + * The segments to query + * + * @param cells + * Output array with the same length as 'segments' + */ + void mapSegmentsToCells(const Segment *segments_begin, + const Segment *segments_end, + CellIdx *cells_begin) const; + + /** + * Gets the segment that this synapse is on. + * + * @param synapse Synapse to get Segment for. + * + * @retval Segment that this synapse is on. + */ + Segment segmentForSynapse(Synapse synapse) const; + + /** + * Gets the data for a segment. + * + * @param segment Segment to get data for. + * + * @retval Segment data. + */ + const SegmentData &dataForSegment(Segment segment) const; + + /** + * Gets the data for a synapse. + * + * @param synapse Synapse to get data for. + * + * @retval Synapse data. + */ + const SynapseData &dataForSynapse(Synapse synapse) const; + + /** + * Get the segment at the specified cell and offset. + * + * @param cell The cell that the segment is on. + * @param idx The index of the segment on the cell. + * + * @retval Segment + */ + Segment getSegment(CellIdx cell, SegmentIdx idx) const; + + /** + * Get the vector length needed to use segments as indices. + * + * @retval A vector length + */ + UInt32 segmentFlatListLength() const; + + /** + * Compare two segments. Returns true if a < b. + * + * Segments are ordered first by cell, then by their order on the cell. + * + * @param a Left segment to compare + * @param b Right segment to compare + * + * @retval true if a < b, false otherwise. + */ + bool compareSegments(Segment a, Segment b) const; + + /** + * Returns the synapses for the source cell that they synapse on. + * + * @param presynapticCell(int) Source cell index + * + * @return Synapse indices + */ + std::vector + synapsesForPresynapticCell(CellIdx presynapticCell) const; + + /** + * Compute the segment excitations for a vector of active presynaptic + * cells. + * + * The output vectors aren't grown or cleared. They must be + * preinitialized with the length returned by + * getSegmentFlatVectorLength(). + * + * @param numActiveConnectedSynapsesForSegment + * An output vector for active connected synapse counts per segment. + * + * @param numActivePotentialSynapsesForSegment + * An output vector for active potential synapse counts per segment. + * + * @param activePresynapticCells + * Active cells in the input. + * + * @param connectedPermanence + * Minimum permanence for a synapse to be "connected". + */ + void + computeActivity(std::vector &numActiveConnectedSynapsesForSegment, + std::vector &numActivePotentialSynapsesForSegment, + const std::vector &activePresynapticCells, + Permanence connectedPermanence) const; + + /** + * Compute the segment excitations for a single active presynaptic cell. + * + * The output vectors aren't grown or cleared. They must be + * preinitialized with the length returned by + * getSegmentFlatVectorLength(). + * + * @param numActiveConnectedSynapsesForSegment + * An output vector for active connected synapse counts per segment. + * + * @param numActivePotentialSynapsesForSegment + * An output vector for active potential synapse counts per segment. + * + * @param activePresynapticCells + * Active cells in the input. + * + * @param connectedPermanence + * Minimum permanence for a synapse to be "connected". + */ + void + computeActivity(std::vector &numActiveConnectedSynapsesForSegment, + std::vector &numActivePotentialSynapsesForSegment, + CellIdx activePresynapticCell, + Permanence connectedPermanence) const; + + // Serialization + + /** + * Saves serialized data to output stream. + */ + virtual void save(std::ostream &outStream) const; + + /** + * Writes serialized data to output stream. + */ + using Serializable::write; + + /** + * Writes serialized data to proto object. + */ + virtual void write(ConnectionsProto::Builder &proto) const override; + + /** + * Loads serialized data from input stream. + */ + virtual void load(std::istream &inStream); + + /** + * Reads serialized data from input stream. + */ + using Serializable::read; + + /** + * Reads serialized data from proto object. + */ + virtual void read(ConnectionsProto::Reader &proto) override; + + // Debugging + + /** + * Gets the number of cells. + * + * @retval Number of cells. + */ + CellIdx numCells() const; + + /** + * Gets the number of segments. + * + * @retval Number of segments. + */ + UInt numSegments() const; + + /** + * Gets the number of segments on a cell. + * + * @retval Number of segments. + */ + UInt numSegments(CellIdx cell) const; + + /** + * Gets the number of synapses. + * + * @retval Number of synapses. + */ + UInt numSynapses() const; + + /** + * Gets the number of synapses on a segment. + * + * @retval Number of synapses. + */ + UInt numSynapses(Segment segment) const; + + /** + * Comparison operator. + */ + bool operator==(const Connections &other) const; + bool operator!=(const Connections &other) const; + + /** + * Add a connections events handler. + * + * The Connections instance takes ownership of the eventHandlers + * object. Don't delete it. When calling from Python, call + * eventHandlers.__disown__() to avoid garbage-collecting the object + * while this instance is still using it. It will be deleted on + * `unsubscribe`. + * + * @param handler + * An object implementing the ConnectionsEventHandler interface + * + * @retval Unsubscribe token + */ + UInt32 subscribe(ConnectionsEventHandler *handler); + + /** + * Remove an event handler. + * + * @param token + * The return value of `subscribe`. + */ + void unsubscribe(UInt32 token); + +protected: + /** + * Gets the synapse with the lowest permanence on the segment. + * + * @param segment Segment whose synapses to consider. + * + * @retval Synapse with the lowest permanence. + */ + Synapse minPermanenceSynapse_(Segment segment) const; + + /** + * Check whether this segment still exists on its cell. + * + * @param Segment + * + * @retval True if it's still in its cell's segment list. + */ + bool segmentExists_(Segment segment) const; + + /** + * Check whether this synapse still exists on its segment. + * + * @param Synapse + * + * @retval True if it's still in its segment's synapse list. + */ + bool synapseExists_(Synapse synapse) const; + + /** + * Remove a synapse from synapsesForPresynapticCell_. + * + * @param Synapse + */ + void removeSynapseFromPresynapticMap_(Synapse synapse); + +private: + std::vector cells_; + std::vector segments_; + std::vector destroyedSegments_; + std::vector synapses_; + std::vector destroyedSynapses_; + + // Extra bookkeeping for faster computing of segment activity. + std::map> synapsesForPresynapticCell_; + + std::vector segmentOrdinals_; + std::vector synapseOrdinals_; + UInt64 nextSegmentOrdinal_; + UInt64 nextSynapseOrdinal_; + + UInt32 nextEventToken_; + std::map eventHandlers_; +}; // end class Connections + +} // end namespace connections + +} // end namespace algorithms } // end namespace nupic diff --git a/src/nupic/algorithms/GaborNode.cpp b/src/nupic/algorithms/GaborNode.cpp index 35aaf73a3b..a77522b786 100644 --- a/src/nupic/algorithms/GaborNode.cpp +++ b/src/nupic/algorithms/GaborNode.cpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * This module implements efficient 2D image convolution (with gabor * filtering as the intended use case.) * It exports a single C function: @@ -37,21 +37,19 @@ * * This exported C function is expected to be used in conjunction * with ctypes wrappers around numpy array objects. - */ + */ -// Includes the correct Python.h. Must be the first header. +// Includes the correct Python.h. Must be the first header. +#include #include #include -#include - // Enable debugging //#define DEBUG 1 #include "GaborNode.hpp" - // if INIT_FROM_PYTHON is defined, this module can initialize // logging from a python system reference. This introduces // a dependency on PythonSystem, which is not included in the @@ -60,95 +58,91 @@ #error "gaborNode should not depend on Python in NuPIC2" #endif // INIT_FROM_PYTHON - - #ifdef __cplusplus extern "C" { -#endif - +#endif // Macros for accessing dimensionalities: -#define IMAGESET_ELEM(array, k) (((long int*)(array->pnDimensions))[k]) -#define IMAGESET_PLANES(array) IMAGESET_ELEM(array, 0) -#define IMAGESET_ROWS(array) IMAGESET_ELEM(array, 1) -#define IMAGESET_COLS(array) IMAGESET_ELEM(array, 2) -#define IMAGESET_STRIDE(array, k) (((long int*)(array->pnStrides))[k]) -#define IMAGESET_PLANESTRIDE(array) IMAGESET_STRIDE(array, 0) -#define IMAGESET_ROWSTRIDE(array) IMAGESET_STRIDE(array, 1) - -#define IMAGE_ELEM(array, k) (((long int*)(array->pnDimensions))[k]) -#define IMAGE_ROWS(array) IMAGE_ELEM(array, 0) -#define IMAGE_COLS(array) IMAGE_ELEM(array, 1) -#define IMAGE_STRIDE(array, k) (((long int*)(array->pnStrides))[k]) -#define IMAGE_ROWSTRIDE(array) IMAGE_STRIDE(array, 0) - -#define GABORSET_ELEM(array, k) (((long int*)(array->pnDimensions))[k]) -#define GABORSET_PLANES(array) GABORSET_ELEM(array, 0) -#define GABORSET_ROWS(array) GABORSET_ELEM(array, 2) -#define GABORSET_COLS(array) GABORSET_ELEM(array, 3) -#define GABORSET_STRIDE(array, k) (((long int*)(array->pnStrides))[k]) -#define GABORSET_PLANESTRIDE(array) GABORSET_STRIDE(array, 0) -#define GABORSET_SHIFTSTRIDE(array) GABORSET_STRIDE(array, 1) -#define GABORSET_ROWSTRIDE(array) GABORSET_STRIDE(array, 2) - -#define BBOX_ELEM(bbox, k) (((int*)(bbox->pData))[k]) -#define BBOX_LEFT(bbox) BBOX_ELEM(bbox, 0) -#define BBOX_TOP(bbox) BBOX_ELEM(bbox, 1) -#define BBOX_RIGHT(bbox) BBOX_ELEM(bbox, 2) -#define BBOX_BOTTOM(bbox) BBOX_ELEM(bbox, 3) -#define BBOX_WIDTH(bbox) (BBOX_RIGHT(bbox) - BBOX_LEFT(bbox)) -#define BBOX_HEIGHT(bbox) (BBOX_BOTTOM(bbox) - BBOX_TOP(bbox)) - -#define VECTOR_ELEM(array, k) (((long int*)(array->pnDimensions))[k]) -#define VECTOR_PLANES(array) VECTOR_ELEM(array, 0) +#define IMAGESET_ELEM(array, k) (((long int *)(array->pnDimensions))[k]) +#define IMAGESET_PLANES(array) IMAGESET_ELEM(array, 0) +#define IMAGESET_ROWS(array) IMAGESET_ELEM(array, 1) +#define IMAGESET_COLS(array) IMAGESET_ELEM(array, 2) +#define IMAGESET_STRIDE(array, k) (((long int *)(array->pnStrides))[k]) +#define IMAGESET_PLANESTRIDE(array) IMAGESET_STRIDE(array, 0) +#define IMAGESET_ROWSTRIDE(array) IMAGESET_STRIDE(array, 1) + +#define IMAGE_ELEM(array, k) (((long int *)(array->pnDimensions))[k]) +#define IMAGE_ROWS(array) IMAGE_ELEM(array, 0) +#define IMAGE_COLS(array) IMAGE_ELEM(array, 1) +#define IMAGE_STRIDE(array, k) (((long int *)(array->pnStrides))[k]) +#define IMAGE_ROWSTRIDE(array) IMAGE_STRIDE(array, 0) + +#define GABORSET_ELEM(array, k) (((long int *)(array->pnDimensions))[k]) +#define GABORSET_PLANES(array) GABORSET_ELEM(array, 0) +#define GABORSET_ROWS(array) GABORSET_ELEM(array, 2) +#define GABORSET_COLS(array) GABORSET_ELEM(array, 3) +#define GABORSET_STRIDE(array, k) (((long int *)(array->pnStrides))[k]) +#define GABORSET_PLANESTRIDE(array) GABORSET_STRIDE(array, 0) +#define GABORSET_SHIFTSTRIDE(array) GABORSET_STRIDE(array, 1) +#define GABORSET_ROWSTRIDE(array) GABORSET_STRIDE(array, 2) + +#define BBOX_ELEM(bbox, k) (((int *)(bbox->pData))[k]) +#define BBOX_LEFT(bbox) BBOX_ELEM(bbox, 0) +#define BBOX_TOP(bbox) BBOX_ELEM(bbox, 1) +#define BBOX_RIGHT(bbox) BBOX_ELEM(bbox, 2) +#define BBOX_BOTTOM(bbox) BBOX_ELEM(bbox, 3) +#define BBOX_WIDTH(bbox) (BBOX_RIGHT(bbox) - BBOX_LEFT(bbox)) +#define BBOX_HEIGHT(bbox) (BBOX_BOTTOM(bbox) - BBOX_TOP(bbox)) + +#define VECTOR_ELEM(array, k) (((long int *)(array->pnDimensions))[k]) +#define VECTOR_PLANES(array) VECTOR_ELEM(array, 0) // Macros for clipping //#define MIN(x, y) ((x) <= (y) ? (x) : (y)) //#define MAX(x, y) ((x) <= (y) ? (y) : (x)) // Macro for fast integer abs() -#define IABS32(x) (((x) ^ ((x) >> 31)) - ((x) >> 31)) -#define IABS64(x) (((x) ^ ((x) >> 63)) - ((x) >> 63)) +#define IABS32(x) (((x) ^ ((x) >> 31)) - ((x) >> 31)) +#define IABS64(x) (((x) ^ ((x) >> 63)) - ((x) >> 63)) // Macros for aligning integers to even values -#define ALIGN_2_FLOOR(value) (((value)>>1)<<1) -#define ALIGN_2_CEIL(value) ALIGN_2_FLOOR((value)+1) -#define ALIGN_4_FLOOR(value) (((value)>>2)<<2) -#define ALIGN_4_CEIL(value) ALIGN_4_FLOOR(((value)+3)) -#define ALIGN_8_FLOOR(value) (((value)>>3)<<3) -#define ALIGN_8_CEIL(value) ALIGN_8_FLOOR((value)+7) - +#define ALIGN_2_FLOOR(value) (((value) >> 1) << 1) +#define ALIGN_2_CEIL(value) ALIGN_2_FLOOR((value) + 1) +#define ALIGN_4_FLOOR(value) (((value) >> 2) << 2) +#define ALIGN_4_CEIL(value) ALIGN_4_FLOOR(((value) + 3)) +#define ALIGN_8_FLOOR(value) (((value) >> 3) << 3) +#define ALIGN_8_CEIL(value) ALIGN_8_FLOOR((value) + 7) // FUNCTION: _prepareInput_sweepOff() // // 1. Convert input image from float to integer32. // 2. If EDGE_MODE is SWEEPOFF, then add "padding pixels" // around the edges of the integrized input plane. -void _prepareInput_sweepOff(const NUMPY_ARRAY * psInput, - const NUMPY_ARRAY * psBufferIn, - int nHalfFilterDim, - const NUMPY_ARRAY * psBBox, - const NUMPY_ARRAY * psImageBox, - float fOffImageFillValue) { +void _prepareInput_sweepOff(const NUMPY_ARRAY *psInput, + const NUMPY_ARRAY *psBufferIn, int nHalfFilterDim, + const NUMPY_ARRAY *psBBox, + const NUMPY_ARRAY *psImageBox, + float fOffImageFillValue) { int i, j; int nFilterDim = nHalfFilterDim << 1; // Locate the start of input plane - const float * pfInput = (const float *)psInput->pData; + const float *pfInput = (const float *)psInput->pData; // Compute stride needed to proceed to next input row (in DWORDS) int nInputRowStride = IMAGE_ROWSTRIDE(psInput) / sizeof(*pfInput); // Locate start of output buffers // Note: the 'psBuffer' numpy array is assumed to be format 'int32' - int * pnOutput = (int *)psBufferIn->pData; + int *pnOutput = (int *)psBufferIn->pData; // Compute stride needed to proceed to next output row (in DWORDS) int nOutputRowStride = IMAGE_ROWSTRIDE(psBufferIn) / sizeof(*pnOutput); // Guard against buffer over-runs #ifdef DEBUG // Start/end of memory - const char * pDebugOutputSOMB = (const char*)(psBufferIn->pData); - const char * pDebugOutputEOMB = pDebugOutputSOMB + IMAGE_ROWSTRIDE(psBufferIn) * IMAGE_ROWS(psBufferIn); + const char *pDebugOutputSOMB = (const char *)(psBufferIn->pData); + const char *pDebugOutputEOMB = + pDebugOutputSOMB + IMAGE_ROWSTRIDE(psBufferIn) * IMAGE_ROWS(psBufferIn); #endif // DEBUG // Both the bounding box and the filler box will be expressed @@ -158,27 +152,31 @@ void _prepareInput_sweepOff(const NUMPY_ARRAY * psInput, // Our convention is that the bounding box expresses the range // of locations in the input image which are valid. // So we only need to convert pixel values within the bounding box - // plus a narrow band around the outside of the bounding box of + // plus a narrow band around the outside of the bounding box of // width equal to half the width of the filter dimension. // // We'll also expand our bounding box horizontally to make it line // up on a 16-byte boundary (four-pixel boundary) // We need to provide fill up to nFill* - int nFillLeft = BBOX_LEFT(psBBox); - int nFillTop = BBOX_TOP(psBBox); - int nFillRight = BBOX_RIGHT(psBBox) + nFilterDim; + int nFillLeft = BBOX_LEFT(psBBox); + int nFillTop = BBOX_TOP(psBBox); + int nFillRight = BBOX_RIGHT(psBBox) + nFilterDim; int nFillBottom = BBOX_BOTTOM(psBBox) + nFilterDim; // Shrink the pixel boxes to where we have actual pixels int nPixelLeft = MAX(nFillLeft, nHalfFilterDim); - int nPixelTop = MAX(nFillTop, nHalfFilterDim); - //int nPixelRight = MIN(nFillRight, BBOX_RIGHT(psBBox) + nHalfFilterDim); - //int nPixelBottom = MIN(nFillBottom, BBOX_BOTTOM(psBBox) + nHalfFilterDim); - //int nPixelRight = MIN(nFillRight, BBOX_RIGHT(psBBox) + nHalfFilterDim << 1); - //int nPixelBottom = MIN(nFillBottom, BBOX_BOTTOM(psBBox) + nHalfFilterDim << 1); - int nPixelRight = MIN(BBOX_RIGHT(psBBox) + nFilterDim, BBOX_RIGHT(psImageBox) + nHalfFilterDim); - int nPixelBottom = MIN(BBOX_BOTTOM(psBBox) + nFilterDim, BBOX_BOTTOM(psImageBox) + nHalfFilterDim); + int nPixelTop = MAX(nFillTop, nHalfFilterDim); + // int nPixelRight = MIN(nFillRight, BBOX_RIGHT(psBBox) + nHalfFilterDim); + // int nPixelBottom = MIN(nFillBottom, BBOX_BOTTOM(psBBox) + nHalfFilterDim); + // int nPixelRight = MIN(nFillRight, BBOX_RIGHT(psBBox) + nHalfFilterDim << + // 1); int nPixelBottom = MIN(nFillBottom, BBOX_BOTTOM(psBBox) + + // nHalfFilterDim + // << 1); + int nPixelRight = MIN(BBOX_RIGHT(psBBox) + nFilterDim, + BBOX_RIGHT(psImageBox) + nHalfFilterDim); + int nPixelBottom = MIN(BBOX_BOTTOM(psBBox) + nFilterDim, + BBOX_BOTTOM(psImageBox) + nHalfFilterDim); // If all of our assumptions have been met, then the following // conditions should hold (otherwise there is a bug somewhere): @@ -196,7 +194,7 @@ void _prepareInput_sweepOff(const NUMPY_ARRAY * psInput, // in our bounding boxes (like in 'constrained' mode); // but that is starting to get pretty complicated... - // Advance our output pointer to the beginning of the + // Advance our output pointer to the beginning of the // fill region. // Note: in a bid to get 16-byte alignment most of the time, // we are actually filling our "pure" fill rows at @@ -204,21 +202,21 @@ void _prepareInput_sweepOff(const NUMPY_ARRAY * psInput, // the right. pnOutput += nFillTop * nOutputRowStride; - // Compute numer of pure fill chunks of four + // Compute numer of pure fill chunks of four int nPureFillQuads = nOutputRowStride >> 2; int nOffImageFillValue = (int)fOffImageFillValue; // We need to pad the true input image with filler pixels. // We'll do the top rows of filler now: - for (j=nPixelTop - nFillTop; j; j--) { - // Fill each row - for (i=nPureFillQuads; i; i--) { + for (j = nPixelTop - nFillTop; j; j--) { + // Fill each row + for (i = nPureFillQuads; i; i--) { // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutput >= pDebugOutputSOMB); - NTA_ASSERT((const char*)&(pnOutput[3]) < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutput >= pDebugOutputSOMB); + NTA_ASSERT((const char *)&(pnOutput[3]) < pDebugOutputEOMB); #endif // DEBUG *pnOutput++ = nOffImageFillValue; @@ -232,14 +230,15 @@ void _prepareInput_sweepOff(const NUMPY_ARRAY * psInput, // advance to get from the end of row N to the beginning // of row N+1 int nPixelWidth = nPixelRight - nPixelLeft; - int nInputRowAdvance = nInputRowStride - nPixelWidth; + int nInputRowAdvance = nInputRowStride - nPixelWidth; int nOutputRowAdvance = nOutputRowStride - (nFillRight - nFillLeft); // Advance our pointers and row counts to skip past // the rows on top of the bounding box, and align // our pointers with the left edge of the bounding box. - pfInput += nInputRowStride * (nPixelTop - nHalfFilterDim) + (nPixelLeft - nHalfFilterDim); - //pnOutput += nOutputRowStride * nPixelTop + nPixelLeft; + pfInput += nInputRowStride * (nPixelTop - nHalfFilterDim) + + (nPixelLeft - nHalfFilterDim); + // pnOutput += nOutputRowStride * nPixelTop + nPixelLeft; // Decide how many rows to convert int nOutputRows = nPixelBottom - nPixelTop; @@ -256,7 +255,8 @@ void _prepareInput_sweepOff(const NUMPY_ARRAY * psInput, int nNumPixelsPerRow = nPixelRight - nPixelLeft - nNumPrepPixels; int nPixelQuadsPerRow = nNumPixelsPerRow >> 2; int nPixelLeftovers = nNumPixelsPerRow - (nPixelQuadsPerRow << 2); - NTA_ASSERT((nNumPrepPixels + nPixelLeftovers + (nPixelQuadsPerRow << 2)) == (nPixelRight - nPixelLeft)); + NTA_ASSERT((nNumPrepPixels + nPixelLeftovers + (nPixelQuadsPerRow << 2)) == + (nPixelRight - nPixelLeft)); // How many pixels to fill on the left and right sides int nNumPreFills = nPixelLeft - nFillLeft; @@ -266,39 +266,39 @@ void _prepareInput_sweepOff(const NUMPY_ARRAY * psInput, pnOutput += nFillLeft; // Process each output location - for (j=nOutputRows; j; j--) { + for (j = nOutputRows; j; j--) { // Do pre-filling (on the left side) - for (i=nNumPreFills; i; i--) { + for (i = nNumPreFills; i; i--) { // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutput >= pDebugOutputSOMB); - NTA_ASSERT((const char*)&(pnOutput[0]) < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutput >= pDebugOutputSOMB); + NTA_ASSERT((const char *)&(pnOutput[0]) < pDebugOutputEOMB); #endif // DEBUG *pnOutput++ = nOffImageFillValue; } // Do prep pixel conversions to get ourselves aligned - for (i=nNumPrepPixels; i; i--) { + for (i = nNumPrepPixels; i; i--) { // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutput >= pDebugOutputSOMB); - NTA_ASSERT((const char*)&(pnOutput[0]) < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutput >= pDebugOutputSOMB); + NTA_ASSERT((const char *)&(pnOutput[0]) < pDebugOutputEOMB); #endif // DEBUG *pnOutput++ = (int)(*pfInput++); } // Do pixel conversion - for (i=nPixelQuadsPerRow; i; i--) { + for (i = nPixelQuadsPerRow; i; i--) { // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutput >= pDebugOutputSOMB); - NTA_ASSERT((const char*)&(pnOutput[3]) < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutput >= pDebugOutputSOMB); + NTA_ASSERT((const char *)&(pnOutput[3]) < pDebugOutputEOMB); #endif // DEBUG // Do four pixel conversions @@ -309,52 +309,52 @@ void _prepareInput_sweepOff(const NUMPY_ARRAY * psInput, } // Do left-overs that didn't fit in a quad - for (i=nPixelLeftovers; i; i--) { + for (i = nPixelLeftovers; i; i--) { // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutput >= pDebugOutputSOMB); - NTA_ASSERT((const char*)&(pnOutput[0]) < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutput >= pDebugOutputSOMB); + NTA_ASSERT((const char *)&(pnOutput[0]) < pDebugOutputEOMB); #endif // DEBUG *pnOutput++ = (int)(*pfInput++); } // Do post-filling (on the right side) - for (i=nNumPostFills; i; i--) { + for (i = nNumPostFills; i; i--) { // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutput >= pDebugOutputSOMB); - NTA_ASSERT((const char*)&(pnOutput[0]) < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutput >= pDebugOutputSOMB); + NTA_ASSERT((const char *)&(pnOutput[0]) < pDebugOutputEOMB); #endif // DEBUG *pnOutput++ = nOffImageFillValue; } - + // Advance to next rows - pfInput += nInputRowAdvance; + pfInput += nInputRowAdvance; pnOutput += nOutputRowAdvance; } - // At this point, our output buffer pointer is + // At this point, our output buffer pointer is // on the correct row (the first row to be filled - // below the pixel bounding box), but it is advanced - // 'nFillLeft' pixels to the right (i.e., it is + // below the pixel bounding box), but it is advanced + // 'nFillLeft' pixels to the right (i.e., it is // on the 'nFillLeft'th column). So we need to // move it back to the beginning of the row // (i.e., to the 0th column) pnOutput -= nFillLeft; // Fill the bottom rows - for (j=nFillBottom - nPixelBottom; j; j--) { - // Fill each row - for (i=nPureFillQuads; i; i--) { + for (j = nFillBottom - nPixelBottom; j; j--) { + // Fill each row + for (i = nPureFillQuads; i; i--) { // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutput >= pDebugOutputSOMB); - NTA_ASSERT((const char*)&(pnOutput[3]) < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutput >= pDebugOutputSOMB); + NTA_ASSERT((const char *)&(pnOutput[3]) < pDebugOutputEOMB); #endif // DEBUG *pnOutput++ = nOffImageFillValue; @@ -367,31 +367,29 @@ void _prepareInput_sweepOff(const NUMPY_ARRAY * psInput, // At this point, our output buffer pointer should be // exactly positioned at the edge of our memory #ifdef DEBUG - NTA_ASSERT((const char*)pnOutput <= pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutput <= pDebugOutputEOMB); #endif // DEBUG } - // FUNCTION: _prepareInput_constrained() // // 1. Convert input image from float to integer32. // 2. If EDGE_MODE is SWEEPOFF, then add "padding pixels" // around the edges of the integrized input plane. -void _prepareInput_constrained(const NUMPY_ARRAY * psInput, - const NUMPY_ARRAY * psBufferIn, - int nHalfFilterDim, - const NUMPY_ARRAY * psBBox, - const NUMPY_ARRAY * psImageBox) { +void _prepareInput_constrained(const NUMPY_ARRAY *psInput, + const NUMPY_ARRAY *psBufferIn, + int nHalfFilterDim, const NUMPY_ARRAY *psBBox, + const NUMPY_ARRAY *psImageBox) { int i, j; // Locate the start of input plane - const float * pfInput = (const float *)psInput->pData; + const float *pfInput = (const float *)psInput->pData; // Compute stride needed to proceed to next input row (in DWORDS) int nInputRowStride = IMAGE_ROWSTRIDE(psInput) / sizeof(*pfInput); // Locate start of output buffers // Note: the 'psBuffer' numpy array is assumed to be format 'int32' - int * pnOutput = (int *)psBufferIn->pData; + int *pnOutput = (int *)psBufferIn->pData; // Compute stride needed to proceed to next output row (in DWORDS) int nOutputRowStride = IMAGE_ROWSTRIDE(psBufferIn) / sizeof(*pnOutput); @@ -399,15 +397,18 @@ void _prepareInput_constrained(const NUMPY_ARRAY * psInput, // Our convention is that the bounding box expresses the range // of locations in the input image which are valid. // So we only need to convert pixel values within the bounding box - // plus a narrow band around the outside of the bounding box of + // plus a narrow band around the outside of the bounding box of // width equal to half the width of the filter dimension. // // We'll also expand our bounding box horizontally to make it line // up on a 16-byte boundary (four-pixel boundary) - int nBoxLeft = MAX(ALIGN_4_FLOOR(BBOX_LEFT(psBBox) - nHalfFilterDim), BBOX_LEFT(psImageBox)); - int nBoxRight = MIN(BBOX_RIGHT(psBBox) + nHalfFilterDim, BBOX_RIGHT(psImageBox)); - int nBoxTop = MAX(BBOX_TOP(psBBox) - nHalfFilterDim, BBOX_TOP(psImageBox)); - int nBoxBottom = MIN(BBOX_BOTTOM(psBBox) + nHalfFilterDim, BBOX_BOTTOM(psImageBox)); + int nBoxLeft = MAX(ALIGN_4_FLOOR(BBOX_LEFT(psBBox) - nHalfFilterDim), + BBOX_LEFT(psImageBox)); + int nBoxRight = + MIN(BBOX_RIGHT(psBBox) + nHalfFilterDim, BBOX_RIGHT(psImageBox)); + int nBoxTop = MAX(BBOX_TOP(psBBox) - nHalfFilterDim, BBOX_TOP(psImageBox)); + int nBoxBottom = + MIN(BBOX_BOTTOM(psBBox) + nHalfFilterDim, BBOX_BOTTOM(psImageBox)); // If all of our assumptions have been met, then the following // conditions should hold (otherwise there is a bug somewhere): @@ -417,19 +418,19 @@ void _prepareInput_constrained(const NUMPY_ARRAY * psInput, NTA_ASSERT(nBoxBottom <= IMAGE_ROWS(psInput)); // Make sure we got the alignment we wanted - NTA_ASSERT(nBoxLeft % 4 == 0); + NTA_ASSERT(nBoxLeft % 4 == 0); // Compute the number of DWORDS (pixels) we'll need to // advance to get from the end of row N to the beginning // of row N+1 int nBoxWidth = nBoxRight - nBoxLeft; - int nInputRowAdvance = nInputRowStride - nBoxWidth; + int nInputRowAdvance = nInputRowStride - nBoxWidth; int nOutputRowAdvance = nOutputRowStride - nBoxWidth; // Advance our pointers and row counts to skip past // the rows on top of the bounding box, and align // our pointers with the left edge of the bounding box. - pfInput += nInputRowStride * nBoxTop + nBoxLeft; + pfInput += nInputRowStride * nBoxTop + nBoxLeft; pnOutput += nOutputRowStride * nBoxTop + nBoxLeft; // Decide how many rows to convert @@ -446,10 +447,10 @@ void _prepareInput_constrained(const NUMPY_ARRAY * psInput, int nLeftovers = nBoxWidth - (nQuadsPerRow << 2); // Process each output location - for (j=nOutputRows; j; j--) { + for (j = nOutputRows; j; j--) { // Do four pixel conversions - for (i=nQuadsPerRow; i; i--) { + for (i = nQuadsPerRow; i; i--) { *pnOutput++ = (int)(*pfInput++); *pnOutput++ = (int)(*pfInput++); *pnOutput++ = (int)(*pfInput++); @@ -457,45 +458,33 @@ void _prepareInput_constrained(const NUMPY_ARRAY * psInput, } // Convert any leftovers - for (i=nLeftovers; i; i--) + for (i = nLeftovers; i; i--) *pnOutput++ = (int)(*pfInput++); - + // Advance to next rows - pfInput += nInputRowAdvance; + pfInput += nInputRowAdvance; pnOutput += nOutputRowAdvance; } } - // FUNCTION: _prepareInput() // // 1. Convert input image from float to integer32. // 2. If EDGE_MODE is SWEEPOFF, then add "padding pixels" // around the edges of the integrized input plane. -void _prepareInput(const NUMPY_ARRAY * psInput, - const NUMPY_ARRAY * psBufferIn, - int nHalfFilterDim, - const NUMPY_ARRAY * psBBox, - const NUMPY_ARRAY * psImageBox, - EDGE_MODE eEdgeMode, +void _prepareInput(const NUMPY_ARRAY *psInput, const NUMPY_ARRAY *psBufferIn, + int nHalfFilterDim, const NUMPY_ARRAY *psBBox, + const NUMPY_ARRAY *psImageBox, EDGE_MODE eEdgeMode, float fOffImageFillValue) { if (eEdgeMode == EDGE_MODE_CONSTRAINED) - _prepareInput_constrained(psInput, - psBufferIn, - nHalfFilterDim, - psBBox, + _prepareInput_constrained(psInput, psBufferIn, nHalfFilterDim, psBBox, psImageBox); else - _prepareInput_sweepOff(psInput, - psBufferIn, - nHalfFilterDim, - psBBox, - psImageBox, - fOffImageFillValue); + _prepareInput_sweepOff(psInput, psBufferIn, nHalfFilterDim, psBBox, + psImageBox, fOffImageFillValue); } - // It is useful for debugging to set this to // a non-zero value so see where the null responses // were filled in. @@ -503,22 +492,19 @@ void _prepareInput(const NUMPY_ARRAY * psInput, //#define NULL_RESPONSE (64 << 12) // Flags indicating which statistics to keep on the fly. -#define STATS_NONE 0x00 -#define STATS_MAX_ABS 0x01 -#define STATS_MAX_MIN 0x02 -#define STATS_SUM_ABS 0x04 -#define STATS_SUM_POS_NEG 0x08 -#define STATS_MAX (STATS_MAX_ABS|STATS_MAX_MIN) -#define STATS_MEAN (STATS_SUM_ABS|STATS_SUM_POS_NEG) -#define STATS_SINGLE (STATS_MAX_ABS|STATS_SUM_ABS) -#define STATS_DUAL (STATS_MAX_MIN|STATS_SUM_POS_NEG) - - -void _computeNormalizers(int & nStatPosGrand, - int & nStatNegGrand, +#define STATS_NONE 0x00 +#define STATS_MAX_ABS 0x01 +#define STATS_MAX_MIN 0x02 +#define STATS_SUM_ABS 0x04 +#define STATS_SUM_POS_NEG 0x08 +#define STATS_MAX (STATS_MAX_ABS | STATS_MAX_MIN) +#define STATS_MEAN (STATS_SUM_ABS | STATS_SUM_POS_NEG) +#define STATS_SINGLE (STATS_MAX_ABS | STATS_SUM_ABS) +#define STATS_DUAL (STATS_MAX_MIN | STATS_SUM_POS_NEG) + +void _computeNormalizers(int &nStatPosGrand, int &nStatNegGrand, unsigned int nStatFlags, - NORMALIZE_METHOD eNormalizeMethod, - int nNumPixels) { + NORMALIZE_METHOD eNormalizeMethod, int nNumPixels) { // "Fixed" normalization mode just normalizes by the maximum // input value, which for 8-bit images is 255.0 @@ -527,7 +513,7 @@ void _computeNormalizers(int & nStatPosGrand, nStatPosGrand = 255; nStatNegGrand = -nStatPosGrand; } - // If we are performing 'mean' operations, then + // If we are performing 'mean' operations, then // we need to divide by the total number of pixels else if (nStatFlags & STATS_MEAN) { if (nNumPixels) { @@ -544,7 +530,6 @@ void _computeNormalizers(int & nStatPosGrand, } } - // FUNCTION: _doConvolution_alpha() // 1. Convolve integerized input image (in bufferIn) against // each filter in gabor filter bank, storing the result @@ -553,29 +538,24 @@ void _computeNormalizers(int & nStatPosGrand, // neccessary statistics for use in normalization // during Pass II. // For case where valid alpha is provided instead of valid box. -void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, - const NUMPY_ARRAY * psBufferOut, - const NUMPY_ARRAY * psGaborBank, - const NUMPY_ARRAY * psAlpha, - const BBOX * psInputBox, - const BBOX * psOutputBox, - PHASE_MODE ePhaseMode, - NORMALIZE_METHOD eNormalizeMethod, - NORMALIZE_MODE eNormalizeMode, - unsigned int anStatPosGrand[], - unsigned int anStatNegGrand[] ) { +void _doConvolution_alpha( + const NUMPY_ARRAY *psBufferIn, const NUMPY_ARRAY *psBufferOut, + const NUMPY_ARRAY *psGaborBank, const NUMPY_ARRAY *psAlpha, + const BBOX *psInputBox, const BBOX *psOutputBox, PHASE_MODE ePhaseMode, + NORMALIZE_METHOD eNormalizeMethod, NORMALIZE_MODE eNormalizeMode, + unsigned int anStatPosGrand[], unsigned int anStatNegGrand[]) { int i, j, jj; int nFilterIndex; int nResponse; int nNumPixels = 0; - const int * pnInput = nullptr; - const int * pnInputRow = nullptr; - const int * pnInputPtr = nullptr; - const int * pnGaborPtr = nullptr; - const float * pfAlpha = nullptr; - const float * pfAlphaRow = nullptr; - const float * pfAlphaRowPtr = nullptr; - int * pnOutputRow = nullptr; + const int *pnInput = nullptr; + const int *pnInputRow = nullptr; + const int *pnInputPtr = nullptr; + const int *pnGaborPtr = nullptr; + const float *pfAlpha = nullptr; + const float *pfAlphaRow = nullptr; + const float *pfAlphaRowPtr = nullptr; + int *pnOutputRow = nullptr; // Decide which statistics to keep. // There are four cases: @@ -589,7 +569,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // something crazy didn't happen, like the minimum // response was positive, or vice versa. // 3. Record mean response in single-phase mode. - // For this, we need to accummulate the sume of + // For this, we need to accummulate the sume of // the absolute values of the responses. // 4. Record mean response in dual-phase mode. // For this, we need to keep the sum of all @@ -626,8 +606,9 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, } // We'll need to know the total number of pixels nNumPixels = 0; - //nNumPixels = (psOutputBox->nRight - psOutputBox->nLeft) * (psOutputBox->nBottom - psOutputBox->nTop); - //if (eNormalizeMode == NORMALIZE_MODE_GLOBAL) + // nNumPixels = (psOutputBox->nRight - psOutputBox->nLeft) * + // (psOutputBox->nBottom - psOutputBox->nTop); if (eNormalizeMode == + // NORMALIZE_MODE_GLOBAL) // nNumPixels *= GABORSET_PLANES(psGaborBank); break; // No normalization needed @@ -654,26 +635,29 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, int nShrinkageY = (psInputBox->nBottom - psOutputBox->nBottom) >> 1; // Locate start of correct Gabor filter - const int * pnFilterBase = (const int *)psGaborBank->pData; + const int *pnFilterBase = (const int *)psGaborBank->pData; // Locate start of correct output plane - int * pnOutputBase = (int *)psBufferOut->pData; - int nOutputRowStride = IMAGESET_ROWSTRIDE(psBufferOut) / sizeof(*pnOutputBase); + int *pnOutputBase = (int *)psBufferOut->pData; + int nOutputRowStride = + IMAGESET_ROWSTRIDE(psBufferOut) / sizeof(*pnOutputBase); // Guard against buffer over-runs #ifdef DEBUG // Start/end of memory - const char * pDebugOutputSOMB = (const char*)(psBufferOut->pData); - const char * pDebugOutputEOMB = pDebugOutputSOMB + IMAGESET_PLANESTRIDE(psBufferOut) * IMAGESET_PLANES(psBufferOut); + const char *pDebugOutputSOMB = (const char *)(psBufferOut->pData); + const char *pDebugOutputEOMB = + pDebugOutputSOMB + + IMAGESET_PLANESTRIDE(psBufferOut) * IMAGESET_PLANES(psBufferOut); #endif // DEBUG // Locate the start of input plane - const int * pnInputBase = (const int *)psBufferIn->pData; + const int *pnInputBase = (const int *)psBufferIn->pData; int nInputRowStride = IMAGE_ROWSTRIDE(psBufferIn) / sizeof(*pnInputBase); int nInputRowAdvance = nInputRowStride - nFilterDim; // Locate the start of alpha plane - const float * pfAlphaBase = (const float *)psAlpha->pData; + const float *pfAlphaBase = (const float *)psAlpha->pData; int nAlphaRowStride = IMAGE_ROWSTRIDE(psAlpha) / sizeof(*pfAlphaBase); // Take into account bounding box suppression @@ -690,7 +674,8 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, nStatPosGrand = 0; nStatNegGrand = 0; - for (nFilterIndex=0; nFilterIndexnTop + psInputBox->nLeft; + pnInput = + pnInputBase + nInputRowStride * psInputBox->nTop + psInputBox->nLeft; // Process each plane of output - const int * pnFilter = pnFilterBase; - int * pnOutput = pnOutputBase; + const int *pnFilter = pnFilterBase; + int *pnOutput = pnOutputBase; // Zero out any rows above the bounding box pnOutput += nNumBlankTopRows * nOutputRowStride; @@ -712,7 +698,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, pfAlpha = pfAlphaBase + (nNumBlankTopRows + nShrinkageY) * nAlphaRowStride; // Process each row within the bounding box (vertically) - for (j=nOutputRows; j; j--) { + for (j = nOutputRows; j; j--) { // Initialize row-wise accummulator nStatPosRow = 0; @@ -724,7 +710,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // Skip any pixels on this row that lie on left side of bounding box pnOutputRow = pnOutput + psOutputBox->nLeft; - // Alpha + // Alpha pfAlphaRow = pfAlpha + (psOutputBox->nLeft + nShrinkageX); // We'll need to know the total number of pixels if we @@ -732,16 +718,16 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // determined on the first pass through the source // image (i.e., when generating the response to the // first gabor filter.) - if (nFilterIndex == 0 && - (eNormalizeMethod == NORMALIZE_METHOD_MEAN || - eNormalizeMethod == NORMALIZE_METHOD_MEANPOWER) && + if (nFilterIndex == 0 && + (eNormalizeMethod == NORMALIZE_METHOD_MEAN || + eNormalizeMethod == NORMALIZE_METHOD_MEANPOWER) && eNormalizeMode == NORMALIZE_MODE_GLOBAL) { // Quickly run across the alpha channel to check how // many positive pixels it has in this row. pfAlphaRowPtr = pfAlphaRow; - for (i=nOutputCols; i; i--) { + for (i = nOutputCols; i; i--) { if (*pfAlphaRowPtr++) - nNumPixels ++; + nNumPixels++; } } @@ -755,7 +741,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // Filter: 5x5 case 5: - for (i=nOutputCols; i; i--) { + for (i = nOutputCols; i; i--) { // Process each row in the filter mask pnGaborPtr = pnFilter; pnInputPtr = pnInputRow; @@ -767,7 +753,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // lies within our valid alpha channel. if (*pfAlphaRow++) { - for (jj=nFilterDim; jj; jj--) { + for (jj = nFilterDim; jj; jj--) { // First 128-bits nResponse += (*pnGaborPtr++) * (*pnInputPtr++); @@ -789,8 +775,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, nStatPosGrand = MAX(nStatPosGrand, nResponse); else nStatNegGrand = MIN(nStatNegGrand, nResponse); - } - else if (nStatFlags & STATS_SUM_ABS) + } else if (nStatFlags & STATS_SUM_ABS) nStatPosRow += IABS32(nResponse); else if (nStatFlags & STATS_SUM_POS_NEG) { if (nResponse >= 0) @@ -802,8 +787,8 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutputRow >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pnOutputRow < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutputRow >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pnOutputRow < pDebugOutputEOMB); #endif // DEBUG // Apply abs() and clipping, and then advance to next location @@ -819,7 +804,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // Filter: 7x7 case 7: - for (i=nOutputCols; i; i--) { + for (i = nOutputCols; i; i--) { // Process each row in the filter mask pnGaborPtr = pnFilter; pnInputPtr = pnInputRow; @@ -831,7 +816,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // lies within our valid alpha channel. if (*pfAlphaRow++) { - for (jj=nFilterDim; jj; jj--) { + for (jj = nFilterDim; jj; jj--) { // First 128-bits nResponse += (*pnGaborPtr++) * (*pnInputPtr++); @@ -855,8 +840,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, nStatPosGrand = MAX(nStatPosGrand, nResponse); else nStatNegGrand = MIN(nStatNegGrand, nResponse); - } - else if (nStatFlags & STATS_SUM_ABS) + } else if (nStatFlags & STATS_SUM_ABS) nStatPosRow += IABS32(nResponse); else if (nStatFlags & STATS_SUM_POS_NEG) { if (nResponse >= 0) @@ -868,8 +852,8 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutputRow >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pnOutputRow < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutputRow >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pnOutputRow < pDebugOutputEOMB); #endif // DEBUG // Apply abs() and clipping, and then advance to next location @@ -885,7 +869,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // Filter: 9x9 case 9: - for (i=nOutputCols; i; i--) { + for (i = nOutputCols; i; i--) { // Process each row in the filter mask pnGaborPtr = pnFilter; pnInputPtr = pnInputRow; @@ -897,7 +881,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // lies within our valid alpha channel. if (*pfAlphaRow++) { - for (jj=nFilterDim; jj; jj--) { + for (jj = nFilterDim; jj; jj--) { // First 128-bits nResponse += (*pnGaborPtr++) * (*pnInputPtr++); @@ -927,8 +911,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, nStatPosGrand = MAX(nStatPosGrand, nResponse); else nStatNegGrand = MIN(nStatNegGrand, nResponse); - } - else if (nStatFlags & STATS_SUM_ABS) + } else if (nStatFlags & STATS_SUM_ABS) nStatPosRow += IABS32(nResponse); else if (nStatFlags & STATS_SUM_POS_NEG) { if (nResponse >= 0) @@ -940,8 +923,8 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutputRow >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pnOutputRow < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutputRow >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pnOutputRow < pDebugOutputEOMB); #endif // DEBUG // Apply abs() and clipping, and then advance to next location @@ -957,7 +940,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // Filter: 11x11 case 11: - for (i=nOutputCols; i; i--) { + for (i = nOutputCols; i; i--) { // Process each row in the filter mask pnGaborPtr = pnFilter; pnInputPtr = pnInputRow; @@ -969,7 +952,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // lies within our valid alpha channel. if (*pfAlphaRow++) { - for (jj=nFilterDim; jj; jj--) { + for (jj = nFilterDim; jj; jj--) { // First 128-bits nResponse += (*pnGaborPtr++) * (*pnInputPtr++); @@ -1001,8 +984,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, nStatPosGrand = MAX(nStatPosGrand, nResponse); else nStatNegGrand = MIN(nStatNegGrand, nResponse); - } - else if (nStatFlags & STATS_SUM_ABS) + } else if (nStatFlags & STATS_SUM_ABS) nStatPosRow += IABS32(nResponse); else if (nStatFlags & STATS_SUM_POS_NEG) { if (nResponse >= 0) @@ -1014,8 +996,8 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutputRow >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pnOutputRow < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutputRow >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pnOutputRow < pDebugOutputEOMB); #endif // DEBUG // Apply abs() and clipping, and then advance to next location @@ -1031,7 +1013,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // Filter: 13x13 case 13: - for (i=nOutputCols; i; i--) { + for (i = nOutputCols; i; i--) { // Process each row in the filter mask pnGaborPtr = pnFilter; pnInputPtr = pnInputRow; @@ -1043,7 +1025,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // lies within our valid alpha channel. if (*pfAlphaRow++) { - for (jj=nFilterDim; jj; jj--) { + for (jj = nFilterDim; jj; jj--) { // First 128-bits nResponse += (*pnGaborPtr++) * (*pnInputPtr++); @@ -1079,8 +1061,7 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, nStatPosGrand = MAX(nStatPosGrand, nResponse); else nStatNegGrand = MIN(nStatNegGrand, nResponse); - } - else if (nStatFlags & STATS_SUM_ABS) + } else if (nStatFlags & STATS_SUM_ABS) nStatPosRow += IABS32(nResponse); else if (nStatFlags & STATS_SUM_POS_NEG) { if (nResponse >= 0) @@ -1092,8 +1073,8 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutputRow >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pnOutputRow < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutputRow >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pnOutputRow < pDebugOutputEOMB); #endif // DEBUG // Apply abs() and clipping, and then advance to next location @@ -1131,24 +1112,21 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, if (eNormalizeMode == NORMALIZE_MODE_PERORIENT) { // Compute final values of the normalizers to use - _computeNormalizers(nStatPosGrand, - nStatNegGrand, - nStatFlags, - eNormalizeMethod, - nNumPixels); + _computeNormalizers(nStatPosGrand, nStatNegGrand, nStatFlags, + eNormalizeMethod, nNumPixels); - NTA_ASSERT(nStatPosGrand >= 0); + NTA_ASSERT(nStatPosGrand >= 0); anStatPosGrand[nFilterIndex] = (unsigned int)(nStatPosGrand + 1); - // We also need to flip the sign of our negative + // We also need to flip the sign of our negative // max stat if we are in dual phase. - //if (nStatFlags & STATS_MAX_MIN) + // if (nStatFlags & STATS_MAX_MIN) // nStatNegGrand = -nStatNegGrand; - //if (nStatFlags & STATS_DUAL) { + // if (nStatFlags & STATS_DUAL) { if (ePhaseMode == PHASE_MODE_DUAL) { nStatNegGrand = -nStatNegGrand; - NTA_ASSERT(nStatNegGrand >= 0); + NTA_ASSERT(nStatNegGrand >= 0); // We add one to the statistical quantity because we want to - // round up in the case of integer arithmetic round off + // round up in the case of integer arithmetic round off // errors. That way (for example) our MAX statistic will // be guaranteed to be >= the largest actual value. anStatNegGrand[nFilterIndex] = (unsigned int)(nStatNegGrand + 1); @@ -1156,9 +1134,9 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // Debugging #ifdef DEBUG - for (int kk=0; kk<=nFilterIndex; kk++) { - fprintf(stdout, "[%d]: anStatPosGrand: %d\tanStatNegGrand: %d\n", - kk, anStatPosGrand[kk], anStatNegGrand[kk]); + for (int kk = 0; kk <= nFilterIndex; kk++) { + fprintf(stdout, "[%d]: anStatPosGrand: %d\tanStatNegGrand: %d\n", kk, + anStatPosGrand[kk], anStatNegGrand[kk]); } #endif // DEBUG } @@ -1168,9 +1146,9 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // determined on the first pass through the source // image (i.e., when generating the response to the // first gabor filter.) - if (nFilterIndex == 0 && - (eNormalizeMethod == NORMALIZE_METHOD_MEAN || - eNormalizeMethod == NORMALIZE_METHOD_MEANPOWER) && + if (nFilterIndex == 0 && + (eNormalizeMethod == NORMALIZE_METHOD_MEAN || + eNormalizeMethod == NORMALIZE_METHOD_MEANPOWER) && eNormalizeMode == NORMALIZE_MODE_GLOBAL) // If using global normalization, then we're summing the // responses over all planes, so there will a multiple @@ -1180,44 +1158,38 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, } // for each filter (output plane) - // If we are storing statistics globally (i.e., not on a - // per-filter basis), then we can finally dump our stats + // If we are storing statistics globally (i.e., not on a + // per-filter basis), then we can finally dump our stats // to the buffer. if (eNormalizeMode == NORMALIZE_MODE_GLOBAL) { // Compute final values of the normalizers to use - _computeNormalizers(nStatPosGrand, - nStatNegGrand, - nStatFlags, - eNormalizeMethod, - nNumPixels); + _computeNormalizers(nStatPosGrand, nStatNegGrand, nStatFlags, + eNormalizeMethod, nNumPixels); - NTA_ASSERT(nStatPosGrand >= 0); + NTA_ASSERT(nStatPosGrand >= 0); anStatPosGrand[0] = (unsigned int)(nStatPosGrand + 1); - // We also need to flip the sign of our negative + // We also need to flip the sign of our negative // max stat if we are in dual phase. - //if (nStatFlags & STATS_MAX_MIN) + // if (nStatFlags & STATS_MAX_MIN) // nStatNegGrand = -nStatNegGrand; - //if (nStatFlags & STATS_DUAL) { + // if (nStatFlags & STATS_DUAL) { if (ePhaseMode == PHASE_MODE_DUAL) { nStatNegGrand = -nStatNegGrand; - NTA_ASSERT(nStatNegGrand >= 0); + NTA_ASSERT(nStatNegGrand >= 0); anStatNegGrand[0] = (unsigned int)(nStatNegGrand + 1); } } - // Debug + // Debug #ifdef DEBUG - for (int kk=0; kk<1; kk++) { + for (int kk = 0; kk < 1; kk++) { fprintf(stdout, "anStatPosGrand[%d]: %d\n", kk, anStatPosGrand[kk]); fprintf(stdout, "anStatNegGrand[%d]: %d\n", kk, anStatNegGrand[kk]); } #endif // DEBUG } - - - // FUNCTION: _doConvolution_bbox() // 1. Convolve integerized input image (in bufferIn) against // each filter in gabor filter bank, storing the result @@ -1227,25 +1199,23 @@ void _doConvolution_alpha( const NUMPY_ARRAY * psBufferIn, // during Pass II. // For case where valid box is provided instead of valid // alpha channel. -void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, - const NUMPY_ARRAY * psBufferOut, - const NUMPY_ARRAY * psGaborBank, - const BBOX * psInputBox, - const BBOX * psOutputBox, - PHASE_MODE ePhaseMode, - NORMALIZE_METHOD eNormalizeMethod, - NORMALIZE_MODE eNormalizeMode, - unsigned int anStatPosGrand[], - unsigned int anStatNegGrand[] ) { +void _doConvolution_bbox(const NUMPY_ARRAY *psBufferIn, + const NUMPY_ARRAY *psBufferOut, + const NUMPY_ARRAY *psGaborBank, const BBOX *psInputBox, + const BBOX *psOutputBox, PHASE_MODE ePhaseMode, + NORMALIZE_METHOD eNormalizeMethod, + NORMALIZE_MODE eNormalizeMode, + unsigned int anStatPosGrand[], + unsigned int anStatNegGrand[]) { int i, j, jj; int nFilterIndex; int nResponse; int nNumPixels = 0; - const int * pnInput = nullptr; - const int * pnInputRow = nullptr; - const int * pnInputPtr = nullptr; - const int * pnGaborPtr = nullptr; - int * pnOutputRow = nullptr; + const int *pnInput = nullptr; + const int *pnInputRow = nullptr; + const int *pnInputPtr = nullptr; + const int *pnGaborPtr = nullptr; + int *pnOutputRow = nullptr; // Decide which statistics to keep. // There are four cases: @@ -1259,7 +1229,7 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, // something crazy didn't happen, like the minimum // response was positive, or vice versa. // 3. Record mean response in single-phase mode. - // For this, we need to accummulate the sume of + // For this, we need to accummulate the sume of // the absolute values of the responses. // 4. Record mean response in dual-phase mode. // For this, we need to keep the sum of all @@ -1295,7 +1265,8 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, nStatFlags |= STATS_SUM_POS_NEG; } // We'll need to know the total number of pixels - nNumPixels = (psOutputBox->nRight - psOutputBox->nLeft) * (psOutputBox->nBottom - psOutputBox->nTop); + nNumPixels = (psOutputBox->nRight - psOutputBox->nLeft) * + (psOutputBox->nBottom - psOutputBox->nTop); if (eNormalizeMode == NORMALIZE_MODE_GLOBAL) nNumPixels *= GABORSET_PLANES(psGaborBank); break; @@ -1312,21 +1283,24 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, int nFilterDim = IMAGESET_ROWS(psGaborBank); // Locate start of correct Gabor filter - const int * pnFilterBase = (const int *)psGaborBank->pData; + const int *pnFilterBase = (const int *)psGaborBank->pData; // Locate start of correct output plane - int * pnOutputBase = (int *)psBufferOut->pData; - int nOutputRowStride = IMAGESET_ROWSTRIDE(psBufferOut) / sizeof(*pnOutputBase); + int *pnOutputBase = (int *)psBufferOut->pData; + int nOutputRowStride = + IMAGESET_ROWSTRIDE(psBufferOut) / sizeof(*pnOutputBase); // Guard against buffer over-runs #ifdef DEBUG // Start/end of memory - const char * pDebugOutputSOMB = (const char*)(psBufferOut->pData); - const char * pDebugOutputEOMB = pDebugOutputSOMB + IMAGESET_PLANESTRIDE(psBufferOut) * IMAGESET_PLANES(psBufferOut); + const char *pDebugOutputSOMB = (const char *)(psBufferOut->pData); + const char *pDebugOutputEOMB = + pDebugOutputSOMB + + IMAGESET_PLANESTRIDE(psBufferOut) * IMAGESET_PLANES(psBufferOut); #endif // DEBUG // Locate the start of input plane - const int * pnInputBase = (const int *)psBufferIn->pData; + const int *pnInputBase = (const int *)psBufferIn->pData; int nInputRowStride = IMAGE_ROWSTRIDE(psBufferIn) / sizeof(*pnInputBase); int nInputRowAdvance = nInputRowStride - nFilterDim; @@ -1344,7 +1318,8 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, nStatPosGrand = 0; nStatNegGrand = 0; - for (nFilterIndex=0; nFilterIndexnTop + psInputBox->nLeft; + pnInput = + pnInputBase + nInputRowStride * psInputBox->nTop + psInputBox->nLeft; // Process each plane of output - const int * pnFilter = pnFilterBase; - int * pnOutput = pnOutputBase; + const int *pnFilter = pnFilterBase; + int *pnOutput = pnOutputBase; // Zero out any rows above the bounding box pnOutput += nNumBlankTopRows * nOutputRowStride; // Process each row within the bounding box (vertically) - for (j=nOutputRows; j; j--) { + for (j = nOutputRows; j; j--) { pnOutputRow = pnOutput; // Initialize row-wise accummulator @@ -1386,14 +1362,14 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, // Filter: 5x5 case 5: - for (i=nOutputCols; i; i--) { + for (i = nOutputCols; i; i--) { // Process each row in the filter mask pnGaborPtr = pnFilter; pnInputPtr = pnInputRow; // Compute gabor response for this location nResponse = 0; - for (jj=nFilterDim; jj; jj--) { + for (jj = nFilterDim; jj; jj--) { // First 128-bits nResponse += (*pnGaborPtr++) * (*pnInputPtr++); @@ -1415,8 +1391,7 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, nStatPosGrand = MAX(nStatPosGrand, nResponse); else nStatNegGrand = MIN(nStatNegGrand, nResponse); - } - else if (nStatFlags & STATS_SUM_ABS) + } else if (nStatFlags & STATS_SUM_ABS) nStatPosRow += IABS32(nResponse); else if (nStatFlags & STATS_SUM_POS_NEG) { if (nResponse >= 0) @@ -1427,8 +1402,8 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutputRow >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pnOutputRow < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutputRow >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pnOutputRow < pDebugOutputEOMB); #endif // DEBUG // Apply abs() and clipping, and then advance to next location @@ -1444,14 +1419,14 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, // Filter: 7x7 case 7: - for (i=nOutputCols; i; i--) { + for (i = nOutputCols; i; i--) { // Process each row in the filter mask pnGaborPtr = pnFilter; pnInputPtr = pnInputRow; // Compute gabor response for this location nResponse = 0; - for (jj=nFilterDim; jj; jj--) { + for (jj = nFilterDim; jj; jj--) { // First 128-bits nResponse += (*pnGaborPtr++) * (*pnInputPtr++); @@ -1475,8 +1450,7 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, nStatPosGrand = MAX(nStatPosGrand, nResponse); else nStatNegGrand = MIN(nStatNegGrand, nResponse); - } - else if (nStatFlags & STATS_SUM_ABS) + } else if (nStatFlags & STATS_SUM_ABS) nStatPosRow += IABS32(nResponse); else if (nStatFlags & STATS_SUM_POS_NEG) { if (nResponse >= 0) @@ -1487,8 +1461,8 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutputRow >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pnOutputRow < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutputRow >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pnOutputRow < pDebugOutputEOMB); #endif // DEBUG // Apply abs() and clipping, and then advance to next location @@ -1504,14 +1478,14 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, // Filter: 9x9 case 9: - for (i=nOutputCols; i; i--) { + for (i = nOutputCols; i; i--) { // Process each row in the filter mask pnGaborPtr = pnFilter; pnInputPtr = pnInputRow; // Compute gabor response for this location nResponse = 0; - for (jj=nFilterDim; jj; jj--) { + for (jj = nFilterDim; jj; jj--) { // First 128-bits nResponse += (*pnGaborPtr++) * (*pnInputPtr++); @@ -1541,8 +1515,7 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, nStatPosGrand = MAX(nStatPosGrand, nResponse); else nStatNegGrand = MIN(nStatNegGrand, nResponse); - } - else if (nStatFlags & STATS_SUM_ABS) + } else if (nStatFlags & STATS_SUM_ABS) nStatPosRow += IABS32(nResponse); else if (nStatFlags & STATS_SUM_POS_NEG) { if (nResponse >= 0) @@ -1553,8 +1526,8 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutputRow >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pnOutputRow < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutputRow >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pnOutputRow < pDebugOutputEOMB); #endif // DEBUG // Apply abs() and clipping, and then advance to next location @@ -1570,14 +1543,14 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, // Filter: 11x11 case 11: - for (i=nOutputCols; i; i--) { + for (i = nOutputCols; i; i--) { // Process each row in the filter mask pnGaborPtr = pnFilter; pnInputPtr = pnInputRow; // Compute gabor response for this location nResponse = 0; - for (jj=nFilterDim; jj; jj--) { + for (jj = nFilterDim; jj; jj--) { // First 128-bits nResponse += (*pnGaborPtr++) * (*pnInputPtr++); @@ -1609,8 +1582,7 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, nStatPosGrand = MAX(nStatPosGrand, nResponse); else nStatNegGrand = MIN(nStatNegGrand, nResponse); - } - else if (nStatFlags & STATS_SUM_ABS) + } else if (nStatFlags & STATS_SUM_ABS) nStatPosRow += IABS32(nResponse); else if (nStatFlags & STATS_SUM_POS_NEG) { if (nResponse >= 0) @@ -1621,8 +1593,8 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutputRow >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pnOutputRow < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutputRow >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pnOutputRow < pDebugOutputEOMB); #endif // DEBUG // Apply abs() and clipping, and then advance to next location @@ -1638,14 +1610,14 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, // Filter: 13x13 case 13: - for (i=nOutputCols; i; i--) { + for (i = nOutputCols; i; i--) { // Process each row in the filter mask pnGaborPtr = pnFilter; pnInputPtr = pnInputRow; // Compute gabor response for this location nResponse = 0; - for (jj=nFilterDim; jj; jj--) { + for (jj = nFilterDim; jj; jj--) { // First 128-bits nResponse += (*pnGaborPtr++) * (*pnInputPtr++); @@ -1681,8 +1653,7 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, nStatPosGrand = MAX(nStatPosGrand, nResponse); else nStatNegGrand = MIN(nStatNegGrand, nResponse); - } - else if (nStatFlags & STATS_SUM_ABS) + } else if (nStatFlags & STATS_SUM_ABS) nStatPosRow += IABS32(nResponse); else if (nStatFlags & STATS_SUM_POS_NEG) { if (nResponse >= 0) @@ -1693,8 +1664,8 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, // Memory bounds checking #ifdef DEBUG - NTA_ASSERT((const char*)pnOutputRow >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pnOutputRow < pDebugOutputEOMB); + NTA_ASSERT((const char *)pnOutputRow >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pnOutputRow < pDebugOutputEOMB); #endif // DEBUG // Apply abs() and clipping, and then advance to next location @@ -1731,24 +1702,21 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, if (eNormalizeMode == NORMALIZE_MODE_PERORIENT) { // Compute final values of the normalizers to use - _computeNormalizers(nStatPosGrand, - nStatNegGrand, - nStatFlags, - eNormalizeMethod, - nNumPixels); + _computeNormalizers(nStatPosGrand, nStatNegGrand, nStatFlags, + eNormalizeMethod, nNumPixels); - NTA_ASSERT(nStatPosGrand >= 0); + NTA_ASSERT(nStatPosGrand >= 0); anStatPosGrand[nFilterIndex] = (unsigned int)(nStatPosGrand + 1); - // We also need to flip the sign of our negative + // We also need to flip the sign of our negative // max stat if we are in dual phase. - //if (nStatFlags & STATS_MAX_MIN) + // if (nStatFlags & STATS_MAX_MIN) // nStatNegGrand = -nStatNegGrand; - //if (nStatFlags & STATS_DUAL) { + // if (nStatFlags & STATS_DUAL) { if (ePhaseMode == PHASE_MODE_DUAL) { nStatNegGrand = -nStatNegGrand; - NTA_ASSERT(nStatNegGrand >= 0); + NTA_ASSERT(nStatNegGrand >= 0); // We add one to the statistical quantity because we want to - // round up in the case of integer arithmetic round off + // round up in the case of integer arithmetic round off // errors. That way (for example) our MAX statistic will // be guaranteed to be >= the largest actual value. anStatNegGrand[nFilterIndex] = (unsigned int)(nStatNegGrand + 1); @@ -1756,50 +1724,46 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, // Debugging #ifdef DEBUG - for (int kk=0; kk<=nFilterIndex; kk++) { - fprintf(stdout, "[%d]: anStatPosGrand: %d\tanStatNegGrand: %d\n", - kk, anStatPosGrand[kk], anStatNegGrand[kk]); + for (int kk = 0; kk <= nFilterIndex; kk++) { + fprintf(stdout, "[%d]: anStatPosGrand: %d\tanStatNegGrand: %d\n", kk, + anStatPosGrand[kk], anStatNegGrand[kk]); } #endif // DEBUG } } // for each filter (output plane) - // If we are storing statistics globally (i.e., not on a - // per-filter basis), then we can finally dump our stats + // If we are storing statistics globally (i.e., not on a + // per-filter basis), then we can finally dump our stats // to the buffer. if (eNormalizeMode == NORMALIZE_MODE_GLOBAL) { // Compute final values of the normalizers to use - _computeNormalizers(nStatPosGrand, - nStatNegGrand, - nStatFlags, - eNormalizeMethod, - nNumPixels); + _computeNormalizers(nStatPosGrand, nStatNegGrand, nStatFlags, + eNormalizeMethod, nNumPixels); - NTA_ASSERT(nStatPosGrand >= 0); + NTA_ASSERT(nStatPosGrand >= 0); anStatPosGrand[0] = (unsigned int)(nStatPosGrand + 1); - // We also need to flip the sign of our negative + // We also need to flip the sign of our negative // max stat if we are in dual phase. - //if (nStatFlags & STATS_MAX_MIN) + // if (nStatFlags & STATS_MAX_MIN) // nStatNegGrand = -nStatNegGrand; - //if (nStatFlags & STATS_DUAL) { + // if (nStatFlags & STATS_DUAL) { if (ePhaseMode == PHASE_MODE_DUAL) { nStatNegGrand = -nStatNegGrand; - NTA_ASSERT(nStatNegGrand >= 0); + NTA_ASSERT(nStatNegGrand >= 0); anStatNegGrand[0] = (unsigned int)(nStatNegGrand + 1); } } - // Debug + // Debug #ifdef DEBUG - for (int kk=0; kk<1; kk++) { + for (int kk = 0; kk < 1; kk++) { fprintf(stdout, "anStatPosGrand[%d]: %d\n", kk, anStatPosGrand[kk]); fprintf(stdout, "anStatNegGrand[%d]: %d\n", kk, anStatNegGrand[kk]); } #endif // DEBUG } - // FUNCTION: _doConvolution() // 1. Convolve integerized input image (in bufferIn) against // each filter in gabor filter bank, storing the result @@ -1807,55 +1771,33 @@ void _doConvolution_bbox( const NUMPY_ARRAY * psBufferIn, // 2. While performing convolution, keeps track of the // neccessary statistics for use in normalization // during Pass II. -void _doConvolution( const NUMPY_ARRAY * psBufferIn, - const NUMPY_ARRAY * psBufferOut, - const NUMPY_ARRAY * psGaborBank, - const NUMPY_ARRAY * psAlpha, - const BBOX * psInputBox, - const BBOX * psOutputBox, - PHASE_MODE ePhaseMode, - NORMALIZE_METHOD eNormalizeMethod, - NORMALIZE_MODE eNormalizeMode, - unsigned int anStatPosGrand[], - unsigned int anStatNegGrand[] ) { +void _doConvolution(const NUMPY_ARRAY *psBufferIn, + const NUMPY_ARRAY *psBufferOut, + const NUMPY_ARRAY *psGaborBank, const NUMPY_ARRAY *psAlpha, + const BBOX *psInputBox, const BBOX *psOutputBox, + PHASE_MODE ePhaseMode, NORMALIZE_METHOD eNormalizeMethod, + NORMALIZE_MODE eNormalizeMode, + unsigned int anStatPosGrand[], + unsigned int anStatNegGrand[]) { if (psAlpha) - _doConvolution_alpha(psBufferIn, - psBufferOut, - psGaborBank, - psAlpha, - psInputBox, - psOutputBox, - ePhaseMode, - eNormalizeMethod, - eNormalizeMode, - anStatPosGrand, - anStatNegGrand); + _doConvolution_alpha(psBufferIn, psBufferOut, psGaborBank, psAlpha, + psInputBox, psOutputBox, ePhaseMode, eNormalizeMethod, + eNormalizeMode, anStatPosGrand, anStatNegGrand); else - _doConvolution_bbox(psBufferIn, - psBufferOut, - psGaborBank, - psInputBox, - psOutputBox, - ePhaseMode, - eNormalizeMethod, - eNormalizeMode, - anStatPosGrand, - anStatNegGrand); + _doConvolution_bbox(psBufferIn, psBufferOut, psGaborBank, psInputBox, + psOutputBox, ePhaseMode, eNormalizeMethod, + eNormalizeMode, anStatPosGrand, anStatNegGrand); } - // FUNCTION: _computeGains() // PURPOSE: Compute the positive (and in the case of dual-phase // filter banks, the negative as well) gains to use by taking // into account the normalizing factor. -void _computeGains(float fGain, - unsigned int nStatPosGrand, - unsigned int nStatNegGrand, - PHASE_MODE ePhaseMode, - PHASENORM_MODE ePhaseNormMode, - float & fGainPos, - float & fGainNeg) { +void _computeGains(float fGain, unsigned int nStatPosGrand, + unsigned int nStatNegGrand, PHASE_MODE ePhaseMode, + PHASENORM_MODE ePhaseNormMode, float &fGainPos, + float &fGainNeg) { fGainPos = fGain; NTA_ASSERT(nStatPosGrand > 0); fGainPos /= float(nStatPosGrand); @@ -1869,7 +1811,7 @@ void _computeGains(float fGain, if (ePhaseNormMode == PHASENORM_MODE_INDIV) fGainNeg = -fGain / float(nStatNegGrand); - // Both phases are normalized using the same + // Both phases are normalized using the same // normalizing factor (max or mean) else { NTA_ASSERT(ePhaseNormMode == PHASENORM_MODE_COMBO); @@ -1878,41 +1820,30 @@ void _computeGains(float fGain, if (nStatNegGrand > nStatPosGrand) { fGainNeg = -fGain / float(nStatNegGrand); fGainPos = -fGainNeg; - } - else + } else fGainNeg = -fGainPos; NTA_ASSERT(fGainNeg == -fGainPos); } } } - // FUNCTION: _postProcess() // 1. Perform rectification // 2. Apply gain (in general, this will be different for each // image based on auto-normalization results); // 3. Apply post-processing method if any; // 4. Convert from integer 32 to float. -void _postProcess( const NUMPY_ARRAY * psBufferIn, - const NUMPY_ARRAY * psOutput, - //const NUMPY_ARRAY * psBBox, - const BBOX * psBox, - PHASE_MODE ePhaseMode, - int nShrinkage, - EDGE_MODE eEdgeMode, - float fGainConstant, - NORMALIZE_METHOD eNormalizeMethod, - NORMALIZE_MODE eNormalizeMode, - PHASENORM_MODE ePhaseNormMode, - POSTPROC_METHOD ePostProcMethod, - float fPostProcSlope, - float fPostProcMidpoint, - float fPostProcMin, - float fPostProcMax, - const unsigned int anStatPosGrand[], - const unsigned int anStatNegGrand[], - const NUMPY_ARRAY * psPostProcLUT, - float fPostProcScalar) { +void _postProcess(const NUMPY_ARRAY *psBufferIn, const NUMPY_ARRAY *psOutput, + // const NUMPY_ARRAY * psBBox, + const BBOX *psBox, PHASE_MODE ePhaseMode, int nShrinkage, + EDGE_MODE eEdgeMode, float fGainConstant, + NORMALIZE_METHOD eNormalizeMethod, + NORMALIZE_MODE eNormalizeMode, PHASENORM_MODE ePhaseNormMode, + POSTPROC_METHOD ePostProcMethod, float fPostProcSlope, + float fPostProcMidpoint, float fPostProcMin, + float fPostProcMax, const unsigned int anStatPosGrand[], + const unsigned int anStatNegGrand[], + const NUMPY_ARRAY *psPostProcLUT, float fPostProcScalar) { int i, j; int nFilterIndex; int nResponse; @@ -1922,25 +1853,27 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, unsigned int nSingleBin; int nDualBin; float fGain, fGainPos = 0.0f, fGainNeg = 0.0f; - float * pfOutputRowPos = nullptr; - float * pfOutputRowNeg = nullptr; - float * pfOutputPos = nullptr; - float * pfOutputNeg = nullptr; - int * pnInputRow = nullptr; + float *pfOutputRowPos = nullptr; + float *pfOutputRowNeg = nullptr; + float *pfOutputPos = nullptr; + float *pfOutputNeg = nullptr; + int *pnInputRow = nullptr; // Locate start of first input plane - int * pnInputBase = (int *)psBufferIn->pData; + int *pnInputBase = (int *)psBufferIn->pData; int nInputRowStride = IMAGESET_ROWSTRIDE(psBufferIn) / sizeof(*pnInputBase); // Locate start of first output plane - float * pfOutputBase = (float *)psOutput->pData; + float *pfOutputBase = (float *)psOutput->pData; int nOutputRowStride = IMAGESET_ROWSTRIDE(psOutput) / sizeof(*pfOutputBase); // Guard against buffer over-runs #ifdef DEBUG // Start/end of memory - const char * pDebugOutputSOMB = (const char*)(psOutput->pData); - const char * pDebugOutputEOMB = pDebugOutputSOMB + IMAGESET_PLANESTRIDE(psOutput) * IMAGESET_PLANES(psOutput); + const char *pDebugOutputSOMB = (const char *)(psOutput->pData); + const char *pDebugOutputEOMB = + pDebugOutputSOMB + + IMAGESET_PLANESTRIDE(psOutput) * IMAGESET_PLANES(psOutput); #endif // DEBUG // Take into account bounding box suppression @@ -1962,7 +1895,7 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, int nTotalLeftovers = IMAGESET_COLS(psOutput) - (nTotalQuadsPerRow << 2); // Access the post-processing LUT - const float * pfPostProcLUT = nullptr; + const float *pfPostProcLUT = nullptr; int nNumLutBins = 0; int nMaxLutBin = 0; unsigned int nOverflowMask = 0x0; @@ -1970,7 +1903,7 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, pfPostProcLUT = (const float *)psPostProcLUT->pData; nNumLutBins = VECTOR_PLANES(psPostProcLUT); - // If we are in single-phase mode, then we'll + // If we are in single-phase mode, then we'll // just use the LUT as is // If we're in dual-phase mode, then we really // have two LUTs, each of which is bi-polar @@ -1979,7 +1912,7 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, // Generate a bit mask that can efficiently detect // whether a bin will overflow our LUT (so it can be clipped) - //unsigned int nOverflowMask = ~(nNumLutBins | nMaxLutBin); + // unsigned int nOverflowMask = ~(nNumLutBins | nMaxLutBin); // We'll use a trick to speed up the inner loop: // if we are using a normalization method other // than MEAN, then we'll be guaranteed to never @@ -2001,28 +1934,20 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, // If we are using global normalization (not per-orientation) then // we compute the final gain once if (eNormalizeMode == NORMALIZE_MODE_GLOBAL) { - _computeGains(fGain, - anStatPosGrand[0], - anStatNegGrand[0], - ePhaseMode, - ePhaseNormMode, - fGainPos, - fGainNeg); + _computeGains(fGain, anStatPosGrand[0], anStatNegGrand[0], ePhaseMode, + ePhaseNormMode, fGainPos, fGainNeg); } // Process each output plane - for (nFilterIndex=0; nFilterIndexnTop; //------------------------------------------------ // Process each row within the bounding box (vertically) - for (j=nOutputRows; j; j--) { + for (j = nOutputRows; j; j--) { // Set up row pointers pnInputRow = pnInput; @@ -2106,17 +2030,16 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, pfOutputRowNeg = pfOutputNeg; //------------------------------------------------ - // Fill in zeros outside the bounding box + // Fill in zeros outside the bounding box if (ePhaseMode == PHASE_MODE_SINGLE) { - for (i=psBox->nLeft; i; i--) + for (i = psBox->nLeft; i; i--) *pfOutputRowPos++ = NULL_RESPONSE; - } - else { + } else { NTA_ASSERT(ePhaseMode == PHASE_MODE_DUAL); - for (i=psBox->nLeft; i; i--) { + for (i = psBox->nLeft; i; i--) { *pfOutputRowPos++ = NULL_RESPONSE; *pfOutputRowNeg++ = NULL_RESPONSE; - } + } } // Skip any rows above the bounding box pnInputRow += psBox->nLeft; @@ -2132,7 +2055,7 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, if (ePostProcMethod == POSTPROC_METHOD_RAW) { // Process this ouput row, one output location at a time - for (i=nOutputQuadsPerRow; i; i--) { + for (i = nOutputQuadsPerRow; i; i--) { // Apply abs() and clipping, and then advance to next location // Note: our gabor filter masks were pre-scaled by shifting @@ -2147,14 +2070,14 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, // Advance pointers pnInputRow += 4; pfOutputRowPos += 4; - } // for (i=nOutputQuadsPerRow; i; i--) + } // for (i=nOutputQuadsPerRow; i; i--) // Handle leftovers - for (i=nNumLeftovers; i; i--) { + for (i = nNumLeftovers; i; i--) { *pfOutputRowPos++ = fGainPos * (float)(IABS32(*pnInputRow)); - pnInputRow ++; - } // for (i=nNumLeftovers; i; i--) - } // no post-processing + pnInputRow++; + } // for (i=nNumLeftovers; i; i--) + } // no post-processing // Post-processing needed else { @@ -2164,10 +2087,11 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, if (nOverflowMask) { // Process this ouput row, four locations at a time - for (i=nOutputQuadsPerRow; i; i--) { + for (i = nOutputQuadsPerRow; i; i--) { // Compute LUT bin to use for looking up post-processed value - nSingleBin = (unsigned int)(IABS32(*pnInputRow) / nDiscreteGainPos); + nSingleBin = + (unsigned int)(IABS32(*pnInputRow) / nDiscreteGainPos); // In 'mean' normalization, the maximum values are essentially // unbounded; so we have to clip them to make sure they // don't overflow our LUT @@ -2179,19 +2103,22 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, *pfOutputRowPos = pfPostProcLUT[nSingleBin]; // Second location - nSingleBin = (unsigned int)(IABS32(pnInputRow[1]) / nDiscreteGainPos); + nSingleBin = + (unsigned int)(IABS32(pnInputRow[1]) / nDiscreteGainPos); if (nSingleBin & nOverflowMask) nSingleBin = nMaxLutBin; pfOutputRowPos[1] = pfPostProcLUT[nSingleBin]; // Third location - nSingleBin = (unsigned int)(IABS32(pnInputRow[2]) / nDiscreteGainPos); + nSingleBin = + (unsigned int)(IABS32(pnInputRow[2]) / nDiscreteGainPos); if (nSingleBin & nOverflowMask) nSingleBin = nMaxLutBin; pfOutputRowPos[2] = pfPostProcLUT[nSingleBin]; // Fourth location - nSingleBin = (unsigned int)(IABS32(pnInputRow[3]) / nDiscreteGainPos); + nSingleBin = + (unsigned int)(IABS32(pnInputRow[3]) / nDiscreteGainPos); if (nSingleBin & nOverflowMask) nSingleBin = nMaxLutBin; pfOutputRowPos[3] = pfPostProcLUT[nSingleBin]; @@ -2209,10 +2136,10 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, // Advance pointers pnInputRow += 4; pfOutputRowPos += 4; - } // for (i=nOutputQuadsPerRow; i; i--) + } // for (i=nOutputQuadsPerRow; i; i--) // Handle leftovers - for (i=nNumLeftovers; i; i--) { + for (i = nNumLeftovers; i; i--) { // Compute LUT bin to use for looking up post-processed value nResponse = *pnInputRow++; nSingleBin = (unsigned int)(IABS32(nResponse) / nDiscreteGainPos); @@ -2225,55 +2152,59 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, nSingleBin = nMaxLutBin; // Apply LUT-based post-processing function *pfOutputRowPos++ = pfPostProcLUT[nSingleBin]; - } // for (i=nNumLeftovers; i; i--) - } // if (nOverflowMask) + } // for (i=nNumLeftovers; i; i--) + } // if (nOverflowMask) // If we don't have to worry about possibly overflowing // our LUT, then we'll go through this faster path: else { // Process this ouput row, four locations at a time - for (i=nOutputQuadsPerRow; i; i--) { + for (i = nOutputQuadsPerRow; i; i--) { // Memory protection #ifdef DEBUG - NTA_ASSERT((const char*)pfOutputRowPos >= pDebugOutputSOMB); - NTA_ASSERT((const char*)&(pfOutputRowPos[3]) < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowPos >= pDebugOutputSOMB); + NTA_ASSERT((const char *)&(pfOutputRowPos[3]) < pDebugOutputEOMB); #endif // DEBUG // Compute LUT bin to use for looking up post-processed value - nSingleBin = (unsigned int)(IABS32(*pnInputRow) / nDiscreteGainPos); + nSingleBin = + (unsigned int)(IABS32(*pnInputRow) / nDiscreteGainPos); NTA_ASSERT(nSingleBin <= (unsigned int)nMaxLutBin); // Apply LUT-based post-processing function *pfOutputRowPos = pfPostProcLUT[nSingleBin]; // Second location - nSingleBin = (unsigned int)(IABS32(pnInputRow[1]) / nDiscreteGainPos); + nSingleBin = + (unsigned int)(IABS32(pnInputRow[1]) / nDiscreteGainPos); NTA_ASSERT(nSingleBin <= (unsigned int)nMaxLutBin); pfOutputRowPos[1] = pfPostProcLUT[nSingleBin]; // Third location - nSingleBin = (unsigned int)(IABS32(pnInputRow[2]) / nDiscreteGainPos); + nSingleBin = + (unsigned int)(IABS32(pnInputRow[2]) / nDiscreteGainPos); NTA_ASSERT(nSingleBin <= (unsigned int)nMaxLutBin); pfOutputRowPos[2] = pfPostProcLUT[nSingleBin]; // Fourth location - nSingleBin = (unsigned int)(IABS32(pnInputRow[3]) / nDiscreteGainPos); + nSingleBin = + (unsigned int)(IABS32(pnInputRow[3]) / nDiscreteGainPos); NTA_ASSERT(nSingleBin <= (unsigned int)nMaxLutBin); pfOutputRowPos[3] = pfPostProcLUT[nSingleBin]; // Advance pointers pnInputRow += 4; pfOutputRowPos += 4; - } // for (i=nOutputQuadsPerRow; i; i--) + } // for (i=nOutputQuadsPerRow; i; i--) // Handle leftovers - for (i=nNumLeftovers; i; i--) { + for (i = nNumLeftovers; i; i--) { // Memory protection #ifdef DEBUG - NTA_ASSERT((const char*)pfOutputRowPos >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pfOutputRowPos < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowPos >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pfOutputRowPos < pDebugOutputEOMB); #endif // DEBUG // Compute LUT bin to use for looking up post-processed value @@ -2282,11 +2213,11 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, NTA_ASSERT(nSingleBin <= (unsigned int)nMaxLutBin); // Apply LUT-based post-processing function *pfOutputRowPos++ = pfPostProcLUT[nSingleBin]; - } // for (i=nNumLeftovers; i; i--) - } // if we don't have an overflow problem to worry about - } // Non-raw post-processing needed - } // Single-phase - + } // for (i=nNumLeftovers; i; i--) + } // if we don't have an overflow problem to worry about + } // Non-raw post-processing needed + } // Single-phase + //------------------------------------------------ // Dual-phases else { @@ -2296,14 +2227,14 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, if (ePostProcMethod == POSTPROC_METHOD_RAW) { // Process this ouput row, one output location at a time - for (i=nOutputQuadsPerRow; i; i--) { + for (i = nOutputQuadsPerRow; i; i--) { // Memory protection #ifdef DEBUG - NTA_ASSERT((const char*)pfOutputRowPos >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pfOutputRowNeg >= pDebugOutputSOMB); - NTA_ASSERT((const char*)&(pfOutputRowPos[3]) < pDebugOutputEOMB); - NTA_ASSERT((const char*)&(pfOutputRowNeg[3]) < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowPos >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pfOutputRowNeg >= pDebugOutputSOMB); + NTA_ASSERT((const char *)&(pfOutputRowPos[3]) < pDebugOutputEOMB); + NTA_ASSERT((const char *)&(pfOutputRowNeg[3]) < pDebugOutputEOMB); #endif // DEBUG // Generate two responses from one original convolution @@ -2311,8 +2242,7 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, if (fResponse >= 0.0f) { *pfOutputRowPos = fResponse * fGainPos; *pfOutputRowNeg = 0.0f; - } - else { + } else { *pfOutputRowPos = 0.0f; *pfOutputRowNeg = fResponse * fGainNeg; } @@ -2323,8 +2253,7 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, if (fResponse >= 0.0f) { pfOutputRowPos[1] = fResponse * fGainPos; pfOutputRowNeg[1] = 0.0f; - } - else { + } else { pfOutputRowPos[1] = 0.0f; pfOutputRowNeg[1] = fResponse * fGainNeg; } @@ -2333,8 +2262,7 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, if (fResponse >= 0.0f) { pfOutputRowPos[2] = fResponse * fGainPos; pfOutputRowNeg[2] = 0.0f; - } - else { + } else { pfOutputRowPos[2] = 0.0f; pfOutputRowNeg[2] = fResponse * fGainNeg; } @@ -2343,8 +2271,7 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, if (fResponse >= 0.0f) { pfOutputRowPos[3] = fResponse * fGainPos; pfOutputRowNeg[3] = 0.0f; - } - else { + } else { pfOutputRowPos[3] = 0.0f; pfOutputRowNeg[3] = fResponse * fGainNeg; } @@ -2356,27 +2283,26 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, } // Handle leftovers - for (i=nNumLeftovers; i; i--) { + for (i = nNumLeftovers; i; i--) { // Memory protection #ifdef DEBUG - NTA_ASSERT((const char*)pfOutputRowPos >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pfOutputRowPos < pDebugOutputEOMB); - NTA_ASSERT((const char*)pfOutputRowNeg >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pfOutputRowNeg < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowPos >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pfOutputRowPos < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowNeg >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pfOutputRowNeg < pDebugOutputEOMB); #endif // DEBUG fResponse = (float)(*pnInputRow++); if (fResponse >= 0.0f) { *pfOutputRowPos++ = fResponse * fGainPos; *pfOutputRowNeg++ = 0.0f; - } - else { + } else { *pfOutputRowPos++ = 0.0f; *pfOutputRowNeg++ = fResponse * fGainNeg; } - } // for (i=nNumLeftovers; i; i--) - } // if doing 'raw' post-processing + } // for (i=nNumLeftovers; i; i--) + } // if doing 'raw' post-processing // Non-trivial post-processing else { @@ -2386,19 +2312,19 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, if (nOverflowMask) { // Process one pixel at a time - for (i=nOutputCols; i; i--) { + for (i = nOutputCols; i; i--) { - // Compute discretized response (in terms of the + // Compute discretized response (in terms of the // LUT bin). This value 'nDualBin' could be positive // or negative. nDualBin = (*pnInputRow) / nDiscreteGainPos; // Memory protection #ifdef DEBUG - NTA_ASSERT((const char*)pfOutputRowPos >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pfOutputRowPos < pDebugOutputEOMB); - NTA_ASSERT((const char*)pfOutputRowNeg >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pfOutputRowNeg < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowPos >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pfOutputRowPos < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowNeg >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pfOutputRowNeg < pDebugOutputEOMB); #endif // DEBUG // If positive response @@ -2415,11 +2341,12 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, *pfOutputRowPos++ = pfPostProcLUT[nDualBin]; *pfOutputRowNeg++ = 0.0f; -//#ifdef DEBUG -// // TEMP TEMP TEMP -// if (pfOutputRowPos[-1] < 0.0f || pfOutputRowPos[-1] > 1.0f) -// NTA_ASSERT(false); -//#endif // DEBUG + //#ifdef DEBUG + // // TEMP TEMP TEMP + // if (pfOutputRowPos[-1] < 0.0f || + // pfOutputRowPos[-1] > 1.0f) + // NTA_ASSERT(false); + //#endif // DEBUG // Sanity checks NTA_ASSERT(pfOutputRowPos[-1] <= 1.0f); @@ -2436,11 +2363,12 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, NTA_ASSERT(nDualBin >= 0); *pfOutputRowNeg++ = pfPostProcLUT[nDualBin]; -//#ifdef DEBUG -// // TEMP TEMP TEMP -// if (pfOutputRowNeg[-1] < 0.0f || pfOutputRowNeg[-1] > 1.0f) -// NTA_ASSERT(false); -//#endif // DEBUG + //#ifdef DEBUG + // // TEMP TEMP TEMP + // if (pfOutputRowNeg[-1] < 0.0f || + // pfOutputRowNeg[-1] > 1.0f) + // NTA_ASSERT(false); + //#endif // DEBUG // Sanity checks NTA_ASSERT(pfOutputRowNeg[-1] <= 1.0f); @@ -2450,31 +2378,31 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, // Next input pnInputRow++; - } // for (i=nOutputQuadsPerRow; i; i--) - } // if (nOverflowMask) + } // for (i=nOutputQuadsPerRow; i; i--) + } // if (nOverflowMask) // If we don't have to worry about possibly overflowing // our LUT, then we'll go through this faster path: else { // Process one pixel at a time - for (i=nOutputCols; i; i--) { + for (i = nOutputCols; i; i--) { // Memory protection #ifdef DEBUG - NTA_ASSERT((const char*)pfOutputRowPos >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pfOutputRowPos < pDebugOutputEOMB); - NTA_ASSERT((const char*)pfOutputRowNeg >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pfOutputRowNeg < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowPos >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pfOutputRowPos < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowNeg >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pfOutputRowNeg < pDebugOutputEOMB); #endif // DEBUG - // Compute LUT bin to use for looking up final post-processed value + // Compute LUT bin to use for looking up final post-processed + // value nDualBin = (*pnInputRow) / nDiscreteGainPos; if (nDualBin >= 0) { NTA_ASSERT(nDualBin <= nMaxLutBin); *pfOutputRowPos++ = pfPostProcLUT[nDualBin]; *pfOutputRowNeg++ = 0.0f; - } - else { + } else { *pfOutputRowPos++ = 0.0f; nDualBin = (*pnInputRow) / nDiscreteGainNeg; NTA_ASSERT(nDualBin <= nMaxLutBin); @@ -2488,39 +2416,38 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, NTA_ASSERT(pfOutputRowNeg[-1] >= 0.0f); pnInputRow++; - } // for (i=nOutputQuadsPerRow; i; i--) - } // if we have no overflow mask to worry about - } // non-raw post-processing - } // Dual-phase + } // for (i=nOutputQuadsPerRow; i; i--) + } // if we have no overflow mask to worry about + } // non-raw post-processing + } // Dual-phase //------------------------------------------------ // Fill in zeros to the right of the bounding box. if (ePhaseMode == PHASE_MODE_SINGLE) { - for (i=nNumBlankRightCols; i; i--) { + for (i = nNumBlankRightCols; i; i--) { // Memory protection #ifdef DEBUG - NTA_ASSERT((const char*)pfOutputRowPos >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pfOutputRowPos >= pDebugOutputSOMB); #endif // DEBUG *pfOutputRowPos++ = NULL_RESPONSE; } - } - else { + } else { NTA_ASSERT(ePhaseMode == PHASE_MODE_DUAL); - for (i=nNumBlankRightCols; i; i--) { + for (i = nNumBlankRightCols; i; i--) { // Memory protection #ifdef DEBUG - NTA_ASSERT((const char*)pfOutputRowPos >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pfOutputRowPos < pDebugOutputEOMB); - NTA_ASSERT((const char*)pfOutputRowNeg >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pfOutputRowNeg < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowPos >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pfOutputRowPos < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowNeg >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pfOutputRowNeg < pDebugOutputEOMB); #endif // DEBUG *pfOutputRowPos++ = NULL_RESPONSE; *pfOutputRowNeg++ = NULL_RESPONSE; - } + } } // Advance to next rows @@ -2532,18 +2459,18 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, //------------------------------------------------ // Zero out any rows below the bounding box - for (j=nNumBlankBottomRows; j; j--) { + for (j = nNumBlankBottomRows; j; j--) { // Single phase if (ePhaseMode == PHASE_MODE_SINGLE) { pfOutputRowPos = pfOutputPos; // Hopefully the compiler will use SIMD for this: - for (i=nTotalQuadsPerRow; i; i--) { + for (i = nTotalQuadsPerRow; i; i--) { // Memory protection #ifdef DEBUG - NTA_ASSERT((const char*)pfOutputRowPos >= pDebugOutputSOMB); - NTA_ASSERT((const char*)&(pfOutputRowPos[3]) < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowPos >= pDebugOutputSOMB); + NTA_ASSERT((const char *)&(pfOutputRowPos[3]) < pDebugOutputEOMB); #endif // DEBUG pfOutputRowPos[0] = NULL_RESPONSE; @@ -2555,12 +2482,12 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, pfOutputRowPos += 4; } // Handle any leftovers that don't fit in a quad - for (i=nTotalLeftovers; i; i--) { + for (i = nTotalLeftovers; i; i--) { // Memory protection #ifdef DEBUG - NTA_ASSERT((const char*)pfOutputRowPos >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pfOutputRowPos < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowPos >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pfOutputRowPos < pDebugOutputEOMB); #endif // DEBUG *pfOutputRowPos++ = NULL_RESPONSE; @@ -2576,14 +2503,14 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, pfOutputRowPos = pfOutputPos; pfOutputRowNeg = pfOutputNeg; // Hopefully the compiler will use SIMD for this: - for (i=nTotalQuadsPerRow; i; i--) { + for (i = nTotalQuadsPerRow; i; i--) { // Memory protection #ifdef DEBUG - NTA_ASSERT((const char*)pfOutputRowPos >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pfOutputRowNeg >= pDebugOutputSOMB); - NTA_ASSERT((const char*)&(pfOutputRowPos[3]) < pDebugOutputEOMB); - NTA_ASSERT((const char*)&(pfOutputRowNeg[3]) < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowPos >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pfOutputRowNeg >= pDebugOutputSOMB); + NTA_ASSERT((const char *)&(pfOutputRowPos[3]) < pDebugOutputEOMB); + NTA_ASSERT((const char *)&(pfOutputRowNeg[3]) < pDebugOutputEOMB); #endif // DEBUG pfOutputRowPos[0] = NULL_RESPONSE; @@ -2600,14 +2527,14 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, pfOutputRowNeg += 4; } // Handle any leftovers that don't fit in a quad - for (i=nTotalLeftovers; i; i--) { + for (i = nTotalLeftovers; i; i--) { // Memory protection #ifdef DEBUG - NTA_ASSERT((const char*)pfOutputRowPos >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pfOutputRowPos < pDebugOutputEOMB); - NTA_ASSERT((const char*)pfOutputRowNeg >= pDebugOutputSOMB); - NTA_ASSERT((const char*)pfOutputRowNeg < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowPos >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pfOutputRowPos < pDebugOutputEOMB); + NTA_ASSERT((const char *)pfOutputRowNeg >= pDebugOutputSOMB); + NTA_ASSERT((const char *)pfOutputRowNeg < pDebugOutputEOMB); #endif // DEBUG *pfOutputRowPos++ = NULL_RESPONSE; @@ -2617,32 +2544,32 @@ void _postProcess( const NUMPY_ARRAY * psBufferIn, // Advance our row pointer(s) pfOutputPos += nOutputRowStride; pfOutputNeg += nOutputRowStride; - } // Dual phase - } // for (j=nNumBlankBottomRows; j; j--) + } // Dual phase + } // for (j=nNumBlankBottomRows; j; j--) // Advance to correct plane for gabor filter and output buffer - pnInputBase += IMAGESET_PLANESTRIDE(psBufferIn) / sizeof(*pnInput); + pnInputBase += IMAGESET_PLANESTRIDE(psBufferIn) / sizeof(*pnInput); pfOutputBase += IMAGESET_PLANESTRIDE(psOutput) / sizeof(*pfOutputBase); } // for each filter (output plane) } - // FUNCTION: _zeroOutputs() // PURPOSE: Special case for when the output planes // have to be uniformly zero response (e.g., when // there is not enough pixels in the input image to // compute a single response.) -void _zeroOutputs(const NUMPY_ARRAY * psOutput) { +void _zeroOutputs(const NUMPY_ARRAY *psOutput) { int k, j, i; - float * pfOutputRow = nullptr; - float * pfOutputPtr = nullptr; + float *pfOutputRow = nullptr; + float *pfOutputPtr = nullptr; // Locate start of first output plane - float * pfOutputBase = (float *)psOutput->pData; + float *pfOutputBase = (float *)psOutput->pData; int nOutputRowStride = IMAGESET_ROWSTRIDE(psOutput) / sizeof(*pfOutputBase); - int nOutputPlaneStride = IMAGESET_PLANESTRIDE(psOutput) / sizeof(*pfOutputBase); + int nOutputPlaneStride = + IMAGESET_PLANESTRIDE(psOutput) / sizeof(*pfOutputBase); // Take into account bounding box suppression int nOutputRows = IMAGESET_ROWS(psOutput); @@ -2654,15 +2581,15 @@ void _zeroOutputs(const NUMPY_ARRAY * psOutput) { int nLeftovers = nOutputCols - (nQuadsPerRow << 2); // Zero out each response plane: - for (k=nNumPlanes; k; k-- ) { + for (k = nNumPlanes; k; k--) { // Zero out each row pfOutputRow = pfOutputBase; - for (j=nOutputRows; j; j--) { + for (j = nOutputRows; j; j--) { // Process most of the row in quads pfOutputPtr = pfOutputRow; - for (i=nQuadsPerRow; i; i--) { + for (i = nQuadsPerRow; i; i--) { *pfOutputPtr++ = 0.0f; *pfOutputPtr++ = 0.0f; *pfOutputPtr++ = 0.0f; @@ -2670,7 +2597,7 @@ void _zeroOutputs(const NUMPY_ARRAY * psOutput) { } // Handle any leftovers - for (i=nLeftovers; i; i--) + for (i = nLeftovers; i; i--) *pfOutputPtr++ = 0.0f; // Move to next row @@ -2682,14 +2609,13 @@ void _zeroOutputs(const NUMPY_ARRAY * psOutput) { } } - // FUNCTION: initFromPython() // PURPOSE: Initialize logging data structures when we are -// being called from python via ctypes as a dynamically -// loaded library. +// being called from python via ctypes as a dynamically +// loaded library. #ifdef INIT_FROM_PYTHON -NTA_EXPORT +NTA_EXPORT void initFromPython(unsigned long long refP) { PythonSystem_initFromReferenceP(refP); } @@ -2698,37 +2624,25 @@ void initFromPython(unsigned long long refP) { // FUNCTION: gaborCompute() // PURPOSE: GaborNode implementation NTA_EXPORT -int gaborCompute(const NUMPY_ARRAY * psGaborBank, - const NUMPY_ARRAY * psInput, - const NUMPY_ARRAY * psAlpha, - const NUMPY_ARRAY * psBBox, - const NUMPY_ARRAY * psImageBox, - const NUMPY_ARRAY * psOutput, - float fGainConstant, - EDGE_MODE eEdgeMode, - float fOffImageFillValue, - PHASE_MODE ePhaseMode, - NORMALIZE_METHOD eNormalizeMethod, - NORMALIZE_MODE eNormalizeMode, - PHASENORM_MODE ePhaseNormMode, - POSTPROC_METHOD ePostProcMethod, - float fPostProcSlope, - float fPostProcMidpoint, - float fPostProcMin, - float fPostProcMax, - const NUMPY_ARRAY * psBufferIn, - const NUMPY_ARRAY * psBufferOut, - const NUMPY_ARRAY * psPostProcLUT, - float fPostProcScalar - ) { +int gaborCompute(const NUMPY_ARRAY *psGaborBank, const NUMPY_ARRAY *psInput, + const NUMPY_ARRAY *psAlpha, const NUMPY_ARRAY *psBBox, + const NUMPY_ARRAY *psImageBox, const NUMPY_ARRAY *psOutput, + float fGainConstant, EDGE_MODE eEdgeMode, + float fOffImageFillValue, PHASE_MODE ePhaseMode, + NORMALIZE_METHOD eNormalizeMethod, + NORMALIZE_MODE eNormalizeMode, PHASENORM_MODE ePhaseNormMode, + POSTPROC_METHOD ePostProcMethod, float fPostProcSlope, + float fPostProcMidpoint, float fPostProcMin, + float fPostProcMax, const NUMPY_ARRAY *psBufferIn, + const NUMPY_ARRAY *psBufferOut, + const NUMPY_ARRAY *psPostProcLUT, float fPostProcScalar) { // Allocate a big chunk of storage on the stack for a temporary buffer // to hold our accummulated response statistics. unsigned int anStatPosGrand[MAXNUM_FILTERS]; unsigned int anStatNegGrand[MAXNUM_FILTERS]; - try - { + try { //------------------------------------------- // Sanity checks @@ -2741,21 +2655,28 @@ int gaborCompute(const NUMPY_ARRAY * psGaborBank, // Sanity check: edge mode and input/output dimensionalities // must make sense if (eEdgeMode == EDGE_MODE_CONSTRAINED) { - //NTA_ASSERT(IMAGE_COLS(psBufferIn) == ALIGN_4_CEIL(IMAGE_COLS(psInput))); - //NTA_ASSERT(IMAGE_ROWS(psBufferIn) == IMAGE_ROWS(psInput)); - //NTA_ASSERT(IMAGESET_COLS(psBufferOut) == ALIGN_4_CEIL(IMAGE_COLS(psInput) - nFilterDim + 1)); - //NTA_ASSERT(IMAGESET_ROWS(psBufferOut) == (IMAGE_ROWS(psInput) - nFilterDim + 1)); - NTA_ASSERT(IMAGESET_COLS(psBufferOut) == ALIGN_4_CEIL(IMAGESET_COLS(psOutput))); + // NTA_ASSERT(IMAGE_COLS(psBufferIn) == + // ALIGN_4_CEIL(IMAGE_COLS(psInput))); NTA_ASSERT(IMAGE_ROWS(psBufferIn) + // == IMAGE_ROWS(psInput)); NTA_ASSERT(IMAGESET_COLS(psBufferOut) == + // ALIGN_4_CEIL(IMAGE_COLS(psInput) - nFilterDim + 1)); + // NTA_ASSERT(IMAGESET_ROWS(psBufferOut) == (IMAGE_ROWS(psInput) - + // nFilterDim + 1)); + NTA_ASSERT(IMAGESET_COLS(psBufferOut) == + ALIGN_4_CEIL(IMAGESET_COLS(psOutput))); NTA_ASSERT(IMAGESET_ROWS(psBufferOut) == IMAGESET_ROWS(psOutput)); - } - else { + } else { NTA_ASSERT(eEdgeMode == EDGE_MODE_SWEEPOFF); - //NTA_ASSERT(IMAGE_COLS(psBufferIn) == ALIGN_4_CEIL(IMAGE_COLS(psInput) + nFilterDim - 1)); - //NTA_ASSERT(IMAGE_ROWS(psBufferIn) == (IMAGE_ROWS(psInput) + nFilterDim - 1)); - //NTA_ASSERT(IMAGE_COLS(psBufferIn) == (ALIGN_4_CEIL(IMAGESET_COLS(psBufferOut)) + nFilterDim - 1)); - NTA_ASSERT(IMAGE_COLS(psBufferIn) <= ALIGN_4_CEIL(IMAGESET_COLS(psBufferOut) + nFilterDim - 1)); - NTA_ASSERT(IMAGE_ROWS(psBufferIn) == (IMAGESET_ROWS(psBufferOut) + nFilterDim - 1)); - NTA_ASSERT(IMAGESET_COLS(psBufferOut) == ALIGN_4_CEIL(IMAGESET_COLS(psOutput))); + // NTA_ASSERT(IMAGE_COLS(psBufferIn) == ALIGN_4_CEIL(IMAGE_COLS(psInput) + + // nFilterDim - 1)); NTA_ASSERT(IMAGE_ROWS(psBufferIn) == + // (IMAGE_ROWS(psInput) + nFilterDim - 1)); + // NTA_ASSERT(IMAGE_COLS(psBufferIn) == + // (ALIGN_4_CEIL(IMAGESET_COLS(psBufferOut)) + nFilterDim - 1)); + NTA_ASSERT(IMAGE_COLS(psBufferIn) <= + ALIGN_4_CEIL(IMAGESET_COLS(psBufferOut) + nFilterDim - 1)); + NTA_ASSERT(IMAGE_ROWS(psBufferIn) == + (IMAGESET_ROWS(psBufferOut) + nFilterDim - 1)); + NTA_ASSERT(IMAGESET_COLS(psBufferOut) == + ALIGN_4_CEIL(IMAGESET_COLS(psOutput))); NTA_ASSERT(IMAGESET_ROWS(psBufferOut) == IMAGESET_ROWS(psOutput)); } @@ -2774,12 +2695,12 @@ int gaborCompute(const NUMPY_ARRAY * psGaborBank, NTA_ASSERT(IMAGESET_COLS(psBufferOut) % 4 == 0); // Make sure out "image box" (the box that defines the actual - // image portion of the psInput array) encloses our + // image portion of the psInput array) encloses our // "bounding box" (the box that defines the portion of image // pixels over which gabor responses are to be computed.) - NTA_ASSERT(BBOX_LEFT(psBBox) >= BBOX_LEFT(psImageBox)); - NTA_ASSERT(BBOX_RIGHT(psBBox) <= BBOX_RIGHT(psImageBox)); - NTA_ASSERT(BBOX_TOP(psBBox) >= BBOX_TOP(psImageBox)); + NTA_ASSERT(BBOX_LEFT(psBBox) >= BBOX_LEFT(psImageBox)); + NTA_ASSERT(BBOX_RIGHT(psBBox) <= BBOX_RIGHT(psImageBox)); + NTA_ASSERT(BBOX_TOP(psBBox) >= BBOX_TOP(psImageBox)); NTA_ASSERT(BBOX_BOTTOM(psBBox) <= BBOX_BOTTOM(psImageBox)); // The alpha mask is optional, but if it is provided, @@ -2790,7 +2711,7 @@ int gaborCompute(const NUMPY_ARRAY * psGaborBank, } //------------------------------------------- - // Set up a bounding box that specifies the + // Set up a bounding box that specifies the // range of pixels for which Gabor responses // are to be competed: // sBoxInput: the locations in the (padded?) @@ -2802,44 +2723,43 @@ int gaborCompute(const NUMPY_ARRAY * psGaborBank, BBOX sBoxInput, sBoxOutput; if (eEdgeMode == EDGE_MODE_CONSTRAINED) { // Input - sBoxInput.nLeft = BBOX_LEFT(psBBox); - sBoxInput.nTop = BBOX_TOP(psBBox); - sBoxInput.nRight = sBoxInput.nLeft + BBOX_WIDTH(psBBox); - sBoxInput.nBottom = sBoxInput.nTop + BBOX_HEIGHT(psBBox); + sBoxInput.nLeft = BBOX_LEFT(psBBox); + sBoxInput.nTop = BBOX_TOP(psBBox); + sBoxInput.nRight = sBoxInput.nLeft + BBOX_WIDTH(psBBox); + sBoxInput.nBottom = sBoxInput.nTop + BBOX_HEIGHT(psBBox); // Output - sBoxOutput.nLeft = sBoxInput.nLeft; - sBoxOutput.nTop = sBoxInput.nTop; - sBoxOutput.nRight = sBoxOutput.nLeft + BBOX_WIDTH(psBBox) - nShrinkage; - sBoxOutput.nBottom = sBoxOutput.nTop + BBOX_HEIGHT(psBBox) - nShrinkage; - } - else { + sBoxOutput.nLeft = sBoxInput.nLeft; + sBoxOutput.nTop = sBoxInput.nTop; + sBoxOutput.nRight = sBoxOutput.nLeft + BBOX_WIDTH(psBBox) - nShrinkage; + sBoxOutput.nBottom = sBoxOutput.nTop + BBOX_HEIGHT(psBBox) - nShrinkage; + } else { NTA_ASSERT(eEdgeMode == EDGE_MODE_SWEEPOFF); // Input sBoxInput.nLeft = BBOX_LEFT(psBBox); - sBoxInput.nTop = BBOX_TOP(psBBox); - sBoxInput.nRight = sBoxInput.nLeft + BBOX_WIDTH(psBBox); - sBoxInput.nBottom = sBoxInput.nTop + BBOX_HEIGHT(psBBox); + sBoxInput.nTop = BBOX_TOP(psBBox); + sBoxInput.nRight = sBoxInput.nLeft + BBOX_WIDTH(psBBox); + sBoxInput.nBottom = sBoxInput.nTop + BBOX_HEIGHT(psBBox); // Output - sBoxOutput.nLeft = sBoxInput.nLeft; - sBoxOutput.nTop = sBoxInput.nTop; - sBoxOutput.nRight = sBoxOutput.nLeft + BBOX_WIDTH(psBBox); - sBoxOutput.nBottom = sBoxOutput.nTop + BBOX_HEIGHT(psBBox); + sBoxOutput.nLeft = sBoxInput.nLeft; + sBoxOutput.nTop = sBoxInput.nTop; + sBoxOutput.nRight = sBoxOutput.nLeft + BBOX_WIDTH(psBBox); + sBoxOutput.nBottom = sBoxOutput.nTop + BBOX_HEIGHT(psBBox); } // Debugging #ifdef DEBUG fprintf(stdout, "sBoxInput: %d %d %d %d (%d x %d)\n", sBoxInput.nLeft, sBoxInput.nTop, sBoxInput.nRight, sBoxInput.nBottom, - (sBoxInput.nRight - sBoxInput.nLeft), + (sBoxInput.nRight - sBoxInput.nLeft), (sBoxInput.nBottom - sBoxInput.nTop)); fprintf(stdout, "sBoxOutput: %d %d %d %d (%d x %d)\n", sBoxOutput.nLeft, sBoxOutput.nTop, sBoxOutput.nRight, sBoxOutput.nBottom, - (sBoxOutput.nRight - sBoxOutput.nLeft), + (sBoxOutput.nRight - sBoxOutput.nLeft), (sBoxOutput.nBottom - sBoxOutput.nTop)); #endif // DEBUG //------------------------------------------- - // Handle case in which bounding box is smaller than + // Handle case in which bounding box is smaller than // our filter: if ((BBOX_RIGHT(psBBox) - BBOX_LEFT(psBBox) < nFilterDim) || (BBOX_BOTTOM(psBBox) - BBOX_TOP(psBBox) < nFilterDim)) { @@ -2852,13 +2772,8 @@ int gaborCompute(const NUMPY_ARRAY * psGaborBank, // 1. Convert input image from float to integer32. // 2. If EDGE_MODE is SWEEPOFF, then add "padding pixels" // around the edges of the integrized input plane. - _prepareInput(psInput, - psBufferIn, - nFilterDim >> 1, - psBBox, - psImageBox, - eEdgeMode, - fOffImageFillValue); + _prepareInput(psInput, psBufferIn, nFilterDim >> 1, psBBox, psImageBox, + eEdgeMode, fOffImageFillValue); //------------------------------------------- // Perform convolution: @@ -2868,17 +2783,9 @@ int gaborCompute(const NUMPY_ARRAY * psGaborBank, // 2. While performing convolution, keeps track of the // neccessary statistics for use in normalization // during Pass II. - _doConvolution(psBufferIn, - psBufferOut, - psGaborBank, - psAlpha, - &sBoxInput, - &sBoxOutput, - ePhaseMode, - eNormalizeMethod, - eNormalizeMode, - anStatPosGrand, - anStatNegGrand); + _doConvolution(psBufferIn, psBufferOut, psGaborBank, psAlpha, &sBoxInput, + &sBoxOutput, ePhaseMode, eNormalizeMethod, eNormalizeMode, + anStatPosGrand, anStatNegGrand); //------------------------------------------- // Perform normalization and post-processing @@ -2887,28 +2794,12 @@ int gaborCompute(const NUMPY_ARRAY * psGaborBank, // image based on auto-normalization results); // 3. Apply post-processing method if any; // 4. Convert from integer 32 to float. - _postProcess(psBufferOut, - psOutput, - &sBoxOutput, - ePhaseMode, - nShrinkage, - eEdgeMode, - fGainConstant, - eNormalizeMethod, - eNormalizeMode, - ePhaseNormMode, - ePostProcMethod, - fPostProcSlope, - fPostProcMidpoint, - fPostProcMin, - fPostProcMax, - anStatPosGrand, - anStatNegGrand, - psPostProcLUT, - fPostProcScalar); - } - catch(std::exception& e) - { + _postProcess(psBufferOut, psOutput, &sBoxOutput, ePhaseMode, nShrinkage, + eEdgeMode, fGainConstant, eNormalizeMethod, eNormalizeMode, + ePhaseNormMode, ePostProcMethod, fPostProcSlope, + fPostProcMidpoint, fPostProcMin, fPostProcMax, anStatPosGrand, + anStatNegGrand, psPostProcLUT, fPostProcScalar); + } catch (std::exception &e) { NTA_WARN << "gaborNode -- returning error: " << e.what(); return -1; } @@ -2917,4 +2808,4 @@ int gaborCompute(const NUMPY_ARRAY * psGaborBank, #ifdef __cplusplus } -#endif +#endif diff --git a/src/nupic/algorithms/GaborNode.hpp b/src/nupic/algorithms/GaborNode.hpp index 50b643905e..c25722f689 100644 --- a/src/nupic/algorithms/GaborNode.hpp +++ b/src/nupic/algorithms/GaborNode.hpp @@ -20,10 +20,10 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * This header file defines the API for performing efficient * Gabor processing. - */ + */ #ifndef NTA_GABOR_NODE_HPP #define NTA_GABOR_NODE_HPP @@ -32,22 +32,22 @@ #ifdef __cplusplus extern "C" { -#endif // __cplusplus +#endif // __cplusplus -#include #include "ArrayBuffer.hpp" +#include // Number of bits that our gabor filter coefficients are // shifted (to the left) for scaling purposes, when // using GABOR_METHOD_INTEGER8 -# define GABOR_SCALING_SHIFT 12 +#define GABOR_SCALING_SHIFT 12 // For reasons of efficiency and simplicity, we'll store our // responses statistics (used for automated normalization) in // a static buffer of fixed size. Because of this, we need to // impose a constraint on the maximum number of filters; // 64 should be enough for anyone... :-) -#define MAXNUM_FILTERS 64 +#define MAXNUM_FILTERS 64 // Enumeration that specifies how we handle boundary // effects @@ -99,36 +99,24 @@ typedef enum _POSTPROC_METHOD { POSTPROC_METHOD__LAST } POSTPROC_METHOD; - // FUNCTION: gaborCompute() // PURPOSE: Implements efficient Gabor filtering. NTA_EXPORT -int gaborCompute(const NUMPY_ARRAY * psGaborBank, - const NUMPY_ARRAY * psInput, - const NUMPY_ARRAY * psAlpha, - const NUMPY_ARRAY * psBBox, - const NUMPY_ARRAY * psImageBox, - const NUMPY_ARRAY * psOutput, - float fGainConstant, - EDGE_MODE eEdgeMode, - float fOffImageFillValue, - PHASE_MODE ePhaseMode, - NORMALIZE_METHOD eNormalizeMethod, - NORMALIZE_MODE eNormalizeMode, - PHASENORM_MODE ePhaseNormMode, - POSTPROC_METHOD ePostProcMethod, - float fPostProcSlope, - float fPostProcMidpoint, - float fPostProcMin, - float fPostProcMax, - const NUMPY_ARRAY * psBufferIn, - const NUMPY_ARRAY * psBufferOut, - const NUMPY_ARRAY * psPostProcLUT, - float fPostProcScalar); - +int gaborCompute(const NUMPY_ARRAY *psGaborBank, const NUMPY_ARRAY *psInput, + const NUMPY_ARRAY *psAlpha, const NUMPY_ARRAY *psBBox, + const NUMPY_ARRAY *psImageBox, const NUMPY_ARRAY *psOutput, + float fGainConstant, EDGE_MODE eEdgeMode, + float fOffImageFillValue, PHASE_MODE ePhaseMode, + NORMALIZE_METHOD eNormalizeMethod, + NORMALIZE_MODE eNormalizeMode, PHASENORM_MODE ePhaseNormMode, + POSTPROC_METHOD ePostProcMethod, float fPostProcSlope, + float fPostProcMidpoint, float fPostProcMin, + float fPostProcMax, const NUMPY_ARRAY *psBufferIn, + const NUMPY_ARRAY *psBufferOut, + const NUMPY_ARRAY *psPostProcLUT, float fPostProcScalar); #ifdef __cplusplus } -#endif // __cplusplus +#endif // __cplusplus #endif // NTA_GABOR_NODE_HPP diff --git a/src/nupic/algorithms/ImageSensorLite.cpp b/src/nupic/algorithms/ImageSensorLite.cpp index 8688ce7ed7..8ea7b9220f 100644 --- a/src/nupic/algorithms/ImageSensorLite.cpp +++ b/src/nupic/algorithms/ImageSensorLite.cpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * This module implements efficient video-related image extraction. * * The C NUMPY_ARRAY structure mirrors an ARRAY class in @@ -28,11 +28,10 @@ * * This exported C function is expected to be used in conjunction * with ctypes wrappers around numpy array objects. - */ + */ -#include #include - +#include // Enable debugging //#define DEBUG 1 @@ -46,12 +45,11 @@ // Visual C++ (Windows) does not come with roundf() in // the standard library #if defined(NTA_OS_WINDOWS) -#define ROUND(x) ((x-floor(x))>0.5 ? ceil(x) : floor(x)) +#define ROUND(x) ((x - floor(x)) > 0.5 ? ceil(x) : floor(x)) #else -#define ROUND(x) (roundf(x)) +#define ROUND(x) (roundf(x)) #endif - // if INIT_FROM_PYTHON is defined, this module can initialize // logging from a python system reference. This introduces // a dependency on PythonSystem, which is not included in the @@ -60,25 +58,25 @@ #error "Unexpected Python dependency for imageSensorLite in NuPIC 2" #endif - #ifdef __cplusplus extern "C" { -#endif +#endif -#define GET_CTLBUF_ELEM(ctlBufAddr, k) (((int*)(ctlBufAddr))[k]) +#define GET_CTLBUF_ELEM(ctlBufAddr, k) (((int *)(ctlBufAddr))[k]) -#define BOX_LEFT(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 0) -#define BOX_TOP(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 1) -#define BOX_RIGHT(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 2) -#define BOX_BOTTOM(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 3) -#define DATA_ADDRESS(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 4) -#define DATA_ALPHA_ADDRESS(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 8) -#define PARTITION_ID(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 5) -#define CATEGORY_ID(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 6) -#define VIDEO_ID(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 7) +#define BOX_LEFT(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 0) +#define BOX_TOP(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 1) +#define BOX_RIGHT(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 2) +#define BOX_BOTTOM(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 3) +#define DATA_ADDRESS(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 4) +#define DATA_ALPHA_ADDRESS(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 8) +#define PARTITION_ID(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 5) +#define CATEGORY_ID(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 6) +#define VIDEO_ID(ctlBufAddr) GET_CTLBUF_ELEM(ctlBufAddr, 7) /* -#define DST_BUF_LEN(psDstBuffer) (((long int *)(psDstBuffer->pnDimensions))[0]) +#define DST_BUF_LEN(psDstBuffer) (((long int +*)(psDstBuffer->pnDimensions))[0]) #define BBOX_ELEM(bbox, k) (((int*)(bbox->pData))[k]) #define BBOX_LEFT(bbox) BBOX_ELEM(bbox, 0) @@ -101,20 +99,19 @@ extern "C" { #define IMAGE_ROWSTRIDE(array) IMAGE_STRIDE(array, 0) */ - /* // Hanning window of window length 9 // @todo: don't both convolving against the first and // last window elements because they're known // to be zero. static const float afHanning[] = { - 0.03661165, - 0.12500000, - 0.21338835, + 0.03661165, + 0.12500000, + 0.21338835, 0.25000000, - 0.21338835, - 0.12500000, - 0.03661165, + 0.21338835, + 0.12500000, + 0.03661165, }; @@ -125,7 +122,7 @@ static const float afHanning[] = { // FUNCTION: _smooth() // PURPOSE: Smooth a 1D histogram -float _smoothHist1D(const float * pfHistogram, +float _smoothHist1D(const float * pfHistogram, float * pfReflHist, float * pfSmoothHist, int nHistWidth) { @@ -136,13 +133,13 @@ float _smoothHist1D(const float * pfHistogram, float fRef = 2.0 * (*pfHistogram); // Construct reflected (extended) histogrm - float * pfReflHistPtr = pfReflHist + HANNING_HALF_LEN; - const float * pfHistPtr = pfHistogram; + float * pfReflHistPtr = pfReflHist + HANNING_HALF_LEN; + const float * pfHistPtr = pfHistogram; // Reflect leading elements for (k=HANNING_HALF_LEN; k; k--) *--pfReflHistPtr = fRef - *++pfHistPtr; // Copy internal elemenents - pfHistPtr = pfHistogram; + pfHistPtr = pfHistogram; pfReflHistPtr += HANNING_HALF_LEN; for (k=nHistWidth; k; k--) *pfReflHistPtr++ = *pfHistPtr++; @@ -156,7 +153,7 @@ float _smoothHist1D(const float * pfHistogram, float * pfSmoothHistPtr = pfSmoothHist; for (k=nHistWidth; k; k--) { pfReflHistPtr = pfReflHist; - pfWindowPtr = afHanning; + pfWindowPtr = afHanning; fAccum = 0.0f; for (i=HANNING_LEN; i; i--) fAccum += (*pfWindowPtr++) * (*pfReflHistPtr++); @@ -191,13 +188,9 @@ void _formHistogramX(// Inputs: // shows, for each column, the number of pixels that // had non-zero SMotion. const float * pfSrc = (const float *)(psSrcImage->pData); - const float * pfSrcPtr = pfSrc + IMAGE_COLS(psSrcImage) * psBox->nTop + psBox->nLeft; - for (j=psBox->nBottom - psBox->nTop; j; j-- ) { - pfHistPtr = pfHist; - for (i=nBoxWidth; i; i-- ) - if (*pfSrcPtr++) - *pfHistPtr++ += 1.0f; - else + const float * pfSrcPtr = pfSrc + IMAGE_COLS(psSrcImage) * psBox->nTop + +psBox->nLeft; for (j=psBox->nBottom - psBox->nTop; j; j-- ) { pfHistPtr = +pfHist; for (i=nBoxWidth; i; i-- ) if (*pfSrcPtr++) *pfHistPtr++ += 1.0f; else pfHistPtr++; pfSrcPtr += nRowAdvance; } @@ -222,8 +215,8 @@ void _formHistogramY(// Inputs: // shows, for each column, the number of pixels that // had non-zero SMotion. const float * pfSrc = (const float *)(psSrcImage->pData); - const float * pfSrcPtr = pfSrc + IMAGE_COLS(psSrcImage) * psBox->nTop + psBox->nLeft; - NTA_ASSERT(psBox->nBottom - psBox->nTop >= 0); + const float * pfSrcPtr = pfSrc + IMAGE_COLS(psSrcImage) * psBox->nTop + +psBox->nLeft; NTA_ASSERT(psBox->nBottom - psBox->nTop >= 0); NTA_ASSERT(nBoxWidth >= 0); for (j=psBox->nBottom - psBox->nTop; j; j-- ) { fAccum = 0.0f; @@ -292,8 +285,9 @@ int adjustBox( // Inputs: BBOX boxExpanded; boxExpanded.nLeft = MAX(0, psBox->nLeft - psParams->nZonePreExpansionX); boxExpanded.nTop = MAX(0, psBox->nTop - psParams->nZonePreExpansionY); - boxExpanded.nRight = MIN(nImageWidth, psBox->nRight + psParams->nZonePreExpansionX); - boxExpanded.nBottom = MIN(nImageHeight, psBox->nBottom + psParams->nZonePreExpansionY); + boxExpanded.nRight = MIN(nImageWidth, psBox->nRight + +psParams->nZonePreExpansionX); boxExpanded.nBottom = MIN(nImageHeight, +psBox->nBottom + psParams->nZonePreExpansionY); // Sanity checks NTA_ASSERT(boxExpanded.nBottom >= boxExpanded.nTop); @@ -304,7 +298,7 @@ int adjustBox( // Inputs: float fExpandedWidth = (float)nExpandedWidth; float fExpandedHeight = (float)nExpandedHeight; - // Generate a horizontal histogram + // Generate a horizontal histogram _formHistogramX(psSrcImage, &boxExpanded, afHistogramX); // Smooth the horizontal histogram @@ -320,12 +314,16 @@ int adjustBox( // Inputs: // zone that we'll accept. // This is the max of an absolute length and a // minimum fraction of the original box. - //nMinZoneLen = MAX(psParams->nMinAbsZoneLenX, (int)roundf(psParams->fMinRelZoneLenX * fExpandedWidth)); - nMinZoneLen = MAX(psParams->nMinAbsZoneLenX, (int)ROUND(psParams->fMinRelZoneLenX * fExpandedWidth)); + //nMinZoneLen = MAX(psParams->nMinAbsZoneLenX, +(int)roundf(psParams->fMinRelZoneLenX * fExpandedWidth)); nMinZoneLen = +MAX(psParams->nMinAbsZoneLenX, (int)ROUND(psParams->fMinRelZoneLenX * +fExpandedWidth)); // Minimum length for a weak gap - //nMinWeakLen = MAX(psParams->nMinAbsWeakLenX, (int)roundf(psParams->fMinRelWeakLenX * fExpandedWidth)); - nMinWeakLen = MAX(psParams->nMinAbsWeakLenX, (int)ROUND(psParams->fMinRelWeakLenX * fExpandedWidth)); + //nMinWeakLen = MAX(psParams->nMinAbsWeakLenX, +(int)roundf(psParams->fMinRelWeakLenX * fExpandedWidth)); nMinWeakLen = +MAX(psParams->nMinAbsWeakLenX, (int)ROUND(psParams->fMinRelWeakLenX * +fExpandedWidth)); // For now, simple threshold fThreshX = psParams->fHeightThresh * fMaxX; @@ -340,7 +338,7 @@ int adjustBox( // Inputs: *pnStrongPtr++ = 0; } - // Pre-calculate the minimum peak strength for + // Pre-calculate the minimum peak strength for // each lobe to avoid being culled float fMinStrength = fMaxX * psParams->fSecondaryHeightThresh; @@ -357,7 +355,8 @@ int adjustBox( // Inputs: if (nAntiDelta < 0) { NTA_ASSERT(nCandidateBegin == -1); // Check if gap was too small - if (nNumStrongZonesX && ((k - anStrongEndX[nNumStrongZonesX-1]) <= nMinWeakLen)) { + if (nNumStrongZonesX && ((k - anStrongEndX[nNumStrongZonesX-1]) <= +nMinWeakLen)) { // Re-start the previous strong zone nCandidateBegin = anStrongBeginX[--nNumStrongZonesX]; } else { @@ -389,13 +388,13 @@ int adjustBox( // Inputs: // Last one if (nCandidateBegin >= 0) { - if (nNumToCheck - nCandidateBegin >= nMinZoneLen && fPeakStrength >= fMinStrength) { - anStrongBeginX[nNumStrongZonesX] = nCandidateBegin; + if (nNumToCheck - nCandidateBegin >= nMinZoneLen && fPeakStrength >= +fMinStrength) { anStrongBeginX[nNumStrongZonesX] = nCandidateBegin; anStrongEndX[nNumStrongZonesX++] = nNumToCheck; } } - + //----------------------------------------------------------------------- // Apply tightening/splitting in vertical direction (to each strong zone) for (k=0; knMinAbsZoneLenY, (int)roundf(psParams->fMinRelZoneLenY * fExpandedHeight)); - nMinZoneLen = MAX(psParams->nMinAbsZoneLenY, (int)ROUND(psParams->fMinRelZoneLenY * fExpandedHeight)); + //nMinZoneLen = MAX(psParams->nMinAbsZoneLenY, +(int)roundf(psParams->fMinRelZoneLenY * fExpandedHeight)); nMinZoneLen = +MAX(psParams->nMinAbsZoneLenY, (int)ROUND(psParams->fMinRelZoneLenY * +fExpandedHeight)); // Minimum length for a weak gap - //nMinWeakLen = MAX(psParams->nMinAbsWeakLenY, (int)roundf(psParams->fMinRelWeakLenY * fExpandedHeight)); - nMinWeakLen = MAX(psParams->nMinAbsWeakLenY, (int)ROUND(psParams->fMinRelWeakLenY * fExpandedHeight)); + //nMinWeakLen = MAX(psParams->nMinAbsWeakLenY, +(int)roundf(psParams->fMinRelWeakLenY * fExpandedHeight)); nMinWeakLen = +MAX(psParams->nMinAbsWeakLenY, (int)ROUND(psParams->fMinRelWeakLenY * +fExpandedHeight)); // For now, simple threshold fThreshY = psParams->fWidthThresh * fMaxY; @@ -439,7 +442,7 @@ int adjustBox( // Inputs: *pnStrongPtr++ = 0; } - // Pre-calculate the minimum peak strength for + // Pre-calculate the minimum peak strength for // each lobe to avoid being culled fMinStrength = fMaxY * psParams->fSecondaryWidthThresh; @@ -455,7 +458,8 @@ int adjustBox( // Inputs: if (nAntiDelta < 0) { NTA_ASSERT(nCandidateBegin == -1); // Check if gap was too small - if (nNumStrongZonesY && ((j - anStrongEndY[nNumStrongZonesY-1]) <= nMinWeakLen)) { + if (nNumStrongZonesY && ((j - anStrongEndY[nNumStrongZonesY-1]) <= +nMinWeakLen)) { // Re-start the previous strong zone nCandidateBegin = anStrongBeginY[--nNumStrongZonesY]; } else { @@ -467,8 +471,8 @@ int adjustBox( // Inputs: else if (nAntiDelta > 0) { NTA_ASSERT(nCandidateBegin >= 0); // Accept or cull the zone - if (j-nCandidateBegin >= nMinZoneLen && fPeakStrength >= fMinStrength) { - anStrongBeginY[nNumStrongZonesY] = nCandidateBegin; + if (j-nCandidateBegin >= nMinZoneLen && fPeakStrength >= fMinStrength) +{ anStrongBeginY[nNumStrongZonesY] = nCandidateBegin; anStrongEndY[nNumStrongZonesY++] = j; // Make sure we don't exceed our hard-coded limits // with a crazily fragmented pathological smotion image @@ -487,8 +491,8 @@ int adjustBox( // Inputs: // Last one if (nCandidateBegin >= 0) { - if (nNumToCheck - nCandidateBegin >= nMinZoneLen && fPeakStrength >= fMinStrength) { - anStrongBeginY[nNumStrongZonesY] = nCandidateBegin; + if (nNumToCheck - nCandidateBegin >= nMinZoneLen && fPeakStrength >= +fMinStrength) { anStrongBeginY[nNumStrongZonesY] = nCandidateBegin; anStrongEndY[nNumStrongZonesY++] = nNumToCheck; } } @@ -499,10 +503,12 @@ int adjustBox( // Inputs: // pathological smotion image) for (j=0; jnLeft = MAX(0, boxStrong.nLeft - psParams->nZonePostExpansionX); - pboxNew->nTop = MAX(0, boxStrong.nTop + anStrongBeginY[j] - psParams->nZonePostExpansionY); - pboxNew->nRight = MIN(nImageWidth, boxStrong.nRight + psParams->nZonePostExpansionX); - pboxNew->nBottom = MIN(nImageHeight, boxStrong.nTop + anStrongEndY[j] + psParams->nZonePostExpansionY); + pboxNew->nLeft = MAX(0, boxStrong.nLeft - +psParams->nZonePostExpansionX); pboxNew->nTop = MAX(0, boxStrong.nTop + +anStrongBeginY[j] - psParams->nZonePostExpansionY); pboxNew->nRight = +MIN(nImageWidth, boxStrong.nRight + psParams->nZonePostExpansionX); + pboxNew->nBottom = MIN(nImageHeight, boxStrong.nTop + anStrongEndY[j] + +psParams->nZonePostExpansionY); } } @@ -518,9 +524,8 @@ int adjustBox( // Inputs: int nBiggestArea = 0; for (k=0; knRight - pboxFinal->nLeft) * (pboxFinal->nBottom - pboxFinal->nTop); - if (nBoxArea > nBiggestArea) { - nBiggestArea = nBoxArea; + nBoxArea = (pboxFinal->nRight - pboxFinal->nLeft) * (pboxFinal->nBottom +- pboxFinal->nTop); if (nBoxArea > nBiggestArea) { nBiggestArea = nBoxArea; nBiggestIndex = k; } } @@ -568,27 +573,20 @@ int adjustBox( // Inputs: } */ - // FUNCTION: extractAuxInfo() // PURPOSE: Extract auxiliary information NTA_EXPORT -int extractAuxInfo(// Inputs: - //const NUMPY_ARRAY * psCtlBuf, - const char * pCtlBufAddr, - //const NUMPY_ARRAY * psBBox, - //const NUMPY_ARRAY * psCategoryBuf, - //const NUMPY_ARRAY * psPartitionBuf, - //const NUMPY_ARRAY * psAddressBuf, - BBOX * psBox, - int * pnAddress, - int * pnPartitionID, - int * pnCategoryID, - int * pnVideoID, - int * pnAlphaAddress - ) { +int extractAuxInfo( // Inputs: + // const NUMPY_ARRAY * psCtlBuf, + const char *pCtlBufAddr, + // const NUMPY_ARRAY * psBBox, + // const NUMPY_ARRAY * psCategoryBuf, + // const NUMPY_ARRAY * psPartitionBuf, + // const NUMPY_ARRAY * psAddressBuf, + BBOX *psBox, int *pnAddress, int *pnPartitionID, int *pnCategoryID, + int *pnVideoID, int *pnAlphaAddress) { - try - { + try { /* // Extract partition and category IDs if (psPartitionBuf) @@ -602,9 +600,9 @@ int extractAuxInfo(// Inputs: */ // Extract BBOX - psBox->nLeft = BOX_LEFT(pCtlBufAddr); - psBox->nTop = BOX_TOP(pCtlBufAddr); - psBox->nRight = BOX_RIGHT(pCtlBufAddr); + psBox->nLeft = BOX_LEFT(pCtlBufAddr); + psBox->nTop = BOX_TOP(pCtlBufAddr); + psBox->nRight = BOX_RIGHT(pCtlBufAddr); psBox->nBottom = BOX_BOTTOM(pCtlBufAddr); // Extract partition and category IDs @@ -638,15 +636,13 @@ int extractAuxInfo(// Inputs: */ } - catch(std::exception& e) - { + catch (std::exception &e) { NTA_WARN << "gaborNode -- returning error: " << e.what(); return -1; } return 0; } - /* // FUNCTION: accessPixels() // PURPOSE: Access pixels of a numpy array @@ -672,8 +668,7 @@ int accessPixels(// Inputs: } */ - -/* +/* // OBSOLETE FUNCTION: formHistogramX() // PURPOSE: Form a histogram of non-zero SMotion NTA_EXPORT @@ -681,7 +676,7 @@ int formHistogramX(// Inputs: const NUMPY_ARRAY * psSrcImage, const BBOX * psBox, // Outputs: - const NUMPY_ARRAY * psHistogram + const NUMPY_ARRAY * psHistogram ) { try { float * pfHist = (float *)(psHistogram->pData); @@ -703,7 +698,7 @@ int formHistogramY(// Inputs: const NUMPY_ARRAY * psSrcImage, const BBOX * psBox, // Outputs: - const NUMPY_ARRAY * psHistogram + const NUMPY_ARRAY * psHistogram ) { try { float * pfHist = (float *)(psHistogram->pData); @@ -718,7 +713,6 @@ int formHistogramY(// Inputs: } */ - #ifdef __cplusplus } -#endif +#endif diff --git a/src/nupic/algorithms/ImageSensorLite.hpp b/src/nupic/algorithms/ImageSensorLite.hpp index 1e6fa31a14..4d4ac01c1f 100644 --- a/src/nupic/algorithms/ImageSensorLite.hpp +++ b/src/nupic/algorithms/ImageSensorLite.hpp @@ -20,10 +20,10 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * This header file defines the API for performing efficient * VideoSensorNode processing. - */ + */ #ifndef NTA_VIDEO_SENSOR_NODE_HPP #define NTA_VIDEO_SENSOR_NODE_HPP @@ -32,36 +32,34 @@ #ifdef __cplusplus extern "C" { -#endif // __cplusplus +#endif // __cplusplus -#include #include "ArrayBuffer.hpp" - +#include // ImageSensorLite control buffer typedef struct _ISL_CTLBUF { // Bounding box - int nBoxLeft; - int nBoxTop; - int nBoxRight; - int nBoxBottom; + int nBoxLeft; + int nBoxTop; + int nBoxRight; + int nBoxBottom; // Address of buffer holding actual pixel data - void * pDataAddr; + void *pDataAddr; // Optional: partition ID - int nPartitionID; + int nPartitionID; // Optional: category ID - int nCategoryID; + int nCategoryID; // Optional: video ID - int nVideoID; + int nVideoID; // Optional: address of buffer holding alpha data - void * pAlphaAddr; + void *pAlphaAddr; } ISL_CTLBUF; - /* // Compile-time assertion to make sure that // the ISL_CTLBUF is 32-bytes long on this -// platform; if not, this code will +// platform; if not, this code will // generate a 'duplicate case value' error // at compile time. #define COMPILE_TIME_ASSERT(pred) \ @@ -73,16 +71,15 @@ void compile_time_assertions(void) { } */ - /* -// Structure that wraps the essential elements of +// Structure that wraps the essential elements of // a numpy array object. typedef struct _NUMPY_ARRAY { int nNumDims; const int * pnDimensions; const int * pnStrides; const char * pData; -} NUMPY_ARRAY; +} NUMPY_ARRAY; // Bounding box typedef struct _BBOX { @@ -94,29 +91,29 @@ typedef struct _BBOX { */ typedef struct _BOXFIXER_PARAMS { - int nZonePreExpansionX; - int nZonePreExpansionY; - int nZonePostExpansionX; - int nZonePostExpansionY; - int nWindowLenX; - int nWindowLenY; - int nMinAbsZoneLenX; - int nMinAbsZoneLenY; - float fMinRelZoneLenX; - float fMinRelZoneLenY; - int nMinAbsWeakLenX; - int nMinAbsWeakLenY; - float fMinRelWeakLenX; - float fMinRelWeakLenY; - float fHeightThresh; - float fWidthThresh; - float fSecondaryHeightThresh; - float fSecondaryWidthThresh; - int nTakeBiggest; + int nZonePreExpansionX; + int nZonePreExpansionY; + int nZonePostExpansionX; + int nZonePostExpansionY; + int nWindowLenX; + int nWindowLenY; + int nMinAbsZoneLenX; + int nMinAbsZoneLenY; + float fMinRelZoneLenX; + float fMinRelZoneLenY; + int nMinAbsWeakLenX; + int nMinAbsWeakLenY; + float fMinRelWeakLenX; + float fMinRelWeakLenY; + float fHeightThresh; + float fWidthThresh; + float fSecondaryHeightThresh; + float fSecondaryWidthThresh; + int nTakeBiggest; } BOXFIXER_PARAMS; -#define MAX_BBOX_WIDTH 640 -#define MAX_BBOX_HEIGHT 480 +#define MAX_BBOX_WIDTH 640 +#define MAX_BBOX_HEIGHT 480 /* // Structure that wraps specification of the @@ -125,48 +122,40 @@ typedef struct _BOXFIXER_PARAMS { typedef struct _TARGETSIZE { int nWidth; int nHeight; -} TARGETSIZE; +} TARGETSIZE; typedef struct _TARGETSIZES { int nNumScales; TARGETSIZE anScales[MAXNUM_SCALES]; -} TARGETSIZES; +} TARGETSIZES; */ // FUNCTION: adjustBox() // PURPOSE: Implements efficient adjustment of tracking box NTA_EXPORT int adjustBox( // Inputs: - const NUMPY_ARRAY * psSrcImage, - const BBOX * psBox, - // Parameters: - const BOXFIXER_PARAMS * psParams, - // Outputs: - BBOX * psFixedBox, - int * pnTotNumBoxes); - + const NUMPY_ARRAY *psSrcImage, const BBOX *psBox, + // Parameters: + const BOXFIXER_PARAMS *psParams, + // Outputs: + BBOX *psFixedBox, int *pnTotNumBoxes); // FUNCTION: accessPixels() // PURPOSE: Access pixels of a numpy array NTA_EXPORT -int accessPixels(// Inputs: - const NUMPY_ARRAY * psSrcImage, - // Outputs: - const NUMPY_ARRAY * psDstImage); - +int accessPixels( // Inputs: + const NUMPY_ARRAY *psSrcImage, + // Outputs: + const NUMPY_ARRAY *psDstImage); // FUNCTION: extractAuxInfo() // PURPOSE: Extract auxiliary information NTA_EXPORT -int extractAuxInfo(// Inputs: - const char * pCtlBufAddr, - //const NUMPY_ARRAY * psCtlBuf, - // Outputs: - BBOX * psBox, - int * pnAddress, - int * pnPartitionID, - int * pnCategoryID, - int * pnVideoID, - int * pnAlphaAddress ); +int extractAuxInfo( // Inputs: + const char *pCtlBufAddr, + // const NUMPY_ARRAY * psCtlBuf, + // Outputs: + BBOX *psBox, int *pnAddress, int *pnPartitionID, int *pnCategoryID, + int *pnVideoID, int *pnAlphaAddress); /* int extractAuxInfo(// Inputs: const NUMPY_ARRAY * psBBox, @@ -200,6 +189,6 @@ int formHistogramY(// Inputs: #ifdef __cplusplus } -#endif // __cplusplus +#endif // __cplusplus #endif // NTA_VIDEO_SENSOR_NODE_HPP diff --git a/src/nupic/algorithms/InSynapse.cpp b/src/nupic/algorithms/InSynapse.cpp index 90a41bef0e..f6ece31d56 100644 --- a/src/nupic/algorithms/InSynapse.cpp +++ b/src/nupic/algorithms/InSynapse.cpp @@ -25,22 +25,20 @@ using namespace nupic::algorithms::Cells4; -inline void InSynapse::print(std::ostream& outStream) const -{ +inline void InSynapse::print(std::ostream &outStream) const { outStream << _srcCellIdx << ',' << std::setprecision(4) << _permanence; } //-------------------------------------------------------------------------------- namespace nupic { - namespace algorithms { - namespace Cells4 { +namespace algorithms { +namespace Cells4 { - std::ostream& operator<<(std::ostream& outStream, const InSynapse& s) - { - s.print(outStream); - return outStream; - } - } - } +std::ostream &operator<<(std::ostream &outStream, const InSynapse &s) { + s.print(outStream); + return outStream; } +} // namespace Cells4 +} // namespace algorithms +} // namespace nupic diff --git a/src/nupic/algorithms/InSynapse.hpp b/src/nupic/algorithms/InSynapse.hpp index cd1c197e6b..dfd18f5e5f 100644 --- a/src/nupic/algorithms/InSynapse.hpp +++ b/src/nupic/algorithms/InSynapse.hpp @@ -25,69 +25,59 @@ #include -#include #include +#include using namespace nupic; //-------------------------------------------------------------------------------- namespace nupic { - namespace algorithms { - namespace Cells4 { - - - //-------------------------------------------------------------------------------- - //-------------------------------------------------------------------------------- - /** - * The type of synapse contained in a Segment. It has the source cell index - * of the synapse, and a permanence value. The source cell index is between - * 0 and nCols * nCellsPerCol. - */ - class InSynapse - { - private: - UInt _srcCellIdx; - Real _permanence; - - public: - inline InSynapse() - : _srcCellIdx((UInt) -1), - _permanence(0) - {} - - inline InSynapse(UInt srcCellIdx, Real permanence) - : _srcCellIdx(srcCellIdx), - _permanence(permanence) - {} - - inline InSynapse(const InSynapse& o) - : _srcCellIdx(o._srcCellIdx), - _permanence(o._permanence) - {} - - inline InSynapse& operator=(const InSynapse& o) - { - _srcCellIdx = o._srcCellIdx; - _permanence = o._permanence; - return *this; - } - - inline UInt srcCellIdx() const { return _srcCellIdx; } const - inline Real& permanence() const { return _permanence; } - inline Real& permanence() { return _permanence; } - - inline void print(std::ostream& outStream) const; - }; - - //-------------------------------------------------------------------------------- +namespace algorithms { +namespace Cells4 { + +//-------------------------------------------------------------------------------- +//-------------------------------------------------------------------------------- +/** + * The type of synapse contained in a Segment. It has the source cell index + * of the synapse, and a permanence value. The source cell index is between + * 0 and nCols * nCellsPerCol. + */ +class InSynapse { +private: + UInt _srcCellIdx; + Real _permanence; + +public: + inline InSynapse() : _srcCellIdx((UInt)-1), _permanence(0) {} + + inline InSynapse(UInt srcCellIdx, Real permanence) + : _srcCellIdx(srcCellIdx), _permanence(permanence) {} + + inline InSynapse(const InSynapse &o) + : _srcCellIdx(o._srcCellIdx), _permanence(o._permanence) {} + + inline InSynapse &operator=(const InSynapse &o) { + _srcCellIdx = o._srcCellIdx; + _permanence = o._permanence; + return *this; + } + + inline UInt srcCellIdx() const { return _srcCellIdx; } + const inline Real &permanence() const { return _permanence; } + inline Real &permanence() { return _permanence; } + + inline void print(std::ostream &outStream) const; +}; + +//-------------------------------------------------------------------------------- #ifndef SWIG - std::ostream& operator<<(std::ostream& outStream, const InSynapse& s); +std::ostream &operator<<(std::ostream &outStream, const InSynapse &s); #endif - // end namespace - } - } -} +// end namespace +} // namespace Cells4 +} // namespace algorithms +} // namespace nupic #endif // NTA_INSYNAPSE_HPP diff --git a/src/nupic/algorithms/OutSynapse.cpp b/src/nupic/algorithms/OutSynapse.cpp index a19d1d7029..ec3f02542a 100644 --- a/src/nupic/algorithms/OutSynapse.cpp +++ b/src/nupic/algorithms/OutSynapse.cpp @@ -20,14 +20,12 @@ * --------------------------------------------------------------------- */ - -#include #include +#include using namespace nupic::algorithms::Cells4; using namespace nupic; -bool OutSynapse::invariants(Cells4* cells) const -{ +bool OutSynapse::invariants(Cells4 *cells) const { bool ok = true; if (cells) { ok &= _dstCellIdx < cells->nCells(); @@ -37,12 +35,11 @@ bool OutSynapse::invariants(Cells4* cells) const } namespace nupic { - namespace algorithms { - namespace Cells4 { - bool operator==(const OutSynapse& a, const OutSynapse& b) - { - return a.equals(b); - } - } - } +namespace algorithms { +namespace Cells4 { +bool operator==(const OutSynapse &a, const OutSynapse &b) { + return a.equals(b); } +} // namespace Cells4 +} // namespace algorithms +} // namespace nupic diff --git a/src/nupic/algorithms/OutSynapse.hpp b/src/nupic/algorithms/OutSynapse.hpp index a535eb8dc1..6d2aeb3302 100644 --- a/src/nupic/algorithms/OutSynapse.hpp +++ b/src/nupic/algorithms/OutSynapse.hpp @@ -29,85 +29,72 @@ using namespace nupic; namespace nupic { - namespace algorithms { - namespace Cells4 { - - - class Cells4; - //-------------------------------------------------------------------------------- - //-------------------------------------------------------------------------------- - /** - * The type of synapse we use to propagate activation forward. It contains - * indices for the *destination* cell, and the *destination* segment on that cell. - * The cell index is between 0 and nCols * nCellsPerCol. - */ - class OutSynapse - { - public: - - private: - UInt _dstCellIdx; - UInt _dstSegIdx; // index in _segActivity - - public: - OutSynapse(UInt dstCellIdx =(UInt) -1, - UInt dstSegIdx =(UInt) -1 - //Cells4* cells =NULL - ) - : _dstCellIdx(dstCellIdx), - _dstSegIdx(dstSegIdx) - { - // TODO: FIX this - //NTA_ASSERT(invariants(cells)); - } - - OutSynapse(const OutSynapse& o) - : _dstCellIdx(o._dstCellIdx), - _dstSegIdx(o._dstSegIdx) - {} +namespace algorithms { +namespace Cells4 { + +class Cells4; +//-------------------------------------------------------------------------------- +//-------------------------------------------------------------------------------- +/** + * The type of synapse we use to propagate activation forward. It contains + * indices for the *destination* cell, and the *destination* segment on that + * cell. The cell index is between 0 and nCols * nCellsPerCol. + */ +class OutSynapse { +public: +private: + UInt _dstCellIdx; + UInt _dstSegIdx; // index in _segActivity + +public: + OutSynapse(UInt dstCellIdx = (UInt)-1, UInt dstSegIdx = (UInt)-1 + // Cells4* cells =NULL + ) + : _dstCellIdx(dstCellIdx), _dstSegIdx(dstSegIdx) { + // TODO: FIX this + // NTA_ASSERT(invariants(cells)); + } - OutSynapse& operator=(const OutSynapse& o) - { - _dstCellIdx = o._dstCellIdx; - _dstSegIdx = o._dstSegIdx; - return *this; - } + OutSynapse(const OutSynapse &o) + : _dstCellIdx(o._dstCellIdx), _dstSegIdx(o._dstSegIdx) {} - UInt dstCellIdx() const { return _dstCellIdx; } - UInt dstSegIdx() const { return _dstSegIdx; } + OutSynapse &operator=(const OutSynapse &o) { + _dstCellIdx = o._dstCellIdx; + _dstSegIdx = o._dstSegIdx; + return *this; + } - /** - * Checks whether this outgoing synapses is going to given destination - * or not. - */ - bool goesTo(UInt dstCellIdx, UInt dstSegIdx) const - { - return _dstCellIdx == dstCellIdx && _dstSegIdx == dstSegIdx; - } + UInt dstCellIdx() const { return _dstCellIdx; } + UInt dstSegIdx() const { return _dstSegIdx; } - /** - * Need for is_in/not_in tests. - */ - bool equals(const OutSynapse& o) const - { - return _dstCellIdx == o._dstCellIdx && _dstSegIdx == o._dstSegIdx; - } + /** + * Checks whether this outgoing synapses is going to given destination + * or not. + */ + bool goesTo(UInt dstCellIdx, UInt dstSegIdx) const { + return _dstCellIdx == dstCellIdx && _dstSegIdx == dstSegIdx; + } - /** - * Checks that the destination cell index and destination segment index - * are in range. - */ - bool invariants(Cells4* cells =nullptr) const; - }; + /** + * Need for is_in/not_in tests. + */ + bool equals(const OutSynapse &o) const { + return _dstCellIdx == o._dstCellIdx && _dstSegIdx == o._dstSegIdx; + } - //-------------------------------------------------------------------------------- - bool operator==(const OutSynapse& a, const OutSynapse& b); + /** + * Checks that the destination cell index and destination segment index + * are in range. + */ + bool invariants(Cells4 *cells = nullptr) const; +}; +//-------------------------------------------------------------------------------- +bool operator==(const OutSynapse &a, const OutSynapse &b); - // End namespace - } - } -} +// End namespace +} // namespace Cells4 +} // namespace algorithms +} // namespace nupic #endif // NTA_OUTSYNAPSE_HPP - diff --git a/src/nupic/algorithms/SDRClassifier.cpp b/src/nupic/algorithms/SDRClassifier.cpp index 0fe73869d1..30868b8930 100644 --- a/src/nupic/algorithms/SDRClassifier.cpp +++ b/src/nupic/algorithms/SDRClassifier.cpp @@ -25,10 +25,10 @@ #include #include #include -#include #include -#include #include +#include +#include #include #include @@ -36,681 +36,559 @@ #include #include -#include #include +#include #include using namespace std; -namespace nupic -{ - namespace algorithms - { - namespace sdr_classifier - { - - SDRClassifier::SDRClassifier( - const vector& steps, Real64 alpha, Real64 actValueAlpha, - UInt verbosity) : steps_(steps), alpha_(alpha), - actValueAlpha_(actValueAlpha), maxInputIdx_(0), maxBucketIdx_(0), - actualValues_({0.0}), actualValuesSet_({false}), - version_(sdrClassifierVersion), verbosity_(verbosity) - { - sort(steps_.begin(), steps_.end()); - if (steps_.size() > 0) - { - maxSteps_ = steps_.at(steps_.size() - 1) + 1; - } else { - maxSteps_ = 1; - } +namespace nupic { +namespace algorithms { +namespace sdr_classifier { + +SDRClassifier::SDRClassifier(const vector &steps, Real64 alpha, + Real64 actValueAlpha, UInt verbosity) + : steps_(steps), alpha_(alpha), actValueAlpha_(actValueAlpha), + maxInputIdx_(0), maxBucketIdx_(0), actualValues_({0.0}), + actualValuesSet_({false}), version_(sdrClassifierVersion), + verbosity_(verbosity) { + sort(steps_.begin(), steps_.end()); + if (steps_.size() > 0) { + maxSteps_ = steps_.at(steps_.size() - 1) + 1; + } else { + maxSteps_ = 1; + } - // TODO: insert maxBucketIdx / maxInputIdx hint as parameter? - // There can be great overhead reallocating the array every time a new - // input is seen, especially if we start at (0, 0). The client will - // usually know what is the final maxInputIdx (typically the number - // of columns?), and we can have heuristics using the encoder's - // settings to get an good approximate of the maxBucketIdx, thus having - // to reallocate this matrix only a few times, even never if we use - // lower bounds - for (const auto& step : steps_) - { - weightMatrix_.emplace(step, Matrix(maxInputIdx_ + 1, - maxBucketIdx_ + 1)); - } - } + // TODO: insert maxBucketIdx / maxInputIdx hint as parameter? + // There can be great overhead reallocating the array every time a new + // input is seen, especially if we start at (0, 0). The client will + // usually know what is the final maxInputIdx (typically the number + // of columns?), and we can have heuristics using the encoder's + // settings to get an good approximate of the maxBucketIdx, thus having + // to reallocate this matrix only a few times, even never if we use + // lower bounds + for (const auto &step : steps_) { + weightMatrix_.emplace(step, Matrix(maxInputIdx_ + 1, maxBucketIdx_ + 1)); + } +} - SDRClassifier::~SDRClassifier() - { - } +SDRClassifier::~SDRClassifier() {} + +void SDRClassifier::compute(UInt recordNum, const vector &patternNZ, + const vector &bucketIdxList, + const vector &actValueList, bool category, + bool learn, bool infer, ClassifierResult *result) { + // ensures that recordNum increases monotonically + UInt lastRecordNum = -1; + if (recordNumHistory_.size() > 0) { + lastRecordNum = recordNumHistory_[recordNumHistory_.size() - 1]; + if (recordNum < lastRecordNum) + NTA_THROW << "the record number has to increase monotonically"; + } - void SDRClassifier::compute( - UInt recordNum, const vector& patternNZ, const vector& bucketIdxList, - const vector& actValueList, bool category, bool learn, bool infer, - ClassifierResult* result) - { - // ensures that recordNum increases monotonically - UInt lastRecordNum = -1; - if (recordNumHistory_.size() > 0) - { - lastRecordNum = recordNumHistory_[recordNumHistory_.size()-1]; - if (recordNum < lastRecordNum) - NTA_THROW << "the record number has to increase monotonically"; - } + // update pattern history if this is a new record + if (recordNumHistory_.size() == 0 || recordNum > lastRecordNum) { + patternNZHistory_.emplace_back(patternNZ.begin(), patternNZ.end()); + recordNumHistory_.push_back(recordNum); + if (patternNZHistory_.size() > maxSteps_) { + patternNZHistory_.pop_front(); + recordNumHistory_.pop_front(); + } + } - // update pattern history if this is a new record - if (recordNumHistory_.size() == 0 || recordNum > lastRecordNum) - { - patternNZHistory_.emplace_back(patternNZ.begin(), patternNZ.end()); - recordNumHistory_.push_back(recordNum); - if (patternNZHistory_.size() > maxSteps_) - { - patternNZHistory_.pop_front(); - recordNumHistory_.pop_front(); - } - } + // if input pattern has greater index than previously seen, update + // maxInputIdx and augment weight matrix with zero padding + if (patternNZ.size() > 0) { + UInt maxInputIdx = *max_element(patternNZ.begin(), patternNZ.end()); + if (maxInputIdx > maxInputIdx_) { + maxInputIdx_ = maxInputIdx; + for (const auto &step : steps_) { + Matrix &weights = weightMatrix_.at(step); + weights.resize(maxInputIdx_ + 1, maxBucketIdx_ + 1); + } + } + } - // if input pattern has greater index than previously seen, update - // maxInputIdx and augment weight matrix with zero padding - if (patternNZ.size() > 0) - { - UInt maxInputIdx = *max_element(patternNZ.begin(), patternNZ.end()); - if (maxInputIdx > maxInputIdx_) - { - maxInputIdx_ = maxInputIdx; - for (const auto& step : steps_) - { - Matrix& weights = weightMatrix_.at(step); - weights.resize(maxInputIdx_ + 1, maxBucketIdx_ + 1); - } - } - } + // if in inference mode, compute likelihood and update return value + if (infer) { + infer_(patternNZ, actValueList, result); + } - // if in inference mode, compute likelihood and update return value - if (infer) - { - infer_(patternNZ, actValueList, result); + // update weights if in learning mode + if (learn) { + for (size_t categoryI = 0; categoryI < bucketIdxList.size(); categoryI++) { + UInt bucketIdx = bucketIdxList[categoryI]; + Real64 actValue = actValueList[categoryI]; + // if bucket is greater, update maxBucketIdx_ and augment weight + // matrix with zero-padding + if (bucketIdx > maxBucketIdx_) { + maxBucketIdx_ = bucketIdx; + for (const auto &step : steps_) { + Matrix &weights = weightMatrix_.at(step); + weights.resize(maxInputIdx_ + 1, maxBucketIdx_ + 1); } + } - // update weights if in learning mode - if (learn) - { - for(size_t categoryI=0; categoryI < bucketIdxList.size(); categoryI++) - { - UInt bucketIdx = bucketIdxList[categoryI]; - Real64 actValue = actValueList[categoryI]; - // if bucket is greater, update maxBucketIdx_ and augment weight - // matrix with zero-padding - if (bucketIdx > maxBucketIdx_) - { - maxBucketIdx_ = bucketIdx; - for (const auto& step : steps_) - { - Matrix& weights = weightMatrix_.at(step); - weights.resize(maxInputIdx_ + 1, maxBucketIdx_ + 1); - } - } - - // update rolling averages of bucket values - while (actualValues_.size() <= maxBucketIdx_) - { - actualValues_.push_back(0.0); - actualValuesSet_.push_back(false); - } - if (!actualValuesSet_[bucketIdx] || category) - { - actualValues_[bucketIdx] = actValue; - actualValuesSet_[bucketIdx] = true; - } else { - actualValues_[bucketIdx] = - ((1.0 - actValueAlpha_) * actualValues_[bucketIdx]) + - (actValueAlpha_ * actValue); - } - } - - // compute errors and update weights - auto patternIteration = patternNZHistory_.begin(); - for (auto learnRecord = recordNumHistory_.begin(); - learnRecord != recordNumHistory_.end(); - learnRecord++, patternIteration++) - { - const vector learnPatternNZ = *patternIteration; - const UInt nSteps = recordNum - *learnRecord; - - // update weights - if (binary_search(steps_.begin(), steps_.end(), nSteps)) - { - vector error = calculateError_(bucketIdxList, - learnPatternNZ, nSteps); - Matrix& weights = weightMatrix_.at(nSteps); - for (auto& bit : learnPatternNZ) - { - weights.axby(bit, 1.0, alpha_, error); - } - } - } + // update rolling averages of bucket values + while (actualValues_.size() <= maxBucketIdx_) { + actualValues_.push_back(0.0); + actualValuesSet_.push_back(false); + } + if (!actualValuesSet_[bucketIdx] || category) { + actualValues_[bucketIdx] = actValue; + actualValuesSet_[bucketIdx] = true; + } else { + actualValues_[bucketIdx] = + ((1.0 - actValueAlpha_) * actualValues_[bucketIdx]) + + (actValueAlpha_ * actValue); + } + } + + // compute errors and update weights + auto patternIteration = patternNZHistory_.begin(); + for (auto learnRecord = recordNumHistory_.begin(); + learnRecord != recordNumHistory_.end(); + learnRecord++, patternIteration++) { + const vector learnPatternNZ = *patternIteration; + const UInt nSteps = recordNum - *learnRecord; + + // update weights + if (binary_search(steps_.begin(), steps_.end(), nSteps)) { + vector error = + calculateError_(bucketIdxList, learnPatternNZ, nSteps); + Matrix &weights = weightMatrix_.at(nSteps); + for (auto &bit : learnPatternNZ) { + weights.axby(bit, 1.0, alpha_, error); } - } + } + } +} + +UInt SDRClassifier::persistentSize() const { + stringstream s; + s.flags(ios::scientific); + s.precision(numeric_limits::digits10 + 1); + save(s); + return s.str().size(); +} - UInt SDRClassifier::persistentSize() const - { - stringstream s; - s.flags(ios::scientific); - s.precision(numeric_limits::digits10 + 1); - save(s); - return s.str().size(); +void SDRClassifier::infer_(const vector &patternNZ, + const vector &actValue, + ClassifierResult *result) { + // add the actual values to the return value. For buckets that haven't + // been seen yet, the actual value doesn't matter since it will have + // zero likelihood. + vector *actValueVector = + result->createVector(-1, actualValues_.size(), 0.0); + for (UInt i = 0; i < actualValues_.size(); ++i) { + if (actualValuesSet_[i]) { + (*actValueVector)[i] = actualValues_[i]; + } else { + // if doing 0-step ahead prediction, we shouldn't use any + // knowledge of the classification input during inference + if (steps_.at(0) == 0) { + (*actValueVector)[i] = 0; + } else { + (*actValueVector)[i] = actValue[0]; } + } + } - void SDRClassifier::infer_(const vector& patternNZ, - const vector& actValue, ClassifierResult* result) - { - // add the actual values to the return value. For buckets that haven't - // been seen yet, the actual value doesn't matter since it will have - // zero likelihood. - vector* actValueVector = result->createVector(-1, - actualValues_.size(), 0.0); - for (UInt i = 0; i < actualValues_.size(); ++i) - { - if (actualValuesSet_[i]) - { - (*actValueVector)[i] = actualValues_[i]; - } else { - // if doing 0-step ahead prediction, we shouldn't use any - // knowledge of the classification input during inference - if (steps_.at(0) == 0) - { - (*actValueVector)[i] = 0; - } else { - (*actValueVector)[i] = actValue[0]; - } - } - } + for (auto nSteps = steps_.begin(); nSteps != steps_.end(); ++nSteps) { + vector *likelihoods = + result->createVector(*nSteps, maxBucketIdx_ + 1, 0.0); + for (auto &bit : patternNZ) { + const Matrix &weights = weightMatrix_.at(*nSteps); + add(likelihoods->begin(), likelihoods->end(), weights.begin(bit), + weights.begin(bit + 1)); + } + // compute softmax of raw scores + // TODO: fix potential overflow problem by shifting scores by their + // maximal value across buckets + range_exp(1.0, *likelihoods); + normalize(*likelihoods, 1.0, 1.0); + } +} - for (auto nSteps = steps_.begin(); nSteps!=steps_.end(); ++nSteps) - { - vector* likelihoods = result->createVector(*nSteps, - maxBucketIdx_ + 1, 0.0); - for (auto& bit : patternNZ) - { - const Matrix& weights = weightMatrix_.at(*nSteps); - add(likelihoods->begin(), likelihoods->end(), weights.begin(bit), - weights.begin(bit + 1)); - } - // compute softmax of raw scores - // TODO: fix potential overflow problem by shifting scores by their - // maximal value across buckets - range_exp(1.0, *likelihoods); - normalize(*likelihoods, 1.0, 1.0); - } - } +vector SDRClassifier::calculateError_(const vector &bucketIdxList, + const vector patternNZ, + UInt step) { + // compute predicted likelihoods + vector likelihoods(maxBucketIdx_ + 1, 0); - vector SDRClassifier::calculateError_(const vector& bucketIdxList, - const vector patternNZ, UInt step) - { - // compute predicted likelihoods - vector likelihoods (maxBucketIdx_ + 1, 0); - - for (auto& bit : patternNZ) - { - const Matrix& weights = weightMatrix_.at(step); - add(likelihoods.begin(), likelihoods.end(), weights.begin(bit), - weights.begin(bit + 1)); - } - range_exp(1.0, likelihoods); - normalize(likelihoods, 1.0, 1.0); + for (auto &bit : patternNZ) { + const Matrix &weights = weightMatrix_.at(step); + add(likelihoods.begin(), likelihoods.end(), weights.begin(bit), + weights.begin(bit + 1)); + } + range_exp(1.0, likelihoods); + normalize(likelihoods, 1.0, 1.0); - // compute target likelihoods - vector targetDistribution (maxBucketIdx_ + 1, 0.0); - Real64 numCategories = (Real64)bucketIdxList.size(); - for(size_t i=0; i targetDistribution(maxBucketIdx_ + 1, 0.0); + Real64 numCategories = (Real64)bucketIdxList.size(); + for (size_t i = 0; i < bucketIdxList.size(); i++) + targetDistribution[bucketIdxList[i]] = 1.0 / numCategories; - axby(-1.0, likelihoods, 1.0, targetDistribution); - return likelihoods; - } + axby(-1.0, likelihoods, 1.0, targetDistribution); + return likelihoods; +} - UInt SDRClassifier::version() const - { - return version_; - } +UInt SDRClassifier::version() const { return version_; } - UInt SDRClassifier::getVerbosity() const - { - return verbosity_; - } +UInt SDRClassifier::getVerbosity() const { return verbosity_; } - void SDRClassifier::setVerbosity(UInt verbosity) - { - verbosity_ = verbosity; - } +void SDRClassifier::setVerbosity(UInt verbosity) { verbosity_ = verbosity; } - UInt SDRClassifier::getAlpha() const - { - return alpha_; - } +UInt SDRClassifier::getAlpha() const { return alpha_; } - void SDRClassifier::save(ostream& outStream) const - { - // Write a starting marker and version. - outStream << "SDRClassifier" << endl; - outStream << version_ << endl; - - // Store the simple variables first. - outStream << version() << " " - << alpha_ << " " - << actValueAlpha_ << " " - << maxSteps_ << " " - << maxBucketIdx_ << " " - << maxInputIdx_ << " " - << verbosity_ << " " - << endl; - - // V1 additions. - outStream << recordNumHistory_.size() << " "; - for (const auto& elem : recordNumHistory_) - { - outStream << elem << " "; - } - outStream << endl; +void SDRClassifier::save(ostream &outStream) const { + // Write a starting marker and version. + outStream << "SDRClassifier" << endl; + outStream << version_ << endl; - // Store the different prediction steps. - outStream << steps_.size() << " "; - for (auto& elem : steps_) - { - outStream << elem << " "; - } - outStream << endl; - - // Store the pattern history. - outStream << patternNZHistory_.size() << " "; - for (auto& pattern : patternNZHistory_) - { - outStream << pattern.size() << " "; - for (auto& pattern_j : pattern) - { - outStream << pattern_j << " "; - } - } - outStream << endl; - - // Store weight matrix - outStream << weightMatrix_.size() << " "; - for (const auto& elem : weightMatrix_) - { - outStream << elem.first << " "; - outStream << elem.second; - } - outStream << endl; - - // Store the actual values for each bucket. - outStream << actualValues_.size() << " "; - for (UInt i = 0; i < actualValues_.size(); ++i) - { - outStream << actualValues_[i] << " "; - outStream << actualValuesSet_[i] << " "; - } - outStream << endl; + // Store the simple variables first. + outStream << version() << " " << alpha_ << " " << actValueAlpha_ << " " + << maxSteps_ << " " << maxBucketIdx_ << " " << maxInputIdx_ << " " + << verbosity_ << " " << endl; - // Write an ending marker. - outStream << "~SDRClassifier" << endl; - } + // V1 additions. + outStream << recordNumHistory_.size() << " "; + for (const auto &elem : recordNumHistory_) { + outStream << elem << " "; + } + outStream << endl; - void SDRClassifier::load(istream& inStream) - { - // Clean up the existing data structures before loading - steps_.clear(); - recordNumHistory_.clear(); - patternNZHistory_.clear(); - actualValues_.clear(); - actualValuesSet_.clear(); - weightMatrix_.clear(); - - // Check the starting marker. - string marker; - inStream >> marker; - NTA_CHECK(marker == "SDRClassifier"); - - // Check the version. - UInt version; - inStream >> version; - NTA_CHECK(version <= 1); - - // Load the simple variables. - inStream >> version_ - >> alpha_ - >> actValueAlpha_ - >> maxSteps_ - >> maxBucketIdx_ - >> maxInputIdx_ - >> verbosity_; - - UInt recordNumHistory; - UInt curRecordNum; - if (version == 1) - { - inStream >> recordNumHistory; - for (UInt i = 0; i < recordNumHistory; ++i) - { - inStream >> curRecordNum; - recordNumHistory_.push_back(curRecordNum); - } - } + // Store the different prediction steps. + outStream << steps_.size() << " "; + for (auto &elem : steps_) { + outStream << elem << " "; + } + outStream << endl; + + // Store the pattern history. + outStream << patternNZHistory_.size() << " "; + for (auto &pattern : patternNZHistory_) { + outStream << pattern.size() << " "; + for (auto &pattern_j : pattern) { + outStream << pattern_j << " "; + } + } + outStream << endl; - // Load the prediction steps. - UInt size; - UInt step; - inStream >> size; - for (UInt i = 0; i < size; ++i) - { - inStream >> step; - steps_.push_back(step); - } + // Store weight matrix + outStream << weightMatrix_.size() << " "; + for (const auto &elem : weightMatrix_) { + outStream << elem.first << " "; + outStream << elem.second; + } + outStream << endl; - // Load the input pattern history. - inStream >> size; - UInt vSize; - for (UInt i = 0; i < size; ++i) - { - inStream >> vSize; - patternNZHistory_.emplace_back(vSize); - for (UInt j = 0; j < vSize; ++j) - { - inStream >> patternNZHistory_[i][j]; - } - } + // Store the actual values for each bucket. + outStream << actualValues_.size() << " "; + for (UInt i = 0; i < actualValues_.size(); ++i) { + outStream << actualValues_[i] << " "; + outStream << actualValuesSet_[i] << " "; + } + outStream << endl; - // Load weight matrix. - UInt numSteps; - inStream >> numSteps; - for (UInt s = 0; s < numSteps; ++s) - { - inStream >> step; - // Insert the step to initialize the weight matrix - weightMatrix_[step] = Matrix(maxInputIdx_ + 1, maxBucketIdx_ + 1); - for (UInt i = 0; i <= maxInputIdx_; ++i) - { - for (UInt j = 0; j <= maxBucketIdx_; ++j) - { - inStream >> weightMatrix_[step].at(i, j); - } - } - } + // Write an ending marker. + outStream << "~SDRClassifier" << endl; +} - // Load the actual values for each bucket. - UInt numBuckets; - Real64 actualValue; - bool actualValueSet; - inStream >> numBuckets; - for (UInt i = 0; i < numBuckets; ++i) - { - inStream >> actualValue; - actualValues_.push_back(actualValue); - inStream >> actualValueSet; - actualValuesSet_.push_back(actualValueSet); - } +void SDRClassifier::load(istream &inStream) { + // Clean up the existing data structures before loading + steps_.clear(); + recordNumHistory_.clear(); + patternNZHistory_.clear(); + actualValues_.clear(); + actualValuesSet_.clear(); + weightMatrix_.clear(); + + // Check the starting marker. + string marker; + inStream >> marker; + NTA_CHECK(marker == "SDRClassifier"); + + // Check the version. + UInt version; + inStream >> version; + NTA_CHECK(version <= 1); + + // Load the simple variables. + inStream >> version_ >> alpha_ >> actValueAlpha_ >> maxSteps_ >> + maxBucketIdx_ >> maxInputIdx_ >> verbosity_; + + UInt recordNumHistory; + UInt curRecordNum; + if (version == 1) { + inStream >> recordNumHistory; + for (UInt i = 0; i < recordNumHistory; ++i) { + inStream >> curRecordNum; + recordNumHistory_.push_back(curRecordNum); + } + } + + // Load the prediction steps. + UInt size; + UInt step; + inStream >> size; + for (UInt i = 0; i < size; ++i) { + inStream >> step; + steps_.push_back(step); + } - // Check for the end marker. - inStream >> marker; - NTA_CHECK(marker == "~SDRClassifier"); + // Load the input pattern history. + inStream >> size; + UInt vSize; + for (UInt i = 0; i < size; ++i) { + inStream >> vSize; + patternNZHistory_.emplace_back(vSize); + for (UInt j = 0; j < vSize; ++j) { + inStream >> patternNZHistory_[i][j]; + } + } - // Update the version number. - version_ = sdrClassifierVersion; + // Load weight matrix. + UInt numSteps; + inStream >> numSteps; + for (UInt s = 0; s < numSteps; ++s) { + inStream >> step; + // Insert the step to initialize the weight matrix + weightMatrix_[step] = Matrix(maxInputIdx_ + 1, maxBucketIdx_ + 1); + for (UInt i = 0; i <= maxInputIdx_; ++i) { + for (UInt j = 0; j <= maxBucketIdx_; ++j) { + inStream >> weightMatrix_[step].at(i, j); } + } + } - void SDRClassifier::write(SdrClassifierProto::Builder& proto) const - { - auto stepsProto = proto.initSteps(steps_.size()); - for (UInt i = 0; i < steps_.size(); i++) - { - stepsProto.set(i, steps_[i]); - } + // Load the actual values for each bucket. + UInt numBuckets; + Real64 actualValue; + bool actualValueSet; + inStream >> numBuckets; + for (UInt i = 0; i < numBuckets; ++i) { + inStream >> actualValue; + actualValues_.push_back(actualValue); + inStream >> actualValueSet; + actualValuesSet_.push_back(actualValueSet); + } - proto.setAlpha(alpha_); - proto.setActValueAlpha(actValueAlpha_); - proto.setMaxSteps(maxSteps_); - - auto patternNZHistoryProto = - proto.initPatternNZHistory(patternNZHistory_.size()); - for (UInt i = 0; i < patternNZHistory_.size(); i++) - { - const auto& pattern = patternNZHistory_[i]; - auto patternProto = patternNZHistoryProto.init(i, pattern.size()); - for (UInt j = 0; j < pattern.size(); j++) - { - patternProto.set(j, pattern[j]); - } - } + // Check for the end marker. + inStream >> marker; + NTA_CHECK(marker == "~SDRClassifier"); - auto recordNumHistoryProto = - proto.initRecordNumHistory(recordNumHistory_.size()); - for (UInt i = 0; i < recordNumHistory_.size(); i++) - { - recordNumHistoryProto.set(i, recordNumHistory_[i]); - } + // Update the version number. + version_ = sdrClassifierVersion; +} - proto.setMaxBucketIdx(maxBucketIdx_); - proto.setMaxInputIdx(maxInputIdx_); - - auto weightMatrixProtos = - proto.initWeightMatrix(weightMatrix_.size()); - UInt k = 0; - for (const auto& stepWeightMatrix : weightMatrix_) - { - auto stepWeightMatrixProto = weightMatrixProtos[k]; - stepWeightMatrixProto.setSteps(stepWeightMatrix.first); - auto weightProto = stepWeightMatrixProto.initWeight( - (maxInputIdx_ + 1) * (maxBucketIdx_ + 1) - ); - // flatten weight matrix, serialized as a list of floats - UInt idx = 0; - for (UInt i = 0; i <= maxInputIdx_; ++i) - { - for (UInt j = 0; j <= maxBucketIdx_; ++j) - { - weightProto.set(idx, stepWeightMatrix.second.at(i, j)); - idx++; - } - } - k++; - } +void SDRClassifier::write(SdrClassifierProto::Builder &proto) const { + auto stepsProto = proto.initSteps(steps_.size()); + for (UInt i = 0; i < steps_.size(); i++) { + stepsProto.set(i, steps_[i]); + } - auto actualValuesProto = proto.initActualValues(actualValues_.size()); - for (UInt i = 0; i < actualValues_.size(); i++) - { - actualValuesProto.set(i, actualValues_[i]); - } + proto.setAlpha(alpha_); + proto.setActValueAlpha(actValueAlpha_); + proto.setMaxSteps(maxSteps_); + + auto patternNZHistoryProto = + proto.initPatternNZHistory(patternNZHistory_.size()); + for (UInt i = 0; i < patternNZHistory_.size(); i++) { + const auto &pattern = patternNZHistory_[i]; + auto patternProto = patternNZHistoryProto.init(i, pattern.size()); + for (UInt j = 0; j < pattern.size(); j++) { + patternProto.set(j, pattern[j]); + } + } - auto actualValuesSetProto = - proto.initActualValuesSet(actualValuesSet_.size()); - for (UInt i = 0; i < actualValuesSet_.size(); i++) - { - actualValuesSetProto.set(i, actualValuesSet_[i]); - } + auto recordNumHistoryProto = + proto.initRecordNumHistory(recordNumHistory_.size()); + for (UInt i = 0; i < recordNumHistory_.size(); i++) { + recordNumHistoryProto.set(i, recordNumHistory_[i]); + } - proto.setVersion(version_); - proto.setVerbosity(verbosity_); + proto.setMaxBucketIdx(maxBucketIdx_); + proto.setMaxInputIdx(maxInputIdx_); + + auto weightMatrixProtos = proto.initWeightMatrix(weightMatrix_.size()); + UInt k = 0; + for (const auto &stepWeightMatrix : weightMatrix_) { + auto stepWeightMatrixProto = weightMatrixProtos[k]; + stepWeightMatrixProto.setSteps(stepWeightMatrix.first); + auto weightProto = stepWeightMatrixProto.initWeight((maxInputIdx_ + 1) * + (maxBucketIdx_ + 1)); + // flatten weight matrix, serialized as a list of floats + UInt idx = 0; + for (UInt i = 0; i <= maxInputIdx_; ++i) { + for (UInt j = 0; j <= maxBucketIdx_; ++j) { + weightProto.set(idx, stepWeightMatrix.second.at(i, j)); + idx++; } + } + k++; + } - void SDRClassifier::read(SdrClassifierProto::Reader& proto) - { - // Clean up the existing data structures before loading - steps_.clear(); - recordNumHistory_.clear(); - patternNZHistory_.clear(); - actualValues_.clear(); - actualValuesSet_.clear(); - weightMatrix_.clear(); - - for (auto step : proto.getSteps()) - { - steps_.push_back(step); - } + auto actualValuesProto = proto.initActualValues(actualValues_.size()); + for (UInt i = 0; i < actualValues_.size(); i++) { + actualValuesProto.set(i, actualValues_[i]); + } - alpha_ = proto.getAlpha(); - actValueAlpha_ = proto.getActValueAlpha(); - maxSteps_ = proto.getMaxSteps(); - - auto patternNZHistoryProto = proto.getPatternNZHistory(); - for (UInt i = 0; i < patternNZHistoryProto.size(); i++) - { - patternNZHistory_.emplace_back(patternNZHistoryProto[i].size()); - for (UInt j = 0; j < patternNZHistoryProto[i].size(); j++) - { - patternNZHistory_[i][j] = patternNZHistoryProto[i][j]; - } - } + auto actualValuesSetProto = + proto.initActualValuesSet(actualValuesSet_.size()); + for (UInt i = 0; i < actualValuesSet_.size(); i++) { + actualValuesSetProto.set(i, actualValuesSet_[i]); + } - auto recordNumHistoryProto = proto.getRecordNumHistory(); - for (UInt i = 0; i < recordNumHistoryProto.size(); i++) - { - recordNumHistory_.push_back(recordNumHistoryProto[i]); - } + proto.setVersion(version_); + proto.setVerbosity(verbosity_); +} - maxBucketIdx_ = proto.getMaxBucketIdx(); - maxInputIdx_ = proto.getMaxInputIdx(); - - auto weightMatrixProto = proto.getWeightMatrix(); - for (UInt i = 0; i < weightMatrixProto.size(); ++i) - { - auto stepWeightMatrix = weightMatrixProto[i]; - UInt steps = stepWeightMatrix.getSteps(); - weightMatrix_[steps] = Matrix(maxInputIdx_ + 1, maxBucketIdx_ + 1); - auto weights = stepWeightMatrix.getWeight(); - UInt j = 0; - // un-flatten weight matrix, serialized as a list of floats - for (UInt row = 0; row <= maxInputIdx_; ++row) - { - for (UInt col = 0; col <= maxBucketIdx_; ++col) - { - weightMatrix_[steps].at(row, col) = weights[j]; - j++; - } - } - } +void SDRClassifier::read(SdrClassifierProto::Reader &proto) { + // Clean up the existing data structures before loading + steps_.clear(); + recordNumHistory_.clear(); + patternNZHistory_.clear(); + actualValues_.clear(); + actualValuesSet_.clear(); + weightMatrix_.clear(); + + for (auto step : proto.getSteps()) { + steps_.push_back(step); + } - for (auto actValue : proto.getActualValues()) - { - actualValues_.push_back(actValue); - } + alpha_ = proto.getAlpha(); + actValueAlpha_ = proto.getActValueAlpha(); + maxSteps_ = proto.getMaxSteps(); - for (auto actValueSet : proto.getActualValuesSet()) - { - actualValuesSet_.push_back(actValueSet); - } + auto patternNZHistoryProto = proto.getPatternNZHistory(); + for (UInt i = 0; i < patternNZHistoryProto.size(); i++) { + patternNZHistory_.emplace_back(patternNZHistoryProto[i].size()); + for (UInt j = 0; j < patternNZHistoryProto[i].size(); j++) { + patternNZHistory_[i][j] = patternNZHistoryProto[i][j]; + } + } - version_ = proto.getVersion(); - verbosity_ = proto.getVerbosity(); + auto recordNumHistoryProto = proto.getRecordNumHistory(); + for (UInt i = 0; i < recordNumHistoryProto.size(); i++) { + recordNumHistory_.push_back(recordNumHistoryProto[i]); + } + + maxBucketIdx_ = proto.getMaxBucketIdx(); + maxInputIdx_ = proto.getMaxInputIdx(); + + auto weightMatrixProto = proto.getWeightMatrix(); + for (UInt i = 0; i < weightMatrixProto.size(); ++i) { + auto stepWeightMatrix = weightMatrixProto[i]; + UInt steps = stepWeightMatrix.getSteps(); + weightMatrix_[steps] = Matrix(maxInputIdx_ + 1, maxBucketIdx_ + 1); + auto weights = stepWeightMatrix.getWeight(); + UInt j = 0; + // un-flatten weight matrix, serialized as a list of floats + for (UInt row = 0; row <= maxInputIdx_; ++row) { + for (UInt col = 0; col <= maxBucketIdx_; ++col) { + weightMatrix_[steps].at(row, col) = weights[j]; + j++; } + } + } - bool SDRClassifier::operator==(const SDRClassifier& other) const - { - if (steps_.size() != other.steps_.size()) - { - return false; - } - for (UInt i = 0; i < steps_.size(); i++) - { - if (steps_.at(i) != other.steps_.at(i)) - { - return false; - } - } + for (auto actValue : proto.getActualValues()) { + actualValues_.push_back(actValue); + } - if (fabs(alpha_ - other.alpha_) > 0.000001 || - fabs(actValueAlpha_ - other.actValueAlpha_) > 0.000001 || - maxSteps_ != other.maxSteps_) - { - return false; - } + for (auto actValueSet : proto.getActualValuesSet()) { + actualValuesSet_.push_back(actValueSet); + } - if (patternNZHistory_.size() != other.patternNZHistory_.size()) - { - return false; - } - for (UInt i = 0; i < patternNZHistory_.size(); i++) - { - if (patternNZHistory_.at(i).size() != - other.patternNZHistory_.at(i).size()) - { - return false; - } - for (UInt j = 0; j < patternNZHistory_.at(i).size(); j++) - { - if (patternNZHistory_.at(i).at(j) != - other.patternNZHistory_.at(i).at(j)) - { - return false; - } - } - } + version_ = proto.getVersion(); + verbosity_ = proto.getVerbosity(); +} - if (recordNumHistory_.size() != - other.recordNumHistory_.size()) - { - return false; - } - for (UInt i = 0; i < recordNumHistory_.size(); i++) - { - if (recordNumHistory_.at(i) != - other.recordNumHistory_.at(i)) - { - return false; - } - } +bool SDRClassifier::operator==(const SDRClassifier &other) const { + if (steps_.size() != other.steps_.size()) { + return false; + } + for (UInt i = 0; i < steps_.size(); i++) { + if (steps_.at(i) != other.steps_.at(i)) { + return false; + } + } - if (maxBucketIdx_ != other.maxBucketIdx_) - { - return false; - } + if (fabs(alpha_ - other.alpha_) > 0.000001 || + fabs(actValueAlpha_ - other.actValueAlpha_) > 0.000001 || + maxSteps_ != other.maxSteps_) { + return false; + } - if (maxInputIdx_ != other.maxInputIdx_) - { - return false; - } + if (patternNZHistory_.size() != other.patternNZHistory_.size()) { + return false; + } + for (UInt i = 0; i < patternNZHistory_.size(); i++) { + if (patternNZHistory_.at(i).size() != + other.patternNZHistory_.at(i).size()) { + return false; + } + for (UInt j = 0; j < patternNZHistory_.at(i).size(); j++) { + if (patternNZHistory_.at(i).at(j) != + other.patternNZHistory_.at(i).at(j)) { + return false; + } + } + } - if (weightMatrix_.size() != other.weightMatrix_.size()) - { - return false; - } - for (auto it = weightMatrix_.begin(); it != weightMatrix_.end(); it++) - { - Matrix thisWeights = it->second; - Matrix otherWeights = other.weightMatrix_.at(it->first); - for (UInt i = 0; i <= maxInputIdx_; ++i) - { - for (UInt j = 0; j <= maxBucketIdx_; ++j) - { - if (thisWeights.at(i, j) != otherWeights.at(i, j)) - { - return false; - } - } - } - } + if (recordNumHistory_.size() != other.recordNumHistory_.size()) { + return false; + } + for (UInt i = 0; i < recordNumHistory_.size(); i++) { + if (recordNumHistory_.at(i) != other.recordNumHistory_.at(i)) { + return false; + } + } - if (actualValues_.size() != other.actualValues_.size() || - actualValuesSet_.size() != other.actualValuesSet_.size()) - { - return false; - } - for (UInt i = 0; i < actualValues_.size(); i++) - { - if (fabs(actualValues_[i] - other.actualValues_[i]) > 0.000001 || - actualValuesSet_[i] != other.actualValuesSet_[i]) - { - return false; - } - } + if (maxBucketIdx_ != other.maxBucketIdx_) { + return false; + } + + if (maxInputIdx_ != other.maxInputIdx_) { + return false; + } - if (version_ != other.version_ || - verbosity_ != other.verbosity_) - { + if (weightMatrix_.size() != other.weightMatrix_.size()) { + return false; + } + for (auto it = weightMatrix_.begin(); it != weightMatrix_.end(); it++) { + Matrix thisWeights = it->second; + Matrix otherWeights = other.weightMatrix_.at(it->first); + for (UInt i = 0; i <= maxInputIdx_; ++i) { + for (UInt j = 0; j <= maxBucketIdx_; ++j) { + if (thisWeights.at(i, j) != otherWeights.at(i, j)) { return false; } - - return true; } + } + } + + if (actualValues_.size() != other.actualValues_.size() || + actualValuesSet_.size() != other.actualValuesSet_.size()) { + return false; + } + for (UInt i = 0; i < actualValues_.size(); i++) { + if (fabs(actualValues_[i] - other.actualValues_[i]) > 0.000001 || + actualValuesSet_[i] != other.actualValuesSet_[i]) { + return false; + } + } - } + if (version_ != other.version_ || verbosity_ != other.verbosity_) { + return false; } + + return true; } + +} // namespace sdr_classifier +} // namespace algorithms +} // namespace nupic diff --git a/src/nupic/algorithms/SDRClassifier.hpp b/src/nupic/algorithms/SDRClassifier.hpp index 38ca068710..a731b55af1 100644 --- a/src/nupic/algorithms/SDRClassifier.hpp +++ b/src/nupic/algorithms/SDRClassifier.hpp @@ -33,182 +33,177 @@ #include #include +#include #include #include #include -#include -namespace nupic -{ - namespace algorithms - { - - typedef cla_classifier::ClassifierResult ClassifierResult; - - namespace sdr_classifier - { - - const UInt sdrClassifierVersion = 1; - - typedef Dense Matrix; - - class SDRClassifier : public Serializable - { - public: - /** - * Constructor for use when deserializing. - */ - SDRClassifier() {} - - /** - * Constructor. - * - * @param steps The different number of steps to learn and predict. - * @param alpha The alpha to use when decaying the duty cycles. - * @param actValueAlpha The alpha to use when decaying the actual - * values for each bucket. - * @param verbosity The logging verbosity. - */ - SDRClassifier( - const vector& steps, Real64 alpha, Real64 actValueAlpha, - UInt verbosity); - - /** - * Destructor. - */ - virtual ~SDRClassifier(); - - /** - * Compute the likelihoods for each bucket. - * - * @param recordNum An incrementing integer for each record. Gaps in - * numbers correspond to missing records. - * @param patternNZ The active input bit indices. - * @param bucketIdx The current value bucket index. - * @param actValue The current scalar value. - * @param category Whether the actual values represent categories. - * @param learn Whether or not to perform learning. - * @param infer Whether or not to perform inference. - * @param result A mapping from prediction step to a vector of - * likelihoods where the value at an index corresponds - * to the bucket with the same index. In addition, the - * values for key 0 correspond to the actual values to - * used when predicting each bucket. - */ - virtual void compute( - UInt recordNum, const vector& patternNZ, const vector& bucketIdxList, - const vector& actValueList, bool category, bool learn, bool infer, - ClassifierResult* result); - - /** - * Gets the version number - */ - UInt version() const; - - /** - * Getter and setter for verbosity level. - */ - UInt getVerbosity() const; - void setVerbosity(UInt verbosity); - - /** - * Gets the learning rate - */ - UInt getAlpha() const; - - /** - * Get the size of the string needed for the serialized state. - */ - UInt persistentSize() const; - - /** - * Save the state to the ostream. - */ - void save(std::ostream& outStream) const; - - /** - * Load state from istream. - */ - void load(std::istream& inStream); - - /** - * Save the state to the builder. - */ - void write(SdrClassifierProto::Builder& proto) const override; - - /** - * Save the state to the stream. - */ - using Serializable::write; - - /** - * Load state from reader. - */ - void read(SdrClassifierProto::Reader& proto) override; - - /** - * Load state from stream. - */ - using Serializable::read; - - /** - * Compare the other instance to this one. - * - * @param other Another instance of SDRClassifier to compare to. - * @returns true iff other is identical to this instance. - */ - virtual bool operator==(const SDRClassifier& other) const; - - private: - // Helper function for inference mode - void infer_(const vector& patternNZ, - const vector& actValue, ClassifierResult* result); - - // Helper function to compute the error signal in learning mode - vector calculateError_(const vector& bucketIdxList, - const vector patternNZ, UInt step); - - // The list of prediction steps to learn and infer. - vector steps_; - - // The alpha used to decay the duty cycles in the BitHistorys. - Real64 alpha_; - - // The alpha used to decay the actual values used for each bucket. - Real64 actValueAlpha_; - - // The maximum number of the prediction steps. - UInt maxSteps_; - - // Stores the input pattern history, starting with the previous input - // and containing _maxSteps total input patterns. - deque< vector > patternNZHistory_; - deque recordNumHistory_; - - // Weight matrices for the classifier (one per prediction step) - map weightMatrix_; - - // The highest input bit that the classifier has seen so far. - UInt maxInputIdx_; - - // The highest bucket index that the classifier has been seen so far. - UInt maxBucketIdx_; - - // The current actual values used for each bucket index. The index of - // the actual value matches the index of the bucket. - vector actualValues_; - - // A boolean that distinguishes between actual values that have been - // seen and those that have not. - vector actualValuesSet_; - - // Version and verbosity. - UInt version_; - UInt verbosity_; - }; // end of SDRClassifier class - - } // end of namespace sdr_classifier - } // end of namespace algorithms -} // end of name space nupic - -#endif +namespace nupic { +namespace algorithms { + +typedef cla_classifier::ClassifierResult ClassifierResult; + +namespace sdr_classifier { + +const UInt sdrClassifierVersion = 1; + +typedef Dense Matrix; + +class SDRClassifier : public Serializable { +public: + /** + * Constructor for use when deserializing. + */ + SDRClassifier() {} + + /** + * Constructor. + * + * @param steps The different number of steps to learn and predict. + * @param alpha The alpha to use when decaying the duty cycles. + * @param actValueAlpha The alpha to use when decaying the actual + * values for each bucket. + * @param verbosity The logging verbosity. + */ + SDRClassifier(const vector &steps, Real64 alpha, Real64 actValueAlpha, + UInt verbosity); + + /** + * Destructor. + */ + virtual ~SDRClassifier(); + + /** + * Compute the likelihoods for each bucket. + * + * @param recordNum An incrementing integer for each record. Gaps in + * numbers correspond to missing records. + * @param patternNZ The active input bit indices. + * @param bucketIdx The current value bucket index. + * @param actValue The current scalar value. + * @param category Whether the actual values represent categories. + * @param learn Whether or not to perform learning. + * @param infer Whether or not to perform inference. + * @param result A mapping from prediction step to a vector of + * likelihoods where the value at an index corresponds + * to the bucket with the same index. In addition, the + * values for key 0 correspond to the actual values to + * used when predicting each bucket. + */ + virtual void compute(UInt recordNum, const vector &patternNZ, + const vector &bucketIdxList, + const vector &actValueList, bool category, + bool learn, bool infer, ClassifierResult *result); + + /** + * Gets the version number + */ + UInt version() const; + + /** + * Getter and setter for verbosity level. + */ + UInt getVerbosity() const; + void setVerbosity(UInt verbosity); + + /** + * Gets the learning rate + */ + UInt getAlpha() const; + + /** + * Get the size of the string needed for the serialized state. + */ + UInt persistentSize() const; + + /** + * Save the state to the ostream. + */ + void save(std::ostream &outStream) const; + + /** + * Load state from istream. + */ + void load(std::istream &inStream); + + /** + * Save the state to the builder. + */ + void write(SdrClassifierProto::Builder &proto) const override; + + /** + * Save the state to the stream. + */ + using Serializable::write; + + /** + * Load state from reader. + */ + void read(SdrClassifierProto::Reader &proto) override; + + /** + * Load state from stream. + */ + using Serializable::read; + + /** + * Compare the other instance to this one. + * + * @param other Another instance of SDRClassifier to compare to. + * @returns true iff other is identical to this instance. + */ + virtual bool operator==(const SDRClassifier &other) const; + +private: + // Helper function for inference mode + void infer_(const vector &patternNZ, const vector &actValue, + ClassifierResult *result); + + // Helper function to compute the error signal in learning mode + vector calculateError_(const vector &bucketIdxList, + const vector patternNZ, UInt step); + + // The list of prediction steps to learn and infer. + vector steps_; + + // The alpha used to decay the duty cycles in the BitHistorys. + Real64 alpha_; + + // The alpha used to decay the actual values used for each bucket. + Real64 actValueAlpha_; + + // The maximum number of the prediction steps. + UInt maxSteps_; + + // Stores the input pattern history, starting with the previous input + // and containing _maxSteps total input patterns. + deque> patternNZHistory_; + deque recordNumHistory_; + + // Weight matrices for the classifier (one per prediction step) + map weightMatrix_; + + // The highest input bit that the classifier has seen so far. + UInt maxInputIdx_; + + // The highest bucket index that the classifier has been seen so far. + UInt maxBucketIdx_; + + // The current actual values used for each bucket index. The index of + // the actual value matches the index of the bucket. + vector actualValues_; + + // A boolean that distinguishes between actual values that have been + // seen and those that have not. + vector actualValuesSet_; + + // Version and verbosity. + UInt version_; + UInt verbosity_; +}; // end of SDRClassifier class + +} // end of namespace sdr_classifier +} // end of namespace algorithms +} // namespace nupic + +#endif diff --git a/src/nupic/algorithms/Scanning.hpp b/src/nupic/algorithms/Scanning.hpp index eb669e24b9..cf4dc0251a 100644 --- a/src/nupic/algorithms/Scanning.hpp +++ b/src/nupic/algorithms/Scanning.hpp @@ -24,17 +24,14 @@ #define NTA_SCANNING_HPP // Performs the time-intensive steps of ScanControlNode.getAlpha -void computeAlpha(size_t xstep, size_t ystep, - size_t widthS, size_t heightS, - size_t imageWidth, size_t imageHeight, - size_t xcount, size_t ycount, - size_t weightWidth, float sharpness, - float* data, float* values, float* counts, float* weights) -{ +void computeAlpha(size_t xstep, size_t ystep, size_t widthS, size_t heightS, + size_t imageWidth, size_t imageHeight, size_t xcount, + size_t ycount, size_t weightWidth, float sharpness, + float *data, float *values, float *counts, float *weights) { size_t y0, y1, x0, x1, i, j, m, n; float coefficient = 0, minval = 0, maxval = 0; float *d, *v, *c, *w; - + if (sharpness < 1) { // Calculate coefficient for sigmoid, used to scale values in range [0, 1] // (If sharpness is 1, the results are simply thresholded) @@ -82,10 +79,10 @@ void computeAlpha(size_t xstep, size_t ystep, *v >= 0.5 ? *v = 1 : *v = 0; } else if (coefficient != 0) { // Sigmoid (coefficient was calculated from value of "sharpness") - *v = (1 / (1 + exp(coefficient * (*v - 0.5f))) - minval) - / (maxval - minval); + *v = (1 / (1 + exp(coefficient * (*v - 0.5f))) - minval) / + (maxval - minval); } } } -#endif //NTA_SCANNING_HPP +#endif // NTA_SCANNING_HPP diff --git a/src/nupic/algorithms/Segment.cpp b/src/nupic/algorithms/Segment.cpp index 9d660d5a69..d907acde6f 100644 --- a/src/nupic/algorithms/Segment.cpp +++ b/src/nupic/algorithms/Segment.cpp @@ -20,20 +20,19 @@ * --------------------------------------------------------------------- */ - +#include // sort #include -#include #include #include #include -#include // sort +#include -#include -#include #include #include #include // is_in -#include // binary_save +#include // binary_save +#include +#include #include @@ -44,38 +43,29 @@ using namespace nupic::algorithms::Cells4; * Utility routine. Given a src cell index, prints synapse as: * [column, cell within col] */ -void printSynapse(UInt srcCellIdx, UInt nCellsPerCol) -{ - UInt col = (UInt) (srcCellIdx / nCellsPerCol); - UInt cell = srcCellIdx - col*nCellsPerCol; +void printSynapse(UInt srcCellIdx, UInt nCellsPerCol) { + UInt col = (UInt)(srcCellIdx / nCellsPerCol); + UInt cell = srcCellIdx - col * nCellsPerCol; std::cout << "[" << col << "," << cell << "] "; } - //---------------------------------------------------------------------- -Segment::Segment(InSynapses _s, Real frequency, bool seqSegFlag, - Real permConnected, UInt iteration) - : _totalActivations(1), - _positiveActivations(1), - _lastActiveIteration(0), - _lastPosDutyCycle(1.0 / iteration), - _lastPosDutyCycleIteration(iteration), - _seqSegFlag(seqSegFlag), - _frequency(frequency), - _synapses(std::move(_s)), - _nConnected(0) -{ +Segment::Segment(InSynapses _s, Real frequency, bool seqSegFlag, + Real permConnected, UInt iteration) + : _totalActivations(1), _positiveActivations(1), _lastActiveIteration(0), + _lastPosDutyCycle(1.0 / iteration), _lastPosDutyCycleIteration(iteration), + _seqSegFlag(seqSegFlag), _frequency(frequency), _synapses(std::move(_s)), + _nConnected(0) { for (UInt i = 0; i != _synapses.size(); ++i) if (_synapses[i].permanence() >= permConnected) - ++ _nConnected; + ++_nConnected; std::sort(_synapses.begin(), _synapses.end(), InSynapseOrder()); NTA_ASSERT(invariants()); } //-------------------------------------------------------------------------------- -Segment& Segment::operator=(const Segment& o) -{ +Segment &Segment::operator=(const Segment &o) { if (&o != this) { _seqSegFlag = o._seqSegFlag; _frequency = o._frequency; @@ -91,31 +81,21 @@ Segment& Segment::operator=(const Segment& o) return *this; } - - //-------------------------------------------------------------------------------- -Segment::Segment(const Segment& o) - : _totalActivations(o._totalActivations), - _positiveActivations(o._positiveActivations), - _lastActiveIteration(o._lastActiveIteration), - _lastPosDutyCycle(o._lastPosDutyCycle), - _lastPosDutyCycleIteration(o._lastPosDutyCycleIteration), - _seqSegFlag(o._seqSegFlag), - _frequency(o._frequency), - _synapses(o._synapses), - _nConnected(o._nConnected) -{ +Segment::Segment(const Segment &o) + : _totalActivations(o._totalActivations), + _positiveActivations(o._positiveActivations), + _lastActiveIteration(o._lastActiveIteration), + _lastPosDutyCycle(o._lastPosDutyCycle), + _lastPosDutyCycleIteration(o._lastPosDutyCycleIteration), + _seqSegFlag(o._seqSegFlag), _frequency(o._frequency), + _synapses(o._synapses), _nConnected(o._nConnected) { NTA_ASSERT(invariants()); } - - -bool Segment::isActive(const CState& activities, - Real permConnected, UInt activationThreshold) const -{ - { - NTA_ASSERT(invariants()); - } +bool Segment::isActive(const CState &activities, Real permConnected, + UInt activationThreshold) const { + { NTA_ASSERT(invariants()); } UInt activity = 0; @@ -124,7 +104,8 @@ bool Segment::isActive(const CState& activities, // TODO: maintain nPermConnected incrementally?? for (UInt i = 0; i != size() && activity < activationThreshold; ++i) - if (_synapses[i].permanence() >= permConnected && activities.isSet(_synapses[i].srcCellIdx())) + if (_synapses[i].permanence() >= permConnected && + activities.isSet(_synapses[i].srcCellIdx())) activity++; return activity >= activationThreshold; @@ -137,17 +118,14 @@ bool Segment::isActive(const CState& activities, * providing good predictions. * */ -Real Segment::dutyCycle(UInt iteration, bool active, bool readOnly) -{ - { - NTA_ASSERT(iteration > 0); - } +Real Segment::dutyCycle(UInt iteration, bool active, bool readOnly) { + { NTA_ASSERT(iteration > 0); } Real dutyCycle = 0.0; // For tier 0, compute it from total number of positive activations seen if (iteration <= _dutyCycleTiers[1]) { - dutyCycle = ((Real) _positiveActivations) / iteration; + dutyCycle = ((Real)_positiveActivations) / iteration; if (!readOnly) { _lastPosDutyCycleIteration = iteration; _lastPosDutyCycle = dutyCycle; @@ -159,13 +137,12 @@ Real Segment::dutyCycle(UInt iteration, bool active, bool readOnly) UInt age = iteration - _lastPosDutyCycleIteration; // If it's already up to date we can return our cached value - if ( age == 0 && !active) + if (age == 0 && !active) return _lastPosDutyCycle; // Figure out which alpha we're using Real alpha = 0; - for (UInt tierIdx= _numTiers-1; tierIdx > 0; tierIdx--) - { + for (UInt tierIdx = _numTiers - 1; tierIdx > 0; tierIdx--) { if (iteration > _dutyCycleTiers[tierIdx]) { alpha = _dutyCycleAlphas[tierIdx]; break; @@ -173,7 +150,7 @@ Real Segment::dutyCycle(UInt iteration, bool active, bool readOnly) } // Update duty cycle - dutyCycle = pow((Real64) (1.0 - alpha), (Real64)age) * _lastPosDutyCycle; + dutyCycle = pow((Real64)(1.0 - alpha), (Real64)age) * _lastPosDutyCycle; if (active) dutyCycle += alpha; @@ -186,19 +163,18 @@ Real Segment::dutyCycle(UInt iteration, bool active, bool readOnly) return dutyCycle; } -UInt Segment::computeActivity(const CState& activities, Real permConnected, +UInt Segment::computeActivity(const CState &activities, Real permConnected, bool connectedSynapsesOnly) const { - { - NTA_ASSERT(invariants()); - } + { NTA_ASSERT(invariants()); } UInt activity = 0; if (connectedSynapsesOnly) { for (UInt i = 0; i != size(); ++i) - if (activities.isSet(_synapses[i].srcCellIdx()) && (_synapses[i].permanence() >= permConnected)) + if (activities.isSet(_synapses[i].srcCellIdx()) && + (_synapses[i].permanence() >= permConnected)) activity++; } else { for (UInt i = 0; i != size(); ++i) @@ -209,36 +185,33 @@ UInt Segment::computeActivity(const CState& activities, Real permConnected, return activity; } -void -Segment::addSynapses(const std::set& srcCells, Real initStrength, - Real permConnected) -{ +void Segment::addSynapses(const std::set &srcCells, Real initStrength, + Real permConnected) { auto srcCellIdx = srcCells.begin(); for (; srcCellIdx != srcCells.end(); ++srcCellIdx) { _synapses.push_back(InSynapse(*srcCellIdx, initStrength)); if (initStrength >= permConnected) - ++ _nConnected; + ++_nConnected; } sort(_synapses, InSynapseOrder()); NTA_ASSERT(invariants()); // will catch non-unique synapses } -void Segment::decaySynapses(Real decay, std::vector& removed, - Real permConnected, bool doDecay) -{ +void Segment::decaySynapses(Real decay, std::vector &removed, + Real permConnected, bool doDecay) { NTA_ASSERT(invariants()); if (_synapses.empty()) return; static std::vector del; - del.clear(); // purge residual data + del.clear(); // purge residual data for (UInt i = 0; i != _synapses.size(); ++i) { - int wasConnected = (int) (_synapses[i].permanence() >= permConnected); + int wasConnected = (int)(_synapses[i].permanence() >= permConnected); if (_synapses[i].permanence() < decay) { @@ -249,7 +222,7 @@ void Segment::decaySynapses(Real decay, std::vector& removed, _synapses[i].permanence() -= decay; } - int isConnected = (int) (_synapses[i].permanence() >= permConnected); + int isConnected = (int)(_synapses[i].permanence() >= permConnected); _nConnected += isConnected - wasConnected; } @@ -259,7 +232,6 @@ void Segment::decaySynapses(Real decay, std::vector& removed, NTA_ASSERT(invariants()); } - //-------------------------------------------------------------------------------- /** * Subtract decay from each synapses' permanence value. @@ -267,16 +239,15 @@ void Segment::decaySynapses(Real decay, std::vector& removed, * are inserted into the "removed" list. * */ -void Segment::decaySynapses2(Real decay, std::vector& removed, - Real permConnected) -{ +void Segment::decaySynapses2(Real decay, std::vector &removed, + Real permConnected) { NTA_ASSERT(invariants()); if (_synapses.empty()) return; static std::vector del; - del.clear(); // purge residual data + del.clear(); // purge residual data for (UInt i = 0; i != _synapses.size(); ++i) { @@ -296,11 +267,10 @@ void Segment::decaySynapses2(Real decay, std::vector& removed, _synapses[i].permanence() -= decay; // If it was connected and is now below permanence, reduce connected count - if ( (_synapses[i].permanence() + decay >= permConnected) - && (_synapses[i].permanence() < permConnected) ) + if ((_synapses[i].permanence() + decay >= permConnected) && + (_synapses[i].permanence() < permConnected)) _nConnected--; } - } _removeSynapses(del); @@ -314,10 +284,8 @@ void Segment::decaySynapses2(Real decay, std::vector& removed, * permanence. * */ -struct InPermanenceOrder -{ - inline bool operator()(const InSynapse& a, const InSynapse& b) const - { +struct InPermanenceOrder { + inline bool operator()(const InSynapse &a, const InSynapse &b) const { return a.permanence() < b.permanence(); } }; @@ -328,15 +296,10 @@ struct InPermanenceOrder * increasing source cell index. * */ -struct InSrcCellOrder -{ - inline bool operator()(const UInt a, const UInt b) const - { - return a < b; - } +struct InSrcCellOrder { + inline bool operator()(const UInt a, const UInt b) const { return a < b; } }; - //---------------------------------------------------------------------- /** * Free up some synapses in this segment. We always free up inactive @@ -348,9 +311,8 @@ void Segment::freeNSynapses(UInt numToFree, std::vector &inactiveSegmentIndices, std::vector &activeSynapseIndices, std::vector &activeSegmentIndices, - std::vector& removed, UInt verbosity, - UInt nCellsPerCol, Real permMax) -{ + std::vector &removed, UInt verbosity, + UInt nCellsPerCol, Real permMax) { NTA_CHECK(inactiveSegmentIndices.size() == inactiveSynapseIndices.size()); NTA_CHECK(activeSegmentIndices.size() == activeSynapseIndices.size()); NTA_ASSERT(numToFree <= _synapses.size()); @@ -359,9 +321,8 @@ void Segment::freeNSynapses(UInt numToFree, if (verbosity >= 4) { std::cout << "\nIn CPP freeNSynapses with numToFree = " << numToFree - << ", inactiveSynapses = "; - for (auto & inactiveSynapseIndice : inactiveSynapseIndices) - { + << ", inactiveSynapses = "; + for (auto &inactiveSynapseIndice : inactiveSynapseIndices) { printSynapse(inactiveSynapseIndice, nCellsPerCol); } std::cout << "\n"; @@ -372,22 +333,22 @@ void Segment::freeNSynapses(UInt numToFree, // We first choose from inactive synapses, in order of increasing permanence InSynapses candidates; - for (UInt i = 0; i < inactiveSegmentIndices.size(); i++) - { + for (UInt i = 0; i < inactiveSegmentIndices.size(); i++) { // Put in *segment indices*, not source cell indices - candidates.push_back(InSynapse(inactiveSegmentIndices[i], - _synapses[inactiveSegmentIndices[i]].permanence())); + candidates.push_back( + InSynapse(inactiveSegmentIndices[i], + _synapses[inactiveSegmentIndices[i]].permanence())); } // If we need more, choose from active synapses in order of increasing // permanence values. This set has lower priority than inactive synapses // so we add a constant permanence value for sorting purposes if (candidates.size() < numToFree) { - for (UInt i = 0; i < activeSegmentIndices.size(); i++) - { + for (UInt i = 0; i < activeSegmentIndices.size(); i++) { // Put in *segment indices*, not source cell indices - candidates.push_back(InSynapse(activeSegmentIndices[i], - _synapses[activeSegmentIndices[i]].permanence() + permMax)); + candidates.push_back( + InSynapse(activeSegmentIndices[i], + _synapses[activeSegmentIndices[i]].permanence() + permMax)); } } @@ -397,9 +358,8 @@ void Segment::freeNSynapses(UInt numToFree, //---------------------------------------------------------------------- // Create the final list of synapses we will remove static std::vector del; - del.clear(); // purge residual data - for (UInt i = 0; i < numToFree; i++) - { + del.clear(); // purge residual data + for (UInt i = 0; i < numToFree; i++) { del.push_back(candidates[i].srcCellIdx()); UInt cellIdx = _synapses[candidates[i].srcCellIdx()].srcCellIdx(); removed.push_back(cellIdx); @@ -408,8 +368,7 @@ void Segment::freeNSynapses(UInt numToFree, // Debug statements if (verbosity >= 4) { std::cout << "Removing these synapses: "; - for (auto & elem : removed) - { + for (auto &elem : removed) { printSynapse(elem, nCellsPerCol); } std::cout << "\n"; @@ -427,61 +386,50 @@ void Segment::freeNSynapses(UInt numToFree, } // Debug statements - if (verbosity >= 4) - { + if (verbosity >= 4) { std::cout << "Segment AFTER remove synapses: "; print(std::cout, nCellsPerCol); std::cout << "\n"; } } - - - -void Segment::print(std::ostream& outStream, UInt nCellsPerCol) const -{ - outStream << (_seqSegFlag ? "True " : "False ") - << "dc" << std::setprecision(4) << _lastPosDutyCycle << " (" +void Segment::print(std::ostream &outStream, UInt nCellsPerCol) const { + outStream << (_seqSegFlag ? "True " : "False ") << "dc" + << std::setprecision(4) << _lastPosDutyCycle << " (" << _positiveActivations << "/" << _totalActivations << ") "; for (UInt i = 0; i != _synapses.size(); ++i) { if (nCellsPerCol > 0) { UInt cellIdx = _synapses[i].srcCellIdx(); - UInt col = (UInt) (cellIdx / nCellsPerCol); - UInt cell = cellIdx - col*nCellsPerCol; - outStream << "[" << col << "," << cell << "]" - << std::setprecision(4) << _synapses[i].permanence() - << " "; + UInt col = (UInt)(cellIdx / nCellsPerCol); + UInt cell = cellIdx - col * nCellsPerCol; + outStream << "[" << col << "," << cell << "]" << std::setprecision(4) + << _synapses[i].permanence() << " "; } else { outStream << _synapses[i]; } - if (i < _synapses.size() -1) + if (i < _synapses.size() - 1) std::cout << " "; } } -namespace nupic{ - namespace algorithms { - namespace Cells4 { - - std::ostream& operator<<(std::ostream& outStream, const Segment& seg) - { - seg.print(outStream); - return outStream; - } - - std::ostream& operator<<(std::ostream& outStream, const CState& cstate) - { - cstate.print(outStream); - return outStream; - } - - std::ostream& operator<<( - std::ostream& outStream, const CStateIndexed& cstate) - { - cstate.print(outStream); - return outStream; - } - } - } +namespace nupic { +namespace algorithms { +namespace Cells4 { + +std::ostream &operator<<(std::ostream &outStream, const Segment &seg) { + seg.print(outStream); + return outStream; } +std::ostream &operator<<(std::ostream &outStream, const CState &cstate) { + cstate.print(outStream); + return outStream; +} + +std::ostream &operator<<(std::ostream &outStream, const CStateIndexed &cstate) { + cstate.print(outStream); + return outStream; +} +} // namespace Cells4 +} // namespace algorithms +} // namespace nupic diff --git a/src/nupic/algorithms/Segment.hpp b/src/nupic/algorithms/Segment.hpp index 441a03172d..8f76410b55 100644 --- a/src/nupic/algorithms/Segment.hpp +++ b/src/nupic/algorithms/Segment.hpp @@ -31,12 +31,11 @@ #include #include +#include #include // is_sorted -#include // binary_save +#include // binary_save #include #include -#include - //----------------------------------------------------------------------- /** @@ -83,880 +82,788 @@ //----------------------------------------------------------------------- namespace nupic { - namespace algorithms { - namespace Cells4 { - - //----------------------------------------------------------------------- - /** - * Encapsulate the arrays used to maintain per-cell state. - */ - class CState : Serializable - { - public: - static const UInt VERSION = 1; - - CState() - { - _nCells = 0; - _pData = nullptr; - _fMemoryAllocatedByPython = false; - _version = VERSION; - } - ~CState() - { - if (_fMemoryAllocatedByPython == false && _pData != nullptr) - delete [] _pData; - } - CState& operator=(const CState& o) - { - NTA_ASSERT(_nCells == o._nCells); // _nCells should be static, since it is the same size for all CStates - memcpy(_pData, o._pData, _nCells); - return *this; - } - bool initialize(const UInt nCells) - { - if (_nCells != 0) // if already initialized - return false; // don't do it again - if (nCells == 0) // if a bogus value - return false; // bail out - _nCells = nCells; - _pData = new Byte[_nCells]; - memset(_pData, 0, _nCells); - return true; - } - void usePythonMemory(Byte* pData, const UInt nCells) - { - // delete a prior allocation - if (_fMemoryAllocatedByPython == false && _pData != nullptr) - delete [] _pData; - - // use the supplied memory and remember its size - _nCells = nCells; - _pData = pData; - _fMemoryAllocatedByPython = true; - } - bool isSet(const UInt cellIdx) const - { - return _pData[cellIdx] != 0; - } - void set(const UInt cellIdx) - { - _pData[cellIdx] = 1; - } - void resetAll() - { - memset(_pData, 0, _nCells); - } - Byte* arrayPtr() const - { - // We expose the data array to Python. For objects in derived - // class CStateIndexed, a Python script can wreak havoc by - // modifying the array, since the _cellsOn index will become - // inconsistent. - return _pData ; - } - void print(std::ostream& outStream) const - { - outStream << version() << " " - << _fMemoryAllocatedByPython << " " - << _nCells << std::endl; - for (UInt i = 0; i < _nCells; ++i) - { - outStream << _pData[i] << " "; - } - outStream << std::endl - << "end" << std::endl; - } - using Serializable::write; - virtual void write(CStateProto::Builder& proto) const override - { - proto.setVersion(VERSION); - proto.setFMemoryAllocatedByPython(_fMemoryAllocatedByPython); - auto pDataProto = proto.initPData(_nCells); - for (UInt i = 0; i < _nCells; ++i) - { - pDataProto[i] = _pData[i]; - } - } - using Serializable::read; - virtual void read(CStateProto::Reader& proto) override - { - NTA_CHECK(proto.getVersion() == 1); - _fMemoryAllocatedByPython = proto.getFMemoryAllocatedByPython(); - auto pDataProto = proto.getPData(); - _nCells = pDataProto.size(); - for (UInt i = 0; i < _nCells; ++i) - { - _pData[i] = pDataProto[i]; - } - } - void load(std::istream& inStream) - { - UInt version; - inStream >> version; - NTA_CHECK(version == 1); - inStream >> _fMemoryAllocatedByPython - >> _nCells; - for (UInt i = 0; i < _nCells; ++i) - { - inStream >> _pData[i]; - } - std::string token; - inStream >> token; - NTA_CHECK(token == "end"); - } - UInt version() const - { - return _version; - } - protected: - UInt _version; - UInt _nCells; // should be static, since same size for all CStates - Byte* _pData; // protected in C++, but exposed to the Python code - bool _fMemoryAllocatedByPython; - }; - /** - * Add an index to CState so that we can find all On cells without - * a sequential search of the entire array. - */ - class CStateIndexed : public CState - { - public: - static const UInt VERSION = 1; - - CStateIndexed() : CState() - { - _version = VERSION; - _countOn = 0; - _isSorted = true; - } - CStateIndexed& operator=(CStateIndexed& o) - { - NTA_ASSERT(_nCells == o._nCells); // _nCells should be static, since it is the same size for all CStates - // Is it faster to reset only the old nonzero indices and set only the new ones? - std::vector::iterator iterOn; - // reset the old On cells - for (iterOn = _cellsOn.begin(); iterOn != _cellsOn.end(); ++iterOn) - _pData[*iterOn] = 0; - // set the new On cells - for (iterOn = o._cellsOn.begin(); iterOn != o._cellsOn.end(); ++iterOn) - _pData[*iterOn] = 1; - // use the new On tracker - _cellsOn = o._cellsOn; - _countOn = o._countOn; - _isSorted = o._isSorted; - return *this; - } - std::vector cellsOn(bool fSorted = false) - { - // It's better for the caller to ask us to sort, rather than - // to sort himself, since we can optimize out the sort when we - // know the vector is already sorted. - if (fSorted && !_isSorted) { - std::sort(_cellsOn.begin(), _cellsOn.end()); - _isSorted = true; - } - return _cellsOn; // returns a copy that can be modified - } - void set(const UInt cellIdx) - { - if (!isSet(cellIdx)) { - CState::set(cellIdx); // call the base class function - if (_isSorted && _countOn > 0 && cellIdx < _cellsOn.back()) - _isSorted = false; - _cellsOn.push_back(cellIdx); // add to the list of On cells - _countOn++; // count the On cell; more efficient than .size()? - } - } - void resetAll() - { - // Is it faster just to zero the _cellsOn indices? - std::vector::iterator iterOn; - // reset the old On cells - for (iterOn = _cellsOn.begin(); iterOn != _cellsOn.end(); ++iterOn) - _pData[*iterOn] = 0; - _cellsOn.clear(); - _countOn = 0; - _isSorted = true; - } - void print(std::ostream& outStream) const - { - outStream << version() << " " - << _fMemoryAllocatedByPython << " " - << _nCells << std::endl; - for (UInt i = 0; i < _nCells; ++i) - { - outStream << _pData[i] << " "; - } - outStream << _countOn << " "; - outStream << _cellsOn.size() << " "; - for (auto & elem : _cellsOn) - { - outStream << elem << " "; - } - outStream << "end" << std::endl; - } - void write(CStateProto::Builder& proto) const override - { - CState::write(proto); - proto.setCountOn(_countOn); - auto cellsOnProto = proto.initCellsOn(_cellsOn.size()); - for (UInt i = 0; i < _cellsOn.size(); ++i) - { - cellsOnProto.set(i, _cellsOn[i]); - } - } - void read(CStateProto::Reader& proto) override - { - CState::read(proto); - _countOn = proto.getCountOn(); - auto cellsOnProto = proto.getCellsOn(); - _cellsOn.resize(cellsOnProto.size()); - for (UInt i = 0; i < cellsOnProto.size(); ++i) - { - _cellsOn[i] = cellsOnProto[i]; - } - } - void load(std::istream& inStream) - { - UInt version; - inStream >> version; - NTA_CHECK(version == 1); - inStream >> _fMemoryAllocatedByPython - >> _nCells; - for (UInt i = 0; i < _nCells; ++i) - { - inStream >> _pData[i]; - } - inStream >> _countOn; - UInt nCellsOn; - inStream >> nCellsOn; - UInt v; - for (UInt i = 0; i < nCellsOn; ++i) - { - inStream >> v; - _cellsOn.push_back(v); - } - std::string token; - inStream >> token; - NTA_CHECK(token == "end"); - } - UInt version() const - { - return _version; - } - private: - UInt _version; - std::vector _cellsOn; - UInt _countOn; // how many cells are On - bool _isSorted; // avoid unnecessary sorting - }; - - // These are iteration count tiers used when computing segment duty cycle - const UInt _numTiers = 9; - const UInt _dutyCycleTiers[] = {0, 100, 320, 1000, - 3200, 10000, 32000, 100000, - 320000}; - - // This is the alpha used in each tier. dutyCycleAlphas[n] is used when - /// iterationIdx > dutyCycleTiers[n] - const Real _dutyCycleAlphas[] = {0.0, 0.0032, 0.0010, 0.00032, - 0.00010, 0.000032, 0.000010, 0.0000032, - 0.0000010}; - - //----------------------------------------------------------------------- - // Forward declarations - class Segment; - - - //----------------------------------------------------------------------- - struct InSynapseOrder - { - inline bool operator()(const InSynapse& a, const InSynapse& b) const - { - return a.srcCellIdx() < b.srcCellIdx(); - } - }; - - - //----------------------------------------------------------------------- - class Segment : Serializable - { - public: - typedef std::vector< InSynapse > InSynapses; - - // Variables representing various metrics of segment activity - UInt _totalActivations; // Total number of times segment was active - UInt _positiveActivations; // Total number of times segment was - // positively reinforced - UInt _lastActiveIteration; // The last iteration on which the segment - // became active (used in learning only) - - Real _lastPosDutyCycle; - UInt _lastPosDutyCycleIteration; - - private: - bool _seqSegFlag; // sequence segment flag - Real _frequency; // frequency [UNUSED IN LATEST IMPLEMENTATION] - InSynapses _synapses; // incoming connections to this segment - UInt _nConnected; // number of current connected synapses - - - public: - //---------------------------------------------------------------------- - inline Segment() - : _totalActivations(1), - _positiveActivations(1), - _lastActiveIteration(0), - _lastPosDutyCycle(0.0), - _lastPosDutyCycleIteration(0), - _seqSegFlag(false), - _frequency(0), - _synapses(), - _nConnected(0) - {} - - //---------------------------------------------------------------------- - Segment(InSynapses _s, Real frequency, bool seqSegFlag, - Real permConnected, UInt iteration); - - //----------------------------------------------------------------------- - Segment(const Segment& o); - - //----------------------------------------------------------------------- - Segment& operator=(const Segment& o); - - //----------------------------------------------------------------------- - /** - Checks that the synapses are unique and sorted in order of increasing - src cell index. This is required by subsequent algorithms. Order - matters for _removeSynapses and updateSynapses, but it prevents from - partitioning the synapses in above/below permConnected, which test is - the bottleneck in activity() (which is the overall bottleneck). - - TODO: Maybe we can remove the sorted restriction? Check if - _removeSynapses and updateSynapses are major bottlenecks. - - */ - inline bool invariants() const - { - static std::vector indices; - static UInt highWaterSize = 0; - if (highWaterSize < _synapses.size()) { - highWaterSize = _synapses.size(); - indices.reserve(highWaterSize); - } - indices.clear(); // purge residual data - - for (UInt i = 0; i != _synapses.size(); ++i) - indices.push_back(_synapses[i].srcCellIdx()); +namespace algorithms { +namespace Cells4 { +//----------------------------------------------------------------------- +/** + * Encapsulate the arrays used to maintain per-cell state. + */ +class CState : Serializable { +public: + static const UInt VERSION = 1; + + CState() { + _nCells = 0; + _pData = nullptr; + _fMemoryAllocatedByPython = false; + _version = VERSION; + } + ~CState() { + if (_fMemoryAllocatedByPython == false && _pData != nullptr) + delete[] _pData; + } + CState &operator=(const CState &o) { + NTA_ASSERT(_nCells == o._nCells); // _nCells should be static, since it is + // the same size for all CStates + memcpy(_pData, o._pData, _nCells); + return *this; + } + bool initialize(const UInt nCells) { + if (_nCells != 0) // if already initialized + return false; // don't do it again + if (nCells == 0) // if a bogus value + return false; // bail out + _nCells = nCells; + _pData = new Byte[_nCells]; + memset(_pData, 0, _nCells); + return true; + } + void usePythonMemory(Byte *pData, const UInt nCells) { + // delete a prior allocation + if (_fMemoryAllocatedByPython == false && _pData != nullptr) + delete[] _pData; + + // use the supplied memory and remember its size + _nCells = nCells; + _pData = pData; + _fMemoryAllocatedByPython = true; + } + bool isSet(const UInt cellIdx) const { return _pData[cellIdx] != 0; } + void set(const UInt cellIdx) { _pData[cellIdx] = 1; } + void resetAll() { memset(_pData, 0, _nCells); } + Byte *arrayPtr() const { + // We expose the data array to Python. For objects in derived + // class CStateIndexed, a Python script can wreak havoc by + // modifying the array, since the _cellsOn index will become + // inconsistent. + return _pData; + } + void print(std::ostream &outStream) const { + outStream << version() << " " << _fMemoryAllocatedByPython << " " << _nCells + << std::endl; + for (UInt i = 0; i < _nCells; ++i) { + outStream << _pData[i] << " "; + } + outStream << std::endl << "end" << std::endl; + } + using Serializable::write; + virtual void write(CStateProto::Builder &proto) const override { + proto.setVersion(VERSION); + proto.setFMemoryAllocatedByPython(_fMemoryAllocatedByPython); + auto pDataProto = proto.initPData(_nCells); + for (UInt i = 0; i < _nCells; ++i) { + pDataProto[i] = _pData[i]; + } + } + using Serializable::read; + virtual void read(CStateProto::Reader &proto) override { + NTA_CHECK(proto.getVersion() == 1); + _fMemoryAllocatedByPython = proto.getFMemoryAllocatedByPython(); + auto pDataProto = proto.getPData(); + _nCells = pDataProto.size(); + for (UInt i = 0; i < _nCells; ++i) { + _pData[i] = pDataProto[i]; + } + } + void load(std::istream &inStream) { + UInt version; + inStream >> version; + NTA_CHECK(version == 1); + inStream >> _fMemoryAllocatedByPython >> _nCells; + for (UInt i = 0; i < _nCells; ++i) { + inStream >> _pData[i]; + } + std::string token; + inStream >> token; + NTA_CHECK(token == "end"); + } + UInt version() const { return _version; } + +protected: + UInt _version; + UInt _nCells; // should be static, since same size for all CStates + Byte *_pData; // protected in C++, but exposed to the Python code + bool _fMemoryAllocatedByPython; +}; +/** + * Add an index to CState so that we can find all On cells without + * a sequential search of the entire array. + */ +class CStateIndexed : public CState { +public: + static const UInt VERSION = 1; + + CStateIndexed() : CState() { + _version = VERSION; + _countOn = 0; + _isSorted = true; + } + CStateIndexed &operator=(CStateIndexed &o) { + NTA_ASSERT(_nCells == o._nCells); // _nCells should be static, since it is + // the same size for all CStates + // Is it faster to reset only the old nonzero indices and set only the new + // ones? + std::vector::iterator iterOn; + // reset the old On cells + for (iterOn = _cellsOn.begin(); iterOn != _cellsOn.end(); ++iterOn) + _pData[*iterOn] = 0; + // set the new On cells + for (iterOn = o._cellsOn.begin(); iterOn != o._cellsOn.end(); ++iterOn) + _pData[*iterOn] = 1; + // use the new On tracker + _cellsOn = o._cellsOn; + _countOn = o._countOn; + _isSorted = o._isSorted; + return *this; + } + std::vector cellsOn(bool fSorted = false) { + // It's better for the caller to ask us to sort, rather than + // to sort himself, since we can optimize out the sort when we + // know the vector is already sorted. + if (fSorted && !_isSorted) { + std::sort(_cellsOn.begin(), _cellsOn.end()); + _isSorted = true; + } + return _cellsOn; // returns a copy that can be modified + } + void set(const UInt cellIdx) { + if (!isSet(cellIdx)) { + CState::set(cellIdx); // call the base class function + if (_isSorted && _countOn > 0 && cellIdx < _cellsOn.back()) + _isSorted = false; + _cellsOn.push_back(cellIdx); // add to the list of On cells + _countOn++; // count the On cell; more efficient than .size()? + } + } + void resetAll() { + // Is it faster just to zero the _cellsOn indices? + std::vector::iterator iterOn; + // reset the old On cells + for (iterOn = _cellsOn.begin(); iterOn != _cellsOn.end(); ++iterOn) + _pData[*iterOn] = 0; + _cellsOn.clear(); + _countOn = 0; + _isSorted = true; + } + void print(std::ostream &outStream) const { + outStream << version() << " " << _fMemoryAllocatedByPython << " " << _nCells + << std::endl; + for (UInt i = 0; i < _nCells; ++i) { + outStream << _pData[i] << " "; + } + outStream << _countOn << " "; + outStream << _cellsOn.size() << " "; + for (auto &elem : _cellsOn) { + outStream << elem << " "; + } + outStream << "end" << std::endl; + } + void write(CStateProto::Builder &proto) const override { + CState::write(proto); + proto.setCountOn(_countOn); + auto cellsOnProto = proto.initCellsOn(_cellsOn.size()); + for (UInt i = 0; i < _cellsOn.size(); ++i) { + cellsOnProto.set(i, _cellsOn[i]); + } + } + void read(CStateProto::Reader &proto) override { + CState::read(proto); + _countOn = proto.getCountOn(); + auto cellsOnProto = proto.getCellsOn(); + _cellsOn.resize(cellsOnProto.size()); + for (UInt i = 0; i < cellsOnProto.size(); ++i) { + _cellsOn[i] = cellsOnProto[i]; + } + } + void load(std::istream &inStream) { + UInt version; + inStream >> version; + NTA_CHECK(version == 1); + inStream >> _fMemoryAllocatedByPython >> _nCells; + for (UInt i = 0; i < _nCells; ++i) { + inStream >> _pData[i]; + } + inStream >> _countOn; + UInt nCellsOn; + inStream >> nCellsOn; + UInt v; + for (UInt i = 0; i < nCellsOn; ++i) { + inStream >> v; + _cellsOn.push_back(v); + } + std::string token; + inStream >> token; + NTA_CHECK(token == "end"); + } + UInt version() const { return _version; } + +private: + UInt _version; + std::vector _cellsOn; + UInt _countOn; // how many cells are On + bool _isSorted; // avoid unnecessary sorting +}; + +// These are iteration count tiers used when computing segment duty cycle +const UInt _numTiers = 9; +const UInt _dutyCycleTiers[] = {0, 100, 320, 1000, 3200, + 10000, 32000, 100000, 320000}; + +// This is the alpha used in each tier. dutyCycleAlphas[n] is used when +/// iterationIdx > dutyCycleTiers[n] +const Real _dutyCycleAlphas[] = {0.0, 0.0032, 0.0010, + 0.00032, 0.00010, 0.000032, + 0.000010, 0.0000032, 0.0000010}; -#ifndef NDEBUG - if (indices.size() != _synapses.size()) - std::cout << "Indices are not unique" << std::endl; - - if (!is_sorted(indices, true, true)) - std::cout << "Indices are not sorted" << std::endl; +//----------------------------------------------------------------------- +// Forward declarations +class Segment; - if (_frequency < 0) - std::cout << "Frequency is less than zero" << std::endl; -#endif +//----------------------------------------------------------------------- +struct InSynapseOrder { + inline bool operator()(const InSynapse &a, const InSynapse &b) const { + return a.srcCellIdx() < b.srcCellIdx(); + } +}; - return _frequency >= 0 && is_sorted(indices, true, true); - } +//----------------------------------------------------------------------- +class Segment : Serializable { +public: + typedef std::vector InSynapses; + + // Variables representing various metrics of segment activity + UInt _totalActivations; // Total number of times segment was active + UInt _positiveActivations; // Total number of times segment was + // positively reinforced + UInt _lastActiveIteration; // The last iteration on which the segment + // became active (used in learning only) + + Real _lastPosDutyCycle; + UInt _lastPosDutyCycleIteration; + +private: + bool _seqSegFlag; // sequence segment flag + Real _frequency; // frequency [UNUSED IN LATEST IMPLEMENTATION] + InSynapses _synapses; // incoming connections to this segment + UInt _nConnected; // number of current connected synapses + +public: + //---------------------------------------------------------------------- + inline Segment() + : _totalActivations(1), _positiveActivations(1), _lastActiveIteration(0), + _lastPosDutyCycle(0.0), _lastPosDutyCycleIteration(0), + _seqSegFlag(false), _frequency(0), _synapses(), _nConnected(0) {} + + //---------------------------------------------------------------------- + Segment(InSynapses _s, Real frequency, bool seqSegFlag, Real permConnected, + UInt iteration); - //----------------------------------------------------------------------- - /** - * Check that _nConnected is equal to actual number of connected synapses - * - */ - inline bool checkConnected(Real permConnected) const { - // - UInt nc = 0; - for (UInt i = 0; i != _synapses.size(); ++i) - nc += (_synapses[i].permanence() >= permConnected); - - if (nc != _nConnected) { - std::cout << "\nConnected stats inconsistent. _nConnected=" - << _nConnected << ", computed nc=" << nc << std::endl; - } - - return nc == _nConnected; - } + //----------------------------------------------------------------------- + Segment(const Segment &o); - //---------------------------------------------------------------------- - /** - * Various accessors - */ - inline bool empty() const { return _synapses.empty(); } - inline UInt size() const { return _synapses.size(); } - inline bool isSequenceSegment() const { return _seqSegFlag; } - inline Real& frequency() { return _frequency; } - inline Real getFrequency() const { return _frequency; } - inline UInt nConnected() const { return _nConnected; } - inline UInt getTotalActivations() const { return _totalActivations;} - inline UInt getPositiveActivations() const { return _positiveActivations;} - inline UInt getLastActiveIteration() const { return _lastActiveIteration;} - inline Real getLastPosDutyCycle() const { return _lastPosDutyCycle;} - inline UInt getLastPosDutyCycleIteration() const - { return _lastPosDutyCycleIteration;} - - //----------------------------------------------------------------------- - /** - * Checks whether the given src cellIdx is already contained in this segment - * or not. - * TODO: optimize with at least a binary search - */ - inline bool has(UInt srcCellIdx) const - { - NTA_ASSERT(srcCellIdx != (UInt) -1); - - UInt lo = 0; - UInt hi = _synapses.size(); - while (lo < hi) { - const UInt test = (lo + hi)/2; - if (_synapses[test].srcCellIdx() < srcCellIdx) - lo = test + 1; - else if (_synapses[test].srcCellIdx() > srcCellIdx) - hi = test; - else - return true; - } - return false; - } + //----------------------------------------------------------------------- + Segment &operator=(const Segment &o); - //----------------------------------------------------------------------- - /** - * Returns the permanence of the idx-th synapse on this Segment. That idx is *not* - * a cell index, but just the index of the synapse on that segment, i.e. that - * index will change if synapses are deleted from this segment in synpase - * adaptation or global decay. - */ - inline void setPermanence(UInt idx, Real val) - { - NTA_ASSERT(idx < _synapses.size()); - - _synapses[idx].permanence() = val; - } + //----------------------------------------------------------------------- + /** + Checks that the synapses are unique and sorted in order of increasing + src cell index. This is required by subsequent algorithms. Order + matters for _removeSynapses and updateSynapses, but it prevents from + partitioning the synapses in above/below permConnected, which test is + the bottleneck in activity() (which is the overall bottleneck). + + TODO: Maybe we can remove the sorted restriction? Check if + _removeSynapses and updateSynapses are major bottlenecks. + + */ + inline bool invariants() const { + static std::vector indices; + static UInt highWaterSize = 0; + if (highWaterSize < _synapses.size()) { + highWaterSize = _synapses.size(); + indices.reserve(highWaterSize); + } + indices.clear(); // purge residual data + + for (UInt i = 0; i != _synapses.size(); ++i) + indices.push_back(_synapses[i].srcCellIdx()); - //----------------------------------------------------------------------- - /** - * Returns the permanence of the idx-th synapse on this Segment as a value - */ - inline Real getPermanence(UInt idx) const - { - NTA_ASSERT(idx < _synapses.size()); - NTA_ASSERT(0 <= _synapses[idx].permanence()); +#ifndef NDEBUG + if (indices.size() != _synapses.size()) + std::cout << "Indices are not unique" << std::endl; - return _synapses[idx].permanence(); - } + if (!is_sorted(indices, true, true)) + std::cout << "Indices are not sorted" << std::endl; - //----------------------------------------------------------------------- - /** - * Returns the source cell index of the synapse at index idx. - */ - inline UInt getSrcCellIdx(UInt idx) const - { - NTA_ASSERT(idx < _synapses.size()); - return _synapses[idx].srcCellIdx(); - } + if (_frequency < 0) + std::cout << "Frequency is less than zero" << std::endl; +#endif - //----------------------------------------------------------------------- - /** - * Returns the indices of all source cells in this segment. - * - * Parameter / return value: - * srcCells: an empty vector. The indices will be returned in - * this vector. - */ - inline void getSrcCellIndices(std::vector& srcCells) const - { - NTA_ASSERT(srcCells.size() == 0); - for (auto & elem : _synapses) { - srcCells.push_back(elem.srcCellIdx()); - } - } + return _frequency >= 0 && is_sorted(indices, true, true); + } - //----------------------------------------------------------------------- - /** - * Note that _seqSegFlag is set back to zero when the synapses are erased: when - * a segment is released, it's empty _AND_ it's no long a sequence segment. - * This simplifies further tests in the algorithm. - */ - inline void clear() - { - _synapses.clear(); - _synapses.resize(0); - _seqSegFlag = false; - _frequency = 0; - _nConnected = 0; - } + //----------------------------------------------------------------------- + /** + * Check that _nConnected is equal to actual number of connected synapses + * + */ + inline bool checkConnected(Real permConnected) const { + // + UInt nc = 0; + for (UInt i = 0; i != _synapses.size(); ++i) + nc += (_synapses[i].permanence() >= permConnected); + + if (nc != _nConnected) { + std::cout << "\nConnected stats inconsistent. _nConnected=" << _nConnected + << ", computed nc=" << nc << std::endl; + } + + return nc == _nConnected; + } + + //---------------------------------------------------------------------- + /** + * Various accessors + */ + inline bool empty() const { return _synapses.empty(); } + inline UInt size() const { return _synapses.size(); } + inline bool isSequenceSegment() const { return _seqSegFlag; } + inline Real &frequency() { return _frequency; } + inline Real getFrequency() const { return _frequency; } + inline UInt nConnected() const { return _nConnected; } + inline UInt getTotalActivations() const { return _totalActivations; } + inline UInt getPositiveActivations() const { return _positiveActivations; } + inline UInt getLastActiveIteration() const { return _lastActiveIteration; } + inline Real getLastPosDutyCycle() const { return _lastPosDutyCycle; } + inline UInt getLastPosDutyCycleIteration() const { + return _lastPosDutyCycleIteration; + } - //----------------------------------------------------------------------- - inline const InSynapse& operator[](UInt idx) const - { - NTA_ASSERT(idx < size()); - return _synapses[idx]; - } + //----------------------------------------------------------------------- + /** + * Checks whether the given src cellIdx is already contained in this segment + * or not. + * TODO: optimize with at least a binary search + */ + inline bool has(UInt srcCellIdx) const { + NTA_ASSERT(srcCellIdx != (UInt)-1); + + UInt lo = 0; + UInt hi = _synapses.size(); + while (lo < hi) { + const UInt test = (lo + hi) / 2; + if (_synapses[test].srcCellIdx() < srcCellIdx) + lo = test + 1; + else if (_synapses[test].srcCellIdx() > srcCellIdx) + hi = test; + else + return true; + } + return false; + } - //----------------------------------------------------------------------- - /** - * Adds synapses to this segment. - * - * Parameters: - * ========== - * - srcCells: a collection of source cell indices (the sources of the - * synapses). Source cell indices are unique on a segment, - * and are kept in increasing order. - * - initStrength: the initial strength to set for the new synapses - */ - void - addSynapses(const std::set& srcCells, Real initStrength, - Real permConnected); - - //----------------------------------------------------------------------- - /** - * Recompute _nConnected for this segment - * - * Parameters: - * ========== - * - permConnected: permanence values >= permConnected are considered - * connected. - * - */ - void recomputeConnected(Real permConnected) { - _nConnected = 0; - for (UInt i = 0; i != _synapses.size(); ++i) - if (_synapses[i].permanence() >= permConnected) - ++ _nConnected; - } + //----------------------------------------------------------------------- + /** + * Returns the permanence of the idx-th synapse on this Segment. That idx is + * *not* a cell index, but just the index of the synapse on that segment, i.e. + * that index will change if synapses are deleted from this segment in synpase + * adaptation or global decay. + */ + inline void setPermanence(UInt idx, Real val) { + NTA_ASSERT(idx < _synapses.size()); + + _synapses[idx].permanence() = val; + } - private: - - //----------------------------------------------------------------------- - /** - A private method invoked by this Segment when synapses need to be - removed. del contains the indices of the synapses to remove, as - indices of synapses on this segment (not source cell indices). This - method maintains the order of the synapses in the segment (they are - sorted in order of increasing source cell index). - */ - inline void _removeSynapses(const std::vector& del) - { - // TODO: check what happens if synapses doesn't exist anymore - // because of decay - UInt i = 0, idel = 0, j = 0; - - while (i < _synapses.size() && idel < del.size()) { - if (i == del[idel]) { - ++i; ++idel; - } else if (i < del[idel]) { - _synapses[j++] = _synapses[i++]; - } else if (del[idel] < i) { - NTA_CHECK(false); // Synapses have to be sorted! - } - } - - while (i < _synapses.size()) - _synapses[j++] = _synapses[i++]; - - _synapses.resize(j); - } + //----------------------------------------------------------------------- + /** + * Returns the permanence of the idx-th synapse on this Segment as a value + */ + inline Real getPermanence(UInt idx) const { + NTA_ASSERT(idx < _synapses.size()); + NTA_ASSERT(0 <= _synapses[idx].permanence()); - public: - //----------------------------------------------------------------------- - /** - * Updates synapses permanences, possibly removing synapses from the segment if their - * permanence drops to or below 0. - * - * Parameters: - * ========== - * - synapses: a collection of source cell indices to update (will be matched - * with source cell index of each synapse) - * - delta: the amount to add to the permanence value of the updated - * synapses - * - removed: collection of synapses that have been removed because their - * permanence dropped below 0 (srcCellIdx of synapses). - * - * TODO: have synapses be 2 pointers, to avoid copies in adaptSegments - */ - template // this blocks swig wrapping which doesn't happen right - inline void - updateSynapses(const std::vector& synapses, Real delta, - Real permMax, Real permConnected, - std::vector& removed) - { - { - NTA_ASSERT(invariants()); - NTA_ASSERT(is_sorted(synapses)); - } - - std::vector del; - - UInt i1 = 0, i2 = 0; - - while (i1 < size() && i2 < synapses.size()) { - - if (_synapses[i1].srcCellIdx() == synapses[i2]) { - - Real oldPerm = getPermanence(i1); - Real newPerm = std::min(oldPerm + delta, permMax); - - if (newPerm <= 0) { - removed.push_back(_synapses[i1].srcCellIdx()); - del.push_back(i1); - } - - setPermanence(i1, newPerm); - - int wasConnected = (int) (oldPerm >= permConnected); - int isConnected = (int) (newPerm >= permConnected); - - _nConnected += isConnected - wasConnected; - - ++i1; ++i2; - - } else if (_synapses[i1].srcCellIdx() < synapses[i2]) { - ++i1; - } else { - ++i2; - } - } - - // _removeSynapses maintains the order of the synapses - _removeSynapses(del); - - NTA_ASSERT(invariants()); - } + return _synapses[idx].permanence(); + } - //---------------------------------------------------------------------- - /** - * Subtract decay from each synapses' permanence value. - * Synapses whose permanence drops below 0 are removed and their source - * indices are inserted into the "removed" list. - * - * Parameters: - * ========== - * - decay: the amount to subtract from the permanence value the - * synapses - * - removed: srcCellIdx of the synapses that are removed - */ - void decaySynapses2(Real decay, std::vector& removed, - Real permConnected); - - //---------------------------------------------------------------------- - /** - * Decay synapses' permanence value. Synapses whose permanence drops - * below 0 are removed. - * - * Parameters: - * ========== - * - decay: the amount to subtract from the permanence value the - * synapses - * - removed: srcCellIdx of the synapses that are removed - */ - void decaySynapses(Real decay, std::vector& removed, - Real permConnected, bool doDecay=true); - - //---------------------------------------------------------------------- - /** - * Free up some synapses in this segment. We always free up inactive - * synapses (lowest permanence freed up first) before we start to free - * up active ones. - * - * Parameters: - * ========== - * numToFree: num synapses we have to remove - * inactiveSynapseIndices: list of inactive synapses (src cell indices) - * inactiveSegmentIndices: list of inactive synapses (index within segment) - * activeSynapseIndices: list of active synapses (src cell indices) - * activeSegmentIndices: list of active synapses (index within segment) - * removed: srcCellIdx of the synapses that are - * removed - * verbosity: verbosity level for debug printing - * nCellsPerCol: number of cells per column (for debug - * printing) - * permMax: maximum allowed permanence value - */ - void freeNSynapses(UInt numToFree, - std::vector &inactiveSynapseIndices, - std::vector &inactiveSegmentIndices, - std::vector &activeSynapseIndices, - std::vector &activeSegmentIndices, - std::vector& removed, UInt verbosity, - UInt nCellsPerCol, Real permMax); - - //---------------------------------------------------------------------- - /** - * Computes the activity level for a segment given permConnected and - * activationThreshold. A segment is active if it has more than - * activationThreshold connected synapses that are active due to - * activeState. - * - * Parameters: - * ========== - * - activities: pointer to activeStateT or activeStateT1 - * - * NOTE: called getSegmentActivityLevel in Python - */ - bool isActive(const CState& activities, - Real permConnected, UInt activationThreshold) const; - - //---------------------------------------------------------------------- - /** - * Compute the activity level of a segment using cell activity levels - * contain in activities. - * - * Parameters: - * ========== - * - activities: pointer to an array of cell activities. - * - * - permConnected: permanence values >= permConnected are considered - * connected. - * - * - connectedSynapsesOnly: if true, only consider synapses that are - * connected. - */ - UInt computeActivity(const CState& activities, Real permConnected, - bool connectedSynapsesOnly) const; - - //---------------------------------------------------------------------- - /** - * Compute/update and return the positive activations duty cycle of - * this segment. This is a measure of how often this segment is - * providing good predictions. - * - * Parameters: - * ========== - * iteration: Current compute iteration. Must be > 0! - * active: True if segment just provided a good prediction - * - */ - Real dutyCycle(UInt iteration, bool active, bool readOnly); - - - //---------------------------------------------------------------------- - /** - * Returns true if iteration is equal to one of the duty cycle tiers. - */ - static bool atDutyCycleTier(UInt iteration) - { - for (auto & _dutyCycleTier : _dutyCycleTiers) { - if (iteration == _dutyCycleTier) return true; - } - return false; - } + //----------------------------------------------------------------------- + /** + * Returns the source cell index of the synapse at index idx. + */ + inline UInt getSrcCellIdx(UInt idx) const { + NTA_ASSERT(idx < _synapses.size()); + return _synapses[idx].srcCellIdx(); + } - //---------------------------------------------------------------------- - // PERSISTENCE - //---------------------------------------------------------------------- - inline UInt persistentSize() const - { - std::stringstream buff; - this->save(buff); - return buff.str().size(); - } + //----------------------------------------------------------------------- + /** + * Returns the indices of all source cells in this segment. + * + * Parameter / return value: + * srcCells: an empty vector. The indices will be returned in + * this vector. + */ + inline void getSrcCellIndices(std::vector &srcCells) const { + NTA_ASSERT(srcCells.size() == 0); + for (auto &elem : _synapses) { + srcCells.push_back(elem.srcCellIdx()); + } + } - //---------------------------------------------------------------------- - using Serializable::write; - void write(SegmentProto::Builder& proto) const override - { - NTA_ASSERT(invariants()); - proto.setSeqSegFlag(_seqSegFlag); - proto.setFrequency(_frequency); - proto.setNConnected(_nConnected); - proto.setTotalActivations(_totalActivations); - proto.setPositiveActivations(_positiveActivations); - proto.setLastActiveIteration(_lastActiveIteration); - proto.setLastPosDutyCycle(_lastPosDutyCycle); - proto.setLastPosDutyCycleIteration(_lastPosDutyCycleIteration); - auto synapsesProto = proto.initSynapses(size()); - for (UInt i = 0; i < size(); ++i) - { - auto inSynapseProto = synapsesProto[i]; - inSynapseProto.setSrcCellIdx(_synapses[i].srcCellIdx()); - inSynapseProto.setPermanence(_synapses[i].permanence()); - } - } + //----------------------------------------------------------------------- + /** + * Note that _seqSegFlag is set back to zero when the synapses are erased: + * when a segment is released, it's empty _AND_ it's no long a sequence + * segment. This simplifies further tests in the algorithm. + */ + inline void clear() { + _synapses.clear(); + _synapses.resize(0); + _seqSegFlag = false; + _frequency = 0; + _nConnected = 0; + } - //---------------------------------------------------------------------- - using Serializable::read; - void read(SegmentProto::Reader& proto) override - { - _seqSegFlag = proto.getSeqSegFlag(); - _frequency = proto.getFrequency(); - _nConnected = proto.getNConnected(); - _totalActivations = proto.getTotalActivations(); - _positiveActivations = proto.getPositiveActivations(); - _lastActiveIteration = proto.getLastActiveIteration(); - _lastPosDutyCycle = proto.getLastPosDutyCycle(); - _lastPosDutyCycleIteration = proto.getLastPosDutyCycleIteration(); - _synapses.clear(); - for (auto inSynapseProto : proto.getSynapses()) - { - _synapses.emplace_back(inSynapseProto.getSrcCellIdx(), - inSynapseProto.getPermanence()); - } - } + //----------------------------------------------------------------------- + inline const InSynapse &operator[](UInt idx) const { + NTA_ASSERT(idx < size()); + return _synapses[idx]; + } - //---------------------------------------------------------------------- - inline void save(std::ostream& outStream) const - { - NTA_ASSERT(invariants()); - outStream << size() << ' ' - << _seqSegFlag << ' ' - << _frequency << ' ' - << _nConnected << ' ' - << _totalActivations << ' ' - << _positiveActivations << ' ' - << _lastActiveIteration << ' ' - << _lastPosDutyCycle << ' ' - << _lastPosDutyCycleIteration << ' '; - binary_save(outStream, _synapses); - outStream << ' '; - } + //----------------------------------------------------------------------- + /** + * Adds synapses to this segment. + * + * Parameters: + * ========== + * - srcCells: a collection of source cell indices (the sources of the + * synapses). Source cell indices are unique on a segment, + * and are kept in increasing order. + * - initStrength: the initial strength to set for the new synapses + */ + void addSynapses(const std::set &srcCells, Real initStrength, + Real permConnected); - //---------------------------------------------------------------------- - inline void load(std::istream& inStream) - { - UInt n = 0; - inStream >> n - >> _seqSegFlag - >> _frequency - >> _nConnected - >> _totalActivations - >> _positiveActivations - >> _lastActiveIteration - >> _lastPosDutyCycle - >> _lastPosDutyCycleIteration; - _synapses.resize(n); - inStream.ignore(1); - binary_load(inStream, _synapses); - NTA_ASSERT(invariants()); - } + //----------------------------------------------------------------------- + /** + * Recompute _nConnected for this segment + * + * Parameters: + * ========== + * - permConnected: permanence values >= permConnected are considered + * connected. + * + */ + void recomputeConnected(Real permConnected) { + _nConnected = 0; + for (UInt i = 0; i != _synapses.size(); ++i) + if (_synapses[i].permanence() >= permConnected) + ++_nConnected; + } + +private: + //----------------------------------------------------------------------- + /** + A private method invoked by this Segment when synapses need to be + removed. del contains the indices of the synapses to remove, as + indices of synapses on this segment (not source cell indices). This + method maintains the order of the synapses in the segment (they are + sorted in order of increasing source cell index). + */ + inline void _removeSynapses(const std::vector &del) { + // TODO: check what happens if synapses doesn't exist anymore + // because of decay + UInt i = 0, idel = 0, j = 0; + + while (i < _synapses.size() && idel < del.size()) { + if (i == del[idel]) { + ++i; + ++idel; + } else if (i < del[idel]) { + _synapses[j++] = _synapses[i++]; + } else if (del[idel] < i) { + NTA_CHECK(false); // Synapses have to be sorted! + } + } + + while (i < _synapses.size()) + _synapses[j++] = _synapses[i++]; + + _synapses.resize(j); + } + +public: + //----------------------------------------------------------------------- + /** + * Updates synapses permanences, possibly removing synapses from the segment + * if their permanence drops to or below 0. + * + * Parameters: + * ========== + * - synapses: a collection of source cell indices to update (will be + * matched with source cell index of each synapse) + * - delta: the amount to add to the permanence value of the updated + * synapses + * - removed: collection of synapses that have been removed because + * their permanence dropped below 0 (srcCellIdx of synapses). + * + * TODO: have synapses be 2 pointers, to avoid copies in adaptSegments + */ + template // this blocks swig wrapping which doesn't happen right + inline void updateSynapses(const std::vector &synapses, Real delta, + Real permMax, Real permConnected, + std::vector &removed) { + { + NTA_ASSERT(invariants()); + NTA_ASSERT(is_sorted(synapses)); + } + + std::vector del; + + UInt i1 = 0, i2 = 0; + + while (i1 < size() && i2 < synapses.size()) { + + if (_synapses[i1].srcCellIdx() == synapses[i2]) { + + Real oldPerm = getPermanence(i1); + Real newPerm = std::min(oldPerm + delta, permMax); + + if (newPerm <= 0) { + removed.push_back(_synapses[i1].srcCellIdx()); + del.push_back(i1); + } + + setPermanence(i1, newPerm); + + int wasConnected = (int)(oldPerm >= permConnected); + int isConnected = (int)(newPerm >= permConnected); + + _nConnected += isConnected - wasConnected; + + ++i1; + ++i2; + + } else if (_synapses[i1].srcCellIdx() < synapses[i2]) { + ++i1; + } else { + ++i2; + } + } + + // _removeSynapses maintains the order of the synapses + _removeSynapses(del); + + NTA_ASSERT(invariants()); + } + + //---------------------------------------------------------------------- + /** + * Subtract decay from each synapses' permanence value. + * Synapses whose permanence drops below 0 are removed and their source + * indices are inserted into the "removed" list. + * + * Parameters: + * ========== + * - decay: the amount to subtract from the permanence value the + * synapses + * - removed: srcCellIdx of the synapses that are removed + */ + void decaySynapses2(Real decay, std::vector &removed, + Real permConnected); + + //---------------------------------------------------------------------- + /** + * Decay synapses' permanence value. Synapses whose permanence drops + * below 0 are removed. + * + * Parameters: + * ========== + * - decay: the amount to subtract from the permanence value the + * synapses + * - removed: srcCellIdx of the synapses that are removed + */ + void decaySynapses(Real decay, std::vector &removed, Real permConnected, + bool doDecay = true); + + //---------------------------------------------------------------------- + /** + * Free up some synapses in this segment. We always free up inactive + * synapses (lowest permanence freed up first) before we start to free + * up active ones. + * + * Parameters: + * ========== + * numToFree: num synapses we have to remove + * inactiveSynapseIndices: list of inactive synapses (src cell indices) + * inactiveSegmentIndices: list of inactive synapses (index within segment) + * activeSynapseIndices: list of active synapses (src cell indices) + * activeSegmentIndices: list of active synapses (index within segment) + * removed: srcCellIdx of the synapses that are + * removed + * verbosity: verbosity level for debug printing + * nCellsPerCol: number of cells per column (for debug + * printing) + * permMax: maximum allowed permanence value + */ + void freeNSynapses(UInt numToFree, std::vector &inactiveSynapseIndices, + std::vector &inactiveSegmentIndices, + std::vector &activeSynapseIndices, + std::vector &activeSegmentIndices, + std::vector &removed, UInt verbosity, + UInt nCellsPerCol, Real permMax); + + //---------------------------------------------------------------------- + /** + * Computes the activity level for a segment given permConnected and + * activationThreshold. A segment is active if it has more than + * activationThreshold connected synapses that are active due to + * activeState. + * + * Parameters: + * ========== + * - activities: pointer to activeStateT or activeStateT1 + * + * NOTE: called getSegmentActivityLevel in Python + */ + bool isActive(const CState &activities, Real permConnected, + UInt activationThreshold) const; + + //---------------------------------------------------------------------- + /** + * Compute the activity level of a segment using cell activity levels + * contain in activities. + * + * Parameters: + * ========== + * - activities: pointer to an array of cell activities. + * + * - permConnected: permanence values >= permConnected are considered + * connected. + * + * - connectedSynapsesOnly: if true, only consider synapses that are + * connected. + */ + UInt computeActivity(const CState &activities, Real permConnected, + bool connectedSynapsesOnly) const; + + //---------------------------------------------------------------------- + /** + * Compute/update and return the positive activations duty cycle of + * this segment. This is a measure of how often this segment is + * providing good predictions. + * + * Parameters: + * ========== + * iteration: Current compute iteration. Must be > 0! + * active: True if segment just provided a good prediction + * + */ + Real dutyCycle(UInt iteration, bool active, bool readOnly); + + //---------------------------------------------------------------------- + /** + * Returns true if iteration is equal to one of the duty cycle tiers. + */ + static bool atDutyCycleTier(UInt iteration) { + for (auto &_dutyCycleTier : _dutyCycleTiers) { + if (iteration == _dutyCycleTier) + return true; + } + return false; + } + + //---------------------------------------------------------------------- + // PERSISTENCE + //---------------------------------------------------------------------- + inline UInt persistentSize() const { + std::stringstream buff; + this->save(buff); + return buff.str().size(); + } + + //---------------------------------------------------------------------- + using Serializable::write; + void write(SegmentProto::Builder &proto) const override { + NTA_ASSERT(invariants()); + proto.setSeqSegFlag(_seqSegFlag); + proto.setFrequency(_frequency); + proto.setNConnected(_nConnected); + proto.setTotalActivations(_totalActivations); + proto.setPositiveActivations(_positiveActivations); + proto.setLastActiveIteration(_lastActiveIteration); + proto.setLastPosDutyCycle(_lastPosDutyCycle); + proto.setLastPosDutyCycleIteration(_lastPosDutyCycleIteration); + auto synapsesProto = proto.initSynapses(size()); + for (UInt i = 0; i < size(); ++i) { + auto inSynapseProto = synapsesProto[i]; + inSynapseProto.setSrcCellIdx(_synapses[i].srcCellIdx()); + inSynapseProto.setPermanence(_synapses[i].permanence()); + } + } + + //---------------------------------------------------------------------- + using Serializable::read; + void read(SegmentProto::Reader &proto) override { + _seqSegFlag = proto.getSeqSegFlag(); + _frequency = proto.getFrequency(); + _nConnected = proto.getNConnected(); + _totalActivations = proto.getTotalActivations(); + _positiveActivations = proto.getPositiveActivations(); + _lastActiveIteration = proto.getLastActiveIteration(); + _lastPosDutyCycle = proto.getLastPosDutyCycle(); + _lastPosDutyCycleIteration = proto.getLastPosDutyCycleIteration(); + _synapses.clear(); + for (auto inSynapseProto : proto.getSynapses()) { + _synapses.emplace_back(inSynapseProto.getSrcCellIdx(), + inSynapseProto.getPermanence()); + } + } + + //---------------------------------------------------------------------- + inline void save(std::ostream &outStream) const { + NTA_ASSERT(invariants()); + outStream << size() << ' ' << _seqSegFlag << ' ' << _frequency << ' ' + << _nConnected << ' ' << _totalActivations << ' ' + << _positiveActivations << ' ' << _lastActiveIteration << ' ' + << _lastPosDutyCycle << ' ' << _lastPosDutyCycleIteration << ' '; + binary_save(outStream, _synapses); + outStream << ' '; + } + + //---------------------------------------------------------------------- + inline void load(std::istream &inStream) { + UInt n = 0; + inStream >> n >> _seqSegFlag >> _frequency >> _nConnected >> + _totalActivations >> _positiveActivations >> _lastActiveIteration >> + _lastPosDutyCycle >> _lastPosDutyCycleIteration; + _synapses.resize(n); + inStream.ignore(1); + binary_load(inStream, _synapses); + NTA_ASSERT(invariants()); + } - //----------------------------------------------------------------------- - /** - * Print the segment in a human readable form. If nCellsPerCol is specified - * then the source col/cell for each synapse will be printed instead of - * cell index. - */ - void print(std::ostream& outStream, UInt nCellsPerCol = 0) const; - }; + //----------------------------------------------------------------------- + /** + * Print the segment in a human readable form. If nCellsPerCol is specified + * then the source col/cell for each synapse will be printed instead of + * cell index. + */ + void print(std::ostream &outStream, UInt nCellsPerCol = 0) const; +}; - //----------------------------------------------------------------------- +//----------------------------------------------------------------------- #ifndef SWIG - std::ostream& operator<<(std::ostream& outStream, const Segment& seg); - std::ostream& operator<<(std::ostream& outStream, const CState& cstate); - std::ostream& operator<<(std::ostream& outStream, - const CStateIndexed& cstate); +std::ostream &operator<<(std::ostream &outStream, const Segment &seg); +std::ostream &operator<<(std::ostream &outStream, const CState &cstate); +std::ostream &operator<<(std::ostream &outStream, const CStateIndexed &cstate); #endif - //----------------------------------------------------------------------- - } // end namespace Cells4 - } // end namespace algorithms +//----------------------------------------------------------------------- +} // end namespace Cells4 +} // end namespace algorithms } // end namespace nupic - //----------------------------------------------------------------------- +//----------------------------------------------------------------------- #endif // NTA_SEGMENT_HPP diff --git a/src/nupic/algorithms/SegmentUpdate.cpp b/src/nupic/algorithms/SegmentUpdate.cpp index d53d8ed3cd..7fd8e69bf3 100644 --- a/src/nupic/algorithms/SegmentUpdate.cpp +++ b/src/nupic/algorithms/SegmentUpdate.cpp @@ -20,44 +20,29 @@ * --------------------------------------------------------------------- */ -#include -#include #include +#include +#include using namespace nupic::algorithms::Cells4; - SegmentUpdate::SegmentUpdate() - : _sequenceSegment(false), - _cellIdx((UInt) -1), - _segIdx((UInt) -1), - _timeStamp((UInt) -1), - _synapses(), - _phase1Flag(false), - _weaklyPredicting(false) -{} - -SegmentUpdate::SegmentUpdate(UInt cellIdx, UInt segIdx, - bool sequenceSegment, UInt timeStamp, - std::vector synapses, - bool phase1Flag, - bool weaklyPredicting, - Cells4* cells) - : _sequenceSegment(sequenceSegment), - _cellIdx(cellIdx), - _segIdx(segIdx), - _timeStamp(timeStamp), - _synapses(std::move(synapses)), - _phase1Flag(phase1Flag), - _weaklyPredicting(weaklyPredicting) -{ + : _sequenceSegment(false), _cellIdx((UInt)-1), _segIdx((UInt)-1), + _timeStamp((UInt)-1), _synapses(), _phase1Flag(false), + _weaklyPredicting(false) {} + +SegmentUpdate::SegmentUpdate(UInt cellIdx, UInt segIdx, bool sequenceSegment, + UInt timeStamp, std::vector synapses, + bool phase1Flag, bool weaklyPredicting, + Cells4 *cells) + : _sequenceSegment(sequenceSegment), _cellIdx(cellIdx), _segIdx(segIdx), + _timeStamp(timeStamp), _synapses(std::move(synapses)), + _phase1Flag(phase1Flag), _weaklyPredicting(weaklyPredicting) { NTA_ASSERT(invariants(cells)); } - //-------------------------------------------------------------------------------- -SegmentUpdate::SegmentUpdate(const SegmentUpdate& o) -{ +SegmentUpdate::SegmentUpdate(const SegmentUpdate &o) { _cellIdx = o._cellIdx; _segIdx = o._segIdx; _sequenceSegment = o._sequenceSegment; @@ -68,18 +53,13 @@ SegmentUpdate::SegmentUpdate(const SegmentUpdate& o) NTA_ASSERT(invariants()); } - - - - -bool SegmentUpdate::invariants(Cells4* cells) const -{ +bool SegmentUpdate::invariants(Cells4 *cells) const { bool ok = true; if (cells) { ok &= _cellIdx < cells->nCells(); - if (_segIdx != (UInt) -1) + if (_segIdx != (UInt)-1) ok &= _segIdx < cells->__nSegmentsOnCell(_cellIdx); if (!_synapses.empty()) { diff --git a/src/nupic/algorithms/SegmentUpdate.hpp b/src/nupic/algorithms/SegmentUpdate.hpp index a3711a3eb4..f08f2160e5 100644 --- a/src/nupic/algorithms/SegmentUpdate.hpp +++ b/src/nupic/algorithms/SegmentUpdate.hpp @@ -29,199 +29,177 @@ #include using namespace nupic; - namespace nupic { - namespace algorithms { - namespace Cells4 { - - class Cells4; - - //------------------------------------------------------------------------ - //------------------------------------------------------------------------ - /** - * SegmentUpdate stores information to update segments by creating, removing - * or updating synapses. - * SegmentUpdates are applied to the segment they target on a different iteration - * than the iteration they were created in. SegmentUpdates have a timeStamp, - * and they are discarded without being applied if they become 'stale'. - */ - class SegmentUpdate : Serializable - { - public: - typedef std::vector::const_iterator const_iterator; - - private: - bool _sequenceSegment; // used when creating a new segment - UInt _cellIdx; // the index of the target cell - UInt _segIdx; // (UInt) -1 if creating new segment - UInt _timeStamp; // controls obsolescence of update - std::vector _synapses; // contains source cell indices - bool _phase1Flag; // If true, this update was created - // during Phase 1 of compute - bool _weaklyPredicting; // Set true if segment only reaches - // activationThreshold when including - // unconnected synapses. - - - public: - SegmentUpdate(); - - //---------------------------------------------------------------------- - SegmentUpdate(UInt cellIdx, UInt segIdx, - bool sequenceSegment, UInt timeStamp, - std::vector synapses = std::vector(), - bool phase1Flag = false, - bool weaklyPredicting = false, - Cells4* cells =nullptr); - - //---------------------------------------------------------------------- - SegmentUpdate(const SegmentUpdate& o); - - //---------------------------------------------------------------------- - SegmentUpdate& operator=(const SegmentUpdate& o) - { - _cellIdx = o._cellIdx; - _segIdx = o._segIdx; - _sequenceSegment = o._sequenceSegment; - _synapses = o._synapses; - _timeStamp = o._timeStamp; - _phase1Flag = o._phase1Flag; - _weaklyPredicting = o._weaklyPredicting; - NTA_ASSERT(invariants()); - return *this; - } - - //--------------------------------------------------------------------- - bool isSequenceSegment() const { return _sequenceSegment; } - UInt cellIdx() const { return _cellIdx; } - UInt segIdx() const { return _segIdx; } - UInt timeStamp() const { return _timeStamp; } - UInt operator[](UInt idx) const { return _synapses[idx]; } - const_iterator begin() const { return _synapses.begin(); } - const_iterator end() const { return _synapses.end(); } - UInt size() const { return _synapses.size(); } - bool empty() const { return _synapses.empty(); } - bool isNewSegment() const { return _segIdx == (UInt) -1; } - bool isPhase1Segment() const {return _phase1Flag;} - bool isWeaklyPredicting() const {return _weaklyPredicting;} - - //--------------------------------------------------------------------- - /** - * Checks that all indices are in range, and that the synapse src cell indices - * are unique and sorted. - */ - bool invariants(Cells4* cells =nullptr) const; - - //--------------------------------------------------------------------- - using Serializable::write; - void write(SegmentUpdateProto::Builder& proto) const override - { - proto.setSequenceSegment(_sequenceSegment); - proto.setCellIdx(_cellIdx); - proto.setSegIdx(_segIdx); - proto.setTimestamp(_timeStamp); - auto synapsesProto = proto.initSynapses(_synapses.size()); - for (UInt i = 0; i < _synapses.size(); ++i) - { - synapsesProto.set(i, _synapses[i]); - } - proto.setPhase1Flag(_phase1Flag); - proto.setWeaklyPredicting(_weaklyPredicting); - } - - //--------------------------------------------------------------------- - using Serializable::read; - void read(SegmentUpdateProto::Reader& proto) override - { - _sequenceSegment = proto.getSequenceSegment(); - _cellIdx = proto.getCellIdx(); - _segIdx = proto.getSegIdx(); - _timeStamp = proto.getTimestamp(); - auto synapsesProto = proto.getSynapses(); - _synapses.resize(synapsesProto.size()); - for (UInt i = 0; i < synapsesProto.size(); ++i) - { - _synapses[i] = synapsesProto[i]; - } - _phase1Flag = proto.getPhase1Flag(); - _weaklyPredicting = proto.getWeaklyPredicting(); - } - - //--------------------------------------------------------------------- - void save(std::ostream& outStream) const - { - outStream << _cellIdx << " " - << _segIdx << " " - << _phase1Flag << " " - << _sequenceSegment << " " - << _weaklyPredicting << " " - << _timeStamp << std::endl; - outStream << _synapses.size() << " "; - for (auto & elem : _synapses) - { - outStream << elem << " "; - } - } - - //--------------------------------------------------------------------- - void load(std::istream& inStream) - { - inStream >> _cellIdx - >> _segIdx - >> _phase1Flag - >> _sequenceSegment - >> _weaklyPredicting - >> _timeStamp; - UInt n, syn; - inStream >> n; - for (UInt i = 0; i < n; ++i) - { - inStream >> syn; - _synapses.push_back(syn); - } - } - - //--------------------------------------------------------------------- - void print(std::ostream& outStream, bool longFormat =false, - UInt nCellsPerCol = 0) const - { - if (!longFormat) { - - outStream << 'c' << _cellIdx << " s" << _segIdx - << (_phase1Flag ? " p1 " : " p2 ") - << (_sequenceSegment ? " ss" : " ") - << (_weaklyPredicting ? " wp" : " sp") - << " t" << _timeStamp << '/'; - - } else { - NTA_CHECK(nCellsPerCol > 0); - UInt col = (UInt) (_cellIdx / nCellsPerCol); - UInt cell = _cellIdx - col*nCellsPerCol; - outStream << "cell: " << "[" << col << "," << cell << "] "; - outStream << " seg: " << _segIdx - << (_sequenceSegment ? " seqSeg " : " ") - << "timeStamp: " << _timeStamp << " / src cells: "; - } - - // Print out list of source cell indices - for (UInt i = 0; i != _synapses.size(); ++i) - outStream << _synapses[i] << ' '; - } - }; - - //-------------------------------------------------------------------------------- -#ifndef SWIG - inline std::ostream& - operator<<(std::ostream& outStream, const SegmentUpdate& update) - { - update.print(outStream); - return outStream; - } -#endif +namespace algorithms { +namespace Cells4 { + +class Cells4; + +//------------------------------------------------------------------------ +//------------------------------------------------------------------------ +/** + * SegmentUpdate stores information to update segments by creating, removing + * or updating synapses. + * SegmentUpdates are applied to the segment they target on a different + * iteration than the iteration they were created in. SegmentUpdates have a + * timeStamp, and they are discarded without being applied if they become + * 'stale'. + */ +class SegmentUpdate : Serializable { +public: + typedef std::vector::const_iterator const_iterator; + +private: + bool _sequenceSegment; // used when creating a new segment + UInt _cellIdx; // the index of the target cell + UInt _segIdx; // (UInt) -1 if creating new segment + UInt _timeStamp; // controls obsolescence of update + std::vector _synapses; // contains source cell indices + bool _phase1Flag; // If true, this update was created + // during Phase 1 of compute + bool _weaklyPredicting; // Set true if segment only reaches + // activationThreshold when including + // unconnected synapses. + +public: + SegmentUpdate(); + + //---------------------------------------------------------------------- + SegmentUpdate(UInt cellIdx, UInt segIdx, bool sequenceSegment, UInt timeStamp, + std::vector synapses = std::vector(), + bool phase1Flag = false, bool weaklyPredicting = false, + Cells4 *cells = nullptr); + + //---------------------------------------------------------------------- + SegmentUpdate(const SegmentUpdate &o); + + //---------------------------------------------------------------------- + SegmentUpdate &operator=(const SegmentUpdate &o) { + _cellIdx = o._cellIdx; + _segIdx = o._segIdx; + _sequenceSegment = o._sequenceSegment; + _synapses = o._synapses; + _timeStamp = o._timeStamp; + _phase1Flag = o._phase1Flag; + _weaklyPredicting = o._weaklyPredicting; + NTA_ASSERT(invariants()); + return *this; + } + + //--------------------------------------------------------------------- + bool isSequenceSegment() const { return _sequenceSegment; } + UInt cellIdx() const { return _cellIdx; } + UInt segIdx() const { return _segIdx; } + UInt timeStamp() const { return _timeStamp; } + UInt operator[](UInt idx) const { return _synapses[idx]; } + const_iterator begin() const { return _synapses.begin(); } + const_iterator end() const { return _synapses.end(); } + UInt size() const { return _synapses.size(); } + bool empty() const { return _synapses.empty(); } + bool isNewSegment() const { return _segIdx == (UInt)-1; } + bool isPhase1Segment() const { return _phase1Flag; } + bool isWeaklyPredicting() const { return _weaklyPredicting; } + + //--------------------------------------------------------------------- + /** + * Checks that all indices are in range, and that the synapse src cell indices + * are unique and sorted. + */ + bool invariants(Cells4 *cells = nullptr) const; + + //--------------------------------------------------------------------- + using Serializable::write; + void write(SegmentUpdateProto::Builder &proto) const override { + proto.setSequenceSegment(_sequenceSegment); + proto.setCellIdx(_cellIdx); + proto.setSegIdx(_segIdx); + proto.setTimestamp(_timeStamp); + auto synapsesProto = proto.initSynapses(_synapses.size()); + for (UInt i = 0; i < _synapses.size(); ++i) { + synapsesProto.set(i, _synapses[i]); + } + proto.setPhase1Flag(_phase1Flag); + proto.setWeaklyPredicting(_weaklyPredicting); + } + + //--------------------------------------------------------------------- + using Serializable::read; + void read(SegmentUpdateProto::Reader &proto) override { + _sequenceSegment = proto.getSequenceSegment(); + _cellIdx = proto.getCellIdx(); + _segIdx = proto.getSegIdx(); + _timeStamp = proto.getTimestamp(); + auto synapsesProto = proto.getSynapses(); + _synapses.resize(synapsesProto.size()); + for (UInt i = 0; i < synapsesProto.size(); ++i) { + _synapses[i] = synapsesProto[i]; + } + _phase1Flag = proto.getPhase1Flag(); + _weaklyPredicting = proto.getWeaklyPredicting(); + } - // End namespace + //--------------------------------------------------------------------- + void save(std::ostream &outStream) const { + outStream << _cellIdx << " " << _segIdx << " " << _phase1Flag << " " + << _sequenceSegment << " " << _weaklyPredicting << " " + << _timeStamp << std::endl; + outStream << _synapses.size() << " "; + for (auto &elem : _synapses) { + outStream << elem << " "; } } + + //--------------------------------------------------------------------- + void load(std::istream &inStream) { + inStream >> _cellIdx >> _segIdx >> _phase1Flag >> _sequenceSegment >> + _weaklyPredicting >> _timeStamp; + UInt n, syn; + inStream >> n; + for (UInt i = 0; i < n; ++i) { + inStream >> syn; + _synapses.push_back(syn); + } + } + + //--------------------------------------------------------------------- + void print(std::ostream &outStream, bool longFormat = false, + UInt nCellsPerCol = 0) const { + if (!longFormat) { + + outStream << 'c' << _cellIdx << " s" << _segIdx + << (_phase1Flag ? " p1 " : " p2 ") + << (_sequenceSegment ? " ss" : " ") + << (_weaklyPredicting ? " wp" : " sp") << " t" << _timeStamp + << '/'; + + } else { + NTA_CHECK(nCellsPerCol > 0); + UInt col = (UInt)(_cellIdx / nCellsPerCol); + UInt cell = _cellIdx - col * nCellsPerCol; + outStream << "cell: " + << "[" << col << "," << cell << "] "; + outStream << " seg: " << _segIdx << (_sequenceSegment ? " seqSeg " : " ") + << "timeStamp: " << _timeStamp << " / src cells: "; + } + + // Print out list of source cell indices + for (UInt i = 0; i != _synapses.size(); ++i) + outStream << _synapses[i] << ' '; + } +}; + +//-------------------------------------------------------------------------------- +#ifndef SWIG +inline std::ostream &operator<<(std::ostream &outStream, + const SegmentUpdate &update) { + update.print(outStream); + return outStream; } +#endif + +// End namespace +} // namespace Cells4 +} // namespace algorithms +} // namespace nupic #endif // NTA_SEGMENTUPDATE_HPP diff --git a/src/nupic/algorithms/SpatialPooler.cpp b/src/nupic/algorithms/SpatialPooler.cpp index 20fbc6f255..e6c7812dfc 100644 --- a/src/nupic/algorithms/SpatialPooler.cpp +++ b/src/nupic/algorithms/SpatialPooler.cpp @@ -43,482 +43,327 @@ static const Real PERMANENCE_EPSILON = 0.000001; // MSVC doesn't provide round() which only became standard in C99 or C++11 #if defined(NTA_COMPILER_MSVC) - template - T round(T num) - { - return (num > 0.0) ? floor(num + 0.5) : ceil(num - 0.5); - } +template T round(T num) { + return (num > 0.0) ? floor(num + 0.5) : ceil(num - 0.5); +} #endif - // Round f to 5 digits of precision. This is used to set // permanence values and help avoid small amounts of drift between // platforms/implementations -static Real round5_(const Real f) -{ - Real p = ((Real) ((Int) (f * 100000))) / 100000.0; +static Real round5_(const Real f) { + Real p = ((Real)((Int)(f * 100000))) / 100000.0; return p; } +class CoordinateConverter2D { -class CoordinateConverter2D -{ +public: + CoordinateConverter2D(UInt nrows, UInt ncols) + : // TODO param nrows is unused + ncols_(ncols) {} + UInt toRow(UInt index) { return index / ncols_; }; + UInt toCol(UInt index) { return index % ncols_; }; + UInt toIndex(UInt row, UInt col) { return row * ncols_ + col; }; - public: - CoordinateConverter2D(UInt nrows, UInt ncols) : //TODO param nrows is unused - ncols_(ncols) {} - UInt toRow(UInt index) { return index / ncols_; }; - UInt toCol(UInt index) { return index % ncols_; }; - UInt toIndex(UInt row, UInt col) { return row * ncols_ + col; }; - - private: - UInt ncols_; +private: + UInt ncols_; }; +class CoordinateConverterND { -class CoordinateConverterND -{ +public: + CoordinateConverterND(vector &dimensions) { + dimensions_ = dimensions; + UInt b = 1; + for (Int i = (Int)dimensions.size() - 1; i >= 0; i--) { + bounds_.insert(bounds_.begin(), b); + b *= dimensions[i]; + } + } - public: - CoordinateConverterND(vector& dimensions) - { - dimensions_ = dimensions; - UInt b = 1; - for (Int i = (Int) dimensions.size()-1; i >= 0; i--) - { - bounds_.insert(bounds_.begin(), b); - b *= dimensions[i]; - } + void toCoord(UInt index, vector &coord) { + coord.clear(); + for (UInt i = 0; i < bounds_.size(); i++) { + coord.push_back((index / bounds_[i]) % dimensions_[i]); } + }; - void toCoord(UInt index, vector& coord) - { - coord.clear(); - for (UInt i = 0; i < bounds_.size(); i++) - { - coord.push_back((index / bounds_[i]) % dimensions_[i]); - } - }; - - UInt toIndex(vector& coord) - { - UInt index = 0; - for (UInt i = 0; i < coord.size(); i++) - { - index += coord[i] * bounds_[i]; - } - return index; - }; + UInt toIndex(vector &coord) { + UInt index = 0; + for (UInt i = 0; i < coord.size(); i++) { + index += coord[i] * bounds_[i]; + } + return index; + }; - private: - vector dimensions_; - vector bounds_; +private: + vector dimensions_; + vector bounds_; }; -SpatialPooler::SpatialPooler() -{ +SpatialPooler::SpatialPooler() { // The current version number. version_ = 2; } -SpatialPooler::SpatialPooler(vector inputDimensions, - vector columnDimensions, - UInt potentialRadius, - Real potentialPct, - bool globalInhibition, - Real localAreaDensity, - UInt numActiveColumnsPerInhArea, - UInt stimulusThreshold, - Real synPermInactiveDec, - Real synPermActiveInc, - Real synPermConnected, - Real minPctOverlapDutyCycles, - UInt dutyCyclePeriod, - Real boostStrength, - Int seed, - UInt spVerbosity, - bool wrapAround) : SpatialPooler::SpatialPooler() -{ - initialize(inputDimensions, - columnDimensions, - potentialRadius, - potentialPct, - globalInhibition, - localAreaDensity, - numActiveColumnsPerInhArea, - stimulusThreshold, - synPermInactiveDec, - synPermActiveInc, - synPermConnected, - minPctOverlapDutyCycles, - dutyCyclePeriod, - boostStrength, - seed, - spVerbosity, - wrapAround); -} - - -vector SpatialPooler::getColumnDimensions() const -{ +SpatialPooler::SpatialPooler( + vector inputDimensions, vector columnDimensions, + UInt potentialRadius, Real potentialPct, bool globalInhibition, + Real localAreaDensity, UInt numActiveColumnsPerInhArea, + UInt stimulusThreshold, Real synPermInactiveDec, Real synPermActiveInc, + Real synPermConnected, Real minPctOverlapDutyCycles, UInt dutyCyclePeriod, + Real boostStrength, Int seed, UInt spVerbosity, bool wrapAround) + : SpatialPooler::SpatialPooler() { + initialize(inputDimensions, columnDimensions, potentialRadius, potentialPct, + globalInhibition, localAreaDensity, numActiveColumnsPerInhArea, + stimulusThreshold, synPermInactiveDec, synPermActiveInc, + synPermConnected, minPctOverlapDutyCycles, dutyCyclePeriod, + boostStrength, seed, spVerbosity, wrapAround); +} + +vector SpatialPooler::getColumnDimensions() const { return columnDimensions_; } -vector SpatialPooler::getInputDimensions() const -{ +vector SpatialPooler::getInputDimensions() const { return inputDimensions_; } -UInt SpatialPooler::getNumColumns() const -{ - return numColumns_; -} +UInt SpatialPooler::getNumColumns() const { return numColumns_; } -UInt SpatialPooler::getNumInputs() const -{ - return numInputs_; -} +UInt SpatialPooler::getNumInputs() const { return numInputs_; } -UInt SpatialPooler::getPotentialRadius() const -{ - return potentialRadius_; -} +UInt SpatialPooler::getPotentialRadius() const { return potentialRadius_; } -void SpatialPooler::setPotentialRadius(UInt potentialRadius) -{ +void SpatialPooler::setPotentialRadius(UInt potentialRadius) { potentialRadius_ = potentialRadius; } -Real SpatialPooler::getPotentialPct() const -{ - return potentialPct_; -} +Real SpatialPooler::getPotentialPct() const { return potentialPct_; } -void SpatialPooler::setPotentialPct(Real potentialPct) -{ +void SpatialPooler::setPotentialPct(Real potentialPct) { potentialPct_ = potentialPct; } -bool SpatialPooler::getGlobalInhibition() const -{ - return globalInhibition_; -} +bool SpatialPooler::getGlobalInhibition() const { return globalInhibition_; } -void SpatialPooler::setGlobalInhibition(bool globalInhibition) -{ +void SpatialPooler::setGlobalInhibition(bool globalInhibition) { globalInhibition_ = globalInhibition; } -Int SpatialPooler::getNumActiveColumnsPerInhArea() const -{ +Int SpatialPooler::getNumActiveColumnsPerInhArea() const { return numActiveColumnsPerInhArea_; } void SpatialPooler::setNumActiveColumnsPerInhArea( - UInt numActiveColumnsPerInhArea) -{ + UInt numActiveColumnsPerInhArea) { NTA_ASSERT(numActiveColumnsPerInhArea > 0); numActiveColumnsPerInhArea_ = numActiveColumnsPerInhArea; localAreaDensity_ = 0; } -Real SpatialPooler::getLocalAreaDensity() const -{ - return localAreaDensity_; -} +Real SpatialPooler::getLocalAreaDensity() const { return localAreaDensity_; } -void SpatialPooler::setLocalAreaDensity(Real localAreaDensity) -{ +void SpatialPooler::setLocalAreaDensity(Real localAreaDensity) { NTA_ASSERT(localAreaDensity > 0 && localAreaDensity <= 1); localAreaDensity_ = localAreaDensity; numActiveColumnsPerInhArea_ = 0; } -UInt SpatialPooler::getStimulusThreshold() const -{ - return stimulusThreshold_; -} +UInt SpatialPooler::getStimulusThreshold() const { return stimulusThreshold_; } -void SpatialPooler::setStimulusThreshold(UInt stimulusThreshold) -{ +void SpatialPooler::setStimulusThreshold(UInt stimulusThreshold) { stimulusThreshold_ = stimulusThreshold; } -UInt SpatialPooler::getInhibitionRadius() const -{ - return inhibitionRadius_; -} +UInt SpatialPooler::getInhibitionRadius() const { return inhibitionRadius_; } -void SpatialPooler::setInhibitionRadius(UInt inhibitionRadius) -{ +void SpatialPooler::setInhibitionRadius(UInt inhibitionRadius) { inhibitionRadius_ = inhibitionRadius; } -UInt SpatialPooler::getDutyCyclePeriod() const -{ - return dutyCyclePeriod_; -} +UInt SpatialPooler::getDutyCyclePeriod() const { return dutyCyclePeriod_; } -void SpatialPooler::setDutyCyclePeriod(UInt dutyCyclePeriod) -{ +void SpatialPooler::setDutyCyclePeriod(UInt dutyCyclePeriod) { dutyCyclePeriod_ = dutyCyclePeriod; } -Real SpatialPooler::getBoostStrength() const -{ - return boostStrength_; -} +Real SpatialPooler::getBoostStrength() const { return boostStrength_; } -void SpatialPooler::setBoostStrength(Real boostStrength) -{ +void SpatialPooler::setBoostStrength(Real boostStrength) { boostStrength_ = boostStrength; } -UInt SpatialPooler::getIterationNum() const -{ - return iterationNum_; -} +UInt SpatialPooler::getIterationNum() const { return iterationNum_; } -void SpatialPooler::setIterationNum(UInt iterationNum) -{ +void SpatialPooler::setIterationNum(UInt iterationNum) { iterationNum_ = iterationNum; } -UInt SpatialPooler::getIterationLearnNum() const -{ - return iterationLearnNum_; -} +UInt SpatialPooler::getIterationLearnNum() const { return iterationLearnNum_; } -void SpatialPooler::setIterationLearnNum(UInt iterationLearnNum) -{ +void SpatialPooler::setIterationLearnNum(UInt iterationLearnNum) { iterationLearnNum_ = iterationLearnNum; } -UInt SpatialPooler::getSpVerbosity() const -{ - return spVerbosity_; -} +UInt SpatialPooler::getSpVerbosity() const { return spVerbosity_; } -void SpatialPooler::setSpVerbosity(UInt spVerbosity) -{ +void SpatialPooler::setSpVerbosity(UInt spVerbosity) { spVerbosity_ = spVerbosity; } -bool SpatialPooler::getWrapAround() const -{ - return wrapAround_; -} +bool SpatialPooler::getWrapAround() const { return wrapAround_; } -void SpatialPooler::setWrapAround(bool wrapAround) -{ - wrapAround_ = wrapAround; -} +void SpatialPooler::setWrapAround(bool wrapAround) { wrapAround_ = wrapAround; } -UInt SpatialPooler::getUpdatePeriod() const -{ - return updatePeriod_; -} +UInt SpatialPooler::getUpdatePeriod() const { return updatePeriod_; } -void SpatialPooler::setUpdatePeriod(UInt updatePeriod) -{ +void SpatialPooler::setUpdatePeriod(UInt updatePeriod) { updatePeriod_ = updatePeriod; } -Real SpatialPooler::getSynPermTrimThreshold() const -{ +Real SpatialPooler::getSynPermTrimThreshold() const { return synPermTrimThreshold_; } -void SpatialPooler::setSynPermTrimThreshold(Real synPermTrimThreshold) -{ +void SpatialPooler::setSynPermTrimThreshold(Real synPermTrimThreshold) { synPermTrimThreshold_ = synPermTrimThreshold; } -Real SpatialPooler::getSynPermActiveInc() const -{ - return synPermActiveInc_; -} +Real SpatialPooler::getSynPermActiveInc() const { return synPermActiveInc_; } -void SpatialPooler::setSynPermActiveInc(Real synPermActiveInc) -{ +void SpatialPooler::setSynPermActiveInc(Real synPermActiveInc) { synPermActiveInc_ = synPermActiveInc; } -Real SpatialPooler::getSynPermInactiveDec() const -{ +Real SpatialPooler::getSynPermInactiveDec() const { return synPermInactiveDec_; } -void SpatialPooler::setSynPermInactiveDec(Real synPermInactiveDec) -{ +void SpatialPooler::setSynPermInactiveDec(Real synPermInactiveDec) { synPermInactiveDec_ = synPermInactiveDec; } -Real SpatialPooler::getSynPermBelowStimulusInc() const -{ +Real SpatialPooler::getSynPermBelowStimulusInc() const { return synPermBelowStimulusInc_; } -void SpatialPooler::setSynPermBelowStimulusInc(Real synPermBelowStimulusInc) -{ +void SpatialPooler::setSynPermBelowStimulusInc(Real synPermBelowStimulusInc) { synPermBelowStimulusInc_ = synPermBelowStimulusInc; } -Real SpatialPooler::getSynPermConnected() const -{ - return synPermConnected_; -} +Real SpatialPooler::getSynPermConnected() const { return synPermConnected_; } -void SpatialPooler::setSynPermConnected(Real synPermConnected) -{ +void SpatialPooler::setSynPermConnected(Real synPermConnected) { synPermConnected_ = synPermConnected; } -Real SpatialPooler::getSynPermMax() const -{ - return synPermMax_; -} +Real SpatialPooler::getSynPermMax() const { return synPermMax_; } -void SpatialPooler::setSynPermMax(Real synPermMax) -{ - synPermMax_ = synPermMax; -} +void SpatialPooler::setSynPermMax(Real synPermMax) { synPermMax_ = synPermMax; } -Real SpatialPooler::getMinPctOverlapDutyCycles() const -{ +Real SpatialPooler::getMinPctOverlapDutyCycles() const { return minPctOverlapDutyCycles_; } -void SpatialPooler::setMinPctOverlapDutyCycles(Real minPctOverlapDutyCycles) -{ +void SpatialPooler::setMinPctOverlapDutyCycles(Real minPctOverlapDutyCycles) { minPctOverlapDutyCycles_ = minPctOverlapDutyCycles; } -void SpatialPooler::getBoostFactors(Real boostFactors[]) const -{ +void SpatialPooler::getBoostFactors(Real boostFactors[]) const { copy(boostFactors_.begin(), boostFactors_.end(), boostFactors); } -void SpatialPooler::setBoostFactors(Real boostFactors[]) -{ +void SpatialPooler::setBoostFactors(Real boostFactors[]) { boostFactors_.assign(&boostFactors[0], &boostFactors[numColumns_]); } -void SpatialPooler::getOverlapDutyCycles(Real overlapDutyCycles[]) const -{ - copy(overlapDutyCycles_.begin(), overlapDutyCycles_.end(), - overlapDutyCycles); +void SpatialPooler::getOverlapDutyCycles(Real overlapDutyCycles[]) const { + copy(overlapDutyCycles_.begin(), overlapDutyCycles_.end(), overlapDutyCycles); } -void SpatialPooler::setOverlapDutyCycles(Real overlapDutyCycles[]) -{ +void SpatialPooler::setOverlapDutyCycles(Real overlapDutyCycles[]) { overlapDutyCycles_.assign(&overlapDutyCycles[0], &overlapDutyCycles[numColumns_]); } -void SpatialPooler::getActiveDutyCycles(Real activeDutyCycles[]) const -{ +void SpatialPooler::getActiveDutyCycles(Real activeDutyCycles[]) const { copy(activeDutyCycles_.begin(), activeDutyCycles_.end(), activeDutyCycles); } -void SpatialPooler::setActiveDutyCycles(Real activeDutyCycles[]) -{ +void SpatialPooler::setActiveDutyCycles(Real activeDutyCycles[]) { activeDutyCycles_.assign(&activeDutyCycles[0], &activeDutyCycles[numColumns_]); } -void SpatialPooler::getMinOverlapDutyCycles(Real minOverlapDutyCycles[]) const -{ +void SpatialPooler::getMinOverlapDutyCycles(Real minOverlapDutyCycles[]) const { copy(minOverlapDutyCycles_.begin(), minOverlapDutyCycles_.end(), minOverlapDutyCycles); } -void SpatialPooler::setMinOverlapDutyCycles(Real minOverlapDutyCycles[]) -{ +void SpatialPooler::setMinOverlapDutyCycles(Real minOverlapDutyCycles[]) { minOverlapDutyCycles_.assign(&minOverlapDutyCycles[0], &minOverlapDutyCycles[numColumns_]); } -void SpatialPooler::getPotential(UInt column, UInt potential[]) const -{ +void SpatialPooler::getPotential(UInt column, UInt potential[]) const { NTA_ASSERT(column < numColumns_); potentialPools_.getRow(column, &potential[0], &potential[numInputs_]); } -void SpatialPooler::setPotential(UInt column, UInt potential[]) -{ +void SpatialPooler::setPotential(UInt column, UInt potential[]) { NTA_ASSERT(column < numColumns_); potentialPools_.rowFromDense(column, &potential[0], &potential[numInputs_]); } -void SpatialPooler::getPermanence(UInt column, Real permanences[]) const -{ +void SpatialPooler::getPermanence(UInt column, Real permanences[]) const { NTA_ASSERT(column < numColumns_); permanences_.getRowToDense(column, permanences); } -void SpatialPooler::setPermanence(UInt column, Real permanences[]) -{ +void SpatialPooler::setPermanence(UInt column, Real permanences[]) { NTA_ASSERT(column < numColumns_); vector perm; - perm.assign(&permanences[0],&permanences[numInputs_]); + perm.assign(&permanences[0], &permanences[numInputs_]); updatePermanencesForColumn_(perm, column, false); } -void SpatialPooler::getConnectedSynapses( - UInt column, UInt connectedSynapses[]) const -{ +void SpatialPooler::getConnectedSynapses(UInt column, + UInt connectedSynapses[]) const { NTA_ASSERT(column < numColumns_); - connectedSynapses_.getRow(column,&connectedSynapses[0], + connectedSynapses_.getRow(column, &connectedSynapses[0], &connectedSynapses[numInputs_]); } -void SpatialPooler::getConnectedCounts(UInt connectedCounts[]) const -{ +void SpatialPooler::getConnectedCounts(UInt connectedCounts[]) const { copy(connectedCounts_.begin(), connectedCounts_.end(), connectedCounts); } -const vector& SpatialPooler::getOverlaps() const -{ - return overlaps_; -} +const vector &SpatialPooler::getOverlaps() const { return overlaps_; } -const vector& SpatialPooler::getBoostedOverlaps() const -{ +const vector &SpatialPooler::getBoostedOverlaps() const { return boostedOverlaps_; } -void SpatialPooler::initialize(vector inputDimensions, - vector columnDimensions, - UInt potentialRadius, - Real potentialPct, - bool globalInhibition, - Real localAreaDensity, - UInt numActiveColumnsPerInhArea, - UInt stimulusThreshold, - Real synPermInactiveDec, - Real synPermActiveInc, - Real synPermConnected, - Real minPctOverlapDutyCycles, - UInt dutyCyclePeriod, - Real boostStrength, - Int seed, - UInt spVerbosity, - bool wrapAround) -{ +void SpatialPooler::initialize( + vector inputDimensions, vector columnDimensions, + UInt potentialRadius, Real potentialPct, bool globalInhibition, + Real localAreaDensity, UInt numActiveColumnsPerInhArea, + UInt stimulusThreshold, Real synPermInactiveDec, Real synPermActiveInc, + Real synPermConnected, Real minPctOverlapDutyCycles, UInt dutyCyclePeriod, + Real boostStrength, Int seed, UInt spVerbosity, bool wrapAround) { numInputs_ = 1; inputDimensions_.clear(); - for (auto & inputDimension : inputDimensions) - { + for (auto &inputDimension : inputDimensions) { numInputs_ *= inputDimension; inputDimensions_.push_back(inputDimension); } numColumns_ = 1; columnDimensions_.clear(); - for (auto & columnDimension : columnDimensions) - { + for (auto &columnDimension : columnDimensions) { numColumns_ *= columnDimension; columnDimensions_.push_back(columnDimension); } @@ -527,13 +372,13 @@ void SpatialPooler::initialize(vector inputDimensions, NTA_ASSERT(numInputs_ > 0); NTA_ASSERT(inputDimensions_.size() == columnDimensions_.size()); NTA_ASSERT(numActiveColumnsPerInhArea > 0 || - (localAreaDensity > 0 && localAreaDensity <= 0.5)); + (localAreaDensity > 0 && localAreaDensity <= 0.5)); NTA_ASSERT(potentialPct > 0 && potentialPct <= 1); - seed_( (UInt64)(seed < 0 ? rand() : seed) ); + seed_((UInt64)(seed < 0 ? rand() : seed)); - potentialRadius_ = potentialRadius > numInputs_ ? numInputs_ : - potentialRadius; + potentialRadius_ = + potentialRadius > numInputs_ ? numInputs_ : potentialRadius; potentialPct_ = potentialPct; globalInhibition_ = globalInhibition; numActiveColumnsPerInhArea_ = numActiveColumnsPerInhArea; @@ -558,8 +403,7 @@ void SpatialPooler::initialize(vector inputDimensions, iterationLearnNum_ = 0; tieBreaker_.resize(numColumns_); - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { tieBreaker_[i] = 0.01 * rng_.getReal64(); } @@ -570,7 +414,7 @@ void SpatialPooler::initialize(vector inputDimensions, overlapDutyCycles_.assign(numColumns_, 0); activeDutyCycles_.assign(numColumns_, 0); - minOverlapDutyCycles_.assign(numColumns_, 0.0); + minOverlapDutyCycles_.assign(numColumns_, 0.0); boostFactors_.assign(numColumns_, 1); overlaps_.resize(numColumns_); overlapsPct_.resize(numColumns_); @@ -578,102 +422,80 @@ void SpatialPooler::initialize(vector inputDimensions, inhibitionRadius_ = 0; - for (UInt i = 0; i < numColumns_; ++i) - { + for (UInt i = 0; i < numColumns_; ++i) { vector potential = mapPotential_(i, wrapAround_); vector perm = initPermanence_(potential, initConnectedPct_); - potentialPools_.rowFromDense(i,potential.begin(),potential.end()); - updatePermanencesForColumn_(perm,i,true); + potentialPools_.rowFromDense(i, potential.begin(), potential.end()); + updatePermanencesForColumn_(perm, i, true); } updateInhibitionRadius_(); - if (spVerbosity_ > 0) - { + if (spVerbosity_ > 0) { printParameters(); std::cout << "CPP SP seed = " << seed << std::endl; } } -void SpatialPooler::compute(UInt inputArray[], bool learn, - UInt activeArray[]) -{ +void SpatialPooler::compute(UInt inputArray[], bool learn, UInt activeArray[]) { updateBookeepingVars_(learn); calculateOverlap_(inputArray, overlaps_); calculateOverlapPct_(overlaps_, overlapsPct_); - if (learn) - { + if (learn) { boostOverlaps_(overlaps_, boostedOverlaps_); - } - else - { + } else { boostedOverlaps_.assign(overlaps_.begin(), overlaps_.end()); } inhibitColumns_(boostedOverlaps_, activeColumns_); toDense_(activeColumns_, activeArray, numColumns_); - if (learn) - { + if (learn) { adaptSynapses_(inputArray, activeColumns_); updateDutyCycles_(overlaps_, activeArray); bumpUpWeakColumns_(); updateBoostFactors_(); - if (isUpdateRound_()) - { + if (isUpdateRound_()) { updateInhibitionRadius_(); updateMinDutyCycles_(); } } } -void SpatialPooler::stripUnlearnedColumns(UInt activeArray[]) const -{ - for (UInt i = 0; i < numColumns_; i++) - { - if (activeDutyCycles_[i] == 0) - { +void SpatialPooler::stripUnlearnedColumns(UInt activeArray[]) const { + for (UInt i = 0; i < numColumns_; i++) { + if (activeDutyCycles_[i] == 0) { activeArray[i] = 0; } } } - -void SpatialPooler::toDense_(vector& sparse, - UInt dense[], - UInt n) -{ - std::fill(dense,dense+n, 0); - for (auto & elem : sparse) - { +void SpatialPooler::toDense_(vector &sparse, UInt dense[], UInt n) { + std::fill(dense, dense + n, 0); + for (auto &elem : sparse) { UInt index = elem; dense[index] = 1; } } -void SpatialPooler::boostOverlaps_(vector& overlaps, - vector& boosted) -{ - for (UInt i = 0; i < numColumns_; i++) - { +void SpatialPooler::boostOverlaps_(vector &overlaps, + vector &boosted) { + for (UInt i = 0; i < numColumns_; i++) { boosted[i] = overlaps[i] * boostFactors_[i]; } } -UInt SpatialPooler::mapColumn_(UInt column) -{ +UInt SpatialPooler::mapColumn_(UInt column) { vector columnCoords; CoordinateConverterND columnConv(columnDimensions_); columnConv.toCoord(column, columnCoords); vector inputCoords; inputCoords.reserve(columnCoords.size()); - for (UInt i = 0; i < columnCoords.size(); i++) - { - const Real inputCoord = - ((Real)columnCoords[i] + 0.5) * - (inputDimensions_[i] / (Real)columnDimensions_[i]); + for (UInt i = 0; i < columnCoords.size(); i++) { + const Real inputCoord = ((Real)columnCoords[i] + 0.5) * + (inputDimensions_[i] / (Real)columnDimensions_[i]); inputCoords.push_back(floor(inputCoord)); } @@ -682,24 +504,18 @@ UInt SpatialPooler::mapColumn_(UInt column) return inputConv.toIndex(inputCoords); } -vector SpatialPooler::mapPotential_(UInt column, bool wrapAround) -{ +vector SpatialPooler::mapPotential_(UInt column, bool wrapAround) { const UInt centerInput = mapColumn_(column); vector columnInputs; - if (wrapAround) - { + if (wrapAround) { for (UInt input : WrappingNeighborhood(centerInput, potentialRadius_, - inputDimensions_)) - { + inputDimensions_)) { columnInputs.push_back(input); } - } - else - { - for (UInt input : Neighborhood(centerInput, potentialRadius_, - inputDimensions_)) - { + } else { + for (UInt input : + Neighborhood(centerInput, potentialRadius_, inputDimensions_)) { columnInputs.push_back(input); } } @@ -711,45 +527,36 @@ vector SpatialPooler::mapPotential_(UInt column, bool wrapAround) &selectedInputs.front(), numPotential); vector potential(numInputs_, 0); - for (UInt input : selectedInputs) - { + for (UInt input : selectedInputs) { potential[input] = 1; } return potential; } -Real SpatialPooler::initPermConnected_() -{ - Real p = synPermConnected_ + - (synPermMax_ - synPermConnected_)*rng_.getReal64(); +Real SpatialPooler::initPermConnected_() { + Real p = + synPermConnected_ + (synPermMax_ - synPermConnected_) * rng_.getReal64(); return round5_(p); } -Real SpatialPooler::initPermNonConnected_() -{ +Real SpatialPooler::initPermNonConnected_() { Real p = synPermConnected_ * rng_.getReal64(); return round5_(p); } -vector SpatialPooler::initPermanence_(vector& potential, - Real connectedPct) -{ +vector SpatialPooler::initPermanence_(vector &potential, + Real connectedPct) { vector perm(numInputs_, 0); - for (UInt i = 0; i < numInputs_; i++) - { - if (potential[i] < 1) - { + for (UInt i = 0; i < numInputs_; i++) { + if (potential[i] < 1) { continue; } - if (rng_.getReal64() <= connectedPct) - { + if (rng_.getReal64() <= connectedPct) { perm[i] = initPermConnected_(); - } - else - { + } else { perm[i] = initPermNonConnected_(); } perm[i] = perm[i] < synPermTrimThreshold_ ? 0 : perm[i]; @@ -758,37 +565,29 @@ vector SpatialPooler::initPermanence_(vector& potential, return perm; } -void SpatialPooler::clip_(vector& perm, bool trim=false) -{ +void SpatialPooler::clip_(vector &perm, bool trim = false) { Real minVal = trim ? synPermTrimThreshold_ : synPermMin_; - for (auto & elem : perm) - { + for (auto &elem : perm) { elem = elem > synPermMax_ ? synPermMax_ : elem; elem = elem < minVal ? synPermMin_ : elem; } } - -void SpatialPooler::updatePermanencesForColumn_(vector& perm, - UInt column, - bool raisePerm) -{ +void SpatialPooler::updatePermanencesForColumn_(vector &perm, UInt column, + bool raisePerm) { vector connectedSparse; UInt numConnected; - if (raisePerm) - { + if (raisePerm) { vector potential; potential.resize(numInputs_); potential = potentialPools_.getSparseRow(column); - raisePermanencesToThreshold_(perm,potential); + raisePermanencesToThreshold_(perm, potential); } numConnected = 0; - for (UInt i = 0; i < perm.size(); ++i) - { - if (perm[i] >= synPermConnected_ - PERMANENCE_EPSILON) - { + for (UInt i = 0; i < perm.size(); ++i) { + if (perm[i] >= synPermConnected_ - PERMANENCE_EPSILON) { connectedSparse.push_back(i); ++numConnected; } @@ -801,32 +600,26 @@ void SpatialPooler::updatePermanencesForColumn_(vector& perm, connectedCounts_[column] = numConnected; } -UInt SpatialPooler::countConnected_(vector& perm) -{ +UInt SpatialPooler::countConnected_(vector &perm) { UInt numConnected = 0; - for (auto & elem : perm) - { - if (elem >= synPermConnected_ - PERMANENCE_EPSILON) - { - ++numConnected; - } - } + for (auto &elem : perm) { + if (elem >= synPermConnected_ - PERMANENCE_EPSILON) { + ++numConnected; + } + } return numConnected; } -UInt SpatialPooler::raisePermanencesToThreshold_(vector& perm, - vector& potential) -{ +UInt SpatialPooler::raisePermanencesToThreshold_(vector &perm, + vector &potential) { clip_(perm, false); UInt numConnected; - while (true) - { + while (true) { numConnected = countConnected_(perm); if (numConnected >= stimulusThreshold_) break; - for (auto & elem : potential) - { + for (auto &elem : potential) { UInt index = elem; perm[index] += synPermBelowStimulusInc_; } @@ -834,71 +627,58 @@ UInt SpatialPooler::raisePermanencesToThreshold_(vector& perm, return numConnected; } -void SpatialPooler::updateInhibitionRadius_() -{ - if (globalInhibition_) - { - inhibitionRadius_ = *max_element(columnDimensions_.begin(), - columnDimensions_.end()); +void SpatialPooler::updateInhibitionRadius_() { + if (globalInhibition_) { + inhibitionRadius_ = + *max_element(columnDimensions_.begin(), columnDimensions_.end()); return; } Real connectedSpan = 0; - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { connectedSpan += avgConnectedSpanForColumnND_(i); } connectedSpan /= numColumns_; Real columnsPerInput = avgColumnsPerInput_(); Real diameter = connectedSpan * columnsPerInput; Real radius = (diameter - 1) / 2.0; - radius = max((Real) 1.0, radius); + radius = max((Real)1.0, radius); inhibitionRadius_ = UInt(round(radius)); } -void SpatialPooler::updateMinDutyCycles_() -{ - if (globalInhibition_ || inhibitionRadius_ > - *max_element(columnDimensions_.begin(), columnDimensions_.end())) - { +void SpatialPooler::updateMinDutyCycles_() { + if (globalInhibition_ || + inhibitionRadius_ > + *max_element(columnDimensions_.begin(), columnDimensions_.end())) { updateMinDutyCyclesGlobal_(); - } - else - { + } else { updateMinDutyCyclesLocal_(); } return; } -void SpatialPooler::updateMinDutyCyclesGlobal_() -{ - Real maxOverlapDutyCycles = *max_element(overlapDutyCycles_.begin(), - overlapDutyCycles_.end()); +void SpatialPooler::updateMinDutyCyclesGlobal_() { + Real maxOverlapDutyCycles = + *max_element(overlapDutyCycles_.begin(), overlapDutyCycles_.end()); fill(minOverlapDutyCycles_.begin(), minOverlapDutyCycles_.end(), minPctOverlapDutyCycles_ * maxOverlapDutyCycles); } -void SpatialPooler::updateMinDutyCyclesLocal_() -{ - for (UInt i = 0; i < numColumns_; i++) - { +void SpatialPooler::updateMinDutyCyclesLocal_() { + for (UInt i = 0; i < numColumns_; i++) { Real maxActiveDuty = 0; Real maxOverlapDuty = 0; - if (wrapAround_) - { - for (UInt column : WrappingNeighborhood(i, inhibitionRadius_, - columnDimensions_)) - { + if (wrapAround_) { + for (UInt column : + WrappingNeighborhood(i, inhibitionRadius_, columnDimensions_)) { maxActiveDuty = max(maxActiveDuty, activeDutyCycles_[column]); maxOverlapDuty = max(maxOverlapDuty, overlapDutyCycles_[column]); } - } - else - { - for (UInt column : Neighborhood(i, inhibitionRadius_, columnDimensions_)) - { + } else { + for (UInt column : + Neighborhood(i, inhibitionRadius_, columnDimensions_)) { maxActiveDuty = max(maxActiveDuty, activeDutyCycles_[column]); maxOverlapDuty = max(maxOverlapDuty, overlapDutyCycles_[column]); } @@ -908,31 +688,27 @@ void SpatialPooler::updateMinDutyCyclesLocal_() } } -void SpatialPooler::updateDutyCycles_(vector& overlaps, - UInt activeArray[]) -{ +void SpatialPooler::updateDutyCycles_(vector &overlaps, + UInt activeArray[]) { vector newOverlapVal(numColumns_, 0); vector newActiveVal(numColumns_, 0); - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { newOverlapVal[i] = overlaps[i] > 0 ? 1 : 0; newActiveVal[i] = activeArray[i] > 0 ? 1 : 0; } - UInt period = dutyCyclePeriod_ > iterationNum_ ? - iterationNum_ : dutyCyclePeriod_; + UInt period = + dutyCyclePeriod_ > iterationNum_ ? iterationNum_ : dutyCyclePeriod_; updateDutyCyclesHelper_(overlapDutyCycles_, newOverlapVal, period); updateDutyCyclesHelper_(activeDutyCycles_, newActiveVal, period); } -Real SpatialPooler::avgColumnsPerInput_() -{ +Real SpatialPooler::avgColumnsPerInput_() { UInt numDim = max(columnDimensions_.size(), inputDimensions_.size()); Real columnsPerInput = 0; - for (UInt i = 0; i < numDim; i++) - { + for (UInt i = 0; i < numDim; i++) { Real col = (i < columnDimensions_.size()) ? columnDimensions_[i] : 1; Real input = (i < inputDimensions_.size()) ? inputDimensions_[i] : 1; columnsPerInput += col / input; @@ -940,39 +716,34 @@ Real SpatialPooler::avgColumnsPerInput_() return columnsPerInput / numDim; } -Real SpatialPooler::avgConnectedSpanForColumn1D_(UInt column) -{ +Real SpatialPooler::avgConnectedSpanForColumn1D_(UInt column) { NTA_ASSERT(inputDimensions_.size() == 1); vector connectedSparse = connectedSynapses_.getSparseRow(column); if (connectedSparse.empty()) return 0; - auto minmax = minmax_element(connectedSparse.begin(), - connectedSparse.end()); + auto minmax = minmax_element(connectedSparse.begin(), connectedSparse.end()); return *minmax.second /*max*/ - *minmax.first /*min*/ + 1; } -Real SpatialPooler::avgConnectedSpanForColumn2D_(UInt column) -{ +Real SpatialPooler::avgConnectedSpanForColumn2D_(UInt column) { NTA_ASSERT(inputDimensions_.size() == 2); UInt nrows = inputDimensions_[0]; UInt ncols = inputDimensions_[1]; - CoordinateConverter2D conv(nrows,ncols); + CoordinateConverter2D conv(nrows, ncols); vector connectedSparse = connectedSynapses_.getSparseRow(column); vector rows, cols; - for (auto & elem : connectedSparse) - { + for (auto &elem : connectedSparse) { UInt index = elem; rows.push_back(conv.toRow(index)); cols.push_back(conv.toCol(index)); } - if (rows.empty() && cols.empty()) - { + if (rows.empty() && cols.empty()) { return 0; } @@ -980,14 +751,12 @@ Real SpatialPooler::avgConnectedSpanForColumn2D_(UInt column) UInt rowSpan = *minmaxRows.second /*max*/ - *minmaxRows.first /*min*/ + 1; auto minmaxCols = minmax_element(cols.begin(), cols.end()); - UInt colSpan = *minmaxCols.second - *minmaxCols.first + 1; + UInt colSpan = *minmaxCols.second - *minmaxCols.first + 1; return (rowSpan + colSpan) / 2.0; - } -Real SpatialPooler::avgConnectedSpanForColumnND_(UInt column) -{ +Real SpatialPooler::avgConnectedSpanForColumnND_(UInt column) { UInt numDimensions = inputDimensions_.size(); vector connectedSparse = connectedSynapses_.getSparseRow(column); vector maxCoord(numDimensions, 0); @@ -996,67 +765,54 @@ Real SpatialPooler::avgConnectedSpanForColumnND_(UInt column) CoordinateConverterND conv(inputDimensions_); - if (connectedSparse.empty() ) - { + if (connectedSparse.empty()) { return 0; } vector columnCoord; - for (auto & elem : connectedSparse) - { - conv.toCoord(elem,columnCoord); - for (UInt j = 0; j < columnCoord.size(); j++) - { + for (auto &elem : connectedSparse) { + conv.toCoord(elem, columnCoord); + for (UInt j = 0; j < columnCoord.size(); j++) { maxCoord[j] = max(maxCoord[j], columnCoord[j]); minCoord[j] = min(minCoord[j], columnCoord[j]); } } UInt totalSpan = 0; - for (UInt j = 0; j < inputDimensions_.size(); j++) - { + for (UInt j = 0; j < inputDimensions_.size(); j++) { totalSpan += maxCoord[j] - minCoord[j] + 1; } - return (Real) totalSpan / inputDimensions_.size(); - + return (Real)totalSpan / inputDimensions_.size(); } void SpatialPooler::adaptSynapses_(UInt inputVector[], - vector& activeColumns) -{ + vector &activeColumns) { vector permChanges(numInputs_, -1 * synPermInactiveDec_); - for (UInt i = 0; i < numInputs_; i++) - { - if (inputVector[i] > 0) - { + for (UInt i = 0; i < numInputs_; i++) { + if (inputVector[i] > 0) { permChanges[i] = synPermActiveInc_; } } - for (UInt i = 0; i < activeColumns.size(); i++) - { + for (UInt i = 0; i < activeColumns.size(); i++) { UInt column = activeColumns[i]; vector potential; - vector perm(numInputs_, 0); + vector perm(numInputs_, 0); potential.resize(potentialPools_.nNonZerosOnRow(i)); potential = potentialPools_.getSparseRow(column); permanences_.getRowToDense(column, perm); - for (auto & elem : potential) - { - UInt index = elem; - perm[index] += permChanges[index]; + for (auto &elem : potential) { + UInt index = elem; + perm[index] += permChanges[index]; } updatePermanencesForColumn_(perm, column, true); } } -void SpatialPooler::bumpUpWeakColumns_() -{ - for (UInt i = 0; i < numColumns_; i++) - { - if (overlapDutyCycles_[i] >= minOverlapDutyCycles_[i]) - { +void SpatialPooler::bumpUpWeakColumns_() { + for (UInt i = 0; i < numColumns_; i++) { + if (overlapDutyCycles_[i] >= minOverlapDutyCycles_[i]) { continue; } vector perm(numInputs_, 0); @@ -1064,8 +820,7 @@ void SpatialPooler::bumpUpWeakColumns_() potential.resize(potentialPools_.nNonZerosOnRow(i)); potential = potentialPools_.getSparseRow(i); permanences_.getRowToDense(i, perm); - for (auto & elem : potential) - { + for (auto &elem : potential) { UInt index = elem; perm[index] += synPermBelowStimulusInc_; } @@ -1073,115 +828,88 @@ void SpatialPooler::bumpUpWeakColumns_() } } -void SpatialPooler::updateDutyCyclesHelper_(vector& dutyCycles, - vector& newValues, - UInt period) -{ +void SpatialPooler::updateDutyCyclesHelper_(vector &dutyCycles, + vector &newValues, + UInt period) { NTA_ASSERT(period >= 1); NTA_ASSERT(dutyCycles.size() == newValues.size()); - for (UInt i = 0; i < dutyCycles.size(); i++) - { + for (UInt i = 0; i < dutyCycles.size(); i++) { dutyCycles[i] = (dutyCycles[i] * (period - 1) + newValues[i]) / period; } } -void SpatialPooler::updateBoostFactors_() -{ - if (globalInhibition_) - { +void SpatialPooler::updateBoostFactors_() { + if (globalInhibition_) { updateBoostFactorsGlobal_(); - } - else - { + } else { updateBoostFactorsLocal_(); } } -void SpatialPooler::updateBoostFactorsGlobal_() -{ +void SpatialPooler::updateBoostFactorsGlobal_() { Real targetDensity; - if (numActiveColumnsPerInhArea_ > 0) - { - UInt inhibitionArea = pow((Real) (2 * inhibitionRadius_ + 1), - (Real) columnDimensions_.size()); + if (numActiveColumnsPerInhArea_ > 0) { + UInt inhibitionArea = + pow((Real)(2 * inhibitionRadius_ + 1), (Real)columnDimensions_.size()); inhibitionArea = min(inhibitionArea, numColumns_); - targetDensity = ((Real) numActiveColumnsPerInhArea_) / inhibitionArea; - targetDensity = min(targetDensity, (Real) 0.5); - } - else - { + targetDensity = ((Real)numActiveColumnsPerInhArea_) / inhibitionArea; + targetDensity = min(targetDensity, (Real)0.5); + } else { targetDensity = localAreaDensity_; } - for (UInt i = 0; i < numColumns_; ++i) - { - boostFactors_[i] = exp((targetDensity - activeDutyCycles_[i]) - * boostStrength_); + for (UInt i = 0; i < numColumns_; ++i) { + boostFactors_[i] = + exp((targetDensity - activeDutyCycles_[i]) * boostStrength_); } } -void SpatialPooler::updateBoostFactorsLocal_() -{ - for (UInt i = 0; i < numColumns_; ++i) - { +void SpatialPooler::updateBoostFactorsLocal_() { + for (UInt i = 0; i < numColumns_; ++i) { UInt numNeighbors = 0; Real localActivityDensity = 0; - if (wrapAround_) - { - for (UInt neighbor : WrappingNeighborhood(i, inhibitionRadius_, - columnDimensions_)) - { + if (wrapAround_) { + for (UInt neighbor : + WrappingNeighborhood(i, inhibitionRadius_, columnDimensions_)) { localActivityDensity += activeDutyCycles_[neighbor]; numNeighbors += 1; } - } - else - { - for (UInt neighbor : Neighborhood(i, inhibitionRadius_, - columnDimensions_)) - { + } else { + for (UInt neighbor : + Neighborhood(i, inhibitionRadius_, columnDimensions_)) { localActivityDensity += activeDutyCycles_[neighbor]; numNeighbors += 1; } } Real targetDensity = localActivityDensity / numNeighbors; - boostFactors_[i] = exp((targetDensity - activeDutyCycles_[i]) - * boostStrength_); + boostFactors_[i] = + exp((targetDensity - activeDutyCycles_[i]) * boostStrength_); } - } -void SpatialPooler::updateBookeepingVars_(bool learn) -{ +void SpatialPooler::updateBookeepingVars_(bool learn) { iterationNum_++; - if (learn) - { + if (learn) { iterationLearnNum_++; } } void SpatialPooler::calculateOverlap_(UInt inputVector[], - vector& overlaps) -{ - overlaps.assign(numColumns_,0); - connectedSynapses_.rightVecSumAtNZ(inputVector,inputVector+numInputs_, - overlaps.begin(),overlaps.end()); -} - -void SpatialPooler::calculateOverlapPct_(vector& overlaps, - vector& overlapPct) -{ - overlapPct.assign(numColumns_,0); - for (UInt i = 0; i < numColumns_; i++) - { - if (connectedCounts_[i] != 0) - { - overlapPct[i] = ((Real) overlaps[i]) / connectedCounts_[i]; - } - else - { + vector &overlaps) { + overlaps.assign(numColumns_, 0); + connectedSynapses_.rightVecSumAtNZ(inputVector, inputVector + numInputs_, + overlaps.begin(), overlaps.end()); +} + +void SpatialPooler::calculateOverlapPct_(vector &overlaps, + vector &overlapPct) { + overlapPct.assign(numColumns_, 0); + for (UInt i = 0; i < numColumns_; i++) { + if (connectedCounts_[i] != 0) { + overlapPct[i] = ((Real)overlaps[i]) / connectedCounts_[i]; + } else { // The intent here is to see if a cell matches its input well. // Therefore if nothing is connected the overlapPct is set to 0. overlapPct[i] = 0; @@ -1189,47 +917,37 @@ void SpatialPooler::calculateOverlapPct_(vector& overlaps, } } -void SpatialPooler::inhibitColumns_( - const vector& overlaps, - vector& activeColumns) -{ +void SpatialPooler::inhibitColumns_(const vector &overlaps, + vector &activeColumns) { Real density = localAreaDensity_; - if (numActiveColumnsPerInhArea_ > 0) - { - UInt inhibitionArea = pow((Real) (2 * inhibitionRadius_ + 1), - (Real) columnDimensions_.size()); + if (numActiveColumnsPerInhArea_ > 0) { + UInt inhibitionArea = + pow((Real)(2 * inhibitionRadius_ + 1), (Real)columnDimensions_.size()); inhibitionArea = min(inhibitionArea, numColumns_); - density = ((Real) numActiveColumnsPerInhArea_) / inhibitionArea; - density = min(density, (Real) 0.5); + density = ((Real)numActiveColumnsPerInhArea_) / inhibitionArea; + density = min(density, (Real)0.5); } if (globalInhibition_ || - inhibitionRadius_ > *max_element(columnDimensions_.begin(), - columnDimensions_.end())) - { + inhibitionRadius_ > + *max_element(columnDimensions_.begin(), columnDimensions_.end())) { inhibitColumnsGlobal_(overlaps, density, activeColumns); - } - else - { + } else { inhibitColumnsLocal_(overlaps, density, activeColumns); } } -bool SpatialPooler::isWinner_(Real score, vector >& winners, - UInt numWinners) -{ - if (score < stimulusThreshold_) - { +bool SpatialPooler::isWinner_(Real score, vector> &winners, + UInt numWinners) { + if (score < stimulusThreshold_) { return false; } - if (winners.size() < numWinners) - { + if (winners.size() < numWinners) { return true; } - if (score >= winners[numWinners-1].second) - { + if (score >= winners[numWinners - 1].second) { return true; } @@ -1237,14 +955,10 @@ bool SpatialPooler::isWinner_(Real score, vector >& winners, } void SpatialPooler::addToWinners_(UInt index, Real score, - vector >& winners) -{ + vector> &winners) { pair val = make_pair(index, score); - for (auto it = winners.begin(); - it != winners.end(); it++) - { - if (score >= it->second) - { + for (auto it = winners.begin(); it != winners.end(); it++) { + if (score >= it->second) { winners.insert(it, val); return; } @@ -1252,92 +966,70 @@ void SpatialPooler::addToWinners_(UInt index, Real score, winners.push_back(val); } -void SpatialPooler::inhibitColumnsGlobal_( - const vector& overlaps, - Real density, - vector& activeColumns) -{ +void SpatialPooler::inhibitColumnsGlobal_(const vector &overlaps, + Real density, + vector &activeColumns) { activeColumns.clear(); - const UInt numDesired = (UInt) (density * numColumns_); - NTA_CHECK(numDesired > 0) - << "Not enough columns (" << numColumns_ << ") " - << "for desired density (" << density << ")."; - vector > winners; - for (UInt i = 0; i < numColumns_; i++) - { - if (isWinner_(overlaps[i], winners, numDesired)) - { - addToWinners_(i,overlaps[i], winners); + const UInt numDesired = (UInt)(density * numColumns_); + NTA_CHECK(numDesired > 0) << "Not enough columns (" << numColumns_ << ") " + << "for desired density (" << density << ")."; + vector> winners; + for (UInt i = 0; i < numColumns_; i++) { + if (isWinner_(overlaps[i], winners, numDesired)) { + addToWinners_(i, overlaps[i], winners); } } const UInt numActual = min(numDesired, (UInt)winners.size()); - for (UInt i = 0; i < numActual; i++) - { + for (UInt i = 0; i < numActual; i++) { activeColumns.push_back(winners[i].first); } - } -void SpatialPooler::inhibitColumnsLocal_( - const vector& overlaps, - Real density, - vector& activeColumns) -{ +void SpatialPooler::inhibitColumnsLocal_(const vector &overlaps, + Real density, + vector &activeColumns) { activeColumns.clear(); // Tie-breaking: when overlaps are equal, columns that have already been // selected are treated as "bigger". vector activeColumnsDense(numColumns_, false); - for (UInt column = 0; column < numColumns_; column++) - { - if (overlaps[column] >= stimulusThreshold_) - { + for (UInt column = 0; column < numColumns_; column++) { + if (overlaps[column] >= stimulusThreshold_) { UInt numNeighbors = 0; UInt numBigger = 0; - if (wrapAround_) - { + if (wrapAround_) { for (UInt neighbor : WrappingNeighborhood(column, inhibitionRadius_, - columnDimensions_)) - { - if (neighbor != column) - { + columnDimensions_)) { + if (neighbor != column) { numNeighbors++; const Real difference = overlaps[neighbor] - overlaps[column]; if (difference > 0 || - (difference == 0 && activeColumnsDense[neighbor])) - { + (difference == 0 && activeColumnsDense[neighbor])) { numBigger++; } } } - } - else - { - for (UInt neighbor : Neighborhood(column, inhibitionRadius_, - columnDimensions_)) - { - if (neighbor != column) - { + } else { + for (UInt neighbor : + Neighborhood(column, inhibitionRadius_, columnDimensions_)) { + if (neighbor != column) { numNeighbors++; const Real difference = overlaps[neighbor] - overlaps[column]; if (difference > 0 || - (difference == 0 && activeColumnsDense[neighbor])) - { + (difference == 0 && activeColumnsDense[neighbor])) { numBigger++; } } } } - - UInt numActive = (UInt) (0.5 + (density * (numNeighbors + 1))); - if (numBigger < numActive) - { + UInt numActive = (UInt)(0.5 + (density * (numNeighbors + 1))); + if (numBigger < numActive) { activeColumns.push_back(column); activeColumnsDense[column] = true; } @@ -1345,19 +1037,14 @@ void SpatialPooler::inhibitColumnsLocal_( } } -bool SpatialPooler::isUpdateRound_() -{ +bool SpatialPooler::isUpdateRound_() { return (iterationNum_ % updatePeriod_) == 0; } /* create a RNG with given seed */ -void SpatialPooler::seed_(UInt64 seed) -{ - rng_ = Random(seed); -} +void SpatialPooler::seed_(UInt64 seed) { rng_ = Random(seed); } -UInt SpatialPooler::persistentSize() const -{ +UInt SpatialPooler::persistentSize() const { // TODO: this won't scale! stringstream s; s.flags(ios::scientific); @@ -1366,43 +1053,35 @@ UInt SpatialPooler::persistentSize() const return s.str().size(); } -template -static void saveFloat_(ostream& outStream, FloatType v) -{ +template +static void saveFloat_(ostream &outStream, FloatType v) { outStream << std::setprecision(std::numeric_limits::max_digits10) - << v - << " "; + << v << " "; } -void SpatialPooler::save(ostream& outStream) const -{ +void SpatialPooler::save(ostream &outStream) const { // Write a starting marker and version. outStream << "SpatialPooler" << endl; outStream << version_ << endl; // Store the simple variables first. - outStream << numInputs_ << " " - << numColumns_ << " " - << potentialRadius_ << " "; + outStream << numInputs_ << " " << numColumns_ << " " << potentialRadius_ + << " "; saveFloat_(outStream, potentialPct_); saveFloat_(outStream, initConnectedPct_); - outStream << globalInhibition_ << " " - << numActiveColumnsPerInhArea_ << " "; + outStream << globalInhibition_ << " " << numActiveColumnsPerInhArea_ << " "; saveFloat_(outStream, localAreaDensity_); - outStream << stimulusThreshold_ << " " - << inhibitionRadius_ << " " + outStream << stimulusThreshold_ << " " << inhibitionRadius_ << " " << dutyCyclePeriod_ << " "; saveFloat_(outStream, boostStrength_); - outStream << iterationNum_ << " " - << iterationLearnNum_ << " " - << spVerbosity_ << " " - << updatePeriod_ << " "; + outStream << iterationNum_ << " " << iterationLearnNum_ << " " << spVerbosity_ + << " " << updatePeriod_ << " "; saveFloat_(outStream, synPermMin_); saveFloat_(outStream, synPermMax_); @@ -1413,78 +1092,65 @@ void SpatialPooler::save(ostream& outStream) const saveFloat_(outStream, synPermConnected_); saveFloat_(outStream, minPctOverlapDutyCycles_); - outStream << wrapAround_ << " " - << endl; + outStream << wrapAround_ << " " << endl; // Store vectors. outStream << inputDimensions_.size() << " "; - for (auto & elem : inputDimensions_) - { + for (auto &elem : inputDimensions_) { outStream << elem << " "; } outStream << endl; outStream << columnDimensions_.size() << " "; - for (auto & elem : columnDimensions_) - { + for (auto &elem : columnDimensions_) { outStream << elem << " "; } outStream << endl; - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { saveFloat_(outStream, boostFactors_[i]); } outStream << endl; - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { saveFloat_(outStream, overlapDutyCycles_[i]); } outStream << endl; - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { saveFloat_(outStream, activeDutyCycles_[i]); } outStream << endl; - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { saveFloat_(outStream, minOverlapDutyCycles_[i]); } outStream << endl; - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { outStream << tieBreaker_[i] << " "; } outStream << endl; - // Store matrices. - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { vector pot; pot.resize(potentialPools_.nNonZerosOnRow(i)); pot = potentialPools_.getSparseRow(i); outStream << pot.size() << endl; - for (auto & elem : pot) - { + for (auto &elem : pot) { outStream << elem << " "; } outStream << endl; } outStream << endl; - for (UInt i = 0; i < numColumns_; i++) - { - vector > perm; + for (UInt i = 0; i < numColumns_; i++) { + vector> perm; perm.resize(permanences_.nNonZerosOnRow(i)); outStream << perm.size() << endl; permanences_.getRowToSparse(i, perm.begin()); - for (auto & elem : perm) - { + for (auto &elem : perm) { outStream << elem.first << " "; saveFloat_(outStream, elem.second); } @@ -1500,8 +1166,7 @@ void SpatialPooler::save(ostream& outStream) const // Implementation note: this method sets up the instance using data from // inStream. This method does not call initialize. As such we have to be careful // that everything in initialize is handled properly here. -void SpatialPooler::load(istream& inStream) -{ +void SpatialPooler::load(istream &inStream) { // Current version version_ = 2; @@ -1515,39 +1180,19 @@ void SpatialPooler::load(istream& inStream) inStream >> version; NTA_CHECK(version <= version_); - // Retrieve simple variables - inStream >> numInputs_ - >> numColumns_ - >> potentialRadius_ - >> potentialPct_ - >> initConnectedPct_ - >> globalInhibition_ - >> numActiveColumnsPerInhArea_ - >> localAreaDensity_ - >> stimulusThreshold_ - >> inhibitionRadius_ - >> dutyCyclePeriod_ - >> boostStrength_ - >> iterationNum_ - >> iterationLearnNum_ - >> spVerbosity_ - >> updatePeriod_ - - >> synPermMin_ - >> synPermMax_ - >> synPermTrimThreshold_ - >> synPermInactiveDec_ - >> synPermActiveInc_ - >> synPermBelowStimulusInc_ - >> synPermConnected_ - >> minPctOverlapDutyCycles_; - if (version == 1) - { + inStream >> numInputs_ >> numColumns_ >> potentialRadius_ >> potentialPct_ >> + initConnectedPct_ >> globalInhibition_ >> numActiveColumnsPerInhArea_ >> + localAreaDensity_ >> stimulusThreshold_ >> inhibitionRadius_ >> + dutyCyclePeriod_ >> boostStrength_ >> iterationNum_ >> + iterationLearnNum_ >> spVerbosity_ >> updatePeriod_ + + >> synPermMin_ >> synPermMax_ >> synPermTrimThreshold_ >> + synPermInactiveDec_ >> synPermActiveInc_ >> synPermBelowStimulusInc_ >> + synPermConnected_ >> minPctOverlapDutyCycles_; + if (version == 1) { wrapAround_ = true; - } - else - { + } else { inStream >> wrapAround_; } @@ -1555,75 +1200,63 @@ void SpatialPooler::load(istream& inStream) UInt numInputDimensions; inStream >> numInputDimensions; inputDimensions_.resize(numInputDimensions); - for (UInt i = 0; i < numInputDimensions; i++) - { + for (UInt i = 0; i < numInputDimensions; i++) { inStream >> inputDimensions_[i]; } UInt numColumnDimensions; inStream >> numColumnDimensions; columnDimensions_.resize(numColumnDimensions); - for (UInt i = 0; i < numColumnDimensions; i++) - { + for (UInt i = 0; i < numColumnDimensions; i++) { inStream >> columnDimensions_[i]; } boostFactors_.resize(numColumns_); - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { inStream >> boostFactors_[i]; } overlapDutyCycles_.resize(numColumns_); - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { inStream >> overlapDutyCycles_[i]; } activeDutyCycles_.resize(numColumns_); - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { inStream >> activeDutyCycles_[i]; } minOverlapDutyCycles_.resize(numColumns_); - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { inStream >> minOverlapDutyCycles_[i]; } tieBreaker_.resize(numColumns_); - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { inStream >> tieBreaker_[i]; } - // Retrieve matrices. potentialPools_.resize(numColumns_, numInputs_); - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { UInt nNonZerosOnRow; inStream >> nNonZerosOnRow; vector pot(nNonZerosOnRow, 0); - for (UInt j = 0; j < nNonZerosOnRow; j++) - { + for (UInt j = 0; j < nNonZerosOnRow; j++) { inStream >> pot[j]; } - potentialPools_.replaceSparseRow(i,pot.begin(), pot.end()); + potentialPools_.replaceSparseRow(i, pot.begin(), pot.end()); } permanences_.resize(numColumns_, numInputs_); connectedSynapses_.resize(numColumns_, numInputs_); connectedCounts_.resize(numColumns_); - for (UInt i = 0; i < numColumns_; i++) - { + for (UInt i = 0; i < numColumns_; i++) { UInt nNonZerosOnRow; inStream >> nNonZerosOnRow; vector perm(numInputs_, 0); - for (UInt j = 0; j < nNonZerosOnRow; j++) - { + for (UInt j = 0; j < nNonZerosOnRow; j++) { UInt index; Real value; inStream >> index; @@ -1644,22 +1277,19 @@ void SpatialPooler::load(istream& inStream) boostedOverlaps_.resize(numColumns_); } -void SpatialPooler::write(SpatialPoolerProto::Builder& proto) const -{ +void SpatialPooler::write(SpatialPoolerProto::Builder &proto) const { auto random = proto.initRandom(); rng_.write(random); proto.setNumInputs(numInputs_); proto.setNumColumns(numColumns_); auto columnDims = proto.initColumnDimensions(columnDimensions_.size()); - for (UInt i = 0; i < columnDimensions_.size(); ++i) - { + for (UInt i = 0; i < columnDimensions_.size(); ++i) { columnDims.set(i, columnDimensions_[i]); } auto inputDims = proto.initInputDimensions(inputDimensions_.size()); - for (UInt i = 0; i < inputDimensions_.size(); ++i) - { + for (UInt i = 0; i < inputDimensions_.size(); ++i) { inputDims.set(i, inputDimensions_[i]); } @@ -1693,12 +1323,10 @@ void SpatialPooler::write(SpatialPoolerProto::Builder& proto) const potentialPools.setNumRows(numColumns_); potentialPools.setNumColumns(numInputs_); auto potentialPoolIndices = potentialPools.initIndices(numColumns_); - for (UInt i = 0; i < numColumns_; ++i) - { - auto & pot = potentialPools_.getSparseRow(i); + for (UInt i = 0; i < numColumns_; ++i) { + auto &pot = potentialPools_.getSparseRow(i); auto indices = potentialPoolIndices.init(i, pot.size()); - for (UInt j = 0; j < pot.size(); ++j) - { + for (UInt j = 0; j < pot.size(); ++j) { indices.set(j, pot[j]); } } @@ -1707,32 +1335,27 @@ void SpatialPooler::write(SpatialPoolerProto::Builder& proto) const permanences_.write(permanences); auto tieBreaker = proto.initTieBreaker(numColumns_); - for (UInt i = 0; i < numColumns_; ++i) - { + for (UInt i = 0; i < numColumns_; ++i) { tieBreaker.set(i, tieBreaker_[i]); } auto overlapDutyCycles = proto.initOverlapDutyCycles(numColumns_); - for (UInt i = 0; i < numColumns_; ++i) - { + for (UInt i = 0; i < numColumns_; ++i) { overlapDutyCycles.set(i, overlapDutyCycles_[i]); } auto activeDutyCycles = proto.initActiveDutyCycles(numColumns_); - for (UInt i = 0; i < numColumns_; ++i) - { + for (UInt i = 0; i < numColumns_; ++i) { activeDutyCycles.set(i, activeDutyCycles_[i]); } auto minOverlapDutyCycles = proto.initMinOverlapDutyCycles(numColumns_); - for (UInt i = 0; i < numColumns_; ++i) - { + for (UInt i = 0; i < numColumns_; ++i) { minOverlapDutyCycles.set(i, minOverlapDutyCycles_[i]); } auto boostFactors = proto.initBoostFactors(numColumns_); - for (UInt i = 0; i < numColumns_; ++i) - { + for (UInt i = 0; i < numColumns_; ++i) { boostFactors.set(i, boostFactors_[i]); } } @@ -1740,22 +1363,19 @@ void SpatialPooler::write(SpatialPoolerProto::Builder& proto) const // Implementation note: this method sets up the instance using data from // proto. This method does not call initialize. As such we have to be careful // that everything in initialize is handled properly here. -void SpatialPooler::read(SpatialPoolerProto::Reader& proto) -{ +void SpatialPooler::read(SpatialPoolerProto::Reader &proto) { auto randomProto = proto.getRandom(); rng_.read(randomProto); numInputs_ = proto.getNumInputs(); numColumns_ = proto.getNumColumns(); columnDimensions_.clear(); - for (UInt dimension : proto.getColumnDimensions()) - { + for (UInt dimension : proto.getColumnDimensions()) { columnDimensions_.push_back(dimension); } inputDimensions_.clear(); - for (UInt dimension : proto.getInputDimensions()) - { + for (UInt dimension : proto.getInputDimensions()) { inputDimensions_.push_back(dimension); } @@ -1798,43 +1418,36 @@ void SpatialPooler::read(SpatialPoolerProto::Reader& proto) auto permanences = proto.getPermanences(); permanences_.resize(permanences.getNumRows(), permanences.getNumColumns()); auto permanenceValues = permanences.getRows(); - for (UInt i = 0; i < numColumns_; ++i) - { + for (UInt i = 0; i < numColumns_; ++i) { vector colPerms(numInputs_, 0); - for (auto perm : permanenceValues[i].getValues()) - { + for (auto perm : permanenceValues[i].getValues()) { colPerms[perm.getIndex()] = perm.getValue(); } updatePermanencesForColumn_(colPerms, i, false); } tieBreaker_.clear(); - for (auto value : proto.getTieBreaker()) - { + for (auto value : proto.getTieBreaker()) { tieBreaker_.push_back(value); } overlapDutyCycles_.clear(); - for (auto value : proto.getOverlapDutyCycles()) - { + for (auto value : proto.getOverlapDutyCycles()) { overlapDutyCycles_.push_back(value); } activeDutyCycles_.clear(); - for (auto value : proto.getActiveDutyCycles()) - { + for (auto value : proto.getActiveDutyCycles()) { activeDutyCycles_.push_back(value); } minOverlapDutyCycles_.clear(); - for (auto value : proto.getMinOverlapDutyCycles()) - { + for (auto value : proto.getMinOverlapDutyCycles()) { minOverlapDutyCycles_.push_back(value); } boostFactors_.clear(); - for (auto value : proto.getBoostFactors()) - { + for (auto value : proto.getBoostFactors()) { boostFactors_.push_back(value); } @@ -1849,39 +1462,36 @@ void SpatialPooler::read(SpatialPoolerProto::Reader& proto) //---------------------------------------------------------------------- // Print the main SP creation parameters -void SpatialPooler::printParameters() const -{ +void SpatialPooler::printParameters() const { std::cout << "------------CPP SpatialPooler Parameters ------------------\n"; std::cout - << "iterationNum = " << getIterationNum() << std::endl - << "iterationLearnNum = " << getIterationLearnNum() << std::endl - << "numInputs = " << getNumInputs() << std::endl - << "numColumns = " << getNumColumns() << std::endl - << "numActiveColumnsPerInhArea = " - << getNumActiveColumnsPerInhArea() << std::endl - << "potentialPct = " << getPotentialPct() << std::endl - << "globalInhibition = " << getGlobalInhibition() << std::endl - << "localAreaDensity = " << getLocalAreaDensity() << std::endl - << "stimulusThreshold = " << getStimulusThreshold() << std::endl - << "synPermActiveInc = " << getSynPermActiveInc() << std::endl - << "synPermInactiveDec = " << getSynPermInactiveDec() << std::endl - << "synPermConnected = " << getSynPermConnected() << std::endl - << "minPctOverlapDutyCycles = " - << getMinPctOverlapDutyCycles() << std::endl - << "dutyCyclePeriod = " << getDutyCyclePeriod() << std::endl - << "boostStrength = " << getBoostStrength() << std::endl - << "spVerbosity = " << getSpVerbosity() << std::endl - << "wrapAround = " << getWrapAround() << std::endl - << "version = " << version() << std::endl; -} - -void SpatialPooler::printState(vector &state) -{ + << "iterationNum = " << getIterationNum() << std::endl + << "iterationLearnNum = " << getIterationLearnNum() << std::endl + << "numInputs = " << getNumInputs() << std::endl + << "numColumns = " << getNumColumns() << std::endl + << "numActiveColumnsPerInhArea = " << getNumActiveColumnsPerInhArea() + << std::endl + << "potentialPct = " << getPotentialPct() << std::endl + << "globalInhibition = " << getGlobalInhibition() << std::endl + << "localAreaDensity = " << getLocalAreaDensity() << std::endl + << "stimulusThreshold = " << getStimulusThreshold() << std::endl + << "synPermActiveInc = " << getSynPermActiveInc() << std::endl + << "synPermInactiveDec = " << getSynPermInactiveDec() + << std::endl + << "synPermConnected = " << getSynPermConnected() << std::endl + << "minPctOverlapDutyCycles = " << getMinPctOverlapDutyCycles() + << std::endl + << "dutyCyclePeriod = " << getDutyCyclePeriod() << std::endl + << "boostStrength = " << getBoostStrength() << std::endl + << "spVerbosity = " << getSpVerbosity() << std::endl + << "wrapAround = " << getWrapAround() << std::endl + << "version = " << version() << std::endl; +} + +void SpatialPooler::printState(vector &state) { std::cout << "[ "; - for (UInt i = 0; i != state.size(); ++i) - { - if (i > 0 && i % 10 == 0) - { + for (UInt i = 0; i != state.size(); ++i) { + if (i > 0 && i % 10 == 0) { std::cout << "\n "; } std::cout << state[i] << " "; @@ -1889,17 +1499,13 @@ void SpatialPooler::printState(vector &state) std::cout << "]\n"; } -void SpatialPooler::printState(vector &state) -{ +void SpatialPooler::printState(vector &state) { std::cout << "[ "; - for (UInt i = 0; i != state.size(); ++i) - { - if (i > 0 && i % 10 == 0) - { + for (UInt i = 0; i != state.size(); ++i) { + if (i > 0 && i % 10 == 0) { std::cout << "\n "; } std::printf("%6.3f ", state[i]); } std::cout << "]\n"; } - diff --git a/src/nupic/algorithms/SpatialPooler.hpp b/src/nupic/algorithms/SpatialPooler.hpp index ed9c7a4c02..59876f0904 100644 --- a/src/nupic/algorithms/SpatialPooler.hpp +++ b/src/nupic/algorithms/SpatialPooler.hpp @@ -27,1303 +27,1269 @@ #ifndef NTA_spatial_pooler_HPP #define NTA_spatial_pooler_HPP +#include #include #include -#include -#include -#include #include #include #include #include #include +#include +#include using namespace std; -namespace nupic -{ - namespace algorithms - { - namespace spatial_pooler - { - - /** - * CLA spatial pooler implementation in C++. - * - * ### Description - * The Spatial Pooler is responsible for creating a sparse distributed - * representation of the input. Given an input it computes a set of sparse - * active columns and simultaneously updates its permanences, duty cycles, - * etc. - * - * The primary public interfaces to this function are the "initialize" - * and "compute" methods. - * - * Example usage: - * - * SpatialPooler sp; - * sp.initialize(inputDimensions, columnDimensions, ); - * while (true) { - * - * sp.compute(inputVector, learn, activeColumns) - * - * } - * - */ - class SpatialPooler : public Serializable - { - public: - SpatialPooler(); - SpatialPooler(vector inputDimensions, - vector columnDimensions, - UInt potentialRadius=16, - Real potentialPct=0.5, - bool globalInhibition=true, - Real localAreaDensity=-1.0, - UInt numActiveColumnsPerInhArea=10, - UInt stimulusThreshold=0, - Real synPermInactiveDec=0.008, - Real synPermActiveInc=0.05, - Real synPermConnected=0.1, - Real minPctOverlapDutyCycles=0.001, - UInt dutyCyclePeriod=1000, - Real boostStrength=0.0, - Int seed=1, - UInt spVerbosity=0, - bool wrapAround=true); - - - virtual ~SpatialPooler() {} - - /** - Initialize the spatial pooler using the given parameters. - - @param inputDimensions A list of integers representing the - dimensions of the input vector. Format is [height, width, - depth, ...], where each value represents the size of the - dimension. For a topology of one dimesion with 100 inputs - use [100]. For a two dimensional topology of 10x5 - use [10,5]. - - @param columnDimensions A list of integers representing the - dimensions of the columns in the region. Format is [height, - width, depth, ...], where each value represents the size of - the dimension. For a topology of one dimesion with 2000 - columns use 2000, or [2000]. For a three dimensional - topology of 32x64x16 use [32, 64, 16]. - - @param potentialRadius This parameter deteremines the extent of the - input that each column can potentially be connected to. This - can be thought of as the input bits that are visible to each - column, or a 'receptive field' of the field of vision. A large - enough value will result in global coverage, meaning - that each column can potentially be connected to every input - bit. This parameter defines a square (or hyper square) area: a - column will have a max square potential pool with sides of - length (2 * potentialRadius + 1). - - @param potentialPct The percent of the inputs, within a column's - potential radius, that a column can be connected to. If set to - 1, the column will be connected to every input within its - potential radius. This parameter is used to give each column a - unique potential pool when a large potentialRadius causes - overlap between the columns. At initialization time we choose - ((2*potentialRadius + 1)^(# inputDimensions) * potentialPct) - input bits to comprise the column's potential pool. - - @param globalInhibition If true, then during inhibition phase the - winning columns are selected as the most active columns from the - region as a whole. Otherwise, the winning columns are selected - with resepct to their local neighborhoods. Global inhibition - boosts performance significantly but there is no topology at the - output. - - @param localAreaDensity The desired density of active columns within - a local inhibition area (the size of which is set by the - internally calculated inhibitionRadius, which is in turn - determined from the average size of the connected potential - pools of all columns). The inhibition logic will insure that at - most N columns remain ON within a local inhibition area, where - N = localAreaDensity * (total number of columns in inhibition - area). If localAreaDensity is set to a negative value output - sparsity will be determined by the numActivePerInhArea. - - @param numActiveColumnsPerInhArea An alternate way to control the sparsity of - active columns. If numActivePerInhArea is specified then - localAreaDensity must be less than 0, and vice versa. When - numActivePerInhArea > 0, the inhibition logic will insure that - at most 'numActivePerInhArea' columns remain ON within a local - inhibition area (the size of which is set by the internally - calculated inhibitionRadius). When using this method, as columns - learn and grow their effective receptive fields, the - inhibitionRadius will grow, and hence the net density of the - active columns will *decrease*. This is in contrast to the - localAreaDensity method, which keeps the density of active - columns the same regardless of the size of their receptive - fields. - - @param stimulusThreshold This is a number specifying the minimum - number of synapses that must be active in order for a column to - turn ON. The purpose of this is to prevent noisy input from - activating columns. - - @param synPermInactiveDec The amount by which the permanence of an - inactive synapse is decremented in each learning step. - - @param synPermActiveInc The amount by which the permanence of an - active synapse is incremented in each round. - - @param synPermConnected The default connected threshold. Any synapse - whose permanence value is above the connected threshold is - a "connected synapse", meaning it can contribute to - the cell's firing. - - @param minPctOverlapDutyCycles A number between 0 and 1.0, used to set - a floor on how often a column should have at least - stimulusThreshold active inputs. Periodically, each column looks - at the overlap duty cycle of all other column within its - inhibition radius and sets its own internal minimal acceptable - duty cycle to: minPctDutyCycleBeforeInh * max(other columns' - duty cycles). On each iteration, any column whose overlap duty - cycle falls below this computed value will get all of its - permanence values boosted up by synPermActiveInc. Raising all - permanences in response to a sub-par duty cycle before - inhibition allows a cell to search for new inputs when either - its previously learned inputs are no longer ever active, or when - the vast majority of them have been "hijacked" by other columns. - - @param dutyCyclePeriod The period used to calculate duty cycles. - Higher values make it take longer to respond to changes in - boost. Shorter values make it potentially more unstable and - likely to oscillate. - - @param boostStrength A number greater or equal than 0, used to - control boosting strength. No boosting is applied if it is set to 0. - The strength of boosting increases as a function of boostStrength. - Boosting encourages columns to have similar activeDutyCycles as their - neighbors, which will lead to more efficient use of columns. However, - too much boosting may also lead to instability of SP outputs. - - - @param seed Seed for our random number generator. If seed is < 0 - a randomly generated seed is used. The behavior of the spatial - pooler is deterministic once the seed is set. - - @param spVerbosity spVerbosity level: 0, 1, 2, or 3 - - @param wrapAround boolean value that determines whether or not inputs - at the beginning and end of an input dimension are considered - neighbors for the purpose of mapping inputs to columns. - - */ - virtual void initialize(vector inputDimensions, - vector columnDimensions, - UInt potentialRadius=16, - Real potentialPct=0.5, - bool globalInhibition=true, - Real localAreaDensity=-1.0, - UInt numActiveColumnsPerInhArea=10, - UInt stimulusThreshold=0, - Real synPermInactiveDec=0.01, - Real synPermActiveInc=0.1, - Real synPermConnected=0.1, - Real minPctOverlapDutyCycles=0.001, - UInt dutyCyclePeriod=1000, - Real boostStrength=0.0, - Int seed=1, - UInt spVerbosity=0, - bool wrapAround=true); - - /** - This is the main workshorse method of the SpatialPooler class. This - method takes an input vector and computes the set of output active - columns. If 'learn' is set to True, this method also performs - learning. - - @param inputVector An array of integer 0's and 1's that comprises - the input to the spatial pooler. The length of the - array must match the total number of input bits implied by - the constructor (also returned by the method getNumInputs). In - cases where the input is multi-dimensional, inputVector is a - flattened array of inputs. - - @param learn A boolean value indicating whether learning should be - performed. Learning entails updating the permanence values of - the synapses, duty cycles, etc. Learning is typically on but - setting learning to 'off' is useful for analyzing the current - state of the SP. For example, you might want to feed in various - inputs and examine the resulting SDR's. Note that if learning - is off, boosting is turned off and columns that have never won - will be removed from activeVector. TODO: we may want to keep - boosting on even when learning is off. - - @param activeVector An array representing the winning columns after - inhibition. The size of the array is equal to the number of - columns (also returned by the method getNumColumns). This array - will be populated with 1's at the indices of the active columns, - and 0's everywhere else. In the case where the output is - multi-dimensional, activeVector represents a flattened array - of outputs. - */ - virtual void compute(UInt inputVector[], bool learn, - UInt activeVector[]); - - /** - Removes the set of columns who have never been active from the set - of active columns selected in the inhibition round. Such columns - cannot represent learned pattern and are therefore meaningless if - only inference is required. - - @param activeArray An array of 1's and 0's representing winning - columns calculated by the 'compute' method after disabling - any columns that are not learned. - */ - void stripUnlearnedColumns(UInt activeArray[]) const; - - /** - * Get the version number of this spatial pooler. - - * @returns Integer version number. - */ - virtual UInt version() const { - return version_; - }; - - /** - Save (serialize) the current state of the spatial pooler to the - specified file. - - @param fd A valid file descriptor. - */ - virtual void save(ostream& outStream) const; - - using Serializable::write; - virtual void write(SpatialPoolerProto::Builder& proto) const override; - - /** - Load (deserialize) and initialize the spatial pooler from the - specified input stream. - - @param inStream A valid istream. - */ - virtual void load(istream& inStream); - - using Serializable::read; - virtual void read(SpatialPoolerProto::Reader& proto) override; - - /** - Returns the number of bytes that a save operation would result in. - Note: this method is currently somewhat inefficient as it just does - a full save into an ostream and counts the resulting size. - - @returns Integer number of bytes - */ - virtual UInt persistentSize() const; - - /** - Returns the dimensions of the columns in the region. - - @returns Integer number of column dimension. - */ - vector getColumnDimensions() const; - - /** - Returns the dimensions of the input vector. - - @returns Integer vector of input dimension. - */ - vector getInputDimensions() const; - - /** - Returns the total number of columns. - - @returns Integer number of column numbers. - */ - UInt getNumColumns() const; - - /** - Returns the total number of inputs. - - @returns Integer number of inputs. - */ - UInt getNumInputs() const; - - /** - Returns the potential radius. - - @returns Integer number of potential radius. - */ - UInt getPotentialRadius() const; - - /** - Sets the potential radius. - - @param potentialRadius integer number of potential raduis. - */ - void setPotentialRadius(UInt potentialRadius); - /** - Returns the potential percent. +namespace nupic { +namespace algorithms { +namespace spatial_pooler { - @returns real number of the potential percent. - */ - Real getPotentialPct() const; - - /** - Sets the potential percent. - - @param potentialPct real number of potential percent. - */ - void setPotentialPct(Real potentialPct); +/** + * CLA spatial pooler implementation in C++. + * + * ### Description + * The Spatial Pooler is responsible for creating a sparse distributed + * representation of the input. Given an input it computes a set of sparse + * active columns and simultaneously updates its permanences, duty cycles, + * etc. + * + * The primary public interfaces to this function are the "initialize" + * and "compute" methods. + * + * Example usage: + * + * SpatialPooler sp; + * sp.initialize(inputDimensions, columnDimensions, ); + * while (true) { + * + * sp.compute(inputVector, learn, activeColumns) + * + * } + * + */ +class SpatialPooler : public Serializable { +public: + SpatialPooler(); + SpatialPooler(vector inputDimensions, vector columnDimensions, + UInt potentialRadius = 16, Real potentialPct = 0.5, + bool globalInhibition = true, Real localAreaDensity = -1.0, + UInt numActiveColumnsPerInhArea = 10, + UInt stimulusThreshold = 0, Real synPermInactiveDec = 0.008, + Real synPermActiveInc = 0.05, Real synPermConnected = 0.1, + Real minPctOverlapDutyCycles = 0.001, + UInt dutyCyclePeriod = 1000, Real boostStrength = 0.0, + Int seed = 1, UInt spVerbosity = 0, bool wrapAround = true); + + virtual ~SpatialPooler() {} + + /** + Initialize the spatial pooler using the given parameters. + + @param inputDimensions A list of integers representing the + dimensions of the input vector. Format is [height, width, + depth, ...], where each value represents the size of the + dimension. For a topology of one dimesion with 100 inputs + use [100]. For a two dimensional topology of 10x5 + use [10,5]. + + @param columnDimensions A list of integers representing the + dimensions of the columns in the region. Format is [height, + width, depth, ...], where each value represents the size of + the dimension. For a topology of one dimesion with 2000 + columns use 2000, or [2000]. For a three dimensional + topology of 32x64x16 use [32, 64, 16]. + + @param potentialRadius This parameter deteremines the extent of the + input that each column can potentially be connected to. This + can be thought of as the input bits that are visible to each + column, or a 'receptive field' of the field of vision. A large + enough value will result in global coverage, meaning + that each column can potentially be connected to every input + bit. This parameter defines a square (or hyper square) area: a + column will have a max square potential pool with sides of + length (2 * potentialRadius + 1). + + @param potentialPct The percent of the inputs, within a column's + potential radius, that a column can be connected to. If set to + 1, the column will be connected to every input within its + potential radius. This parameter is used to give each column a + unique potential pool when a large potentialRadius causes + overlap between the columns. At initialization time we choose + ((2*potentialRadius + 1)^(# inputDimensions) * potentialPct) + input bits to comprise the column's potential pool. + + @param globalInhibition If true, then during inhibition phase the + winning columns are selected as the most active columns from the + region as a whole. Otherwise, the winning columns are selected + with resepct to their local neighborhoods. Global inhibition + boosts performance significantly but there is no topology at the + output. + + @param localAreaDensity The desired density of active columns within + a local inhibition area (the size of which is set by the + internally calculated inhibitionRadius, which is in turn + determined from the average size of the connected potential + pools of all columns). The inhibition logic will insure that at + most N columns remain ON within a local inhibition area, where + N = localAreaDensity * (total number of columns in inhibition + area). If localAreaDensity is set to a negative value output + sparsity will be determined by the numActivePerInhArea. + + @param numActiveColumnsPerInhArea An alternate way to control the sparsity of + active columns. If numActivePerInhArea is specified then + localAreaDensity must be less than 0, and vice versa. When + numActivePerInhArea > 0, the inhibition logic will insure that + at most 'numActivePerInhArea' columns remain ON within a local + inhibition area (the size of which is set by the internally + calculated inhibitionRadius). When using this method, as columns + learn and grow their effective receptive fields, the + inhibitionRadius will grow, and hence the net density of the + active columns will *decrease*. This is in contrast to the + localAreaDensity method, which keeps the density of active + columns the same regardless of the size of their receptive + fields. + + @param stimulusThreshold This is a number specifying the minimum + number of synapses that must be active in order for a column to + turn ON. The purpose of this is to prevent noisy input from + activating columns. + + @param synPermInactiveDec The amount by which the permanence of an + inactive synapse is decremented in each learning step. + + @param synPermActiveInc The amount by which the permanence of an + active synapse is incremented in each round. + + @param synPermConnected The default connected threshold. Any synapse + whose permanence value is above the connected threshold is + a "connected synapse", meaning it can contribute to + the cell's firing. + + @param minPctOverlapDutyCycles A number between 0 and 1.0, used to set + a floor on how often a column should have at least + stimulusThreshold active inputs. Periodically, each column looks + at the overlap duty cycle of all other column within its + inhibition radius and sets its own internal minimal acceptable + duty cycle to: minPctDutyCycleBeforeInh * max(other columns' + duty cycles). On each iteration, any column whose overlap duty + cycle falls below this computed value will get all of its + permanence values boosted up by synPermActiveInc. Raising all + permanences in response to a sub-par duty cycle before + inhibition allows a cell to search for new inputs when either + its previously learned inputs are no longer ever active, or when + the vast majority of them have been "hijacked" by other columns. + + @param dutyCyclePeriod The period used to calculate duty cycles. + Higher values make it take longer to respond to changes in + boost. Shorter values make it potentially more unstable and + likely to oscillate. + + @param boostStrength A number greater or equal than 0, used to + control boosting strength. No boosting is applied if it is set to 0. + The strength of boosting increases as a function of boostStrength. + Boosting encourages columns to have similar activeDutyCycles as their + neighbors, which will lead to more efficient use of columns. However, + too much boosting may also lead to instability of SP outputs. + + + @param seed Seed for our random number generator. If seed is < 0 + a randomly generated seed is used. The behavior of the spatial + pooler is deterministic once the seed is set. + + @param spVerbosity spVerbosity level: 0, 1, 2, or 3 + + @param wrapAround boolean value that determines whether or not inputs + at the beginning and end of an input dimension are considered + neighbors for the purpose of mapping inputs to columns. + + */ + virtual void + initialize(vector inputDimensions, vector columnDimensions, + UInt potentialRadius = 16, Real potentialPct = 0.5, + bool globalInhibition = true, Real localAreaDensity = -1.0, + UInt numActiveColumnsPerInhArea = 10, UInt stimulusThreshold = 0, + Real synPermInactiveDec = 0.01, Real synPermActiveInc = 0.1, + Real synPermConnected = 0.1, Real minPctOverlapDutyCycles = 0.001, + UInt dutyCyclePeriod = 1000, Real boostStrength = 0.0, + Int seed = 1, UInt spVerbosity = 0, bool wrapAround = true); + + /** + This is the main workshorse method of the SpatialPooler class. This + method takes an input vector and computes the set of output active + columns. If 'learn' is set to True, this method also performs + learning. + + @param inputVector An array of integer 0's and 1's that comprises + the input to the spatial pooler. The length of the + array must match the total number of input bits implied by + the constructor (also returned by the method getNumInputs). In + cases where the input is multi-dimensional, inputVector is a + flattened array of inputs. + + @param learn A boolean value indicating whether learning should be + performed. Learning entails updating the permanence values of + the synapses, duty cycles, etc. Learning is typically on but + setting learning to 'off' is useful for analyzing the current + state of the SP. For example, you might want to feed in various + inputs and examine the resulting SDR's. Note that if learning + is off, boosting is turned off and columns that have never won + will be removed from activeVector. TODO: we may want to keep + boosting on even when learning is off. + + @param activeVector An array representing the winning columns after + inhibition. The size of the array is equal to the number of + columns (also returned by the method getNumColumns). This array + will be populated with 1's at the indices of the active columns, + and 0's everywhere else. In the case where the output is + multi-dimensional, activeVector represents a flattened array + of outputs. + */ + virtual void compute(UInt inputVector[], bool learn, UInt activeVector[]); + + /** + Removes the set of columns who have never been active from the set + of active columns selected in the inhibition round. Such columns + cannot represent learned pattern and are therefore meaningless if + only inference is required. + + @param activeArray An array of 1's and 0's representing winning + columns calculated by the 'compute' method after disabling + any columns that are not learned. + */ + void stripUnlearnedColumns(UInt activeArray[]) const; + + /** + * Get the version number of this spatial pooler. + + * @returns Integer version number. + */ + virtual UInt version() const { return version_; }; + + /** + Save (serialize) the current state of the spatial pooler to the + specified file. + + @param fd A valid file descriptor. + */ + virtual void save(ostream &outStream) const; + + using Serializable::write; + virtual void write(SpatialPoolerProto::Builder &proto) const override; + + /** + Load (deserialize) and initialize the spatial pooler from the + specified input stream. + + @param inStream A valid istream. + */ + virtual void load(istream &inStream); + + using Serializable::read; + virtual void read(SpatialPoolerProto::Reader &proto) override; + + /** + Returns the number of bytes that a save operation would result in. + Note: this method is currently somewhat inefficient as it just does + a full save into an ostream and counts the resulting size. + + @returns Integer number of bytes + */ + virtual UInt persistentSize() const; + + /** + Returns the dimensions of the columns in the region. + + @returns Integer number of column dimension. + */ + vector getColumnDimensions() const; + + /** + Returns the dimensions of the input vector. + + @returns Integer vector of input dimension. + */ + vector getInputDimensions() const; + + /** + Returns the total number of columns. + + @returns Integer number of column numbers. + */ + UInt getNumColumns() const; + + /** + Returns the total number of inputs. + + @returns Integer number of inputs. + */ + UInt getNumInputs() const; + + /** + Returns the potential radius. + + @returns Integer number of potential radius. + */ + UInt getPotentialRadius() const; + + /** + Sets the potential radius. - /** - @returns boolen value of whether global inhibition is enabled. - */ - bool getGlobalInhibition() const; + @param potentialRadius integer number of potential raduis. + */ + void setPotentialRadius(UInt potentialRadius); + /** + Returns the potential percent. - /** - Sets global inhibition. + @returns real number of the potential percent. + */ + Real getPotentialPct() const; - @param globalInhibition boolen varable of whether global inhibition is enabled. - */ - void setGlobalInhibition(bool globalInhibition); + /** + Sets the potential percent. - /** - Returns the number of active columns per inhibition area. + @param potentialPct real number of potential percent. + */ + void setPotentialPct(Real potentialPct); - @returns integer number of active columns per inhbition area, Returns a - value less than 0 if parameter is unuse. - */ - Int getNumActiveColumnsPerInhArea() const; + /** + @returns boolen value of whether global inhibition is enabled. + */ + bool getGlobalInhibition() const; + /** + Sets global inhibition. - /** - Sets the number of active columns per inhibition area. Invalidates the - 'localAreaDensity' parameter. + @param globalInhibition boolen varable of whether global inhibition is + enabled. + */ + void setGlobalInhibition(bool globalInhibition); - @param numActiveColumnsPerInhArea integer number of active columns per inhibition area. - */ - void setNumActiveColumnsPerInhArea(UInt numActiveColumnsPerInhArea); + /** + Returns the number of active columns per inhibition area. - /** - Returns the local area density. Returns a value less than 0 if parameter - is unused". + @returns integer number of active columns per inhbition area, Returns a + value less than 0 if parameter is unuse. + */ + Int getNumActiveColumnsPerInhArea() const; - @returns real number of local area density. - */ - Real getLocalAreaDensity() const; + /** + Sets the number of active columns per inhibition area. Invalidates the + 'localAreaDensity' parameter. - /** - Sets the local area density. Invalidates the 'numActivePerInhArea' - parameter". + @param numActiveColumnsPerInhArea integer number of active columns per + inhibition area. + */ + void setNumActiveColumnsPerInhArea(UInt numActiveColumnsPerInhArea); - @param localAreaDensity real number of local area density. - */ - void setLocalAreaDensity(Real localAreaDensity); + /** + Returns the local area density. Returns a value less than 0 if parameter + is unused". - /** - Returns the stimulus threshold. + @returns real number of local area density. + */ + Real getLocalAreaDensity() const; - @returns integer number of stimulus threshold. - */ - UInt getStimulusThreshold() const; + /** + Sets the local area density. Invalidates the 'numActivePerInhArea' + parameter". - /** - Sets the stimulus threshold. + @param localAreaDensity real number of local area density. + */ + void setLocalAreaDensity(Real localAreaDensity); - @param stimulusThreshold (positive) integer number of stimulus threshold - */ - void setStimulusThreshold(UInt stimulusThreshold); + /** + Returns the stimulus threshold. - /** - Returns the inhibition radius. + @returns integer number of stimulus threshold. + */ + UInt getStimulusThreshold() const; - @returns (positive) integer of inhibition radius/ - */ - UInt getInhibitionRadius() const; - /** - Sets the inhibition radius. + /** + Sets the stimulus threshold. - @param inhibitionRadius integer of inhibition radius. - */ - void setInhibitionRadius(UInt inhibitionRadius); + @param stimulusThreshold (positive) integer number of stimulus threshold + */ + void setStimulusThreshold(UInt stimulusThreshold); - /** - Returns the duty cycle period. + /** + Returns the inhibition radius. - @returns integer of duty cycle period. - */ - UInt getDutyCyclePeriod() const; + @returns (positive) integer of inhibition radius/ + */ + UInt getInhibitionRadius() const; + /** + Sets the inhibition radius. - /** - Sets the duty cycle period. + @param inhibitionRadius integer of inhibition radius. + */ + void setInhibitionRadius(UInt inhibitionRadius); - @param dutyCyclePeriod integer number of duty cycle period. - */ - void setDutyCyclePeriod(UInt dutyCyclePeriod); + /** + Returns the duty cycle period. - /** - Returns the maximum boost value. + @returns integer of duty cycle period. + */ + UInt getDutyCyclePeriod() const; - @returns real number of the maximum boost value. - */ - Real getBoostStrength() const; + /** + Sets the duty cycle period. - /** - Sets the strength of boost. + @param dutyCyclePeriod integer number of duty cycle period. + */ + void setDutyCyclePeriod(UInt dutyCyclePeriod); - @param boostStrength real number of boosting strength, - must be larger than 0.0 - */ - void setBoostStrength(Real boostStrength); + /** + Returns the maximum boost value. - /** - Returns the iteration number. + @returns real number of the maximum boost value. + */ + Real getBoostStrength() const; - @returns integer number of iteration number. - */ - UInt getIterationNum() const; + /** + Sets the strength of boost. - /** - Sets the iteration number. + @param boostStrength real number of boosting strength, + must be larger than 0.0 + */ + void setBoostStrength(Real boostStrength); - @param iterationNum integer number of iteration number. - */ - void setIterationNum(UInt iterationNum); + /** + Returns the iteration number. - /** - Returns the learning iteration number. + @returns integer number of iteration number. + */ + UInt getIterationNum() const; - @returns integer of the learning iteration number. - */ - UInt getIterationLearnNum() const; + /** + Sets the iteration number. - /** - Sets the learning iteration number. + @param iterationNum integer number of iteration number. + */ + void setIterationNum(UInt iterationNum); - @param iterationLearnNum integer of learning iteration number. - */ - void setIterationLearnNum(UInt iterationLearnNum); + /** + Returns the learning iteration number. - /** - Returns the verbosity level. + @returns integer of the learning iteration number. + */ + UInt getIterationLearnNum() const; - @returns integer of the verbosity level. - */ - UInt getSpVerbosity() const; + /** + Sets the learning iteration number. + + @param iterationLearnNum integer of learning iteration number. + */ + void setIterationLearnNum(UInt iterationLearnNum); - /** - Sets the verbosity level. - - @param spVerbosity integer of verbosity level. - */ - void setSpVerbosity(UInt spVerbosity); - - /** - Returns boolean value of wrapAround which indicates if receptive - fields should wrap around from the beginning the input dimensions - to the end. - - @returns the boolean value of wrapAround. - */ - bool getWrapAround() const; - - /** - Sets wrapAround. - - @param wrapAround boolean value - */ - void setWrapAround(bool wrapAround); - - /** - Returns the update period. - - @returns integer of update period. - */ - UInt getUpdatePeriod() const; - /** - Sets the update period. - - @param updatePeriod integer of update period. - */ - void setUpdatePeriod(UInt updatePeriod); - - /** - Returns the permanence trim threshold. - - @returns real number of the permanence trim threshold. - */ - Real getSynPermTrimThreshold() const; - /** - Sets the permanence trim threshold. - - @param synPermTrimThreshold real number of the permanence trim threshold. - */ - void setSynPermTrimThreshold(Real synPermTrimThreshold); - - /** - Returns the permanence increment amount for active synapses - inputs. - - @returns real number of the permanence increment amount for active synapses - inputs. - */ - Real getSynPermActiveInc() const; - /** - Sets the permanence increment amount for active synapses - inputs. - - @param synPermActiveInc real number of the permanence increment amount - for active synapses inputs, must be >0. - */ - void setSynPermActiveInc(Real synPermActiveInc); - - /** - Returns the permanence decrement amount for inactive synapses. - - @returns real number of the permanence decrement amount for inactive synapses. - */ - Real getSynPermInactiveDec() const; - /** - Returns the permanence decrement amount for inactive synapses. - - @param synPermInactiveDec real number of the permanence decrement amount for inactive synapses. - */ - void setSynPermInactiveDec(Real synPermInactiveDec); - - /** - Returns the permanence increment amount for columns that have not been - recently active. - - @returns positive real number of the permanence increment amount for columns that have not been - recently active. - */ - Real getSynPermBelowStimulusInc() const; - /** - Sets the permanence increment amount for columns that have not been - recently active. - - @param synPermBelowStimulusInc real number of the permanence increment amount for columns that have not been - recently active, must be larger than 0. - */ - void setSynPermBelowStimulusInc(Real synPermBelowStimulusInc); - - /** - Returns the permanence amount that qualifies a synapse as - being connected. - - @returns real number of the permanence amount - that qualifies a synapse as being connected. - */ - Real getSynPermConnected() const; - /** - Sets the permanence amount that qualifies a synapse as - being connected. - - @param synPermConnected real number of the permanence amount that qualifies a synapse as - being connected. - */ - void setSynPermConnected(Real synPermConnected); - - /** - Returns the maximum permanence amount a synapse can - achieve. - - @returns real number of the max permanence amount. - */ - Real getSynPermMax() const; - /** - Sets the maximum permanence amount a synapse can - achieve. - - @param synPermCMax real number of the maximum permanence - amount that a synapse can achieve. - */ - void setSynPermMax(Real synPermMax); - - /** - Returns the minimum tolerated overlaps, given as percent of - neighbors overlap score. - - @returns real number of the minimum tolerated overlaps. - */ - Real getMinPctOverlapDutyCycles() const; - /** - Sets the minimum tolerated overlaps, given as percent of - neighbors overlap score. - - @param minPctOverlapDutyCycles real number of the minimum tolerated overlaps. - */ - void setMinPctOverlapDutyCycles(Real minPctOverlapDutyCycles); - - /** - Returns the boost factors for all columns. 'boostFactors' size must - match the number of columns. - - @param boostFactors real array to store boost factors of all columns. - */ - void getBoostFactors(Real boostFactors[]) const; - /** - Sets the boost factors for all columns. 'boostFactors' size must - match the number of columns. - - @param boostFactors real array of boost factors of all columns. - */ - void setBoostFactors(Real boostFactors[]); - - /** - Returns the overlap duty cycles for all columns. 'overlapDutyCycles' - size must match the number of columns. - - @param overlapDutyCycles real array to store overlap duty cycles for all columns. - */ - void getOverlapDutyCycles(Real overlapDutyCycles[]) const; - /** - Sets the overlap duty cycles for all columns. 'overlapDutyCycles' - size must match the number of columns. - - @param overlapDutyCycles real array of the overlap duty cycles for all columns. - */ - void setOverlapDutyCycles(Real overlapDutyCycles[]); - - /** - Returns the activity duty cycles for all columns. 'activeDutyCycles' - size must match the number of columns. - - @param activeDutyCycles real array to store activity duty cycles for all columns. - */ - void getActiveDutyCycles(Real activeDutyCycles[]) const; - /** - Sets the activity duty cycles for all columns. 'activeDutyCycles' - size must match the number of columns. - - @param activeDutyCycles real array of the activity duty cycles for all columns. - */ - void setActiveDutyCycles(Real activeDutyCycles[]); - - /** - Returns the minimum overlap duty cycles for all columns. - - @param minOverlapDutyCycles real arry to store mininum overlap duty cycles for all columns. - 'minOverlapDutyCycles' size must match the number of columns. - */ - void getMinOverlapDutyCycles(Real minOverlapDutyCycles[]) const; - /** - Sets the minimum overlap duty cycles for all columns. - '_minOverlapDutyCycles' size must match the number of columns. - - @param minOverlapDutyCycles real array of the minimum overlap duty cycles for all columns. - */ - void setMinOverlapDutyCycles(Real minOverlapDutyCycles[]); - - /** - Returns the potential mapping for a given column. 'potential' size - must match the number of inputs. - - @param column integer of column index. - - @param potential integer array of potential mapping for the selected column. - */ - void getPotential(UInt column, UInt potential[]) const; - /** - Sets the potential mapping for a given column. 'potential' size - must match the number of inputs. - - @param column integer of column index. - - @param potential integer array of potential mapping for the selected column. - */ - void setPotential(UInt column, UInt potential[]); - - /** - Returns the permanence values for a given column. 'permanence' size - must match the number of inputs. - - @param column integer of column index. - - @param permanence real array to store permanence values for the selected column. - */ - void getPermanence(UInt column, Real permanence[]) const; - /** - Sets the permanence values for a given column. 'permanence' size - must match the number of inputs. - - @param column integer of column index. - - @param permanence real array of permanence values for the selected column. - */ - void setPermanence(UInt column, Real permanence[]); - - /** - Returns the connected synapses for a given column. - 'connectedSynapses' size must match the number of inputs. - - @param column integer of column index. - - @param connectedSynapses integer array to store the connected synapses for a given column. - */ - void getConnectedSynapses(UInt column, UInt connectedSynapses[]) const; - - /** - Returns the number of connected synapses for all columns. - 'connectedCounts' size must match the number of columns. - - @param connectedCounts integer array to store the connected synapses for all columns. - */ - void getConnectedCounts(UInt connectedCounts[]) const; - - /** - Print the main SP creation parameters to stdout. - */ - void printParameters() const; - - /** - Returns the overlap score for each column. - */ - const vector& getOverlaps() const; - - /** - Returns the boosted overlap score for each column. - */ - const vector& getBoostedOverlaps() const; - - - /////////////////////////////////////////////////////////// - // - // Implementation methods. all methods below this line are - // NOT part of the public API - - void toDense_(vector& sparse, - UInt dense[], - UInt n); - - void boostOverlaps_(vector& overlaps, - vector& boostedOverlaps); - - /** - Maps a column to its respective input index, keeping to the topology of - the region. It takes the index of the column as an argument and determines - what is the index of the flattened input vector that is to be the center of - the column's potential pool. It distributes the columns over the inputs - uniformly. The return value is an integer representing the index of the - input bit. Examples of the expected output of this method: - * If the topology is one dimensional, and the column index is 0, this - method will return the input index 0. If the column index is 1, and there - are 3 columns over 7 inputs, this method will return the input index 3. - * If the topology is two dimensional, with column dimensions [3, 5] and - input dimensions [7, 11], and the column index is 3, the method - returns input index 8. - - ---------------------------- - @param index The index identifying a column in the permanence, potential - and connectivity matrices. - @param wrapAround A boolean value indicating that boundaries should be - ignored. - */ - UInt mapColumn_(UInt column); - - /** - Maps a column to its input bits. - - This method encapsultes the topology of - the region. It takes the index of the column as an argument and determines - what are the indices of the input vector that are located within the - column's potential pool. The return value is a list containing the indices - of the input bits. The current implementation of the base class only - supports a 1 dimensional topology of columns with a 1 dimensional topology - of inputs. To extend this class to support 2-D topology you will need to - override this method. Examples of the expected output of this method: - * If the potentialRadius is greater than or equal to the entire input - space, (global visibility), then this method returns an array filled with - all the indices - * If the topology is one dimensional, and the potentialRadius is 5, this - method will return an array containing 5 consecutive values centered on - the index of the column (wrapping around if necessary). - * If the topology is two dimensional (not implemented), and the - potentialRadius is 5, the method should return an array containing 25 - '1's, where the exact indices are to be determined by the mapping from - 1-D index to 2-D position. - - ---------------------------- - @param column An int index identifying a column in the permanence, potential - and connectivity matrices. - - @param wrapAround A boolean value indicating that boundaries should be - ignored. - */ - vector mapPotential_(UInt column, bool wrapAround); - - /** - Returns a randomly generated permanence value for a synapses that is - initialized in a connected state. - - The basic idea here is to initialize - permanence values very close to synPermConnected so that a small number of - learning steps could make it disconnected or connected. - - Note: experimentation was done a long time ago on the best way to initialize - permanence values, but the history for this particular scheme has been lost. - - @returns real number of a randomly generated permanence value for a synapses that is - initialized in a connected state. - */ - Real initPermConnected_(); - /** - Returns a randomly generated permanence value for a synapses that is to be - initialized in a non-connected state. - - @returns real number of a randomly generated permanence value for a synapses that is to be - initialized in a non-connected state. - */ - Real initPermNonConnected_(); - - - /** - Initializes the permanences of a column. The method - returns a 1-D array the size of the input, where each entry in the - array represents the initial permanence value between the input bit - at the particular index in the array, and the column represented by - the 'index' parameter. - - @param potential A int vector specifying the potential pool of the column. - Permanence values will only be generated for input bits - - corresponding to indices for which the mask value is 1. - @param connectedPct A real value between 0 or 1 specifying the percent of the input - bits that will start off in a connected state. - */ - vector initPermanence_(vector& potential, - Real connectedPct); - void clip_(vector& perm, bool trim); - - /** - This method updates the permanence matrix with a column's new permanence - values. - - The column is identified by its index, which reflects the row in - the matrix, and the permanence is given in 'dense' form, i.e. a full - array containing all the zeros as well as the non-zero values. It is in - charge of implementing 'clipping' - ensuring that the permanence values are - always between 0 and 1 - and 'trimming' - enforcing sparsity by zeroing out - all permanence values below '_synPermTrimThreshold'. It also maintains - the consistency between 'self._permanences' (the matrix storing the - permanence values), 'self._connectedSynapses', (the matrix storing the bits - each column is connected to), and 'self._connectedCounts' (an array storing - the number of input bits each column is connected to). Every method wishing - to modify the permanence matrix should do so through this method. - - ---------------------------- - @param perm An int vector of permanence values for a column. The array is - "dense", i.e. it contains an entry for each input bit, even - if the permanence value is 0. - - @param column An int number identifying a column in the permanence, potential - and connectivity matrices. - - @param raisePerm a boolean value indicating whether the permanence values - should be raised until a minimum number are synapses are in - a connected state. Should be set to 'false' when a direct - assignment is required. - */ - void updatePermanencesForColumn_(vector& perm, UInt column, - bool raisePerm=true); - UInt countConnected_(vector& perm); - UInt raisePermanencesToThreshold_(vector& perm, - vector& potential); - - /** - This function determines each column's overlap with the current - input vector. - - The overlap of a column is the number of synapses for that column - that are connected (permanence value is greater than - '_synPermConnected') to input bits which are turned on. The - implementation takes advantage of the SparseBinaryMatrix class to - perform this calculation efficiently. - - @param inputVector - a int array of 0's and 1's that comprises the input to the spatial - pooler. - - @param overlap - an int vector containing the overlap score for each column. The - overlap score for a column is defined as the number of synapses in - a "connected state" (connected synapses) that are connected to - input bits which are turned on. - */ - void calculateOverlap_(UInt inputVector[], - vector& overlap); - void calculateOverlapPct_(vector& overlaps, - vector& overlapPct); - - - bool isWinner_(Real score, vector >& winners, - UInt numWinners); - - void addToWinners_(UInt index, Real score, - vector >& winners); - - /** - Performs inhibition. This method calculates the necessary values needed to - actually perform inhibition and then delegates the task of picking the - active columns to helper functions. - - - @param overlaps an array containing the overlap score for each column. - The overlap score for a column is defined as the number - of synapses in a "connected state" (connected synapses) - that are connected to input bits which are turned on. - - @param activeColumns an int array containing the indices of the active columns. - */ - void inhibitColumns_( - const vector& overlaps, - vector& activeColumns); - - /** - Perform global inhibition. - - Performing global inhibition entails picking the top 'numActive' - columns with the highest overlap score in the entire region. At - most half of the columns in a local neighborhood are allowed to be - active. Columns with an overlap score below the 'stimulusThreshold' - are always inhibited. - - @param overlaps - a real array containing the overlap score for each column. The - overlap score for a column is defined as the number of synapses in - a "connected state" (connected synapses) that are connected to - input bits which are turned on. - - @param density - a real number of the fraction of columns to survive inhibition. - - @param activeColumns - an int array containing the indices of the active columns. - */ - void inhibitColumnsGlobal_( - const vector& overlaps, - Real density, - vector& activeColumns); - - /** - Performs local inhibition. - - Local inhibition is performed on a column by column basis. Each - column observes the overlaps of its neighbors and is selected if - its overlap score is within the top 'numActive' in its local - neighborhood. At most half of the columns in a local neighborhood - are allowed to be active. Columns with an overlap score below the - 'stimulusThreshold' are always inhibited. - - ---------------------------- - @param overlaps - an array containing the overlap score for each column. The overlap - score for a column is defined as the number of synapses in a - "connected state" (connected synapses) that are connected to input - bits which are turned on. - - @param density - The fraction of columns to survive inhibition. This value is only - an intended target. Since the surviving columns are picked in a - local fashion, the exact fraction of surviving columns is likely to - vary. - - @param activeColumns - an int array containing the indices of the active columns. - */ - void inhibitColumnsLocal_( - const vector& overlaps, - Real density, - vector& activeColumns); - - /** - The primary method in charge of learning. - - Adapts the permanence values of - the synapses based on the input vector, and the chosen columns after - inhibition round. Permanence values are increased for synapses connected to - input bits that are turned on, and decreased for synapses connected to - inputs bits that are turned off. - - ---------------------------- - @param inputVector an int array of 0's and 1's that comprises the input to - the spatial pooler. There exists an entry in the array - for every input bit. - - @param activeColumns an int vector containing the indices of the columns that - survived inhibition. - */ - void adaptSynapses_(UInt inputVector[], - vector& activeColumns); - - /** - This method increases the permanence values of synapses of columns whose - activity level has been too low. Such columns are identified by having an - overlap duty cycle that drops too much below those of their peers. The - permanence values for such columns are increased. - */ - void bumpUpWeakColumns_(); - - /** - Update the inhibition radius. The inhibition radius is a meausre of the - square (or hypersquare) of columns that each a column is "connected to" - on average. Since columns are not connected to each other directly, we - determine this quantity by first figuring out how many *inputs* a column is - connected to, and then multiplying it by the total number of columns that - exist for each input. For multiple dimension the aforementioned - calculations are averaged over all dimensions of inputs and columns. This - value is meaningless if global inhibition is enabled. - */ - void updateInhibitionRadius_(); - - /** - REturns the average number of columns per input, taking into account the topology - of the inputs and columns. This value is used to calculate the inhibition - radius. This function supports an arbitrary number of dimensions. If the - number of column dimensions does not match the number of input dimensions, - we treat the missing, or phantom dimensions as 'ones'. - - @returns real number of the average number of columns per input. - */ - Real avgColumnsPerInput_(); - - /** - The range of connected synapses for column. This is used to - calculate the inhibition radius. This variation of the function only - supports a 1 dimensional column topology. - - @param column An int number identifying a column in the permanence, potential - and connectivity matrices. - */ - Real avgConnectedSpanForColumn1D_(UInt column); - - /** - The range of connectedSynapses per column, averaged for each dimension. - This vaule is used to calculate the inhibition radius. This variation of - the function only supports a 2 dimensional column topology. - - @param column An int number identifying a column in the permanence, potential - and connectivity matrices. - */ - Real avgConnectedSpanForColumn2D_(UInt column); - - - /** - The range of connectedSynapses per column, averaged for each dimension. - This vaule is used to calculate the inhibition radius. This variation of - the function supports arbitrary column dimensions. - - @param column An int number identifying a column in the permanence, potential - and connectivity matrices. - */ - Real avgConnectedSpanForColumnND_(UInt column); - - /** - Updates the minimum duty cycles defining normal activity for a column. A - column with activity duty cycle below this minimum threshold is boosted. - */ - void updateMinDutyCycles_(); - - /** - Updates the minimum duty cycles in a global fashion. Sets the minimum duty - cycles for the overlap and activation of all columns to be a percent of the - maximum in the region, specified by minPctOverlapDutyCycle and - minPctActiveDutyCycle respectively. Functionally it is equivalent to - _updateMinDutyCyclesLocal, but this function exploits the globalilty of the - computation to perform it in a straightforward, and more efficient manner. - */ - void updateMinDutyCyclesGlobal_(); - - /** - Updates the minimum duty cycles. The minimum duty cycles are determined - locally. Each column's minimum duty cycles are set to be a percent of the - maximum duty cycles in the column's neighborhood. Unlike - _updateMinDutyCycles - */ - void updateMinDutyCyclesLocal_(); - - - - /** - Updates a duty cycle estimate with a new value. This is a helper - function that is used to update several duty cycle variables in - the Column class, such as: overlapDutyCucle, activeDutyCycle, - minPctDutyCycleBeforeInh, minPctDutyCycleAfterInh, etc. returns - the updated duty cycle. Duty cycles are updated according to the following - formula: - @verbatim - - - (period - 1)*dutyCycle + newValue - dutyCycle := ---------------------------------- - period - @endverbatim - - ---------------------------- - @param dutyCycles A real array containing one or more duty cycle values that need - to be updated. - - @param newValues A int vector used to update the duty cycle. - - @param period A int number indicating the period of the duty cycle - */ - static void updateDutyCyclesHelper_(vector& dutyCycles, - vector& newValues, - UInt period); - - /** - Updates the duty cycles for each column. The OVERLAP duty cycle is a moving - average of the number of inputs which overlapped with the each column. The - ACTIVITY duty cycles is a moving average of the frequency of activation for - each column. - - @param overlaps an int vector containing the overlap score for each column. - The overlap score for a column is defined as the number - of synapses in a "connected state" (connected synapses) - that are connected to input bits which are turned on. - - @param activeArray An int array containing the indices of the active columns, - the sprase set of columns which survived inhibition - */ - void updateDutyCycles_(vector& overlaps, - UInt activeArray[]); - - /** - Update the boost factors for all columns. The boost factors are used to - increase the overlap of inactive columns to improve their chances of - becoming active, and hence encourage participation of more columns in the - learning process. The boosting function is a curve defined as: - boostFactors = exp[ - boostStrength * (dutyCycle - targetDensity)] - Intuitively this means that columns that have been active at the target - activation level have a boost factor of 1, meaning their overlap is not - boosted. Columns whose active duty cycle drops too much below that of their - neighbors are boosted depending on how infrequently they have been active. - Columns that has been active more than the target activation level have - a boost factor below 1, meaning their overlap is suppressed - - The boostFactor depends on the activeDutyCycle via an exponential function: - - boostFactor - ^ - | - |\ - | \ - 1 _ | \ - | _ - | _ _ - | _ _ _ _ - +--------------------> activeDutyCycle - | - targetDensity - @endverbatim + /** + Returns the verbosity level. + + @returns integer of the verbosity level. + */ + UInt getSpVerbosity() const; + + /** + Sets the verbosity level. + + @param spVerbosity integer of verbosity level. + */ + void setSpVerbosity(UInt spVerbosity); + + /** + Returns boolean value of wrapAround which indicates if receptive + fields should wrap around from the beginning the input dimensions + to the end. + + @returns the boolean value of wrapAround. + */ + bool getWrapAround() const; + + /** + Sets wrapAround. + + @param wrapAround boolean value + */ + void setWrapAround(bool wrapAround); + + /** + Returns the update period. + + @returns integer of update period. + */ + UInt getUpdatePeriod() const; + /** + Sets the update period. + + @param updatePeriod integer of update period. + */ + void setUpdatePeriod(UInt updatePeriod); + + /** + Returns the permanence trim threshold. + + @returns real number of the permanence trim threshold. + */ + Real getSynPermTrimThreshold() const; + /** + Sets the permanence trim threshold. + + @param synPermTrimThreshold real number of the permanence trim threshold. + */ + void setSynPermTrimThreshold(Real synPermTrimThreshold); + + /** + Returns the permanence increment amount for active synapses + inputs. + + @returns real number of the permanence increment amount for active synapses + inputs. + */ + Real getSynPermActiveInc() const; + /** + Sets the permanence increment amount for active synapses + inputs. + + @param synPermActiveInc real number of the permanence increment amount + for active synapses inputs, must be >0. + */ + void setSynPermActiveInc(Real synPermActiveInc); + + /** + Returns the permanence decrement amount for inactive synapses. + + @returns real number of the permanence decrement amount for inactive synapses. + */ + Real getSynPermInactiveDec() const; + /** + Returns the permanence decrement amount for inactive synapses. + + @param synPermInactiveDec real number of the permanence decrement amount for + inactive synapses. + */ + void setSynPermInactiveDec(Real synPermInactiveDec); + + /** + Returns the permanence increment amount for columns that have not been + recently active. + + @returns positive real number of the permanence increment amount for columns + that have not been recently active. + */ + Real getSynPermBelowStimulusInc() const; + /** + Sets the permanence increment amount for columns that have not been + recently active. + + @param synPermBelowStimulusInc real number of the permanence increment amount + for columns that have not been recently active, must be larger than 0. + */ + void setSynPermBelowStimulusInc(Real synPermBelowStimulusInc); + + /** + Returns the permanence amount that qualifies a synapse as + being connected. + + @returns real number of the permanence amount + that qualifies a synapse as being connected. + */ + Real getSynPermConnected() const; + /** + Sets the permanence amount that qualifies a synapse as + being connected. + + @param synPermConnected real number of the permanence amount that qualifies a + synapse as being connected. + */ + void setSynPermConnected(Real synPermConnected); + + /** + Returns the maximum permanence amount a synapse can + achieve. + + @returns real number of the max permanence amount. + */ + Real getSynPermMax() const; + /** + Sets the maximum permanence amount a synapse can + achieve. + + @param synPermCMax real number of the maximum permanence + amount that a synapse can achieve. + */ + void setSynPermMax(Real synPermMax); + + /** + Returns the minimum tolerated overlaps, given as percent of + neighbors overlap score. + + @returns real number of the minimum tolerated overlaps. + */ + Real getMinPctOverlapDutyCycles() const; + /** + Sets the minimum tolerated overlaps, given as percent of + neighbors overlap score. + + @param minPctOverlapDutyCycles real number of the minimum tolerated overlaps. + */ + void setMinPctOverlapDutyCycles(Real minPctOverlapDutyCycles); + + /** + Returns the boost factors for all columns. 'boostFactors' size must + match the number of columns. + + @param boostFactors real array to store boost factors of all columns. + */ + void getBoostFactors(Real boostFactors[]) const; + /** + Sets the boost factors for all columns. 'boostFactors' size must + match the number of columns. + + @param boostFactors real array of boost factors of all columns. + */ + void setBoostFactors(Real boostFactors[]); + + /** + Returns the overlap duty cycles for all columns. 'overlapDutyCycles' + size must match the number of columns. + + @param overlapDutyCycles real array to store overlap duty cycles for all + columns. + */ + void getOverlapDutyCycles(Real overlapDutyCycles[]) const; + /** + Sets the overlap duty cycles for all columns. 'overlapDutyCycles' + size must match the number of columns. + + @param overlapDutyCycles real array of the overlap duty cycles for all + columns. + */ + void setOverlapDutyCycles(Real overlapDutyCycles[]); + + /** + Returns the activity duty cycles for all columns. 'activeDutyCycles' + size must match the number of columns. + + @param activeDutyCycles real array to store activity duty cycles for all + columns. + */ + void getActiveDutyCycles(Real activeDutyCycles[]) const; + /** + Sets the activity duty cycles for all columns. 'activeDutyCycles' + size must match the number of columns. + + @param activeDutyCycles real array of the activity duty cycles for all + columns. + */ + void setActiveDutyCycles(Real activeDutyCycles[]); + + /** + Returns the minimum overlap duty cycles for all columns. + + @param minOverlapDutyCycles real arry to store mininum overlap duty cycles for + all columns. 'minOverlapDutyCycles' size must match the number of columns. + */ + void getMinOverlapDutyCycles(Real minOverlapDutyCycles[]) const; + /** + Sets the minimum overlap duty cycles for all columns. + '_minOverlapDutyCycles' size must match the number of columns. + + @param minOverlapDutyCycles real array of the minimum overlap duty cycles for + all columns. + */ + void setMinOverlapDutyCycles(Real minOverlapDutyCycles[]); + + /** + Returns the potential mapping for a given column. 'potential' size + must match the number of inputs. + + @param column integer of column index. + + @param potential integer array of potential mapping for the selected column. + */ + void getPotential(UInt column, UInt potential[]) const; + /** + Sets the potential mapping for a given column. 'potential' size + must match the number of inputs. + + @param column integer of column index. + + @param potential integer array of potential mapping for the selected column. + */ + void setPotential(UInt column, UInt potential[]); + + /** + Returns the permanence values for a given column. 'permanence' size + must match the number of inputs. + + @param column integer of column index. + + @param permanence real array to store permanence values for the selected + column. + */ + void getPermanence(UInt column, Real permanence[]) const; + /** + Sets the permanence values for a given column. 'permanence' size + must match the number of inputs. + + @param column integer of column index. + + @param permanence real array of permanence values for the selected column. + */ + void setPermanence(UInt column, Real permanence[]); + + /** + Returns the connected synapses for a given column. + 'connectedSynapses' size must match the number of inputs. + + @param column integer of column index. + + @param connectedSynapses integer array to store the connected synapses for a + given column. + */ + void getConnectedSynapses(UInt column, UInt connectedSynapses[]) const; + + /** + Returns the number of connected synapses for all columns. + 'connectedCounts' size must match the number of columns. + + @param connectedCounts integer array to store the connected synapses for all + columns. + */ + void getConnectedCounts(UInt connectedCounts[]) const; + + /** + Print the main SP creation parameters to stdout. + */ + void printParameters() const; + + /** + Returns the overlap score for each column. + */ + const vector &getOverlaps() const; + + /** + Returns the boosted overlap score for each column. + */ + const vector &getBoostedOverlaps() const; + + /////////////////////////////////////////////////////////// + // + // Implementation methods. all methods below this line are + // NOT part of the public API + + void toDense_(vector &sparse, UInt dense[], UInt n); + + void boostOverlaps_(vector &overlaps, vector &boostedOverlaps); + + /** + Maps a column to its respective input index, keeping to the topology of + the region. It takes the index of the column as an argument and determines + what is the index of the flattened input vector that is to be the center of + the column's potential pool. It distributes the columns over the inputs + uniformly. The return value is an integer representing the index of the + input bit. Examples of the expected output of this method: + * If the topology is one dimensional, and the column index is 0, this + method will return the input index 0. If the column index is 1, and there + are 3 columns over 7 inputs, this method will return the input index 3. + * If the topology is two dimensional, with column dimensions [3, 5] and + input dimensions [7, 11], and the column index is 3, the method + returns input index 8. + + ---------------------------- + @param index The index identifying a column in the permanence, + potential and connectivity matrices. + @param wrapAround A boolean value indicating that boundaries should be + ignored. + */ + UInt mapColumn_(UInt column); + + /** + Maps a column to its input bits. + + This method encapsultes the topology of + the region. It takes the index of the column as an argument and determines + what are the indices of the input vector that are located within the + column's potential pool. The return value is a list containing the indices + of the input bits. The current implementation of the base class only + supports a 1 dimensional topology of columns with a 1 dimensional topology + of inputs. To extend this class to support 2-D topology you will need to + override this method. Examples of the expected output of this method: + * If the potentialRadius is greater than or equal to the entire input + space, (global visibility), then this method returns an array filled with + all the indices + * If the topology is one dimensional, and the potentialRadius is 5, this + method will return an array containing 5 consecutive values centered on + the index of the column (wrapping around if necessary). + * If the topology is two dimensional (not implemented), and the + potentialRadius is 5, the method should return an array containing 25 + '1's, where the exact indices are to be determined by the mapping from + 1-D index to 2-D position. + + ---------------------------- + @param column An int index identifying a column in the permanence, + potential and connectivity matrices. + + @param wrapAround A boolean value indicating that boundaries should be + ignored. + */ + vector mapPotential_(UInt column, bool wrapAround); + + /** + Returns a randomly generated permanence value for a synapses that is + initialized in a connected state. + + The basic idea here is to initialize + permanence values very close to synPermConnected so that a small number of + learning steps could make it disconnected or connected. + + Note: experimentation was done a long time ago on the best way to initialize + permanence values, but the history for this particular scheme has been lost. + + @returns real number of a randomly generated permanence value for a synapses + that is initialized in a connected state. + */ + Real initPermConnected_(); + /** + Returns a randomly generated permanence value for a synapses that is to be + initialized in a non-connected state. + + @returns real number of a randomly generated permanence value for a + synapses that is to be initialized in a non-connected state. + */ + Real initPermNonConnected_(); + + /** + Initializes the permanences of a column. The method + returns a 1-D array the size of the input, where each entry in the + array represents the initial permanence value between the input bit + at the particular index in the array, and the column represented by + the 'index' parameter. + + @param potential A int vector specifying the potential pool of the + column. Permanence values will only be generated for input bits + + corresponding to indices for which the mask value is 1. + @param connectedPct A real value between 0 or 1 specifying the percent of + the input bits that will start off in a connected state. + */ + vector initPermanence_(vector &potential, Real connectedPct); + void clip_(vector &perm, bool trim); + + /** + This method updates the permanence matrix with a column's new permanence + values. + + The column is identified by its index, which reflects the row in + the matrix, and the permanence is given in 'dense' form, i.e. a full + array containing all the zeros as well as the non-zero values. It is in + charge of implementing 'clipping' - ensuring that the permanence values + are always between 0 and 1 - and 'trimming' - enforcing sparsity by zeroing + out all permanence values below '_synPermTrimThreshold'. It also maintains + the consistency between 'self._permanences' (the matrix storing the + permanence values), 'self._connectedSynapses', (the matrix storing the + bits each column is connected to), and 'self._connectedCounts' (an array + storing the number of input bits each column is connected to). Every method + wishing to modify the permanence matrix should do so through this method. + + ---------------------------- + @param perm An int vector of permanence values for a column. The + array is "dense", i.e. it contains an entry for each input bit, even if the + permanence value is 0. + + @param column An int number identifying a column in the + permanence, potential and connectivity matrices. + + @param raisePerm a boolean value indicating whether the permanence + values should be raised until a minimum number are synapses are in a + connected state. Should be set to 'false' when a direct assignment is + required. + */ + void updatePermanencesForColumn_(vector &perm, UInt column, + bool raisePerm = true); + UInt countConnected_(vector &perm); + UInt raisePermanencesToThreshold_(vector &perm, + vector &potential); + + /** + This function determines each column's overlap with the current + input vector. + + The overlap of a column is the number of synapses for that column + that are connected (permanence value is greater than + '_synPermConnected') to input bits which are turned on. The + implementation takes advantage of the SparseBinaryMatrix class to + perform this calculation efficiently. + + @param inputVector + a int array of 0's and 1's that comprises the input to the spatial + pooler. + + @param overlap + an int vector containing the overlap score for each column. The + overlap score for a column is defined as the number of synapses in + a "connected state" (connected synapses) that are connected to + input bits which are turned on. + */ + void calculateOverlap_(UInt inputVector[], vector &overlap); + void calculateOverlapPct_(vector &overlaps, vector &overlapPct); + + bool isWinner_(Real score, vector> &winners, + UInt numWinners); + + void addToWinners_(UInt index, Real score, vector> &winners); + + /** + Performs inhibition. This method calculates the necessary values needed to + actually perform inhibition and then delegates the task of picking the + active columns to helper functions. + + + @param overlaps an array containing the overlap score for each + column. The overlap score for a column is defined as the number of synapses + in a "connected state" (connected synapses) that are connected to input + bits which are turned on. + + @param activeColumns an int array containing the indices of the active + columns. + */ + void inhibitColumns_(const vector &overlaps, + vector &activeColumns); + + /** + Perform global inhibition. + + Performing global inhibition entails picking the top 'numActive' + columns with the highest overlap score in the entire region. At + most half of the columns in a local neighborhood are allowed to be + active. Columns with an overlap score below the 'stimulusThreshold' + are always inhibited. + + @param overlaps + a real array containing the overlap score for each column. The + overlap score for a column is defined as the number of synapses in + a "connected state" (connected synapses) that are connected to + input bits which are turned on. + + @param density + a real number of the fraction of columns to survive inhibition. + + @param activeColumns + an int array containing the indices of the active columns. + */ + void inhibitColumnsGlobal_(const vector &overlaps, Real density, + vector &activeColumns); + + /** + Performs local inhibition. + + Local inhibition is performed on a column by column basis. Each + column observes the overlaps of its neighbors and is selected if + its overlap score is within the top 'numActive' in its local + neighborhood. At most half of the columns in a local neighborhood + are allowed to be active. Columns with an overlap score below the + 'stimulusThreshold' are always inhibited. + + ---------------------------- + @param overlaps + an array containing the overlap score for each column. The overlap + score for a column is defined as the number of synapses in a + "connected state" (connected synapses) that are connected to input + bits which are turned on. + + @param density + The fraction of columns to survive inhibition. This value is only + an intended target. Since the surviving columns are picked in a + local fashion, the exact fraction of surviving columns is likely to + vary. + + @param activeColumns + an int array containing the indices of the active columns. + */ + void inhibitColumnsLocal_(const vector &overlaps, Real density, + vector &activeColumns); + + /** + The primary method in charge of learning. + + Adapts the permanence values of + the synapses based on the input vector, and the chosen columns after + inhibition round. Permanence values are increased for synapses connected + to input bits that are turned on, and decreased for synapses connected to + inputs bits that are turned off. + + ---------------------------- + @param inputVector an int array of 0's and 1's that comprises the input + to the spatial pooler. There exists an entry in the array for every input + bit. + + @param activeColumns an int vector containing the indices of the columns + that survived inhibition. */ - void updateBoostFactors_(); - - /** - Update boost factors when local inhibition is enabled. In this case, - the target activation level for each column is estimated as the - average activation level for columns in its neighborhood. - */ - void updateBoostFactorsLocal_(); - - /** - Update boost factors when global inhibition is enabled. All columns - share the same target activation level in this case, which is the - sparsity of spatial pooler. - */ - void updateBoostFactorsGlobal_(); - - /** - Updates counter instance variables each round. - - @param learn a boolean value indicating whether learning should be - performed. Learning entails updating the permanence - values of the synapses, and hence modifying the 'state' - of the model. setting learning to 'off' might be useful - for indicating separate training vs. testing sets. - */ - void updateBookeepingVars_(bool learn); - - /** - @returns boolean value indicating whether enough rounds have passed to warrant updates of - duty cycles - */ - bool isUpdateRound_(); - - /** - Initialize the random seed - - @param seed 64bit int of random seed - */ - void seed_(UInt64 seed); - - //------------------------------------------------------------------- - // Debugging helpers - //------------------------------------------------------------------- - - /** - Print the given UInt array in a nice format - */ - void printState(vector &state); - /** - Print the given Real array in a nice format - */ - void printState(vector &state); - - protected: - UInt numInputs_; - UInt numColumns_; - vector columnDimensions_; - vector inputDimensions_; - UInt potentialRadius_; - Real potentialPct_; - Real initConnectedPct_; - bool globalInhibition_; - Int numActiveColumnsPerInhArea_; - Real localAreaDensity_; - UInt stimulusThreshold_; - UInt inhibitionRadius_; - UInt dutyCyclePeriod_; - Real boostStrength_; - UInt iterationNum_; - UInt iterationLearnNum_; - UInt spVerbosity_; - bool wrapAround_; - UInt updatePeriod_; - - Real synPermMin_; - Real synPermMax_; - Real synPermTrimThreshold_; - Real synPermInactiveDec_; - Real synPermActiveInc_; - Real synPermBelowStimulusInc_; - Real synPermConnected_; - - vector boostFactors_; - vector overlapDutyCycles_; - vector activeDutyCycles_; - vector minOverlapDutyCycles_; - vector minActiveDutyCycles_; - - Real minPctOverlapDutyCycles_; - - SparseMatrix permanences_; - SparseBinaryMatrix potentialPools_; - SparseBinaryMatrix connectedSynapses_; - vector connectedCounts_; - - vector overlaps_; - vector overlapsPct_; - vector boostedOverlaps_; - vector activeColumns_; - vector tieBreaker_; - - UInt version_; - Random rng_; - - }; - - } // end namespace spatial_pooler - } // end namespace algorithms + void adaptSynapses_(UInt inputVector[], vector &activeColumns); + + /** + This method increases the permanence values of synapses of columns whose + activity level has been too low. Such columns are identified by having an + overlap duty cycle that drops too much below those of their peers. The + permanence values for such columns are increased. + */ + void bumpUpWeakColumns_(); + + /** + Update the inhibition radius. The inhibition radius is a meausre of the + square (or hypersquare) of columns that each a column is "connected to" + on average. Since columns are not connected to each other directly, we + determine this quantity by first figuring out how many *inputs* a column + is connected to, and then multiplying it by the total number of columns + that exist for each input. For multiple dimension the aforementioned + calculations are averaged over all dimensions of inputs and columns. This + value is meaningless if global inhibition is enabled. + */ + void updateInhibitionRadius_(); + + /** + REturns the average number of columns per input, taking into account the + topology of the inputs and columns. This value is used to calculate the + inhibition radius. This function supports an arbitrary number of + dimensions. If the number of column dimensions does not match the number of + input dimensions, we treat the missing, or phantom dimensions as 'ones'. + + @returns real number of the average number of columns per input. + */ + Real avgColumnsPerInput_(); + + /** + The range of connected synapses for column. This is used to + calculate the inhibition radius. This variation of the function only + supports a 1 dimensional column topology. + + @param column An int number identifying a column in the permanence, + potential and connectivity matrices. + */ + Real avgConnectedSpanForColumn1D_(UInt column); + + /** + The range of connectedSynapses per column, averaged for each dimension. + This vaule is used to calculate the inhibition radius. This variation of + the function only supports a 2 dimensional column topology. + + @param column An int number identifying a column in the permanence, + potential and connectivity matrices. + */ + Real avgConnectedSpanForColumn2D_(UInt column); + + /** + The range of connectedSynapses per column, averaged for each dimension. + This vaule is used to calculate the inhibition radius. This variation of + the function supports arbitrary column dimensions. + + @param column An int number identifying a column in the permanence, + potential and connectivity matrices. + */ + Real avgConnectedSpanForColumnND_(UInt column); + + /** + Updates the minimum duty cycles defining normal activity for a column. A + column with activity duty cycle below this minimum threshold is boosted. + */ + void updateMinDutyCycles_(); + + /** + Updates the minimum duty cycles in a global fashion. Sets the minimum duty + cycles for the overlap and activation of all columns to be a percent of + the maximum in the region, specified by minPctOverlapDutyCycle and + minPctActiveDutyCycle respectively. Functionally it is equivalent to + _updateMinDutyCyclesLocal, but this function exploits the globalilty of + the computation to perform it in a straightforward, and more efficient + manner. + */ + void updateMinDutyCyclesGlobal_(); + + /** + Updates the minimum duty cycles. The minimum duty cycles are determined + locally. Each column's minimum duty cycles are set to be a percent of the + maximum duty cycles in the column's neighborhood. Unlike + _updateMinDutyCycles + */ + void updateMinDutyCyclesLocal_(); + + /** + Updates a duty cycle estimate with a new value. This is a helper + function that is used to update several duty cycle variables in + the Column class, such as: overlapDutyCucle, activeDutyCycle, + minPctDutyCycleBeforeInh, minPctDutyCycleAfterInh, etc. returns + the updated duty cycle. Duty cycles are updated according to the following + formula: + @verbatim + + + (period - 1)*dutyCycle + newValue + dutyCycle := ---------------------------------- + period + @endverbatim + + ---------------------------- + @param dutyCycles A real array containing one or more duty cycle + values that need to be updated. + + @param newValues A int vector used to update the duty cycle. + + @param period A int number indicating the period of the duty cycle + */ + static void updateDutyCyclesHelper_(vector &dutyCycles, + vector &newValues, UInt period); + + /** + Updates the duty cycles for each column. The OVERLAP duty cycle is a moving + average of the number of inputs which overlapped with the each column. The + ACTIVITY duty cycles is a moving average of the frequency of activation for + each column. + + @param overlaps an int vector containing the overlap score for each + column. The overlap score for a column is defined as the number of synapses in + a "connected state" (connected synapses) that are connected to input bits + which are turned on. + + @param activeArray An int array containing the indices of the active columns, + the sprase set of columns which survived inhibition + */ + void updateDutyCycles_(vector &overlaps, UInt activeArray[]); + + /** + Update the boost factors for all columns. The boost factors are used to + increase the overlap of inactive columns to improve their chances of + becoming active, and hence encourage participation of more columns in the + learning process. The boosting function is a curve defined as: + boostFactors = exp[ - boostStrength * (dutyCycle - targetDensity)] + Intuitively this means that columns that have been active at the target + activation level have a boost factor of 1, meaning their overlap is not + boosted. Columns whose active duty cycle drops too much below that of their + neighbors are boosted depending on how infrequently they have been active. + Columns that has been active more than the target activation level have + a boost factor below 1, meaning their overlap is suppressed + + The boostFactor depends on the activeDutyCycle via an exponential function: + + boostFactor + ^ + | + |\ + | \ + 1 _ | \ + | _ + | _ _ + | _ _ _ _ + +--------------------> activeDutyCycle + | + targetDensity + @endverbatim + */ + void updateBoostFactors_(); + + /** + Update boost factors when local inhibition is enabled. In this case, + the target activation level for each column is estimated as the + average activation level for columns in its neighborhood. + */ + void updateBoostFactorsLocal_(); + + /** + Update boost factors when global inhibition is enabled. All columns + share the same target activation level in this case, which is the + sparsity of spatial pooler. + */ + void updateBoostFactorsGlobal_(); + + /** + Updates counter instance variables each round. + + @param learn a boolean value indicating whether learning should be + performed. Learning entails updating the permanence + values of the synapses, and hence modifying the 'state' + of the model. setting learning to 'off' might be useful + for indicating separate training vs. testing sets. + */ + void updateBookeepingVars_(bool learn); + + /** + @returns boolean value indicating whether enough rounds have passed to warrant + updates of duty cycles + */ + bool isUpdateRound_(); + + /** + Initialize the random seed + + @param seed 64bit int of random seed + */ + void seed_(UInt64 seed); + + //------------------------------------------------------------------- + // Debugging helpers + //------------------------------------------------------------------- + + /** + Print the given UInt array in a nice format + */ + void printState(vector &state); + /** + Print the given Real array in a nice format + */ + void printState(vector &state); + +protected: + UInt numInputs_; + UInt numColumns_; + vector columnDimensions_; + vector inputDimensions_; + UInt potentialRadius_; + Real potentialPct_; + Real initConnectedPct_; + bool globalInhibition_; + Int numActiveColumnsPerInhArea_; + Real localAreaDensity_; + UInt stimulusThreshold_; + UInt inhibitionRadius_; + UInt dutyCyclePeriod_; + Real boostStrength_; + UInt iterationNum_; + UInt iterationLearnNum_; + UInt spVerbosity_; + bool wrapAround_; + UInt updatePeriod_; + + Real synPermMin_; + Real synPermMax_; + Real synPermTrimThreshold_; + Real synPermInactiveDec_; + Real synPermActiveInc_; + Real synPermBelowStimulusInc_; + Real synPermConnected_; + + vector boostFactors_; + vector overlapDutyCycles_; + vector activeDutyCycles_; + vector minOverlapDutyCycles_; + vector minActiveDutyCycles_; + + Real minPctOverlapDutyCycles_; + + SparseMatrix permanences_; + SparseBinaryMatrix potentialPools_; + SparseBinaryMatrix connectedSynapses_; + vector connectedCounts_; + + vector overlaps_; + vector overlapsPct_; + vector boostedOverlaps_; + vector activeColumns_; + vector tieBreaker_; + + UInt version_; + Random rng_; +}; + +} // end namespace spatial_pooler +} // end namespace algorithms } // end namespace nupic #endif // NTA_spatial_pooler_HPP diff --git a/src/nupic/algorithms/SvmT.hpp b/src/nupic/algorithms/SvmT.hpp index 9acf505f63..17c59b78b6 100644 --- a/src/nupic/algorithms/SvmT.hpp +++ b/src/nupic/algorithms/SvmT.hpp @@ -60,36 +60,35 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. //-------------------------------------------------------------------------------- template -float Solver::solve(int l, TQ& Q, const signed char *y_, - float *alpha_, float C, float eps, int shrinking) -{ +float Solver::solve(int l, TQ &Q, const signed char *y_, float *alpha_, + float C, float eps, int shrinking) { this->l = l; this->Q = &Q; this->QD = Q.get_QD(); this->C = C; this->eps = eps; unshrinked = false; - - p = new float [l]; + + p = new float[l]; std::fill(p, p + l, float(-1.0)); - y = new signed char [l]; + y = new signed char[l]; std::copy(y_, y_ + l, y); - alpha = new float [l]; + alpha = new float[l]; std::copy(alpha_, alpha_ + l, alpha); // initialize alpha_status { alpha_status = new int[l]; - for(int i=0;i::solve(int l, TQ& Q, const signed char *y_, { G = new float[l]; G_bar = new float[l]; - for(int i=0;i 0); - float delta = (-G[i]-G[j])/quad_coef; - float diff = alpha[i] - alpha[j]; - alpha[i] += delta; - alpha[j] += delta; - - if(diff > 0) - { - if(alpha[j] < 0) - { - alpha[j] = 0; - alpha[i] = diff; - } - } - else - { - if(alpha[i] < 0) - { - alpha[i] = 0; - alpha[j] = -diff; - } - } - if(diff > C_i - C_j) - { - if(alpha[i] > C_i) - { - alpha[i] = C_i; - alpha[j] = C_i - diff; - } - } - else - { - if(alpha[j] > C_j) - { - alpha[j] = C_j; - alpha[i] = C_j + diff; - } - } - } + break; else - { - float quad_coef = Q_i[i]+Q_j[j]-2*Q_i[j]; - if (quad_coef <= 0) - quad_coef = TAU; - NTA_ASSERT(quad_coef > 0); - float delta = (G[i]-G[j])/quad_coef; - float sum = alpha[i] + alpha[j]; - alpha[i] -= delta; - alpha[j] += delta; - - if(sum > C_i) - { - if(alpha[i] > C_i) - { - alpha[i] = C_i; - alpha[j] = sum - C_i; - } - } - else - { - if(alpha[j] < 0) - { - alpha[j] = 0; - alpha[i] = sum; - } - } - if(sum > C_j) - { - if(alpha[j] > C_j) - { - alpha[j] = C_j; - alpha[i] = sum - C_j; - } - } - else - { - if(alpha[i] < 0) - { - alpha[i] = 0; - alpha[j] = sum; - } - } - } - - // update G - float delta_alpha_i = alpha[i] - old_alpha_i; - float delta_alpha_j = alpha[j] - old_alpha_j; - - for(int k=0;k 0); + float delta = (-G[i] - G[j]) / quad_coef; + float diff = alpha[i] - alpha[j]; + alpha[i] += delta; + alpha[j] += delta; + + if (diff > 0) { + if (alpha[j] < 0) { + alpha[j] = 0; + alpha[i] = diff; + } + } else { + if (alpha[i] < 0) { + alpha[i] = 0; + alpha[j] = -diff; + } + } + if (diff > C_i - C_j) { + if (alpha[i] > C_i) { + alpha[i] = C_i; + alpha[j] = C_i - diff; + } + } else { + if (alpha[j] > C_j) { + alpha[j] = C_j; + alpha[i] = C_j + diff; + } + } + } else { + float quad_coef = Q_i[i] + Q_j[j] - 2 * Q_i[j]; + if (quad_coef <= 0) + quad_coef = TAU; + NTA_ASSERT(quad_coef > 0); + float delta = (G[i] - G[j]) / quad_coef; + float sum = alpha[i] + alpha[j]; + alpha[i] -= delta; + alpha[j] += delta; + + if (sum > C_i) { + if (alpha[i] > C_i) { + alpha[i] = C_i; + alpha[j] = sum - C_i; + } + } else { + if (alpha[j] < 0) { + alpha[j] = 0; + alpha[i] = sum; + } + } + if (sum > C_j) { + if (alpha[j] > C_j) { + alpha[j] = C_j; + alpha[i] = sum - C_j; + } + } else { + if (alpha[i] < 0) { + alpha[i] = 0; + alpha[j] = sum; + } } + } - // update alpha_status and G_bar - { - bool ui = is_upper_bound(i); - bool uj = is_upper_bound(j); - update_alpha_status(i); - update_alpha_status(j); - - if(ui != is_upper_bound(i)) - { - Q_i = Q.get_Q(i,l); - if(ui) - for(int k=0;k::solve(int l, TQ& Q, const signed char *y_, } //-------------------------------------------------------------------------------- -template -void Solver::swap_index(int i, int j) -{ - Q->swap_index(i,j); - std::swap(y[i],y[j]); - std::swap(G[i],G[j]); - std::swap(alpha_status[i],alpha_status[j]); - std::swap(alpha[i],alpha[j]); - std::swap(p[i],p[j]); - std::swap(active_set[i],active_set[j]); - std::swap(G_bar[i],G_bar[j]); +template void Solver::swap_index(int i, int j) { + Q->swap_index(i, j); + std::swap(y[i], y[j]); + std::swap(G[i], G[j]); + std::swap(alpha_status[i], alpha_status[j]); + std::swap(alpha[i], alpha[j]); + std::swap(p[i], p[j]); + std::swap(active_set[i], active_set[j]); + std::swap(G_bar[i], G_bar[j]); } //-------------------------------------------------------------------------------- -template -void Solver::reconstruct_gradient() -{ +template void Solver::reconstruct_gradient() { // reconstruct inactive elements of G from G_bar and free variables - if(active_size == l) return; + if (active_size == l) + return; - for(int i=active_size;iget_Q(i,l); - float alpha_i = alpha[i]; - for(int j=active_size;jget_Q(i, l); + float alpha_i = alpha[i]; + for (int j = active_size; j < l; j++) + G[j] += alpha_i * Q_i[j]; + } } //-------------------------------------------------------------------------------- // return 1 if already optimal, return 0 otherwise template -int Solver::select_working_set(int &out_i, int &out_j) -{ +int Solver::select_working_set(int &out_i, int &out_j) { // return i,j such that // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha) // j: minimizes the decrease of obj value // (if quadratic coefficeint <= 0, replace it with tau) // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha) - - float Gmax = -HUGE_VAL; //std::numeric_limits::max(); - float Gmax2 = -HUGE_VAL; //std::numeric_limits::max(); + + float Gmax = -HUGE_VAL; // std::numeric_limits::max(); + float Gmax2 = -HUGE_VAL; // std::numeric_limits::max(); int Gmax_idx = -1; int Gmin_idx = -1; - float obj_diff_min = HUGE_VAL; //std::numeric_limits::max(); - - for (int t=0;t= Gmax) - { - Gmax = -G[t]; - Gmax_idx = t; - } - } - else - { - if (!is_lower_bound(t)) - if (G[t] >= Gmax) - { - Gmax = G[t]; - Gmax_idx = t; - } - } + float obj_diff_min = HUGE_VAL; // std::numeric_limits::max(); + + for (int t = 0; t < active_size; t++) { + + if (y[t] == +1) { + if (!is_upper_bound(t)) + if (-G[t] >= Gmax) { + Gmax = -G[t]; + Gmax_idx = t; + } + } else { + if (!is_lower_bound(t)) + if (G[t] >= Gmax) { + Gmax = G[t]; + Gmax_idx = t; + } + } } int i = Gmax_idx; const float *Q_i = nullptr; if (i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1 - Q_i = Q->get_Q(i,active_size); + Q_i = Q->get_Q(i, active_size); NTA_ASSERT(0 <= i); - for(int j=0;j= Gmax2) - Gmax2 = G[j]; - - if (grad_diff > 0) - { - float obj_diff; - float quad_coef=Q_i[i]+QD[j]-2*y[i]*Q_i[j]; - - if (quad_coef > 0) - obj_diff = -(grad_diff*grad_diff)/quad_coef; - else - obj_diff = -(grad_diff*grad_diff)/TAU; - - if (obj_diff <= obj_diff_min) - { - Gmin_idx=j; - obj_diff_min = obj_diff; - } - } - } - } - else - { - if (!is_upper_bound(j)) - { - float grad_diff= Gmax - G[j]; - - if (-G[j] >= Gmax2) - Gmax2 = -G[j]; - - if (grad_diff > 0) - { - float obj_diff; - float quad_coef = Q_i[i]+QD[j]+2*y[i]*Q_i[j]; - - if (quad_coef > 0) - obj_diff = -(grad_diff*grad_diff)/quad_coef; - else - obj_diff = -(grad_diff*grad_diff)/TAU; - - if (obj_diff <= obj_diff_min) - { - Gmin_idx = j; - obj_diff_min = obj_diff; - } - } - } - } + for (int j = 0; j < active_size; j++) { + if (y[j] == +1) { + if (!is_lower_bound(j)) { + float grad_diff = Gmax + G[j]; + + if (G[j] >= Gmax2) + Gmax2 = G[j]; + + if (grad_diff > 0) { + float obj_diff; + float quad_coef = Q_i[i] + QD[j] - 2 * y[i] * Q_i[j]; + + if (quad_coef > 0) + obj_diff = -(grad_diff * grad_diff) / quad_coef; + else + obj_diff = -(grad_diff * grad_diff) / TAU; + + if (obj_diff <= obj_diff_min) { + Gmin_idx = j; + obj_diff_min = obj_diff; + } + } + } + } else { + if (!is_upper_bound(j)) { + float grad_diff = Gmax - G[j]; + + if (-G[j] >= Gmax2) + Gmax2 = -G[j]; + + if (grad_diff > 0) { + float obj_diff; + float quad_coef = Q_i[i] + QD[j] + 2 * y[i] * Q_i[j]; + + if (quad_coef > 0) + obj_diff = -(grad_diff * grad_diff) / quad_coef; + else + obj_diff = -(grad_diff * grad_diff) / TAU; + + if (obj_diff <= obj_diff_min) { + Gmin_idx = j; + obj_diff_min = obj_diff; + } + } + } } + } if (Gmax + Gmax2 < eps) return 1; @@ -461,140 +410,112 @@ int Solver::select_working_set(int &out_i, int &out_j) //-------------------------------------------------------------------------------- template -bool Solver::be_shrunken(int i, float Gmax1, float Gmax2) -{ - if(is_upper_bound(i)) - { - if(y[i]==+1) - return(-G[i] > Gmax1); - else - return(-G[i] > Gmax2); - } - else if(is_lower_bound(i)) - { - if(y[i]==+1) - return(G[i] > Gmax2); - else - return(G[i] > Gmax1); - } - else - return(false); +bool Solver::be_shrunken(int i, float Gmax1, float Gmax2) { + if (is_upper_bound(i)) { + if (y[i] == +1) + return (-G[i] > Gmax1); + else + return (-G[i] > Gmax2); + } else if (is_lower_bound(i)) { + if (y[i] == +1) + return (G[i] > Gmax2); + else + return (G[i] > Gmax1); + } else + return (false); } //-------------------------------------------------------------------------------- -template -void Solver::do_shrinking() -{ - float Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) } - float Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) } +template void Solver::do_shrinking() { + float Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) } + float Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) } // find maximal violating pair first - for(int i=0;i= Gmax1) - Gmax1 = -G[i]; - } - if(!is_lower_bound(i)) - { - if(G[i] >= Gmax2) - Gmax2 = G[i]; - } - } - else - { - if(!is_upper_bound(i)) - { - if(-G[i] >= Gmax2) - Gmax2 = -G[i]; - } - if(!is_lower_bound(i)) - { - if(G[i] >= Gmax1) - Gmax1 = G[i]; - } - } + for (int i = 0; i < active_size; i++) { + if (y[i] == +1) { + if (!is_upper_bound(i)) { + if (-G[i] >= Gmax1) + Gmax1 = -G[i]; + } + if (!is_lower_bound(i)) { + if (G[i] >= Gmax2) + Gmax2 = G[i]; + } + } else { + if (!is_upper_bound(i)) { + if (-G[i] >= Gmax2) + Gmax2 = -G[i]; + } + if (!is_lower_bound(i)) { + if (G[i] >= Gmax1) + Gmax1 = G[i]; + } } + } // shrink - for(int i=0;i i) - { - if (!be_shrunken(active_size, Gmax1, Gmax2)) - { - swap_index(i,active_size); - break; - } - active_size--; - } + for (int i = 0; i < active_size; i++) + if (be_shrunken(i, Gmax1, Gmax2)) { + active_size--; + while (active_size > i) { + if (!be_shrunken(active_size, Gmax1, Gmax2)) { + swap_index(i, active_size); + break; + } + active_size--; } + } // unshrink, check all variables again before final iterations - if(unshrinked || Gmax1 + Gmax2 > eps*10) return; - + if (unshrinked || Gmax1 + Gmax2 > eps * 10) + return; + unshrinked = true; reconstruct_gradient(); - for(int i=l-1;i>=active_size;i--) - if (!be_shrunken(i, Gmax1, Gmax2)) - { - while (active_size < i) - { - if (be_shrunken(active_size, Gmax1, Gmax2)) - { - swap_index(i,active_size); - break; - } - active_size++; - } - active_size++; + for (int i = l - 1; i >= active_size; i--) + if (!be_shrunken(i, Gmax1, Gmax2)) { + while (active_size < i) { + if (be_shrunken(active_size, Gmax1, Gmax2)) { + swap_index(i, active_size); + break; + } + active_size++; } + active_size++; + } } //-------------------------------------------------------------------------------- -template -float Solver::calculate_rho() -{ +template float Solver::calculate_rho() { float r; int nr_free = 0; float ub = INF, lb = -INF, sum_free = 0; - for(int i=0;i0) - r = sum_free/nr_free; + if (nr_free > 0) + r = sum_free / nr_free; else - r = (ub+lb)/2; + r = (ub + lb) / 2; return r; } @@ -604,133 +525,126 @@ float Solver::calculate_rho() //-------------------------------------------------------------------------------- // Platt's binary SVM Probablistic Output: an improvement from Lin et al. template -void svm::sigmoid_train(int l, - const Vector& dec_values, - const Vector& labels, - float& A, float& B) -{ - float prior1=0, prior0 = 0; - - for (int i=0;i 0) - prior1+=1; - else - prior0+=1; - - int max_iter=100; // Maximal number of iterations - float min_step=float(1e-10); // Minimal step taken in line search - float sigma=float(1e-3); // For numerically strict PD of Hessian - float eps=float(1e-5); - float hiTarget=(prior1+float(1.0))/(prior1+float(2.0)); - float loTarget=float(1.0)/(prior0+float(2.0)); +void svm::sigmoid_train(int l, const Vector &dec_values, + const Vector &labels, float &A, float &B) { + float prior1 = 0, prior0 = 0; + + for (int i = 0; i < l; i++) + if (labels[i] > 0) + prior1 += 1; + else + prior0 += 1; + + int max_iter = 100; // Maximal number of iterations + float min_step = float(1e-10); // Minimal step taken in line search + float sigma = float(1e-3); // For numerically strict PD of Hessian + float eps = float(1e-5); + float hiTarget = (prior1 + float(1.0)) / (prior1 + float(2.0)); + float loTarget = float(1.0) / (prior0 + float(2.0)); Vector t(l); - float fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize; - float newA,newB,newf,d1,d2; - + float fApB, p, q, h11, h22, h21, g1, g2, det, dA, dB, gd, stepsize; + float newA, newB, newf, d1, d2; + // Initial Point and Initial Fun Value - A=0.0; B=log((prior0+float(1.0))/(prior1+float(1.0))); + A = 0.0; + B = log((prior0 + float(1.0)) / (prior1 + float(1.0))); float fval = 0.0; - for (int i=0;i0) t[i]=hiTarget; - else t[i]=loTarget; - fApB = dec_values[i]*A+B; - if (fApB>=0) - fval += t[i]*fApB + log(float(1.0)+exp(-fApB)); - else - fval += (t[i] - float(1.0))*fApB +log(float(1.0)+exp(fApB)); + for (int i = 0; i < l; i++) { + if (labels[i] > 0) + t[i] = hiTarget; + else + t[i] = loTarget; + fApB = dec_values[i] * A + B; + if (fApB >= 0) + fval += t[i] * fApB + log(float(1.0) + exp(-fApB)); + else + fval += (t[i] - float(1.0)) * fApB + log(float(1.0) + exp(fApB)); + } + + for (int iter = 0; iter < max_iter; iter++) { + // Update Gradient and Hessian (use H' = H + sigma I) + h11 = sigma; // numerically ensures strict PD + h22 = sigma; + h21 = 0.0f; + g1 = 0.0f; + g2 = 0.0f; + for (int i = 0; i < l; i++) { + fApB = dec_values[i] * A + B; + if (fApB >= 0) { + p = exp(-fApB) / (1.0f + exp(-fApB)); + q = 1.0f / (1.0f + exp(-fApB)); + } else { + p = 1.0f / (1.0f + exp(fApB)); + q = exp(fApB) / (1.0f + exp(fApB)); + } + d2 = p * q; + h11 += dec_values[i] * dec_values[i] * d2; + h22 += d2; + h21 += dec_values[i] * d2; + d1 = t[i] - p; + g1 += dec_values[i] * d1; + g2 += d1; } - - for (int iter=0;iter= 0) - { - p=exp(-fApB)/(1.0f+exp(-fApB)); - q=1.0f/(1.0f+exp(-fApB)); - } - else - { - p=1.0f/(1.0f+exp(fApB)); - q=exp(fApB)/(1.0f+exp(fApB)); - } - d2=p*q; - h11+=dec_values[i]*dec_values[i]*d2; - h22+=d2; - h21+=dec_values[i]*d2; - d1=t[i]-p; - g1+=dec_values[i]*d1; - g2+=d1; - } - - // Stopping Criteria - if (fabs(g1)= min_step) - { - newA = A + stepsize * dA; - newB = B + stepsize * dB; - - // New function value - newf = 0.0; - for (int i=0;i= 0) - newf += t[i]*fApB + log(1.0f+exp(-fApB)); - else - newf += (t[i] - 1.0f)*fApB +log(1.0f+exp(fApB)); - } - // Check sufficient decrease - if (newf= min_step) { + newA = A + stepsize * dA; + newB = B + stepsize * dB; + + // New function value + newf = 0.0; + for (int i = 0; i < l; i++) { + fApB = dec_values[i] * newA + newB; + if (fApB >= 0) + newf += t[i] * fApB + log(1.0f + exp(-fApB)); + else + newf += (t[i] - 1.0f) * fApB + log(1.0f + exp(fApB)); + } + // Check sufficient decrease + if (newf < fval + 0.0001f * stepsize * gd) { + A = newA; + B = newB; + fval = newf; + break; + } else + stepsize = stepsize / 2.0f; } + + if (stepsize < min_step) { + break; + } + } } //-------------------------------------------------------------------------------- template -inline float svm::sigmoid_predict(float decision_value, float A, float B) -{ - float fApB = decision_value*A+B; +inline float svm::sigmoid_predict(float decision_value, float A, + float B) { + float fApB = decision_value * A + B; if (fApB >= 0) - return exp(-fApB)/(1.0f+exp(-fApB)); + return exp(-fApB) / (1.0f + exp(-fApB)); else - return 1.0f/(1.0f+exp(fApB)); + return 1.0f / (1.0f + exp(fApB)); } //-------------------------------------------------------------------------------- template -inline float svm::rbf_function(float* x, float* x_end, float* y) const -{ +inline float svm::rbf_function(float *x, float *x_end, float *y) const { float sum = 0; -#if defined(NTA_ASM) && defined(NTA_ARCH_32) && defined(NTA_OS_WINDOWS) && defined(NTA_COMPILER_MSVC) +#if defined(NTA_ASM) && defined(NTA_ARCH_32) && defined(NTA_OS_WINDOWS) && \ + defined(NTA_COMPILER_MSVC) if (with_sse) { @@ -763,57 +677,57 @@ inline float svm::rbf_function(float* x, float* x_end, float* y) const haddps xmm1, xmm4 haddps xmm1, xmm4 movss sum, xmm1 - } + } } else { // no sse while (x != x_end) { float d = *x - *y; sum += d * d; - ++x; ++y; + ++x; + ++y; } } #elif defined(NTA_ASM) && defined(NTA_ARCH_32) && defined(NTA_OS_DARWIN) if (with_sse) { - - asm( - "xorps %%xmm4,%%xmm4\n\t" // only contains zeros on purpose - "xorps %%xmm1,%%xmm1\n\t" - "xorps %%xmm3,%%xmm3\n\t" - - "0:\t\n" - "movaps (%%esi), %%xmm0\n\t" - "movaps 16(%%esi), %%xmm2\n\t" - "subps (%%edi), %%xmm0\n\t" - "subps 16(%%edi), %%xmm2\n\t" - "mulps %%xmm0, %%xmm0\n\t" - "mulps %%xmm2, %%xmm2\n\t" - "addps %%xmm0, %%xmm1\n\t" - "addps %%xmm2, %%xmm3\n\t" - - "addl $32, %%esi\n\t" - "addl $32, %%edi\n\t" - "cmpl %1, %%esi\n\t" - "jne 0b\n\t" - - "addps %%xmm3, %%xmm1\n\t" - "haddps %%xmm4, %%xmm1\n\t" - "haddps %%xmm4, %%xmm1\n\t" - "movss %%xmm1, %0\n\t" - - : "=m" (sum) - : "m" (x_end), "S" (x), "D" (y) - : - ); + + asm("xorps %%xmm4,%%xmm4\n\t" // only contains zeros on purpose + "xorps %%xmm1,%%xmm1\n\t" + "xorps %%xmm3,%%xmm3\n\t" + + "0:\t\n" + "movaps (%%esi), %%xmm0\n\t" + "movaps 16(%%esi), %%xmm2\n\t" + "subps (%%edi), %%xmm0\n\t" + "subps 16(%%edi), %%xmm2\n\t" + "mulps %%xmm0, %%xmm0\n\t" + "mulps %%xmm2, %%xmm2\n\t" + "addps %%xmm0, %%xmm1\n\t" + "addps %%xmm2, %%xmm3\n\t" + + "addl $32, %%esi\n\t" + "addl $32, %%edi\n\t" + "cmpl %1, %%esi\n\t" + "jne 0b\n\t" + + "addps %%xmm3, %%xmm1\n\t" + "haddps %%xmm4, %%xmm1\n\t" + "haddps %%xmm4, %%xmm1\n\t" + "movss %%xmm1, %0\n\t" + + : "=m"(sum) + : "m"(x_end), "S"(x), "D"(y) + :); } else { // no sse while (x != x_end) { float d = *x - *y; sum += d * d; - ++x; ++y; + ++x; + ++y; } } @@ -822,190 +736,186 @@ inline float svm::rbf_function(float* x, float* x_end, float* y) const while (x != x_end) { float d = *x - *y; sum += d * d; - ++x; ++y; + ++x; + ++y; } #endif - - return exp(-param_.gamma*sum); + + return exp(-param_.gamma * sum); } //-------------------------------------------------------------------------------- template -inline float svm::linear_function(float* x, float* x_end, float* y) const -{ +inline float svm::linear_function(float *x, float *x_end, + float *y) const { float sum = 0; while (x != x_end) { sum += *x * *y; - ++x; ++y; + ++x; + ++y; } - + return sum; } //-------------------------------------------------------------------------------- template -inline void -svm::multiclass_probability(Matrix& pairwise_proba, Vector& prob_estimates) -{ +inline void svm::multiclass_probability(Matrix &pairwise_proba, + Vector &prob_estimates) { int n_class = pairwise_proba.nrows(), max_iter = std::max(100, n_class); Matrix Q(n_class, n_class); Vector Qp(n_class); - float pQp, eps = float(0.005)/float(n_class); + float pQp, eps = float(0.005) / float(n_class); for (int t = 0; t < n_class; ++t) { - prob_estimates[t] = float(1.0)/float(n_class); // Valid if n_class = 1 - Q(t,t)=0; + prob_estimates[t] = float(1.0) / float(n_class); // Valid if n_class = 1 + Q(t, t) = 0; - for (int j=0;jmax_error) - max_error=error; + float max_error = 0; + for (int t = 0; t < n_class; t++) { + float error = fabs(Qp[t] - pQp); + if (error > max_error) + max_error = error; } - if (max_error -void -svm::binary_probability(const problem_type& prob, float& probA, float& probB) -{ +void svm::binary_probability(const problem_type &prob, float &probA, + float &probB) { int nr_fold = 5, l = prob.size(), n_dims = prob.n_dims(); std::vector perm(l); Vector dec_values(l); // random shuffle - for(int i=0;i0) - p_count++; + int p_count = 0, n_count = 0; + for (int j = 0; j < k; j++) + if (sub_prob.y_[j] > 0) + p_count++; else - n_count++; - - if(p_count==0 && n_count==0) - for(int j=begin;j 0 && n_count == 0) - for(int j=begin;j 0) - for(int j=begin;j 0 && n_count == 0) + for (int j = begin; j < end; j++) + dec_values[perm[j]] = 1; + else if (p_count == 0 && n_count > 0) + for (int j = begin; j < end; j++) + dec_values[perm[j]] = -1; else { - svm_parameter sub_param(param_.kernel, - false, - param_.gamma, - 1.0, //param_.C, HERE - param_.eps, - param_.cache_size, - param_.shrinking); + svm_parameter sub_param(param_.kernel, false, param_.gamma, + 1.0, // param_.C, HERE + param_.eps, param_.cache_size, param_.shrinking); sub_param.weight_label.resize(2); sub_param.weight.resize(2); - sub_param.weight_label[0]=+1; - sub_param.weight_label[1]=-1; - sub_param.weight[0]=param_.C; - sub_param.weight[1]=param_.C; + sub_param.weight_label[0] = +1; + sub_param.weight_label[1] = -1; + sub_param.weight[0] = param_.C; + sub_param.weight[1] = param_.C; svm_model *sub_model = train(sub_prob, sub_param); #if defined(NTA_OS_WINDOWS) && defined(NTA_COMPILER_MSVC) - float* x_tmp = (float*) _aligned_malloc(4*prob.n_dims(), 16); + float *x_tmp = (float *)_aligned_malloc(4 * prob.n_dims(), 16); #else auto x_tmp = new float[prob.n_dims()]; #endif - for(int j=begin;jlabel[0]; - } + for (int j = begin; j < end; j++) { + prob.dense(perm[j], x_tmp); + float val; + predict_values(*sub_model, x_tmp, &val); + // ensure +1 -1 order; reason not using CV subroutine + dec_values[perm[j]] = val * sub_model->label[0]; + } #if defined(NTA_OS_WINDOWS) && defined(NTA_COMPILER_MSVC) _aligned_free(x_tmp); #else - delete [] x_tmp; + delete[] x_tmp; #endif delete sub_model; } - } - + } + sigmoid_train(l, dec_values, prob.y_, probA, probB); } //-------------------------------------------------------------------------------- template -void svm::group_classes(const problem_type& prob, - std::vector& label, - std::vector& start, - std::vector& count, - std::vector& perm) -{ +void svm::group_classes(const problem_type &prob, + std::vector &label, + std::vector &start, + std::vector &count, + std::vector &perm) { int l = prob.size(), n_class = 0; std::vector data_label(l); @@ -1016,10 +926,10 @@ void svm::group_classes(const problem_type& prob, for (j = 0; j < n_class; ++j) if (this_label == label[j]) { - ++count[j]; - break; + ++count[j]; + break; } - + data_label[i] = j; if (j == n_class) { @@ -1030,10 +940,10 @@ void svm::group_classes(const problem_type& prob, } start.resize(n_class); - + start[0] = 0; for (int i = 1; i < n_class; ++i) - start[i] = start[i-1]+count[i-1]; + start[i] = start[i - 1] + count[i - 1]; for (int i = 0; i < l; ++i) { perm[start[data_label[i]]] = i; @@ -1042,22 +952,22 @@ void svm::group_classes(const problem_type& prob, start[0] = 0; for (int i = 1; i < n_class; ++i) - start[i] = start[i-1]+count[i-1]; + start[i] = start[i - 1] + count[i - 1]; } //-------------------------------------------------------------------------------- template -svm_model* svm::train(const problem_type& prob, const svm_parameter& param) -{ +svm_model *svm::train(const problem_type &prob, + const svm_parameter ¶m) { int l = prob.size(), n_dims = prob.n_dims(); std::vector label, count, start, perm(l); // svm_train group_classes(prob, label, start, count, perm); - int n_class = (int) label.size(); + int n_class = (int)label.size(); // train k*(k-1)/2 models - size_t m = n_class*(n_class-1)/2; + size_t m = n_class * (n_class - 1) / 2; std::vector nonzero(l, false); std::vector f(m); @@ -1070,68 +980,68 @@ svm_model* svm::train(const problem_type& prob, const svm_parameter& par int p = 0; for (int i = 0; i < n_class; ++i) { - for (int j = i+1; j < n_class; ++j, ++p) { - + for (int j = i + 1; j < n_class; ++j, ++p) { + int si = start[i], sj = start[j]; int ci = count[i], cj = count[j]; - int sub_prob_size = ci+cj; - + int sub_prob_size = ci + cj; + problem_type sub_prob(n_dims, sub_prob_size, false); for (int k = 0; k < ci; ++k) { - sub_prob.set_sample(k, prob.get_sample(perm[si+k])); - sub_prob.y_[k] = +1; + sub_prob.set_sample(k, prob.get_sample(perm[si + k])); + sub_prob.y_[k] = +1; } - + for (int k = 0; k < cj; ++k) { - sub_prob.set_sample(ci+k, prob.get_sample(perm[sj+k])); - sub_prob.y_[ci+k] = -1; + sub_prob.set_sample(ci + k, prob.get_sample(perm[sj + k])); + sub_prob.y_[ci + k] = -1; } // binary svc probability if (param.probability) - binary_probability(sub_prob, model->probA[p], model->probB[p]); + binary_probability(sub_prob, model->probA[p], model->probB[p]); // solve_c_svc - auto alpha = new float [sub_prob_size]; + auto alpha = new float[sub_prob_size]; std::fill(alpha, alpha + sub_prob_size, float(0)); auto y = new signed char[l]; - for (int k = 0; k < sub_prob_size; ++k) - y[k] = sub_prob.y_[k] > 0 ? +1 : -1; + for (int k = 0; k < sub_prob_size; ++k) + y[k] = sub_prob.y_[k] > 0 ? +1 : -1; - q_matrix_type q(sub_prob, param.gamma, param.kernel, param.cache_size); + q_matrix_type q(sub_prob, param.gamma, param.kernel, param.cache_size); Solver s; - //param.print(); - //sub_prob.print(); - - float rho = - s.solve(sub_prob_size, q, y, alpha, param.C, param.eps, param.shrinking); - + // param.print(); + // sub_prob.print(); + + float rho = s.solve(sub_prob_size, q, y, alpha, param.C, param.eps, + param.shrinking); + for (int k = 0; k < sub_prob_size; ++k) - alpha[k] *= y[k]; - + alpha[k] *= y[k]; + f[p].alpha = alpha; f[p].rho = rho; - + for (int k = 0; k < ci; ++k) - if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0) - nonzero[si+k] = true; - + if (!nonzero[si + k] && fabs(f[p].alpha[k]) > 0) + nonzero[si + k] = true; + for (int k = 0; k < cj; ++k) - if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0) - nonzero[sj+k] = true; + if (!nonzero[sj + k] && fabs(f[p].alpha[ci + k]) > 0) + nonzero[sj + k] = true; - delete [] y; - } + delete[] y; + } } // finish building model model->label.resize(n_class); for (int i = 0; i < n_class; ++i) model->label[i] = label[i]; - + model->rho.resize(m); for (size_t i = 0; i < m; ++i) model->rho[i] = f[i].rho; @@ -1139,25 +1049,25 @@ svm_model* svm::train(const problem_type& prob, const svm_parameter& par int total_sv = 0; std::vector nz_count(n_class); model->n_sv.resize(n_class); - + for (int i = 0; i < n_class; ++i) { int n_sv = 0; - for(int j=0;jn_sv[i] = n_sv; nz_count[i] = n_sv; } - + model->n_dims_ = n_dims; for (int i = 0; i != l; ++i) if (nonzero[i]) { #if defined(NTA_OS_WINDOWS) && defined(NTA_COMPILER_MSVC) - float* new_sv = (float*) _aligned_malloc(4*n_dims, 16); + float *new_sv = (float *)_aligned_malloc(4 * n_dims, 16); #else auto new_sv = new float[n_dims]; #endif @@ -1169,15 +1079,15 @@ svm_model* svm::train(const problem_type& prob, const svm_parameter& par std::vector nz_start(n_class); nz_start[0] = 0; for (int i = 1; i < n_class; ++i) - nz_start[i] = nz_start[i-1] + nz_count[i-1]; + nz_start[i] = nz_start[i - 1] + nz_count[i - 1]; - model->sv_coef.resize(n_class-1); - for (int i = 0; i < n_class-1; ++i) - model->sv_coef[i] = new float [total_sv]; + model->sv_coef.resize(n_class - 1); + for (int i = 0; i < n_class - 1; ++i) + model->sv_coef[i] = new float[total_sv]; p = 0; for (int i = 0; i < n_class; ++i) { - for(int j = i+1; j < n_class; ++j, ++p) { + for (int j = i + 1; j < n_class; ++j, ++p) { // classifier (i,j): coefficients with // i are in sv_coef[j-1][nz_start[i]...], @@ -1185,16 +1095,16 @@ svm_model* svm::train(const problem_type& prob, const svm_parameter& par int si = start[i], sj = start[j]; int ci = count[i], cj = count[j]; - + int q = nz_start[i]; for (int k = 0; k < ci; ++k) - if (nonzero[si+k]) - model->sv_coef[j-1][q++] = f[p].alpha[k]; + if (nonzero[si + k]) + model->sv_coef[j - 1][q++] = f[p].alpha[k]; q = nz_start[j]; for (int k = 0; k < cj; ++k) - if (nonzero[sj+k]) - model->sv_coef[i][q++] = f[p].alpha[ci+k]; + if (nonzero[sj + k]) + model->sv_coef[i][q++] = f[p].alpha[ci + k]; } } @@ -1205,31 +1115,31 @@ svm_model* svm::train(const problem_type& prob, const svm_parameter& par if (param.kernel == 0) { // Linear kernel only model->w.resize(m); - for (size_t i = 0; i != m; ++i) + for (size_t i = 0; i != m; ++i) model->w[i].resize(n_dims); p = 0; - + for (int i = 0; i < n_class; ++i) { - for(int j = i+1; j < n_class; ++j, ++p) { - - int si = nz_start[i], sj = nz_start[j]; - int ci = model->n_sv[i], cj = model->n_sv[j]; - float *coef1 = model->sv_coef[j-1], *coef2 = model->sv_coef[i]; + for (int j = i + 1; j < n_class; ++j, ++p) { + + int si = nz_start[i], sj = nz_start[j]; + int ci = model->n_sv[i], cj = model->n_sv[j]; + float *coef1 = model->sv_coef[j - 1], *coef2 = model->sv_coef[i]; - for (int dim = 0; dim != n_dims; ++dim) { + for (int dim = 0; dim != n_dims; ++dim) { - float sum = 0; - for (int k = 0; k < ci; ++k) { - sum += coef1[si+k] * (model->sv[si+k])[dim]; - } + float sum = 0; + for (int k = 0; k < ci; ++k) { + sum += coef1[si + k] * (model->sv[si + k])[dim]; + } - for (int k = 0; k < cj; ++k) { - sum += coef2[sj+k] * (model->sv[sj+k])[dim]; - } + for (int k = 0; k < cj; ++k) { + sum += coef2[sj + k] * (model->sv[sj + k])[dim]; + } - model->w[p][dim] = sum; - } + model->w[p][dim] = sum; + } } } } @@ -1238,67 +1148,66 @@ svm_model* svm::train(const problem_type& prob, const svm_parameter& par } //-------------------------------------------------------------------------------- -template -void svm::predict_values(const svm_model& model, float* x, float* dec_values) -{ +template +void svm::predict_values(const svm_model &model, float *x, + float *dec_values) { int n_class = model.n_class(), l = model.size(); Vector kvalue(l); if (param_.kernel == 0) { - for (int i=0;i start(n_class); start[0] = 0; - for(int i=1;i template -float svm::predict(const svm_model& model, InIter x) -{ +template +template +float svm::predict(const svm_model &model, InIter x) { int n_class = model.n_class(), n_dims = model.n_dims(); if (dec_values_ == nullptr) { - dec_values_ = new float [n_class*(n_class-1)/2]; + dec_values_ = new float[n_class * (n_class - 1) / 2]; #if defined(NTA_OS_WINDOWS) && defined(NTA_COMPILER_MSVC) - x_tmp_ = (float*) _aligned_malloc(4*n_dims, 16); + x_tmp_ = (float *)_aligned_malloc(4 * n_dims, 16); #else - x_tmp_ = new float [n_dims]; + x_tmp_ = new float[n_dims]; #endif - } std::copy(x, x + n_dims, x_tmp_); @@ -1306,42 +1215,41 @@ float svm::predict(const svm_model& model, InIter x) predict_values(model, x_tmp_, dec_values_); std::vector vote(n_class, 0); - - int pos=0; - for(int i=0;i 0) - ++vote[i]; - else - ++vote[j]; - } - + + int pos = 0; + for (int i = 0; i < n_class; i++) + for (int j = i + 1; j < n_class; j++) { + if (dec_values_[pos++] > 0) + ++vote[i]; + else + ++vote[j]; + } + int vote_max_idx = 0; - for(int i=1;i vote[vote_max_idx]) + for (int i = 1; i < n_class; i++) + if (vote[i] > vote[vote_max_idx]) vote_max_idx = i; - - return (float) model.label[vote_max_idx]; + + return (float)model.label[vote_max_idx]; } //-------------------------------------------------------------------------------- // model is the same between predict and predict_probability // predict_values comes out the same in predict and predict_probability -template template -float svm::predict_probability(const svm_model& model, InIter x, OutIter proba) -{ +template +template +float svm::predict_probability(const svm_model &model, InIter x, + OutIter proba) { int n_class = model.n_class(), n_dims = model.n_dims(); if (dec_values_ == nullptr) { - dec_values_ = new float [n_class*(n_class-1)/2]; + dec_values_ = new float[n_class * (n_class - 1) / 2]; #if defined(NTA_OS_WINDOWS) && defined(NTA_COMPILER_MSVC) - x_tmp_ = (float*) _aligned_malloc(4*n_dims, 16); + x_tmp_ = (float *)_aligned_malloc(4 * n_dims, 16); #else - x_tmp_ = new float [n_dims]; + x_tmp_ = new float[n_dims]; #endif - } std::copy(x, x + n_dims, x_tmp_); @@ -1349,17 +1257,18 @@ float svm::predict_probability(const svm_model& model, InIter x, OutIter if (param_.probability) { predict_values(model, x_tmp_, dec_values_); - + float min_prob = float(1e-7); Matrix pairwise_proba(n_class, n_class); int k = 0; for (int i = 0; i < n_class; ++i) { - pairwise_proba(i,i) = 0; - for (int j = i+1; j < n_class; ++j, ++k) { - float v = sigmoid_predict(dec_values_[k], model.probA[k], model.probB[k]); - pairwise_proba(i,j) = std::min(std::max(v, min_prob), 1-min_prob); - pairwise_proba(j,i) = 1-pairwise_proba(i,j); + pairwise_proba(i, i) = 0; + for (int j = i + 1; j < n_class; ++j, ++k) { + float v = + sigmoid_predict(dec_values_[k], model.probA[k], model.probB[k]); + pairwise_proba(i, j) = std::min(std::max(v, min_prob), 1 - min_prob); + pairwise_proba(j, i) = 1 - pairwise_proba(i, j); } } @@ -1370,21 +1279,19 @@ float svm::predict_probability(const svm_model& model, InIter x, OutIter int prob_max_idx = 0; for (int i = 0; i < n_class; ++i) if (proba_estimates[i] > proba_estimates[prob_max_idx]) - prob_max_idx = i; - - return (float) model.label[prob_max_idx]; - + prob_max_idx = i; + + return (float)model.label[prob_max_idx]; + } else { return predict(model, x); } } //-------------------------------------------------------------------------------- -template -float svm::cross_validation(int nr_fold) -{ +template float svm::cross_validation(int nr_fold) { int l = problem_->size(); - std::vector fold_start(nr_fold+1), perm(l); + std::vector fold_start(nr_fold + 1), perm(l); // stratified cv may not give leave-one-out rate // Each class to l folds -> some folds may have zero elements @@ -1392,74 +1299,69 @@ float svm::cross_validation(int nr_fold) std::vector start, label, count; group_classes(*problem_, label, start, count, perm); - int n_class = (int) label.size(); + int n_class = (int)label.size(); // random shuffle and then data grouped by fold using the array perm std::vector fold_count(nr_fold), index(l); - - for(int i=0;in_dims(), false); - + if ((end - begin) != l) { - sub_prob.resize(l-(end-begin)); - - int k=0; - for(int j=0;jget_sample(perm[j])); - - for(int j=end;jget_sample(perm[j])); - + sub_prob.resize(l - (end - begin)); + + int k = 0; + for (int j = 0; j < begin; j++, ++k) + sub_prob.set_sample(k, problem_->get_sample(perm[j])); + + for (int j = end; j < l; j++, ++k) + sub_prob.set_sample(k, problem_->get_sample(perm[j])); + } else { sub_prob.resize(l); @@ -1467,44 +1369,43 @@ float svm::cross_validation(int nr_fold) // In the case where this only one fold, the sub problem // becomes the whole problem for (int j = 0; j < l; ++j) - sub_prob.set_sample(j, problem_->get_sample(perm[j])); + sub_prob.set_sample(j, problem_->get_sample(perm[j])); } - + svm_model *sub_model = train(sub_prob, param_); auto x_tmp = new float[problem_->n_dims()]; if (param_.probability) { - + std::vector proba_estimates(sub_model->n_class()); - for(int j=begin;jdense(perm[j], x_tmp); - float p = predict_probability(*sub_model, x_tmp, proba_estimates.begin()); - if (p == problem_->y_[perm[j]]) - success += 1.0; + for (int j = begin; j < end; j++) { + problem_->dense(perm[j], x_tmp); + float p = + predict_probability(*sub_model, x_tmp, proba_estimates.begin()); + if (p == problem_->y_[perm[j]]) + success += 1.0; } - + } else { - for(int j=begin;jdense(perm[j], x_tmp); - float p = predict(*sub_model, x_tmp); - if (p == problem_->y_[perm[j]]) - success += 1.0; + for (int j = begin; j < end; j++) { + problem_->dense(perm[j], x_tmp); + float p = predict(*sub_model, x_tmp); + if (p == problem_->y_[perm[j]]) + success += 1.0; } } - - delete [] x_tmp; + + delete[] x_tmp; delete sub_model; - } + } return success / float(problem_->size()); } //-------------------------------------------------------------------------------- -template -int svm::persistent_size() const -{ +template int svm::persistent_size() const { int n = 6 + param_.persistent_size(); if (problem_) @@ -1518,8 +1419,7 @@ int svm::persistent_size() const //-------------------------------------------------------------------------------- template -void svm::save(std::ostream& outStream) const -{ +void svm::save(std::ostream &outStream) const { param_.save(outStream); if (problem_) { @@ -1532,21 +1432,19 @@ void svm::save(std::ostream& outStream) const if (model_) { outStream << " 1 "; model_->save(outStream); - } else { + } else { outStream << " 0 "; } } //-------------------------------------------------------------------------------- -template -void svm::load(std::istream& inStream) -{ +template void svm::load(std::istream &inStream) { param_.load(inStream); int problemSaved = 0, modelSaved = 0; inStream >> problemSaved; - + if (problemSaved == 1) { delete problem_; problem_ = new problem_type(inStream); @@ -1562,10 +1460,10 @@ void svm::load(std::istream& inStream) // Can't assert that, because problem might not get loaded, to save // space! - //NTA_ASSERT(model_->n_dims() == problem_->n_dims()); - + // NTA_ASSERT(model_->n_dims() == problem_->n_dims()); + // Recompute with_sse flag based on possibly new dims of the model - // just loaded (not the problem, since the problem might not be + // just loaded (not the problem, since the problem might not be // loaded to save space, but the model always will). with_sse = checkSSE(); } diff --git a/src/nupic/algorithms/TemporalMemory.cpp b/src/nupic/algorithms/TemporalMemory.cpp index 72eca4e1a3..868687dbc5 100644 --- a/src/nupic/algorithms/TemporalMemory.cpp +++ b/src/nupic/algorithms/TemporalMemory.cpp @@ -33,12 +33,12 @@ * 4. Model parameters (including "learn") */ -#include #include +#include #include #include -#include #include +#include #include #include @@ -57,19 +57,13 @@ using namespace nupic::algorithms::temporal_memory; static const Permanence EPSILON = 0.000001; static const UInt TM_VERSION = 2; - - -template -bool isSortedWithoutDuplicates(Iterator begin, Iterator end) -{ - if (std::distance(begin, end) >= 2) - { +template +bool isSortedWithoutDuplicates(Iterator begin, Iterator end) { + if (std::distance(begin, end) >= 2) { Iterator now = begin; Iterator next = begin + 1; - while (next != end) - { - if (*now >= *next) - { + while (next != end) { + if (*now >= *next) { return false; } @@ -80,73 +74,38 @@ bool isSortedWithoutDuplicates(Iterator begin, Iterator end) return true; } - -TemporalMemory::TemporalMemory() -{ -} +TemporalMemory::TemporalMemory() {} TemporalMemory::TemporalMemory( - vector columnDimensions, - UInt cellsPerColumn, - UInt activationThreshold, - Permanence initialPermanence, - Permanence connectedPermanence, - UInt minThreshold, - UInt maxNewSynapseCount, - Permanence permanenceIncrement, - Permanence permanenceDecrement, - Permanence predictedSegmentDecrement, - Int seed, - UInt maxSegmentsPerCell, - UInt maxSynapsesPerSegment, - bool checkInputs) -{ - initialize( - columnDimensions, - cellsPerColumn, - activationThreshold, - initialPermanence, - connectedPermanence, - minThreshold, - maxNewSynapseCount, - permanenceIncrement, - permanenceDecrement, - predictedSegmentDecrement, - seed, - maxSegmentsPerCell, - maxSynapsesPerSegment, - checkInputs); + vector columnDimensions, UInt cellsPerColumn, + UInt activationThreshold, Permanence initialPermanence, + Permanence connectedPermanence, UInt minThreshold, UInt maxNewSynapseCount, + Permanence permanenceIncrement, Permanence permanenceDecrement, + Permanence predictedSegmentDecrement, Int seed, UInt maxSegmentsPerCell, + UInt maxSynapsesPerSegment, bool checkInputs) { + initialize(columnDimensions, cellsPerColumn, activationThreshold, + initialPermanence, connectedPermanence, minThreshold, + maxNewSynapseCount, permanenceIncrement, permanenceDecrement, + predictedSegmentDecrement, seed, maxSegmentsPerCell, + maxSynapsesPerSegment, checkInputs); } -TemporalMemory::~TemporalMemory() -{ -} +TemporalMemory::~TemporalMemory() {} void TemporalMemory::initialize( - vector columnDimensions, - UInt cellsPerColumn, - UInt activationThreshold, - Permanence initialPermanence, - Permanence connectedPermanence, - UInt minThreshold, - UInt maxNewSynapseCount, - Permanence permanenceIncrement, - Permanence permanenceDecrement, - Permanence predictedSegmentDecrement, - Int seed, - UInt maxSegmentsPerCell, - UInt maxSynapsesPerSegment, - bool checkInputs) -{ + vector columnDimensions, UInt cellsPerColumn, + UInt activationThreshold, Permanence initialPermanence, + Permanence connectedPermanence, UInt minThreshold, UInt maxNewSynapseCount, + Permanence permanenceIncrement, Permanence permanenceDecrement, + Permanence predictedSegmentDecrement, Int seed, UInt maxSegmentsPerCell, + UInt maxSynapsesPerSegment, bool checkInputs) { // Validate all input parameters - if (columnDimensions.size() <= 0) - { + if (columnDimensions.size() <= 0) { NTA_THROW << "Number of column dimensions must be greater than 0"; } - if (cellsPerColumn <= 0) - { + if (cellsPerColumn <= 0) { NTA_THROW << "Number of cells per column must be greater than 0"; } @@ -160,8 +119,7 @@ void TemporalMemory::initialize( numColumns_ = 1; columnDimensions_.clear(); - for (auto & columnDimension : columnDimensions) - { + for (auto &columnDimension : columnDimensions) { numColumns_ *= columnDimension; columnDimensions_.push_back(columnDimension); } @@ -191,27 +149,20 @@ void TemporalMemory::initialize( matchingSegments_.clear(); } -static CellIdx getLeastUsedCell( - Random& rng, - UInt column, - const Connections& connections, - UInt cellsPerColumn) -{ +static CellIdx getLeastUsedCell(Random &rng, UInt column, + const Connections &connections, + UInt cellsPerColumn) { const CellIdx start = column * cellsPerColumn; const CellIdx end = start + cellsPerColumn; UInt32 minNumSegments = UINT_MAX; UInt32 numTiedCells = 0; - for (CellIdx cell = start; cell < end; cell++) - { + for (CellIdx cell = start; cell < end; cell++) { const UInt32 numSegments = connections.numSegments(cell); - if (numSegments < minNumSegments) - { + if (numSegments < minNumSegments) { minNumSegments = numSegments; numTiedCells = 1; - } - else if (numSegments == minNumSegments) - { + } else if (numSegments == minNumSegments) { numTiedCells++; } } @@ -219,16 +170,11 @@ static CellIdx getLeastUsedCell( const UInt32 tieWinnerIndex = rng.getUInt32(numTiedCells); UInt32 tieIndex = 0; - for (CellIdx cell = start; cell < end; cell++) - { - if (connections.numSegments(cell) == minNumSegments) - { - if (tieIndex == tieWinnerIndex) - { + for (CellIdx cell = start; cell < end; cell++) { + if (connections.numSegments(cell) == minNumSegments) { + if (tieIndex == tieWinnerIndex) { return cell; - } - else - { + } else { tieIndex++; } } @@ -237,91 +183,70 @@ static CellIdx getLeastUsedCell( NTA_THROW << "getLeastUsedCell failed to find a cell"; } -static void adaptSegment( - Connections& connections, - Segment segment, - const vector& prevActiveCellsDense, - Permanence permanenceIncrement, - Permanence permanenceDecrement) -{ - const vector& synapses = connections.synapsesForSegment(segment); +static void adaptSegment(Connections &connections, Segment segment, + const vector &prevActiveCellsDense, + Permanence permanenceIncrement, + Permanence permanenceDecrement) { + const vector &synapses = connections.synapsesForSegment(segment); - for (SynapseIdx i = 0; i < synapses.size();) - { - const SynapseData& synapseData = connections.dataForSynapse(synapses[i]); + for (SynapseIdx i = 0; i < synapses.size();) { + const SynapseData &synapseData = connections.dataForSynapse(synapses[i]); NTA_ASSERT(synapseData.presynapticCell < connections.numCells()); Permanence permanence = synapseData.permanence; - if (prevActiveCellsDense[synapseData.presynapticCell]) - { + if (prevActiveCellsDense[synapseData.presynapticCell]) { permanence += permanenceIncrement; - } - else - { + } else { permanence -= permanenceDecrement; } permanence = min(permanence, (Permanence)1.0); permanence = max(permanence, (Permanence)0.0); - if (permanence < EPSILON) - { + if (permanence < EPSILON) { connections.destroySynapse(synapses[i]); // Synapses vector is modified in-place, so don't update `i`. - } - else - { + } else { connections.updateSynapsePermanence(synapses[i], permanence); i++; } } - if (synapses.size() == 0) - { + if (synapses.size() == 0) { connections.destroySegment(segment); } } -static void destroyMinPermanenceSynapses( - Connections& connections, - Random& rng, - Segment segment, - Int nDestroy, - const vector& excludeCells) -{ +static void destroyMinPermanenceSynapses(Connections &connections, Random &rng, + Segment segment, Int nDestroy, + const vector &excludeCells) { // Don't destroy any cells that are in excludeCells. vector destroyCandidates; - for (Synapse synapse : connections.synapsesForSegment(segment)) - { + for (Synapse synapse : connections.synapsesForSegment(segment)) { const CellIdx presynapticCell = - connections.dataForSynapse(synapse).presynapticCell; + connections.dataForSynapse(synapse).presynapticCell; if (!std::binary_search(excludeCells.begin(), excludeCells.end(), - presynapticCell)) - { + presynapticCell)) { destroyCandidates.push_back(synapse); } } // Find cells one at a time. This is slow, but this code rarely runs, and it // needs to work around floating point differences between environments. - for (Int32 i = 0; i < nDestroy && !destroyCandidates.empty(); i++) - { + for (Int32 i = 0; i < nDestroy && !destroyCandidates.empty(); i++) { Permanence minPermanence = std::numeric_limits::max(); vector::iterator minSynapse = destroyCandidates.end(); for (auto synapse = destroyCandidates.begin(); - synapse != destroyCandidates.end(); - synapse++) - { + synapse != destroyCandidates.end(); synapse++) { const Permanence permanence = - connections.dataForSynapse(*synapse).permanence; + connections.dataForSynapse(*synapse).permanence; // Use special EPSILON logic to compensate for floating point // differences between C++ and other environments. - if (permanence < minPermanence - EPSILON) - { + if (permanence < minPermanence - EPSILON) { minSynapse = synapse; minPermanence = permanence; } @@ -332,15 +257,11 @@ static void destroyMinPermanenceSynapses( } } -static void growSynapses( - Connections& connections, - Random& rng, - Segment segment, - UInt32 nDesiredNewSynapses, - const vector& prevWinnerCells, - Permanence initialPermanence, - UInt maxSynapsesPerSegment) -{ +static void growSynapses(Connections &connections, Random &rng, Segment segment, + UInt32 nDesiredNewSynapses, + const vector &prevWinnerCells, + Permanence initialPermanence, + UInt maxSynapsesPerSegment) { // It's possible to optimize this, swapping candidates to the end as // they're used. But this is awkward to mimic in other // implementations, especially because it requires iterating over @@ -350,38 +271,33 @@ static void growSynapses( NTA_ASSERT(std::is_sorted(candidates.begin(), candidates.end())); // Remove cells that are already synapsed on by this segment - for (Synapse synapse : connections.synapsesForSegment(segment)) - { + for (Synapse synapse : connections.synapsesForSegment(segment)) { CellIdx presynapticCell = - connections.dataForSynapse(synapse).presynapticCell; - auto ineligible = std::lower_bound(candidates.begin(), candidates.end(), - presynapticCell); - if (ineligible != candidates.end() && *ineligible == presynapticCell) - { + connections.dataForSynapse(synapse).presynapticCell; + auto ineligible = + std::lower_bound(candidates.begin(), candidates.end(), presynapticCell); + if (ineligible != candidates.end() && *ineligible == presynapticCell) { candidates.erase(ineligible); } } - const UInt32 nActual = std::min(nDesiredNewSynapses, - (UInt32)candidates.size()); + const UInt32 nActual = + std::min(nDesiredNewSynapses, (UInt32)candidates.size()); // Check if we're going to surpass the maximum number of synapses. - const Int32 overrun = (connections.numSynapses(segment) + - nActual - maxSynapsesPerSegment); - if (overrun > 0) - { + const Int32 overrun = + (connections.numSynapses(segment) + nActual - maxSynapsesPerSegment); + if (overrun > 0) { destroyMinPermanenceSynapses(connections, rng, segment, overrun, prevWinnerCells); } // Recalculate in case we weren't able to destroy as many synapses as needed. - const UInt32 nActualWithMax = std::min(nActual, - maxSynapsesPerSegment - - connections.numSynapses(segment)); + const UInt32 nActualWithMax = std::min( + nActual, maxSynapsesPerSegment - connections.numSynapses(segment)); // Pick nActual cells randomly. - for (UInt32 c = 0; c < nActualWithMax; c++) - { + for (UInt32 c = 0; c < nActualWithMax; c++) { size_t i = rng.getUInt32(candidates.size()); connections.createSynapse(segment, candidates[i], initialPermanence); candidates.erase(candidates.begin() + i); @@ -389,47 +305,35 @@ static void growSynapses( } static void activatePredictedColumn( - vector& activeCells, - vector& winnerCells, - Connections& connections, - Random& rng, - vector::const_iterator columnActiveSegmentsBegin, - vector::const_iterator columnActiveSegmentsEnd, - const vector& prevActiveCellsDense, - const vector& prevWinnerCells, - const vector& numActivePotentialSynapsesForSegment, - UInt maxNewSynapseCount, - Permanence initialPermanence, - Permanence permanenceIncrement, - Permanence permanenceDecrement, - UInt maxSynapsesPerSegment, - bool learn) -{ + vector &activeCells, vector &winnerCells, + Connections &connections, Random &rng, + vector::const_iterator columnActiveSegmentsBegin, + vector::const_iterator columnActiveSegmentsEnd, + const vector &prevActiveCellsDense, + const vector &prevWinnerCells, + const vector &numActivePotentialSynapsesForSegment, + UInt maxNewSynapseCount, Permanence initialPermanence, + Permanence permanenceIncrement, Permanence permanenceDecrement, + UInt maxSynapsesPerSegment, bool learn) { auto activeSegment = columnActiveSegmentsBegin; - do - { + do { const CellIdx cell = connections.cellForSegment(*activeSegment); activeCells.push_back(cell); winnerCells.push_back(cell); // This cell might have multiple active segments. - do - { - if (learn) - { - adaptSegment(connections, - *activeSegment, - prevActiveCellsDense, + do { + if (learn) { + adaptSegment(connections, *activeSegment, prevActiveCellsDense, permanenceIncrement, permanenceDecrement); - const Int32 nGrowDesired = maxNewSynapseCount - - numActivePotentialSynapsesForSegment[*activeSegment]; - if (nGrowDesired > 0) - { - growSynapses(connections, rng, - *activeSegment, nGrowDesired, - prevWinnerCells, - initialPermanence, maxSynapsesPerSegment); + const Int32 nGrowDesired = + maxNewSynapseCount - + numActivePotentialSynapsesForSegment[*activeSegment]; + if (nGrowDesired > 0) { + growSynapses(connections, rng, *activeSegment, nGrowDesired, + prevWinnerCells, initialPermanence, + maxSynapsesPerSegment); } } } while (++activeSegment != columnActiveSegmentsEnd && @@ -437,25 +341,20 @@ static void activatePredictedColumn( } while (activeSegment != columnActiveSegmentsEnd); } -static Segment createSegment( - Connections& connections, - vector& lastUsedIterationForSegment, - CellIdx cell, - UInt64 iteration, - UInt maxSegmentsPerCell) -{ - while (connections.numSegments(cell) >= maxSegmentsPerCell) - { - const vector& destroyCandidates = - connections.segmentsForCell(cell); - - auto leastRecentlyUsedSegment = std::min_element( - destroyCandidates.begin(), destroyCandidates.end(), - [&](Segment a, Segment b) - { - return (lastUsedIterationForSegment[a] < - lastUsedIterationForSegment[b]); - }); +static Segment createSegment(Connections &connections, + vector &lastUsedIterationForSegment, + CellIdx cell, UInt64 iteration, + UInt maxSegmentsPerCell) { + while (connections.numSegments(cell) >= maxSegmentsPerCell) { + const vector &destroyCandidates = + connections.segmentsForCell(cell); + + auto leastRecentlyUsedSegment = + std::min_element(destroyCandidates.begin(), destroyCandidates.end(), + [&](Segment a, Segment b) { + return (lastUsedIterationForSegment[a] < + lastUsedIterationForSegment[b]); + }); connections.destroySegment(*leastRecentlyUsedSegment); } @@ -467,88 +366,67 @@ static Segment createSegment( return segment; } -static void burstColumn( - vector& activeCells, - vector& winnerCells, - Connections& connections, - Random& rng, - vector& lastUsedIterationForSegment, - UInt column, - vector::const_iterator columnMatchingSegmentsBegin, - vector::const_iterator columnMatchingSegmentsEnd, - const vector& prevActiveCellsDense, - const vector& prevWinnerCells, - const vector& numActivePotentialSynapsesForSegment, - UInt64 iteration, - UInt cellsPerColumn, - UInt maxNewSynapseCount, - Permanence initialPermanence, - Permanence permanenceIncrement, - Permanence permanenceDecrement, - UInt maxSegmentsPerCell, - UInt maxSynapsesPerSegment, - bool learn) -{ +static void +burstColumn(vector &activeCells, vector &winnerCells, + Connections &connections, Random &rng, + vector &lastUsedIterationForSegment, UInt column, + vector::const_iterator columnMatchingSegmentsBegin, + vector::const_iterator columnMatchingSegmentsEnd, + const vector &prevActiveCellsDense, + const vector &prevWinnerCells, + const vector &numActivePotentialSynapsesForSegment, + UInt64 iteration, UInt cellsPerColumn, UInt maxNewSynapseCount, + Permanence initialPermanence, Permanence permanenceIncrement, + Permanence permanenceDecrement, UInt maxSegmentsPerCell, + UInt maxSynapsesPerSegment, bool learn) { // Calculate the active cells. const CellIdx start = column * cellsPerColumn; const CellIdx end = start + cellsPerColumn; - for (CellIdx cell = start; cell < end; cell++) - { + for (CellIdx cell = start; cell < end; cell++) { activeCells.push_back(cell); } - const auto bestMatchingSegment = std::max_element( - columnMatchingSegmentsBegin, columnMatchingSegmentsEnd, - [&](Segment a, Segment b) - { - return (numActivePotentialSynapsesForSegment[a] < - numActivePotentialSynapsesForSegment[b]); - }); + const auto bestMatchingSegment = + std::max_element(columnMatchingSegmentsBegin, columnMatchingSegmentsEnd, + [&](Segment a, Segment b) { + return (numActivePotentialSynapsesForSegment[a] < + numActivePotentialSynapsesForSegment[b]); + }); - const CellIdx winnerCell = (bestMatchingSegment != columnMatchingSegmentsEnd) - ? connections.cellForSegment(*bestMatchingSegment) - : getLeastUsedCell(rng, column, connections, cellsPerColumn); + const CellIdx winnerCell = + (bestMatchingSegment != columnMatchingSegmentsEnd) + ? connections.cellForSegment(*bestMatchingSegment) + : getLeastUsedCell(rng, column, connections, cellsPerColumn); winnerCells.push_back(winnerCell); // Learn. - if (learn) - { - if (bestMatchingSegment != columnMatchingSegmentsEnd) - { + if (learn) { + if (bestMatchingSegment != columnMatchingSegmentsEnd) { // Learn on the best matching segment. - adaptSegment(connections, - *bestMatchingSegment, - prevActiveCellsDense, + adaptSegment(connections, *bestMatchingSegment, prevActiveCellsDense, permanenceIncrement, permanenceDecrement); - const Int32 nGrowDesired = maxNewSynapseCount - - numActivePotentialSynapsesForSegment[*bestMatchingSegment]; - if (nGrowDesired > 0) - { - growSynapses(connections, rng, - *bestMatchingSegment, nGrowDesired, - prevWinnerCells, - initialPermanence, maxSynapsesPerSegment); + const Int32 nGrowDesired = + maxNewSynapseCount - + numActivePotentialSynapsesForSegment[*bestMatchingSegment]; + if (nGrowDesired > 0) { + growSynapses(connections, rng, *bestMatchingSegment, nGrowDesired, + prevWinnerCells, initialPermanence, maxSynapsesPerSegment); } - } - else - { + } else { // No matching segments. // Grow a new segment and learn on it. // Don't grow a segment that will never match. - const UInt32 nGrowExact = std::min(maxNewSynapseCount, - (UInt32)prevWinnerCells.size()); - if (nGrowExact > 0) - { + const UInt32 nGrowExact = + std::min(maxNewSynapseCount, (UInt32)prevWinnerCells.size()); + if (nGrowExact > 0) { const Segment segment = - createSegment(connections, lastUsedIterationForSegment, - winnerCell, iteration, maxSegmentsPerCell); + createSegment(connections, lastUsedIterationForSegment, winnerCell, + iteration, maxSegmentsPerCell); - growSynapses(connections, rng, - segment, nGrowExact, - prevWinnerCells, + growSynapses(connections, rng, segment, nGrowExact, prevWinnerCells, initialPermanence, maxSynapsesPerSegment); NTA_ASSERT(connections.numSynapses(segment) == nGrowExact); } @@ -557,153 +435,122 @@ static void burstColumn( } static void punishPredictedColumn( - Connections& connections, - vector::const_iterator columnMatchingSegmentsBegin, - vector::const_iterator columnMatchingSegmentsEnd, - const vector& prevActiveCellsDense, - Permanence predictedSegmentDecrement) -{ - if (predictedSegmentDecrement > 0.0) - { + Connections &connections, + vector::const_iterator columnMatchingSegmentsBegin, + vector::const_iterator columnMatchingSegmentsEnd, + const vector &prevActiveCellsDense, + Permanence predictedSegmentDecrement) { + if (predictedSegmentDecrement > 0.0) { for (auto matchingSegment = columnMatchingSegmentsBegin; - matchingSegment != columnMatchingSegmentsEnd; matchingSegment++) - { + matchingSegment != columnMatchingSegmentsEnd; matchingSegment++) { adaptSegment(connections, *matchingSegment, prevActiveCellsDense, -predictedSegmentDecrement, 0.0); } } } -void TemporalMemory::activateCells( - size_t activeColumnsSize, - const UInt activeColumns[], - bool learn) -{ - if (checkInputs_) - { +void TemporalMemory::activateCells(size_t activeColumnsSize, + const UInt activeColumns[], bool learn) { + if (checkInputs_) { NTA_CHECK(isSortedWithoutDuplicates(activeColumns, activeColumns + activeColumnsSize)) - << "The activeColumns must be a sorted list of indices without duplicates."; + << "The activeColumns must be a sorted list of indices without " + "duplicates."; } vector prevActiveCellsDense(numberOfCells(), false); - for (CellIdx cell : activeCells_) - { + for (CellIdx cell : activeCells_) { prevActiveCellsDense[cell] = true; } activeCells_.clear(); const vector prevWinnerCells = std::move(winnerCells_); - const auto columnForSegment = [&](Segment segment) - { return connections.cellForSegment(segment) / cellsPerColumn_; }; + const auto columnForSegment = [&](Segment segment) { + return connections.cellForSegment(segment) / cellsPerColumn_; + }; - for (auto& columnData : iterGroupBy( - activeColumns, activeColumns + activeColumnsSize, identity, - activeSegments_.begin(), activeSegments_.end(), columnForSegment, - matchingSegments_.begin(), matchingSegments_.end(), columnForSegment)) - { + for (auto &columnData : iterGroupBy( + activeColumns, activeColumns + activeColumnsSize, identity, + activeSegments_.begin(), activeSegments_.end(), columnForSegment, + matchingSegments_.begin(), matchingSegments_.end(), + columnForSegment)) { UInt column; - const UInt* activeColumnsBegin; - const UInt* activeColumnsEnd; - vector::const_iterator - columnActiveSegmentsBegin, columnActiveSegmentsEnd, - columnMatchingSegmentsBegin, columnMatchingSegmentsEnd; - tie(column, - activeColumnsBegin, activeColumnsEnd, - columnActiveSegmentsBegin, columnActiveSegmentsEnd, - columnMatchingSegmentsBegin, columnMatchingSegmentsEnd) = columnData; + const UInt *activeColumnsBegin; + const UInt *activeColumnsEnd; + vector::const_iterator columnActiveSegmentsBegin, + columnActiveSegmentsEnd, columnMatchingSegmentsBegin, + columnMatchingSegmentsEnd; + tie(column, activeColumnsBegin, activeColumnsEnd, columnActiveSegmentsBegin, + columnActiveSegmentsEnd, columnMatchingSegmentsBegin, + columnMatchingSegmentsEnd) = columnData; const bool isActiveColumn = activeColumnsBegin != activeColumnsEnd; - if (isActiveColumn) - { - if (columnActiveSegmentsBegin != columnActiveSegmentsEnd) - { + if (isActiveColumn) { + if (columnActiveSegmentsBegin != columnActiveSegmentsEnd) { activatePredictedColumn( - activeCells_, winnerCells_, connections, rng_, - columnActiveSegmentsBegin, columnActiveSegmentsEnd, - prevActiveCellsDense, prevWinnerCells, - numActivePotentialSynapsesForSegment_, - maxNewSynapseCount_, - initialPermanence_, permanenceIncrement_, permanenceDecrement_, - maxSynapsesPerSegment_, learn); + activeCells_, winnerCells_, connections, rng_, + columnActiveSegmentsBegin, columnActiveSegmentsEnd, + prevActiveCellsDense, prevWinnerCells, + numActivePotentialSynapsesForSegment_, maxNewSynapseCount_, + initialPermanence_, permanenceIncrement_, permanenceDecrement_, + maxSynapsesPerSegment_, learn); + } else { + burstColumn(activeCells_, winnerCells_, connections, rng_, + lastUsedIterationForSegment_, column, + columnMatchingSegmentsBegin, columnMatchingSegmentsEnd, + prevActiveCellsDense, prevWinnerCells, + numActivePotentialSynapsesForSegment_, iteration_, + cellsPerColumn_, maxNewSynapseCount_, initialPermanence_, + permanenceIncrement_, permanenceDecrement_, + maxSegmentsPerCell_, maxSynapsesPerSegment_, learn); } - else - { - burstColumn( - activeCells_, winnerCells_, connections, rng_, - lastUsedIterationForSegment_, - column, columnMatchingSegmentsBegin, columnMatchingSegmentsEnd, - prevActiveCellsDense, prevWinnerCells, - numActivePotentialSynapsesForSegment_, iteration_, - cellsPerColumn_, maxNewSynapseCount_, - initialPermanence_, permanenceIncrement_, permanenceDecrement_, - maxSegmentsPerCell_, maxSynapsesPerSegment_, learn); - } - } - else - { - if (learn) - { - punishPredictedColumn( - connections, - columnMatchingSegmentsBegin, columnMatchingSegmentsEnd, - prevActiveCellsDense, - predictedSegmentDecrement_); + } else { + if (learn) { + punishPredictedColumn(connections, columnMatchingSegmentsBegin, + columnMatchingSegmentsEnd, prevActiveCellsDense, + predictedSegmentDecrement_); } } } } -void TemporalMemory::activateDendrites(bool learn) -{ +void TemporalMemory::activateDendrites(bool learn) { const UInt32 length = connections.segmentFlatListLength(); numActiveConnectedSynapsesForSegment_.assign(length, 0); numActivePotentialSynapsesForSegment_.assign(length, 0); connections.computeActivity(numActiveConnectedSynapsesForSegment_, numActivePotentialSynapsesForSegment_, - activeCells_, - connectedPermanence_); + activeCells_, connectedPermanence_); // Active segments, connected synapses. activeSegments_.clear(); for (Segment segment = 0; - segment < numActiveConnectedSynapsesForSegment_.size(); - segment++) - { - if (numActiveConnectedSynapsesForSegment_[segment] >= activationThreshold_) - { + segment < numActiveConnectedSynapsesForSegment_.size(); segment++) { + if (numActiveConnectedSynapsesForSegment_[segment] >= + activationThreshold_) { activeSegments_.push_back(segment); } } - std::sort(activeSegments_.begin(), activeSegments_.end(), - [&](Segment a, Segment b) - { - return connections.compareSegments(a, b); - }); + std::sort( + activeSegments_.begin(), activeSegments_.end(), + [&](Segment a, Segment b) { return connections.compareSegments(a, b); }); // Matching segments, potential synapses. matchingSegments_.clear(); for (Segment segment = 0; - segment < numActivePotentialSynapsesForSegment_.size(); - segment++) - { - if (numActivePotentialSynapsesForSegment_[segment] >= minThreshold_) - { + segment < numActivePotentialSynapsesForSegment_.size(); segment++) { + if (numActivePotentialSynapsesForSegment_[segment] >= minThreshold_) { matchingSegments_.push_back(segment); } } - std::sort(matchingSegments_.begin(), matchingSegments_.end(), - [&](Segment a, Segment b) - { - return connections.compareSegments(a, b); - }); - - if (learn) - { - for (Segment segment : activeSegments_) - { + std::sort( + matchingSegments_.begin(), matchingSegments_.end(), + [&](Segment a, Segment b) { return connections.compareSegments(a, b); }); + + if (learn) { + for (Segment segment : activeSegments_) { lastUsedIterationForSegment_[segment] = iteration_; } @@ -711,17 +558,13 @@ void TemporalMemory::activateDendrites(bool learn) } } -void TemporalMemory::compute( - size_t activeColumnsSize, - const UInt activeColumns[], - bool learn) -{ +void TemporalMemory::compute(size_t activeColumnsSize, + const UInt activeColumns[], bool learn) { activateCells(activeColumnsSize, activeColumns, learn); activateDendrites(learn); } -void TemporalMemory::reset(void) -{ +void TemporalMemory::reset(void) { activeCells_.clear(); winnerCells_.clear(); activeSegments_.clear(); @@ -732,54 +575,40 @@ void TemporalMemory::reset(void) // Helper functions // ============================== -Segment TemporalMemory::createSegment(CellIdx cell) -{ - return ::createSegment(connections, lastUsedIterationForSegment_, - cell, iteration_, maxSegmentsPerCell_); +Segment TemporalMemory::createSegment(CellIdx cell) { + return ::createSegment(connections, lastUsedIterationForSegment_, cell, + iteration_, maxSegmentsPerCell_); } -Int TemporalMemory::columnForCell(CellIdx cell) -{ +Int TemporalMemory::columnForCell(CellIdx cell) { _validateCell(cell); return cell / cellsPerColumn_; } -vector TemporalMemory::cellsForColumn(Int column) -{ +vector TemporalMemory::cellsForColumn(Int column) { const CellIdx start = cellsPerColumn_ * column; const CellIdx end = start + cellsPerColumn_; vector cellsInColumn; - for (CellIdx i = start; i < end; i++) - { + for (CellIdx i = start; i < end; i++) { cellsInColumn.push_back(i); } return cellsInColumn; } -UInt TemporalMemory::numberOfCells(void) -{ - return connections.numCells(); -} +UInt TemporalMemory::numberOfCells(void) { return connections.numCells(); } -vector TemporalMemory::getActiveCells() const -{ - return activeCells_; -} +vector TemporalMemory::getActiveCells() const { return activeCells_; } -vector TemporalMemory::getPredictiveCells() const -{ +vector TemporalMemory::getPredictiveCells() const { vector predictiveCells; - for (auto segment = activeSegments_.begin(); - segment != activeSegments_.end(); segment++) - { + for (auto segment = activeSegments_.begin(); segment != activeSegments_.end(); + segment++) { CellIdx cell = connections.cellForSegment(*segment); - if (segment == activeSegments_.begin() || - cell != predictiveCells.back()) - { + if (segment == activeSegments_.begin() || cell != predictiveCells.back()) { predictiveCells.push_back(cell); } } @@ -787,28 +616,19 @@ vector TemporalMemory::getPredictiveCells() const return predictiveCells; } -vector TemporalMemory::getWinnerCells() const -{ - return winnerCells_; -} +vector TemporalMemory::getWinnerCells() const { return winnerCells_; } -vector TemporalMemory::getActiveSegments() const -{ +vector TemporalMemory::getActiveSegments() const { return activeSegments_; } -vector TemporalMemory::getMatchingSegments() const -{ +vector TemporalMemory::getMatchingSegments() const { return matchingSegments_; } -UInt TemporalMemory::numberOfColumns() const -{ - return numColumns_; -} +UInt TemporalMemory::numberOfColumns() const { return numColumns_; } -bool TemporalMemory::_validateCell(CellIdx cell) -{ +bool TemporalMemory::_validateCell(CellIdx cell) { if (cell < numberOfCells()) return true; @@ -816,131 +636,97 @@ bool TemporalMemory::_validateCell(CellIdx cell) return false; } -vector TemporalMemory::getColumnDimensions() const -{ +vector TemporalMemory::getColumnDimensions() const { return columnDimensions_; } -UInt TemporalMemory::getCellsPerColumn() const -{ - return cellsPerColumn_; -} +UInt TemporalMemory::getCellsPerColumn() const { return cellsPerColumn_; } -UInt TemporalMemory::getActivationThreshold() const -{ +UInt TemporalMemory::getActivationThreshold() const { return activationThreshold_; } -void TemporalMemory::setActivationThreshold(UInt activationThreshold) -{ +void TemporalMemory::setActivationThreshold(UInt activationThreshold) { activationThreshold_ = activationThreshold; } -Permanence TemporalMemory::getInitialPermanence() const -{ +Permanence TemporalMemory::getInitialPermanence() const { return initialPermanence_; } -void TemporalMemory::setInitialPermanence(Permanence initialPermanence) -{ +void TemporalMemory::setInitialPermanence(Permanence initialPermanence) { initialPermanence_ = initialPermanence; } -Permanence TemporalMemory::getConnectedPermanence() const -{ +Permanence TemporalMemory::getConnectedPermanence() const { return connectedPermanence_; } -void TemporalMemory::setConnectedPermanence(Permanence connectedPermanence) -{ +void TemporalMemory::setConnectedPermanence(Permanence connectedPermanence) { connectedPermanence_ = connectedPermanence; } -UInt TemporalMemory::getMinThreshold() const -{ - return minThreshold_; -} +UInt TemporalMemory::getMinThreshold() const { return minThreshold_; } -void TemporalMemory::setMinThreshold(UInt minThreshold) -{ +void TemporalMemory::setMinThreshold(UInt minThreshold) { minThreshold_ = minThreshold; } -UInt TemporalMemory::getMaxNewSynapseCount() const -{ +UInt TemporalMemory::getMaxNewSynapseCount() const { return maxNewSynapseCount_; } -void TemporalMemory::setMaxNewSynapseCount(UInt maxNewSynapseCount) -{ +void TemporalMemory::setMaxNewSynapseCount(UInt maxNewSynapseCount) { maxNewSynapseCount_ = maxNewSynapseCount; } -bool TemporalMemory::getCheckInputs() const -{ - return checkInputs_; -} +bool TemporalMemory::getCheckInputs() const { return checkInputs_; } -void TemporalMemory::setCheckInputs(bool checkInputs) -{ +void TemporalMemory::setCheckInputs(bool checkInputs) { checkInputs_ = checkInputs; } -Permanence TemporalMemory::getPermanenceIncrement() const -{ +Permanence TemporalMemory::getPermanenceIncrement() const { return permanenceIncrement_; } -void TemporalMemory::setPermanenceIncrement(Permanence permanenceIncrement) -{ +void TemporalMemory::setPermanenceIncrement(Permanence permanenceIncrement) { permanenceIncrement_ = permanenceIncrement; } -Permanence TemporalMemory::getPermanenceDecrement() const -{ +Permanence TemporalMemory::getPermanenceDecrement() const { return permanenceDecrement_; } -void TemporalMemory::setPermanenceDecrement(Permanence permanenceDecrement) -{ +void TemporalMemory::setPermanenceDecrement(Permanence permanenceDecrement) { permanenceDecrement_ = permanenceDecrement; } -Permanence TemporalMemory::getPredictedSegmentDecrement() const -{ +Permanence TemporalMemory::getPredictedSegmentDecrement() const { return predictedSegmentDecrement_; } -void TemporalMemory::setPredictedSegmentDecrement(Permanence predictedSegmentDecrement) -{ +void TemporalMemory::setPredictedSegmentDecrement( + Permanence predictedSegmentDecrement) { predictedSegmentDecrement_ = predictedSegmentDecrement; } -UInt TemporalMemory::getMaxSegmentsPerCell() const -{ +UInt TemporalMemory::getMaxSegmentsPerCell() const { return maxSegmentsPerCell_; } -UInt TemporalMemory::getMaxSynapsesPerSegment() const -{ +UInt TemporalMemory::getMaxSynapsesPerSegment() const { return maxSynapsesPerSegment_; } -UInt TemporalMemory::version() const -{ - return TM_VERSION; -} +UInt TemporalMemory::version() const { return TM_VERSION; } /** -* Create a RNG with given seed -*/ -void TemporalMemory::seed_(UInt64 seed) -{ - rng_ = Random(seed); -} + * Create a RNG with given seed + */ +void TemporalMemory::seed_(UInt64 seed) { rng_ = Random(seed); } -UInt TemporalMemory::persistentSize() const -{ +UInt TemporalMemory::persistentSize() const { stringstream s; s.flags(ios::scientific); s.precision(numeric_limits::digits10 + 1); @@ -948,37 +734,31 @@ UInt TemporalMemory::persistentSize() const return s.str().size(); } -template -static void saveFloat_(ostream& outStream, FloatType v) -{ +template +static void saveFloat_(ostream &outStream, FloatType v) { outStream << std::setprecision(std::numeric_limits::max_digits10) - << v - << " "; + << v << " "; } -void TemporalMemory::save(ostream& outStream) const -{ +void TemporalMemory::save(ostream &outStream) const { // Write a starting marker and version. outStream << "TemporalMemory" << endl; outStream << TM_VERSION << endl; - outStream << numColumns_ << " " - << cellsPerColumn_ << " " + outStream << numColumns_ << " " << cellsPerColumn_ << " " << activationThreshold_ << " "; saveFloat_(outStream, initialPermanence_); saveFloat_(outStream, connectedPermanence_); - outStream << minThreshold_ << " " - << maxNewSynapseCount_ << " " + outStream << minThreshold_ << " " << maxNewSynapseCount_ << " " << checkInputs_ << " "; saveFloat_(outStream, permanenceIncrement_); saveFloat_(outStream, permanenceDecrement_); saveFloat_(outStream, predictedSegmentDecrement_); - outStream << maxSegmentsPerCell_ << " " - << maxSynapsesPerSegment_ << " " + outStream << maxSegmentsPerCell_ << " " << maxSynapsesPerSegment_ << " " << iteration_ << " "; outStream << endl; @@ -989,35 +769,30 @@ void TemporalMemory::save(ostream& outStream) const outStream << rng_ << endl; outStream << columnDimensions_.size() << " "; - for (auto & elem : columnDimensions_) - { + for (auto &elem : columnDimensions_) { outStream << elem << " "; } outStream << endl; outStream << activeCells_.size() << " "; - for (CellIdx cell : activeCells_) - { + for (CellIdx cell : activeCells_) { outStream << cell << " "; } outStream << endl; outStream << winnerCells_.size() << " "; - for (CellIdx cell : winnerCells_) - { + for (CellIdx cell : winnerCells_) { outStream << cell << " "; } outStream << endl; outStream << activeSegments_.size() << " "; - for (Segment segment : activeSegments_) - { + for (Segment segment : activeSegments_) { const CellIdx cell = connections.cellForSegment(segment); - const vector& segments = connections.segmentsForCell(cell); + const vector &segments = connections.segmentsForCell(cell); SegmentIdx idx = std::distance( - segments.begin(), - std::find(segments.begin(), segments.end(), segment)); + segments.begin(), std::find(segments.begin(), segments.end(), segment)); outStream << idx << " "; outStream << cell << " "; @@ -1026,14 +801,12 @@ void TemporalMemory::save(ostream& outStream) const outStream << endl; outStream << matchingSegments_.size() << " "; - for (Segment segment : matchingSegments_) - { + for (Segment segment : matchingSegments_) { const CellIdx cell = connections.cellForSegment(segment); - const vector& segments = connections.segmentsForCell(cell); + const vector &segments = connections.segmentsForCell(cell); SegmentIdx idx = std::distance( - segments.begin(), - std::find(segments.begin(), segments.end(), segment)); + segments.begin(), std::find(segments.begin(), segments.end(), segment)); outStream << idx << " "; outStream << cell << " "; @@ -1044,11 +817,9 @@ void TemporalMemory::save(ostream& outStream) const outStream << "~TemporalMemory" << endl; } -void TemporalMemory::write(TemporalMemoryProto::Builder& proto) const -{ +void TemporalMemory::write(TemporalMemoryProto::Builder &proto) const { auto columnDims = proto.initColumnDimensions(columnDimensions_.size()); - for (UInt i = 0; i < columnDimensions_.size(); i++) - { + for (UInt i = 0; i < columnDimensions_.size(); i++) { columnDims.set(i, columnDimensions_[i]); } @@ -1074,77 +845,66 @@ void TemporalMemory::write(TemporalMemoryProto::Builder& proto) const auto activeCells = proto.initActiveCells(activeCells_.size()); UInt i = 0; - for (CellIdx cell : activeCells_) - { + for (CellIdx cell : activeCells_) { activeCells.set(i++, cell); } auto winnerCells = proto.initWinnerCells(winnerCells_.size()); i = 0; - for (CellIdx cell : winnerCells_) - { + for (CellIdx cell : winnerCells_) { winnerCells.set(i++, cell); } auto activeSegments = proto.initActiveSegments(activeSegments_.size()); - for (UInt i = 0; i < activeSegments_.size(); ++i) - { - activeSegments[i].setCell( - connections.cellForSegment(activeSegments_[i])); + for (UInt i = 0; i < activeSegments_.size(); ++i) { + activeSegments[i].setCell(connections.cellForSegment(activeSegments_[i])); activeSegments[i].setIdxOnCell( - connections.idxOnCellForSegment(activeSegments_[i])); + connections.idxOnCellForSegment(activeSegments_[i])); } auto matchingSegments = proto.initMatchingSegments(matchingSegments_.size()); - for (UInt i = 0; i < matchingSegments_.size(); ++i) - { + for (UInt i = 0; i < matchingSegments_.size(); ++i) { matchingSegments[i].setCell( - connections.cellForSegment(matchingSegments_[i])); + connections.cellForSegment(matchingSegments_[i])); matchingSegments[i].setIdxOnCell( - connections.idxOnCellForSegment(matchingSegments_[i])); + connections.idxOnCellForSegment(matchingSegments_[i])); } auto numActivePotentialSynapsesForSegment = - proto.initNumActivePotentialSynapsesForSegment( - numActivePotentialSynapsesForSegment_.size()); + proto.initNumActivePotentialSynapsesForSegment( + numActivePotentialSynapsesForSegment_.size()); for (Segment segment = 0; - segment < numActivePotentialSynapsesForSegment_.size(); - segment++) - { + segment < numActivePotentialSynapsesForSegment_.size(); segment++) { numActivePotentialSynapsesForSegment[segment].setCell( - connections.cellForSegment(segment)); + connections.cellForSegment(segment)); numActivePotentialSynapsesForSegment[segment].setIdxOnCell( - connections.idxOnCellForSegment(segment)); + connections.idxOnCellForSegment(segment)); numActivePotentialSynapsesForSegment[segment].setNumber( - numActivePotentialSynapsesForSegment_[segment]); + numActivePotentialSynapsesForSegment_[segment]); } proto.setIteration(iteration_); - auto lastUsedIterationForSegment = - proto.initLastUsedIterationForSegment(lastUsedIterationForSegment_.size()); - for (Segment segment = 0; - segment < lastUsedIterationForSegment_.size(); - ++segment) - { + auto lastUsedIterationForSegment = proto.initLastUsedIterationForSegment( + lastUsedIterationForSegment_.size()); + for (Segment segment = 0; segment < lastUsedIterationForSegment_.size(); + ++segment) { lastUsedIterationForSegment[segment].setCell( - connections.cellForSegment(segment)); + connections.cellForSegment(segment)); lastUsedIterationForSegment[segment].setIdxOnCell( - connections.idxOnCellForSegment(segment)); + connections.idxOnCellForSegment(segment)); lastUsedIterationForSegment[segment].setNumber( - lastUsedIterationForSegment_[segment]); + lastUsedIterationForSegment_[segment]); } } // Implementation note: this method sets up the instance using data from // proto. This method does not call initialize. As such we have to be careful // that everything in initialize is handled properly here. -void TemporalMemory::read(TemporalMemoryProto::Reader& proto) -{ +void TemporalMemory::read(TemporalMemoryProto::Reader &proto) { numColumns_ = 1; columnDimensions_.clear(); - for (UInt dimension : proto.getColumnDimensions()) - { + for (UInt dimension : proto.getColumnDimensions()) { numColumns_ *= dimension; columnDimensions_.push_back(dimension); } @@ -1167,48 +927,43 @@ void TemporalMemory::read(TemporalMemoryProto::Reader& proto) connections.read(_connections); numActiveConnectedSynapsesForSegment_.assign( - connections.segmentFlatListLength(), 0); + connections.segmentFlatListLength(), 0); numActivePotentialSynapsesForSegment_.assign( - connections.segmentFlatListLength(), 0); + connections.segmentFlatListLength(), 0); auto random = proto.getRandom(); rng_.read(random); activeCells_.clear(); - for (auto cell : proto.getActiveCells()) - { + for (auto cell : proto.getActiveCells()) { activeCells_.push_back(cell); } winnerCells_.clear(); - for (auto cell : proto.getWinnerCells()) - { + for (auto cell : proto.getWinnerCells()) { winnerCells_.push_back(cell); } activeSegments_.clear(); - for (auto value : proto.getActiveSegments()) - { - const Segment segment = connections.getSegment(value.getCell(), - value.getIdxOnCell()); + for (auto value : proto.getActiveSegments()) { + const Segment segment = + connections.getSegment(value.getCell(), value.getIdxOnCell()); activeSegments_.push_back(segment); } matchingSegments_.clear(); - for (auto value : proto.getMatchingSegments()) - { - const Segment segment = connections.getSegment(value.getCell(), - value.getIdxOnCell()); + for (auto value : proto.getMatchingSegments()) { + const Segment segment = + connections.getSegment(value.getCell(), value.getIdxOnCell()); matchingSegments_.push_back(segment); } numActivePotentialSynapsesForSegment_.clear(); numActivePotentialSynapsesForSegment_.resize( - connections.segmentFlatListLength()); - for (auto segmentNumPair : proto.getNumActivePotentialSynapsesForSegment()) - { + connections.segmentFlatListLength()); + for (auto segmentNumPair : proto.getNumActivePotentialSynapsesForSegment()) { const Segment segment = connections.getSegment( - segmentNumPair.getCell(), segmentNumPair.getIdxOnCell()); + segmentNumPair.getCell(), segmentNumPair.getIdxOnCell()); numActivePotentialSynapsesForSegment_[segment] = segmentNumPair.getNumber(); } @@ -1216,16 +971,14 @@ void TemporalMemory::read(TemporalMemoryProto::Reader& proto) lastUsedIterationForSegment_.clear(); lastUsedIterationForSegment_.resize(connections.segmentFlatListLength()); - for (auto segmentIterationPair : proto.getLastUsedIterationForSegment()) - { + for (auto segmentIterationPair : proto.getLastUsedIterationForSegment()) { const Segment segment = connections.getSegment( - segmentIterationPair.getCell(), segmentIterationPair.getIdxOnCell()); + segmentIterationPair.getCell(), segmentIterationPair.getIdxOnCell()); lastUsedIterationForSegment_[segment] = segmentIterationPair.getNumber(); } } -void TemporalMemory::load(istream& inStream) -{ +void TemporalMemory::load(istream &inStream) { // Check the marker string marker; inStream >> marker; @@ -1237,27 +990,18 @@ void TemporalMemory::load(istream& inStream) NTA_CHECK(version <= TM_VERSION); // Retrieve simple variables - inStream >> numColumns_ - >> cellsPerColumn_ - >> activationThreshold_ - >> initialPermanence_ - >> connectedPermanence_ - >> minThreshold_ - >> maxNewSynapseCount_ - >> checkInputs_ - >> permanenceIncrement_ - >> permanenceDecrement_ - >> predictedSegmentDecrement_ - >> maxSegmentsPerCell_ - >> maxSynapsesPerSegment_ - >> iteration_; + inStream >> numColumns_ >> cellsPerColumn_ >> activationThreshold_ >> + initialPermanence_ >> connectedPermanence_ >> minThreshold_ >> + maxNewSynapseCount_ >> checkInputs_ >> permanenceIncrement_ >> + permanenceDecrement_ >> predictedSegmentDecrement_ >> + maxSegmentsPerCell_ >> maxSynapsesPerSegment_ >> iteration_; connections.load(inStream); numActiveConnectedSynapsesForSegment_.assign( - connections.segmentFlatListLength(), 0); + connections.segmentFlatListLength(), 0); numActivePotentialSynapsesForSegment_.assign( - connections.segmentFlatListLength(), 0); + connections.segmentFlatListLength(), 0); inStream >> rng_; @@ -1265,26 +1009,22 @@ void TemporalMemory::load(istream& inStream) UInt numColumnDimensions; inStream >> numColumnDimensions; columnDimensions_.resize(numColumnDimensions); - for (UInt i = 0; i < numColumnDimensions; i++) - { + for (UInt i = 0; i < numColumnDimensions; i++) { inStream >> columnDimensions_[i]; } UInt numActiveCells; inStream >> numActiveCells; - for (UInt i = 0; i < numActiveCells; i++) - { + for (UInt i = 0; i < numActiveCells; i++) { CellIdx cell; inStream >> cell; activeCells_.push_back(cell); } - if (version < 2) - { + if (version < 2) { UInt numPredictiveCells; inStream >> numPredictiveCells; - for (UInt i = 0; i < numPredictiveCells; i++) - { + for (UInt i = 0; i < numPredictiveCells; i++) { CellIdx cell; inStream >> cell; // Ignore } @@ -1292,8 +1032,7 @@ void TemporalMemory::load(istream& inStream) UInt numWinnerCells; inStream >> numWinnerCells; - for (UInt i = 0; i < numWinnerCells; i++) - { + for (UInt i = 0; i < numWinnerCells; i++) { CellIdx cell; inStream >> cell; winnerCells_.push_back(cell); @@ -1302,8 +1041,7 @@ void TemporalMemory::load(istream& inStream) UInt numActiveSegments; inStream >> numActiveSegments; activeSegments_.resize(numActiveSegments); - for (UInt i = 0; i < numActiveSegments; i++) - { + for (UInt i = 0; i < numActiveSegments; i++) { SegmentIdx idx; inStream >> idx; @@ -1313,12 +1051,9 @@ void TemporalMemory::load(istream& inStream) Segment segment = connections.getSegment(cellIdx, idx); activeSegments_[i] = segment; - if (version < 2) - { + if (version < 2) { numActiveConnectedSynapsesForSegment_[segment] = 0; // Unknown - } - else - { + } else { inStream >> numActiveConnectedSynapsesForSegment_[segment]; } } @@ -1326,8 +1061,7 @@ void TemporalMemory::load(istream& inStream) UInt numMatchingSegments; inStream >> numMatchingSegments; matchingSegments_.resize(numMatchingSegments); - for (UInt i = 0; i < numMatchingSegments; i++) - { + for (UInt i = 0; i < numMatchingSegments; i++) { SegmentIdx idx; inStream >> idx; @@ -1337,22 +1071,17 @@ void TemporalMemory::load(istream& inStream) Segment segment = connections.getSegment(cellIdx, idx); matchingSegments_[i] = segment; - if (version < 2) - { + if (version < 2) { numActivePotentialSynapsesForSegment_[segment] = 0; // Unknown - } - else - { + } else { inStream >> numActivePotentialSynapsesForSegment_[segment]; } } - if (version < 2) - { + if (version < 2) { UInt numMatchingCells; inStream >> numMatchingCells; - for (UInt i = 0; i < numMatchingCells; i++) - { + for (UInt i = 0; i < numMatchingCells; i++) { CellIdx cell; inStream >> cell; // Ignore } @@ -1362,24 +1091,20 @@ void TemporalMemory::load(istream& inStream) inStream >> marker; NTA_CHECK(marker == "~TemporalMemory"); - } -static set< pair > -getComparableSegmentSet(const Connections& connections, - const vector& segments) -{ - set< pair > segmentSet; - for (Segment segment : segments) - { +static set> +getComparableSegmentSet(const Connections &connections, + const vector &segments) { + set> segmentSet; + for (Segment segment : segments) { segmentSet.emplace(connections.cellForSegment(segment), connections.idxOnCellForSegment(segment)); } return segmentSet; } -bool TemporalMemory::operator==(const TemporalMemory& other) -{ +bool TemporalMemory::operator==(const TemporalMemory &other) { if (numColumns_ != other.numColumns_ || columnDimensions_ != other.columnDimensions_ || cellsPerColumn_ != other.cellsPerColumn_ || @@ -1395,29 +1120,25 @@ bool TemporalMemory::operator==(const TemporalMemory& other) winnerCells_ != other.winnerCells_ || maxSegmentsPerCell_ != other.maxSegmentsPerCell_ || maxSynapsesPerSegment_ != other.maxSynapsesPerSegment_ || - iteration_ != other.iteration_) - { + iteration_ != other.iteration_) { return false; } - if (connections != other.connections) - { + if (connections != other.connections) { return false; } if (getComparableSegmentSet(connections, activeSegments_) != - getComparableSegmentSet(other.connections, other.activeSegments_) || + getComparableSegmentSet(other.connections, other.activeSegments_) || getComparableSegmentSet(connections, matchingSegments_) != - getComparableSegmentSet(other.connections, other.matchingSegments_)) - { + getComparableSegmentSet(other.connections, other.matchingSegments_)) { return false; } return true; } -bool TemporalMemory::operator!=(const TemporalMemory& other) -{ +bool TemporalMemory::operator!=(const TemporalMemory &other) { return !(*this == other); } @@ -1426,32 +1147,30 @@ bool TemporalMemory::operator!=(const TemporalMemory& other) //---------------------------------------------------------------------- // Print the main TM creation parameters -void TemporalMemory::printParameters() -{ +void TemporalMemory::printParameters() { std::cout << "------------CPP TemporalMemory Parameters ------------------\n"; std::cout - << "version = " << TM_VERSION << std::endl - << "numColumns = " << numberOfColumns() << std::endl - << "cellsPerColumn = " << getCellsPerColumn() << std::endl - << "activationThreshold = " << getActivationThreshold() << std::endl - << "initialPermanence = " << getInitialPermanence() << std::endl - << "connectedPermanence = " << getConnectedPermanence() << std::endl - << "minThreshold = " << getMinThreshold() << std::endl - << "maxNewSynapseCount = " << getMaxNewSynapseCount() << std::endl - << "permanenceIncrement = " << getPermanenceIncrement() << std::endl - << "permanenceDecrement = " << getPermanenceDecrement() << std::endl - << "predictedSegmentDecrement = " << getPredictedSegmentDecrement() << std::endl - << "maxSegmentsPerCell = " << getMaxSegmentsPerCell() << std::endl - << "maxSynapsesPerSegment = " << getMaxSynapsesPerSegment() << std::endl; -} - -void TemporalMemory::printState(vector &state) -{ + << "version = " << TM_VERSION << std::endl + << "numColumns = " << numberOfColumns() << std::endl + << "cellsPerColumn = " << getCellsPerColumn() << std::endl + << "activationThreshold = " << getActivationThreshold() << std::endl + << "initialPermanence = " << getInitialPermanence() << std::endl + << "connectedPermanence = " << getConnectedPermanence() << std::endl + << "minThreshold = " << getMinThreshold() << std::endl + << "maxNewSynapseCount = " << getMaxNewSynapseCount() << std::endl + << "permanenceIncrement = " << getPermanenceIncrement() << std::endl + << "permanenceDecrement = " << getPermanenceDecrement() << std::endl + << "predictedSegmentDecrement = " << getPredictedSegmentDecrement() + << std::endl + << "maxSegmentsPerCell = " << getMaxSegmentsPerCell() << std::endl + << "maxSynapsesPerSegment = " << getMaxSynapsesPerSegment() + << std::endl; +} + +void TemporalMemory::printState(vector &state) { std::cout << "[ "; - for (UInt i = 0; i != state.size(); ++i) - { - if (i > 0 && i % 10 == 0) - { + for (UInt i = 0; i != state.size(); ++i) { + if (i > 0 && i % 10 == 0) { std::cout << "\n "; } std::cout << state[i] << " "; @@ -1459,13 +1178,10 @@ void TemporalMemory::printState(vector &state) std::cout << "]\n"; } -void TemporalMemory::printState(vector &state) -{ +void TemporalMemory::printState(vector &state) { std::cout << "[ "; - for (UInt i = 0; i != state.size(); ++i) - { - if (i > 0 && i % 10 == 0) - { + for (UInt i = 0; i != state.size(); ++i) { + if (i > 0 && i % 10 == 0) { std::cout << "\n "; } std::printf("%6.3f ", state[i]); diff --git a/src/nupic/algorithms/TemporalMemory.hpp b/src/nupic/algorithms/TemporalMemory.hpp index e13ebcd4ea..1f40f90812 100644 --- a/src/nupic/algorithms/TemporalMemory.hpp +++ b/src/nupic/algorithms/TemporalMemory.hpp @@ -27,11 +27,11 @@ #ifndef NTA_TEMPORAL_MEMORY_HPP #define NTA_TEMPORAL_MEMORY_HPP -#include +#include #include #include #include -#include +#include #include @@ -40,469 +40,454 @@ using namespace nupic; using namespace nupic::algorithms::connections; namespace nupic { - namespace algorithms { - namespace temporal_memory { - - /** - * Temporal Memory implementation in C++. - * - * Example usage: - * - * SpatialPooler sp(inputDimensions, columnDimensions, ); - * TemporalMemory tm(columnDimensions, ); - * - * while (true) { - * - * sp.compute(inputVector, learn, activeColumns) - * tm.compute(number of activeColumns, activeColumns, learn) - * - * } - * - * The public API uses C arrays, not std::vectors, as inputs. C arrays are - * a good lowest common denominator. You can get a C array from a vector, - * but you can't get a vector from a C array without copying it. This is - * important, for example, when using numpy arrays. The only way to - * convert a numpy array into a std::vector is to copy it, but you can - * access a numpy array's internal C array directly. - */ - class TemporalMemory : public Serializable { - public: - TemporalMemory(); - - /** - * Initialize the temporal memory (TM) using the given parameters. - * - * @param columnDimensions - * Dimensions of the column space - * - * @param cellsPerColumn - * Number of cells per column - * - * @param activationThreshold - * If the number of active connected synapses on a segment is at least - * this threshold, the segment is said to be active. - * - * @param initialPermanence - * Initial permanence of a new synapse. - * - * @param connectedPermanence - * If the permanence value for a synapse is greater than this value, it - * is said to be connected. - * - * @param minThreshold - * If the number of potential synapses active on a segment is at least - * this threshold, it is said to be "matching" and is eligible for - * learning. - * - * @param maxNewSynapseCount - * The maximum number of synapses added to a segment during learning. - * - * @param permanenceIncrement - * Amount by which permanences of synapses are incremented during - * learning. - * - * @param permanenceDecrement - * Amount by which permanences of synapses are decremented during - * learning. - * - * @param predictedSegmentDecrement - * Amount by which segments are punished for incorrect predictions. - * - * @param seed - * Seed for the random number generator. - * - * @param maxSegmentsPerCell - * The maximum number of segments per cell. - * - * @param maxSynapsesPerSegment - * The maximum number of synapses per segment. - * - * @param checkInputs - * Whether to check that the activeColumns are sorted without - * duplicates. Disable this for a small speed boost. - * - * Notes: - * - * predictedSegmentDecrement: A good value is just a bit larger than - * (the column-level sparsity * permanenceIncrement). So, if column-level - * sparsity is 2% and permanenceIncrement is 0.01, this parameter should be - * something like 4% * 0.01 = 0.0004). - */ - TemporalMemory( - vector columnDimensions, - UInt cellsPerColumn = 32, - UInt activationThreshold = 13, - Permanence initialPermanence = 0.21, - Permanence connectedPermanence = 0.50, - UInt minThreshold = 10, - UInt maxNewSynapseCount = 20, - Permanence permanenceIncrement = 0.10, - Permanence permanenceDecrement = 0.10, - Permanence predictedSegmentDecrement = 0.0, - Int seed = 42, - UInt maxSegmentsPerCell=255, - UInt maxSynapsesPerSegment=255, - bool checkInputs=true); - - virtual void initialize( - vector columnDimensions = { 2048 }, - UInt cellsPerColumn = 32, - UInt activationThreshold = 13, - Permanence initialPermanence = 0.21, - Permanence connectedPermanence = 0.50, - UInt minThreshold = 10, - UInt maxNewSynapseCount = 20, - Permanence permanenceIncrement = 0.10, - Permanence permanenceDecrement = 0.10, - Permanence predictedSegmentDecrement = 0.0, - Int seed = 42, - UInt maxSegmentsPerCell=255, - UInt maxSynapsesPerSegment=255, - bool checkInputs=true); - - virtual ~TemporalMemory(); - - //---------------------------------------------------------------------- - // Main functions - //---------------------------------------------------------------------- - - /** - * Get the version number of for the TM implementation. - * - * @returns Integer version number. - */ - virtual UInt version() const; - - /** - * This *only* updates _rng to a new Random using seed. - * - * @returns Integer version number. - */ - void seed_(UInt64 seed); - - /** - * Indicates the start of a new sequence. - * Resets sequence state of the TM. - */ - virtual void reset(); - - /** - * Calculate the active cells, using the current active columns and - * dendrite segments. Grow and reinforce synapses. - * - * @param activeColumnsSize - * Size of activeColumns. - * - * @param activeColumns - * A sorted list of active column indices. - * - * @param learn - * If true, reinforce / punish / grow synapses. - */ - void activateCells( - size_t activeColumnsSize, - const UInt activeColumns[], - bool learn = true); - - /** - * Calculate dendrite segment activity, using the current active cells. - * - * @param learn - * If true, segment activations will be recorded. This information is - * used during segment cleanup. - */ - void activateDendrites(bool learn = true); - - /** - * Perform one time step of the Temporal Memory algorithm. - * - * This method calls activateCells, then calls activateDendrites. Using - * the TemporalMemory via its compute method ensures that you'll always - * be able to call getPredictiveCells to get predictions for the next - * time step. - * - * @param activeColumnsSize - * Number of active columns. - * - * @param activeColumns - * Sorted list of indices of active columns. - * - * @param learn - * Whether or not learning is enabled. - */ - virtual void compute( - size_t activeColumnsSize, - const UInt activeColumns[], - bool learn = true); - - - // ============================== - // Helper functions - // ============================== - - /** - * Create a segment on the specified cell. This method calls - * createSegment on the underlying connections, and it does some extra - * bookkeeping. Unit tests should call this method, and not - * connections.createSegment(). - * - * @param cell - * Cell to add a segment to. - * - * @return Segment - * The created segment. - */ - Segment createSegment(CellIdx cell); - - /** - * Returns the indices of cells that belong to a column. - * - * @param column Column index - * - * @return (vector) Cell indices - */ - vector cellsForColumn(Int column); - - /** - * Returns the number of cells in this layer. - * - * @return (int) Number of cells - */ - UInt numberOfCells(void); - - /** - * Returns the indices of the active cells. - * - * @returns (std::vector) Vector of indices of active cells. - */ - vector getActiveCells() const; - - /** - * Returns the indices of the predictive cells. - * - * @returns (std::vector) Vector of indices of predictive cells. - */ - vector getPredictiveCells() const; - - /** - * Returns the indices of the winner cells. - * - * @returns (std::vector) Vector of indices of winner cells. - */ - vector getWinnerCells() const; - - vector getActiveSegments() const; - vector getMatchingSegments() const; - - /** - * Returns the dimensions of the columns in the region. - * - * @returns Integer number of column dimension - */ - vector getColumnDimensions() const; - - /** - * Returns the total number of columns. - * - * @returns Integer number of column numbers - */ - UInt numberOfColumns() const; - - /** - * Returns the number of cells per column. - * - * @returns Integer number of cells per column - */ - UInt getCellsPerColumn() const; - - /** - * Returns the activation threshold. - * - * @returns Integer number of the activation threshold - */ - UInt getActivationThreshold() const; - void setActivationThreshold(UInt); - - /** - * Returns the initial permanence. - * - * @returns Initial permanence - */ - Permanence getInitialPermanence() const; - void setInitialPermanence(Permanence); - - /** - * Returns the connected permanance. - * - * @returns Returns the connected permanance - */ - Permanence getConnectedPermanence() const; - void setConnectedPermanence(Permanence); - - /** - * Returns the minimum threshold. - * - * @returns Integer number of minimum threshold - */ - UInt getMinThreshold() const; - void setMinThreshold(UInt); - - /** - * Returns the maximum number of synapses that can be added to a segment - * in a single time step. - * - * @returns Integer number of maximum new synapse count - */ - UInt getMaxNewSynapseCount() const; - void setMaxNewSynapseCount(UInt); - - /** - * Get and set the checkInputs parameter. - */ - bool getCheckInputs() const; - void setCheckInputs(bool); - - /** - * Returns the permanence increment. - * - * @returns Returns the Permanence increment - */ - Permanence getPermanenceIncrement() const; - void setPermanenceIncrement(Permanence); - - /** - * Returns the permanence decrement. - * - * @returns Returns the Permanence decrement - */ - Permanence getPermanenceDecrement() const; - void setPermanenceDecrement(Permanence); - - /** - * Returns the predicted Segment decrement. - * - * @returns Returns the segment decrement - */ - Permanence getPredictedSegmentDecrement() const; - void setPredictedSegmentDecrement(Permanence); - - /** - * Returns the maxSegmentsPerCell. - * - * @returns Max segments per cell - */ - UInt getMaxSegmentsPerCell() const; - - /** - * Returns the maxSynapsesPerSegment. - * - * @returns Max synapses per segment - */ - UInt getMaxSynapsesPerSegment() const; - - /** - * Raises an error if cell index is invalid. - * - * @param cell Cell index - */ - bool _validateCell(CellIdx cell); - - /** - * Save (serialize) the current state of the spatial pooler to the - * specified file. - * - * @param fd A valid file descriptor. - */ - virtual void save(ostream& outStream) const; - - using Serializable::write; - virtual void write(TemporalMemoryProto::Builder& proto) const override; - - /** - * Load (deserialize) and initialize the spatial pooler from the - * specified input stream. - * - * @param inStream A valid istream. - */ - virtual void load(istream& inStream); - - using Serializable::read; - virtual void read(TemporalMemoryProto::Reader& proto) override; - - /** - * Returns the number of bytes that a save operation would result in. - * Note: this method is currently somewhat inefficient as it just does - * a full save into an ostream and counts the resulting size. - * - * @returns Integer number of bytes - */ - virtual UInt persistentSize() const; - - bool operator==(const TemporalMemory& other); - bool operator!=(const TemporalMemory& other); - - //---------------------------------------------------------------------- - // Debugging helpers - //---------------------------------------------------------------------- - - /** - * Print the main TM creation parameters - */ - void printParameters(); - - /** - * Returns the index of the column that a cell belongs to. - * - * @param cell Cell index - * - * @return (int) Column index - */ - Int columnForCell(CellIdx cell); - - /** - * Print the given UInt array in a nice format - */ - void printState(vector &state); - - /** - * Print the given Real array in a nice format - */ - void printState(vector &state); - - protected: - UInt numColumns_; - vector columnDimensions_; - UInt cellsPerColumn_; - UInt activationThreshold_; - UInt minThreshold_; - UInt maxNewSynapseCount_; - bool checkInputs_; - Permanence initialPermanence_; - Permanence connectedPermanence_; - Permanence permanenceIncrement_; - Permanence permanenceDecrement_; - Permanence predictedSegmentDecrement_; - - vector activeCells_; - vector winnerCells_; - vector activeSegments_; - vector matchingSegments_; - vector numActiveConnectedSynapsesForSegment_; - vector numActivePotentialSynapsesForSegment_; - - UInt maxSegmentsPerCell_; - UInt maxSynapsesPerSegment_; - UInt64 iteration_; - vector lastUsedIterationForSegment_; - - Random rng_; - - public: - Connections connections; - }; - - } // end namespace temporal_memory - } // end namespace algorithms -} // end namespace nta +namespace algorithms { +namespace temporal_memory { + +/** + * Temporal Memory implementation in C++. + * + * Example usage: + * + * SpatialPooler sp(inputDimensions, columnDimensions, ); + * TemporalMemory tm(columnDimensions, ); + * + * while (true) { + * + * sp.compute(inputVector, learn, activeColumns) + * tm.compute(number of activeColumns, activeColumns, learn) + * + * } + * + * The public API uses C arrays, not std::vectors, as inputs. C arrays are + * a good lowest common denominator. You can get a C array from a vector, + * but you can't get a vector from a C array without copying it. This is + * important, for example, when using numpy arrays. The only way to + * convert a numpy array into a std::vector is to copy it, but you can + * access a numpy array's internal C array directly. + */ +class TemporalMemory : public Serializable { +public: + TemporalMemory(); + + /** + * Initialize the temporal memory (TM) using the given parameters. + * + * @param columnDimensions + * Dimensions of the column space + * + * @param cellsPerColumn + * Number of cells per column + * + * @param activationThreshold + * If the number of active connected synapses on a segment is at least + * this threshold, the segment is said to be active. + * + * @param initialPermanence + * Initial permanence of a new synapse. + * + * @param connectedPermanence + * If the permanence value for a synapse is greater than this value, it + * is said to be connected. + * + * @param minThreshold + * If the number of potential synapses active on a segment is at least + * this threshold, it is said to be "matching" and is eligible for + * learning. + * + * @param maxNewSynapseCount + * The maximum number of synapses added to a segment during learning. + * + * @param permanenceIncrement + * Amount by which permanences of synapses are incremented during + * learning. + * + * @param permanenceDecrement + * Amount by which permanences of synapses are decremented during + * learning. + * + * @param predictedSegmentDecrement + * Amount by which segments are punished for incorrect predictions. + * + * @param seed + * Seed for the random number generator. + * + * @param maxSegmentsPerCell + * The maximum number of segments per cell. + * + * @param maxSynapsesPerSegment + * The maximum number of synapses per segment. + * + * @param checkInputs + * Whether to check that the activeColumns are sorted without + * duplicates. Disable this for a small speed boost. + * + * Notes: + * + * predictedSegmentDecrement: A good value is just a bit larger than + * (the column-level sparsity * permanenceIncrement). So, if column-level + * sparsity is 2% and permanenceIncrement is 0.01, this parameter should be + * something like 4% * 0.01 = 0.0004). + */ + TemporalMemory(vector columnDimensions, UInt cellsPerColumn = 32, + UInt activationThreshold = 13, + Permanence initialPermanence = 0.21, + Permanence connectedPermanence = 0.50, UInt minThreshold = 10, + UInt maxNewSynapseCount = 20, + Permanence permanenceIncrement = 0.10, + Permanence permanenceDecrement = 0.10, + Permanence predictedSegmentDecrement = 0.0, Int seed = 42, + UInt maxSegmentsPerCell = 255, + UInt maxSynapsesPerSegment = 255, bool checkInputs = true); + + virtual void + initialize(vector columnDimensions = {2048}, UInt cellsPerColumn = 32, + UInt activationThreshold = 13, Permanence initialPermanence = 0.21, + Permanence connectedPermanence = 0.50, UInt minThreshold = 10, + UInt maxNewSynapseCount = 20, + Permanence permanenceIncrement = 0.10, + Permanence permanenceDecrement = 0.10, + Permanence predictedSegmentDecrement = 0.0, Int seed = 42, + UInt maxSegmentsPerCell = 255, UInt maxSynapsesPerSegment = 255, + bool checkInputs = true); + + virtual ~TemporalMemory(); + + //---------------------------------------------------------------------- + // Main functions + //---------------------------------------------------------------------- + + /** + * Get the version number of for the TM implementation. + * + * @returns Integer version number. + */ + virtual UInt version() const; + + /** + * This *only* updates _rng to a new Random using seed. + * + * @returns Integer version number. + */ + void seed_(UInt64 seed); + + /** + * Indicates the start of a new sequence. + * Resets sequence state of the TM. + */ + virtual void reset(); + + /** + * Calculate the active cells, using the current active columns and + * dendrite segments. Grow and reinforce synapses. + * + * @param activeColumnsSize + * Size of activeColumns. + * + * @param activeColumns + * A sorted list of active column indices. + * + * @param learn + * If true, reinforce / punish / grow synapses. + */ + void activateCells(size_t activeColumnsSize, const UInt activeColumns[], + bool learn = true); + + /** + * Calculate dendrite segment activity, using the current active cells. + * + * @param learn + * If true, segment activations will be recorded. This information is + * used during segment cleanup. + */ + void activateDendrites(bool learn = true); + + /** + * Perform one time step of the Temporal Memory algorithm. + * + * This method calls activateCells, then calls activateDendrites. Using + * the TemporalMemory via its compute method ensures that you'll always + * be able to call getPredictiveCells to get predictions for the next + * time step. + * + * @param activeColumnsSize + * Number of active columns. + * + * @param activeColumns + * Sorted list of indices of active columns. + * + * @param learn + * Whether or not learning is enabled. + */ + virtual void compute(size_t activeColumnsSize, const UInt activeColumns[], + bool learn = true); + + // ============================== + // Helper functions + // ============================== + + /** + * Create a segment on the specified cell. This method calls + * createSegment on the underlying connections, and it does some extra + * bookkeeping. Unit tests should call this method, and not + * connections.createSegment(). + * + * @param cell + * Cell to add a segment to. + * + * @return Segment + * The created segment. + */ + Segment createSegment(CellIdx cell); + + /** + * Returns the indices of cells that belong to a column. + * + * @param column Column index + * + * @return (vector) Cell indices + */ + vector cellsForColumn(Int column); + + /** + * Returns the number of cells in this layer. + * + * @return (int) Number of cells + */ + UInt numberOfCells(void); + + /** + * Returns the indices of the active cells. + * + * @returns (std::vector) Vector of indices of active cells. + */ + vector getActiveCells() const; + + /** + * Returns the indices of the predictive cells. + * + * @returns (std::vector) Vector of indices of predictive cells. + */ + vector getPredictiveCells() const; + + /** + * Returns the indices of the winner cells. + * + * @returns (std::vector) Vector of indices of winner cells. + */ + vector getWinnerCells() const; + + vector getActiveSegments() const; + vector getMatchingSegments() const; + + /** + * Returns the dimensions of the columns in the region. + * + * @returns Integer number of column dimension + */ + vector getColumnDimensions() const; + + /** + * Returns the total number of columns. + * + * @returns Integer number of column numbers + */ + UInt numberOfColumns() const; + + /** + * Returns the number of cells per column. + * + * @returns Integer number of cells per column + */ + UInt getCellsPerColumn() const; + + /** + * Returns the activation threshold. + * + * @returns Integer number of the activation threshold + */ + UInt getActivationThreshold() const; + void setActivationThreshold(UInt); + + /** + * Returns the initial permanence. + * + * @returns Initial permanence + */ + Permanence getInitialPermanence() const; + void setInitialPermanence(Permanence); + + /** + * Returns the connected permanance. + * + * @returns Returns the connected permanance + */ + Permanence getConnectedPermanence() const; + void setConnectedPermanence(Permanence); + + /** + * Returns the minimum threshold. + * + * @returns Integer number of minimum threshold + */ + UInt getMinThreshold() const; + void setMinThreshold(UInt); + + /** + * Returns the maximum number of synapses that can be added to a segment + * in a single time step. + * + * @returns Integer number of maximum new synapse count + */ + UInt getMaxNewSynapseCount() const; + void setMaxNewSynapseCount(UInt); + + /** + * Get and set the checkInputs parameter. + */ + bool getCheckInputs() const; + void setCheckInputs(bool); + + /** + * Returns the permanence increment. + * + * @returns Returns the Permanence increment + */ + Permanence getPermanenceIncrement() const; + void setPermanenceIncrement(Permanence); + + /** + * Returns the permanence decrement. + * + * @returns Returns the Permanence decrement + */ + Permanence getPermanenceDecrement() const; + void setPermanenceDecrement(Permanence); + + /** + * Returns the predicted Segment decrement. + * + * @returns Returns the segment decrement + */ + Permanence getPredictedSegmentDecrement() const; + void setPredictedSegmentDecrement(Permanence); + + /** + * Returns the maxSegmentsPerCell. + * + * @returns Max segments per cell + */ + UInt getMaxSegmentsPerCell() const; + + /** + * Returns the maxSynapsesPerSegment. + * + * @returns Max synapses per segment + */ + UInt getMaxSynapsesPerSegment() const; + + /** + * Raises an error if cell index is invalid. + * + * @param cell Cell index + */ + bool _validateCell(CellIdx cell); + + /** + * Save (serialize) the current state of the spatial pooler to the + * specified file. + * + * @param fd A valid file descriptor. + */ + virtual void save(ostream &outStream) const; + + using Serializable::write; + virtual void write(TemporalMemoryProto::Builder &proto) const override; + + /** + * Load (deserialize) and initialize the spatial pooler from the + * specified input stream. + * + * @param inStream A valid istream. + */ + virtual void load(istream &inStream); + + using Serializable::read; + virtual void read(TemporalMemoryProto::Reader &proto) override; + + /** + * Returns the number of bytes that a save operation would result in. + * Note: this method is currently somewhat inefficient as it just does + * a full save into an ostream and counts the resulting size. + * + * @returns Integer number of bytes + */ + virtual UInt persistentSize() const; + + bool operator==(const TemporalMemory &other); + bool operator!=(const TemporalMemory &other); + + //---------------------------------------------------------------------- + // Debugging helpers + //---------------------------------------------------------------------- + + /** + * Print the main TM creation parameters + */ + void printParameters(); + + /** + * Returns the index of the column that a cell belongs to. + * + * @param cell Cell index + * + * @return (int) Column index + */ + Int columnForCell(CellIdx cell); + + /** + * Print the given UInt array in a nice format + */ + void printState(vector &state); + + /** + * Print the given Real array in a nice format + */ + void printState(vector &state); + +protected: + UInt numColumns_; + vector columnDimensions_; + UInt cellsPerColumn_; + UInt activationThreshold_; + UInt minThreshold_; + UInt maxNewSynapseCount_; + bool checkInputs_; + Permanence initialPermanence_; + Permanence connectedPermanence_; + Permanence permanenceIncrement_; + Permanence permanenceDecrement_; + Permanence predictedSegmentDecrement_; + + vector activeCells_; + vector winnerCells_; + vector activeSegments_; + vector matchingSegments_; + vector numActiveConnectedSynapsesForSegment_; + vector numActivePotentialSynapsesForSegment_; + + UInt maxSegmentsPerCell_; + UInt maxSynapsesPerSegment_; + UInt64 iteration_; + vector lastUsedIterationForSegment_; + + Random rng_; + +public: + Connections connections; +}; + +} // end namespace temporal_memory +} // end namespace algorithms +} // namespace nupic #endif // NTA_TEMPORAL_MEMORY_HPP diff --git a/src/nupic/bindings/PySparseTensor.cpp b/src/nupic/bindings/PySparseTensor.cpp index ad79d18f3e..45c1952a86 100644 --- a/src/nupic/bindings/PySparseTensor.cpp +++ b/src/nupic/bindings/PySparseTensor.cpp @@ -20,13 +20,13 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file */ #include -#include #include +#include using namespace std; using namespace nupic; @@ -34,52 +34,49 @@ using namespace nupic; typedef nupic::SparseTensor STBase; PySparseTensor::PySparseTensor(PyObject *numpyArray) - // TODO: Switch to rank 0 (or at least dimension 0) default. - : tensor_(PyTensorIndex(1)) -{ + // TODO: Switch to rank 0 (or at least dimension 0) default. + : tensor_(PyTensorIndex(1)) { NumpyNDArray a(numpyArray); int rank = a.getRank(); - if(rank > PYSPARSETENSOR_MAX_RANK) - throw invalid_argument("Array rank exceeds max rank for SparseTensor bindings."); + if (rank > PYSPARSETENSOR_MAX_RANK) + throw invalid_argument( + "Array rank exceeds max rank for SparseTensor bindings."); int dims[PYSPARSETENSOR_MAX_RANK]; // Never larger than max ND array rank. a.getDims(dims); tensor_ = STBase(PyTensorIndex(rank, dims)); tensor_.fromDense(a.getData()); } -void PySparseTensor::set(const PyTensorIndex &i, PyObject *x) -{ +void PySparseTensor::set(const PyTensorIndex &i, PyObject *x) { PyObject *num = PyNumber_Float(x); - if(!num) throw std::invalid_argument("value is not a float."); - nupic::Real y = (nupic::Real) PyFloat_AsDouble(num); + if (!num) + throw std::invalid_argument("value is not a float."); + nupic::Real y = (nupic::Real)PyFloat_AsDouble(num); Py_CLEAR(num); set(i, y); } -PyObject *PySparseTensor::toDense() const -{ +PyObject *PySparseTensor::toDense() const { const PyTensorIndex &bounds = tensor_.getBounds(); int rank = bounds.size(); int dims[PYSPARSETENSOR_MAX_RANK]; - if(rank > PYSPARSETENSOR_MAX_RANK) + if (rank > PYSPARSETENSOR_MAX_RANK) throw std::logic_error("Rank exceeds max rank."); - for(int i=0; i> rank; }; - PyTensorIndex index(rank, (const size_t *) 0); - for(size_t i=0; i -#include #include +#include #include +#include -#include #include +#include //-------------------------------------------------------------------------------- typedef std::vector TIV; @@ -45,8 +44,7 @@ class PyTensorIndex; inline std::ostream &operator<<(std::ostream &o, const PyTensorIndex &j); //-------------------------------------------------------------------------------- -class PyTensorIndex -{ +class PyTensorIndex { enum { maxRank = PYSPARSETENSOR_MAX_RANK }; nupic::UInt32 index_[maxRank]; nupic::UInt32 rank_; @@ -58,39 +56,35 @@ class PyTensorIndex PyTensorIndex() : rank_(0) {} - PyTensorIndex(const PyTensorIndex &x) : rank_(x.rank_) - { - ::memcpy(index_, x.index_, rank_*sizeof(nupic::UInt32)); + PyTensorIndex(const PyTensorIndex &x) : rank_(x.rank_) { + ::memcpy(index_, x.index_, rank_ * sizeof(nupic::UInt32)); } - PyTensorIndex(nupic::UInt32 i) : rank_(1) - { - index_[0] = i; - } + PyTensorIndex(nupic::UInt32 i) : rank_(1) { index_[0] = i; } - PyTensorIndex(nupic::UInt32 i, nupic::UInt32 j) : rank_(2) - { + PyTensorIndex(nupic::UInt32 i, nupic::UInt32 j) : rank_(2) { index_[0] = i; index_[1] = j; } - PyTensorIndex(nupic::UInt32 i, nupic::UInt32 j, nupic::UInt32 k) : rank_(3) - { + PyTensorIndex(nupic::UInt32 i, nupic::UInt32 j, nupic::UInt32 k) : rank_(3) { index_[0] = i; index_[1] = j; index_[2] = k; } - PyTensorIndex(nupic::UInt32 i, nupic::UInt32 j, nupic::UInt32 k, nupic::UInt32 l) : rank_(4) - { + PyTensorIndex(nupic::UInt32 i, nupic::UInt32 j, nupic::UInt32 k, + nupic::UInt32 l) + : rank_(4) { index_[0] = i; index_[1] = j; index_[2] = k; index_[3] = l; } - PyTensorIndex(nupic::UInt32 i, nupic::UInt32 j, nupic::UInt32 k, nupic::UInt32 l, nupic::UInt32 m) : rank_(5) - { + PyTensorIndex(nupic::UInt32 i, nupic::UInt32 j, nupic::UInt32 k, + nupic::UInt32 l, nupic::UInt32 m) + : rank_(5) { index_[0] = i; index_[1] = j; index_[2] = k; @@ -98,8 +92,9 @@ class PyTensorIndex index_[4] = m; } - PyTensorIndex(nupic::UInt32 i, nupic::UInt32 j, nupic::UInt32 k, nupic::UInt32 l, nupic::UInt32 m, nupic::UInt32 n) : rank_(6) - { + PyTensorIndex(nupic::UInt32 i, nupic::UInt32 j, nupic::UInt32 k, + nupic::UInt32 l, nupic::UInt32 m, nupic::UInt32 n) + : rank_(6) { index_[0] = i; index_[1] = j; index_[2] = k; @@ -108,60 +103,60 @@ class PyTensorIndex index_[5] = n; } - PyTensorIndex(const TIV &i) : rank_(i.size()) - { + PyTensorIndex(const TIV &i) : rank_(i.size()) { if (rank_ > maxRank) { char errBuf[512]; - snprintf(errBuf, 512, - "Tensors may not be constructed of rank greater than %d.", maxRank); + snprintf(errBuf, 512, + "Tensors may not be constructed of rank greater than %d.", + maxRank); rank_ = 0; throw std::runtime_error(errBuf); } std::copy(i.begin(), i.end(), index_); } - template - PyTensorIndex(int nd, const T *d) : rank_(nd) - { + template PyTensorIndex(int nd, const T *d) : rank_(nd) { if (nd > maxRank) { char errBuf[512]; - snprintf(errBuf, 512, - "Tensors may not be constructed of rank greater than %d.", maxRank); + snprintf(errBuf, 512, + "Tensors may not be constructed of rank greater than %d.", + maxRank); rank_ = 0; throw std::runtime_error(errBuf); } - if(d) std::copy(d, d+nd, index_); - else std::fill(index_, index_+nd, 0); + if (d) + std::copy(d, d + nd, index_); + else + std::fill(index_, index_ + nd, 0); } - - PyTensorIndex(const PyTensorIndex& i1, const PyTensorIndex& i2) - : rank_(i1.rank_ + i2.rank_) - { + + PyTensorIndex(const PyTensorIndex &i1, const PyTensorIndex &i2) + : rank_(i1.rank_ + i2.rank_) { if (rank_ > maxRank) { char errBuf[512]; - snprintf(errBuf, 512, - "Tensors may not be constructed of rank greater than %d.", maxRank); + snprintf(errBuf, 512, + "Tensors may not be constructed of rank greater than %d.", + maxRank); rank_ = 0; throw std::runtime_error(errBuf); } - ::memcpy(index_, i1.index_, i1.rank_*sizeof(nupic::UInt32)); - ::memcpy(index_ + i1.rank_, i2.index_, i2.rank_*sizeof(nupic::UInt32)); + ::memcpy(index_, i1.index_, i1.rank_ * sizeof(nupic::UInt32)); + ::memcpy(index_ + i1.rank_, i2.index_, i2.rank_ * sizeof(nupic::UInt32)); } - - PyTensorIndex &operator=(const PyTensorIndex &x) - { + + PyTensorIndex &operator=(const PyTensorIndex &x) { rank_ = x.rank_; - ::memcpy(index_, x.index_, rank_*sizeof(nupic::UInt32)); + ::memcpy(index_, x.index_, rank_ * sizeof(nupic::UInt32)); return *this; } - PyTensorIndex &operator=(const TIV &i) - { - if(i.size() > maxRank) { + PyTensorIndex &operator=(const TIV &i) { + if (i.size() > maxRank) { char errBuf[512]; - snprintf(errBuf, 512, - "Tensors may not be constructed of rank greater than %d.", maxRank); + snprintf(errBuf, 512, + "Tensors may not be constructed of rank greater than %d.", + maxRank); rank_ = 0; throw std::runtime_error(errBuf); } @@ -172,19 +167,27 @@ class PyTensorIndex nupic::UInt32 size() const { return rank_; } - nupic::UInt32 operator[](nupic::UInt32 i) const - { - if(!(i < rank_)) throw std::invalid_argument("Index out of bounds."); + nupic::UInt32 operator[](nupic::UInt32 i) const { + if (!(i < rank_)) + throw std::invalid_argument("Index out of bounds."); return index_[i]; } - nupic::UInt32 &operator[](nupic::UInt32 i) - { - if(!(i < rank_)) throw std::invalid_argument("Index out of bounds."); + nupic::UInt32 &operator[](nupic::UInt32 i) { + if (!(i < rank_)) + throw std::invalid_argument("Index out of bounds."); return index_[i]; } - nupic::UInt32 __getitem__(int i) const { if(i < 0) i += rank_; return index_[i]; } - void __setitem__(int i, nupic::UInt32 d) { if(i < 0) i += rank_; index_[i] = d; } + nupic::UInt32 __getitem__(int i) const { + if (i < 0) + i += rank_; + return index_[i]; + } + void __setitem__(int i, nupic::UInt32 d) { + if (i < 0) + i += rank_; + index_[i] = d; + } nupic::UInt32 __len__() const { return rank_; } const nupic::UInt32 *begin() const { return index_; } @@ -192,16 +195,17 @@ class PyTensorIndex const nupic::UInt32 *end() const { return index_ + rank_; } nupic::UInt32 *end() { return index_ + rank_; } - bool operator==(const PyTensorIndex &j) const - { - if(rank_ != j.rank_) return false; - for(nupic::UInt32 i=0; i j.index_[k]) return false; - if(n < j.rank_) return true; - else return false; + if (n < j.rank_) + return true; + else + return false; return false; } bool __eq__(const PyTensorIndex &j) const { return (*this) == j; } @@ -218,55 +224,60 @@ class PyTensorIndex // bool __lt__(const PyTensorIndex &j) const { return (*this) < j; } bool __gt__(const PyTensorIndex &j) const { return j < (*this); } - bool operator==(const TIV &j) const - { - if(size() != j.size()) return false; - for(nupic::UInt32 i=0; i -{ +class PyDomain : public nupic::Domain { public: - PyDomain(const TIV &lowerHalfSpace) : nupic::Domain(lowerHalfSpace) {} + PyDomain(const TIV &lowerHalfSpace) + : nupic::Domain(lowerHalfSpace) {} PyDomain(const TIV &lower, const TIV &upper) - : nupic::Domain(lower, upper) {} + : nupic::Domain(lower, upper) {} - PyTensorIndex getLowerBound() const - { - PyTensorIndex bounds(rank(), (const nupic::UInt32 *) 0); + PyTensorIndex getLowerBound() const { + PyTensorIndex bounds(rank(), (const nupic::UInt32 *)0); getLB(bounds); return bounds; } - PyTensorIndex getUpperBound() const - { - PyTensorIndex bounds(rank(), (const nupic::UInt32 *) 0); + PyTensorIndex getUpperBound() const { + PyTensorIndex bounds(rank(), (const nupic::UInt32 *)0); getUB(bounds); return bounds; } - std::vector __getitem__(int i) const - { + std::vector __getitem__(int i) const { nupic::DimRange r = (*this)[i]; nupic::UInt32 v[3]; v[0] = r.getDim(); v[1] = r.getLB(); v[2] = r.getUB(); - return std::vector(v, v+3); + return std::vector(v, v + 3); } - PyTensorIndex getDimensions() const - { - PyTensorIndex bounds(rank(), (const nupic::UInt32 *) 0); + PyTensorIndex getDimensions() const { + PyTensorIndex bounds(rank(), (const nupic::UInt32 *)0); getDims(bounds); return bounds; } nupic::UInt32 getNumOpenDims() const { return getNOpenDims(); } - PyTensorIndex getOpenDimensions() const - { - PyTensorIndex bounds(getNumOpenDims(), (const nupic::UInt32 *) 0); + PyTensorIndex getOpenDimensions() const { + PyTensorIndex bounds(getNumOpenDims(), (const nupic::UInt32 *)0); getOpenDims(bounds); return bounds; } // PyTensorIndex getSliceBounds(const TIV &maxBounds) const - PyTensorIndex getSliceBounds() const - { - PyTensorIndex bounds(getNumOpenDims(), (const nupic::UInt32 *) 0); + PyTensorIndex getSliceBounds() const { + PyTensorIndex bounds(getNumOpenDims(), (const nupic::UInt32 *)0); nupic::UInt32 n = rank(); nupic::UInt32 cur = 0; - for(nupic::UInt32 i=0; i r = (*this)[i]; - if(!(r.getDim() == i)) throw std::invalid_argument("Out-of-order dims."); - if(r.empty()) {} - else { + if (!(r.getDim() == i)) + throw std::invalid_argument("Out-of-order dims."); + if (r.empty()) { + } else { bounds[cur++] = r.getUB() - r.getLB(); } } return bounds; } - bool doesInclude(const TIV &x) const - { return includes(x); } + bool doesInclude(const TIV &x) const { return includes(x); } - std::string __str__() const - { + std::string __str__() const { std::stringstream s; s << "("; nupic::UInt32 n = rank(); - for(nupic::UInt32 i=0; i r = (*this)[i]; s << "(" << r.getDim() << ", " << r.getLB() << ", " << r.getUB() << ")"; } @@ -366,10 +372,9 @@ class PyDomain : public nupic::Domain }; //-------------------------------------------------------------------------------- -class PySparseTensor -{ +class PySparseTensor { nupic::SparseTensor tensor_; - + public: PySparseTensor(const std::string &state); PySparseTensor(const TIV &bounds) : tensor_(PyTensorIndex(bounds)) {} @@ -379,7 +384,9 @@ class PySparseTensor nupic::UInt32 getRank() const { return tensor_.getRank(); } PyTensorIndex getBounds() const { return tensor_.getBounds(); } - nupic::UInt32 getBound(const nupic::UInt32 dim) const { return tensor_.getBound(dim); } + nupic::UInt32 getBound(const nupic::UInt32 dim) const { + return tensor_.getBound(dim); + } nupic::Real get(const TIV &i) const { return get(PyTensorIndex(i)); } nupic::Real get(const PyTensorIndex &i) const { return tensor_.get(i); } @@ -390,121 +397,117 @@ class PySparseTensor nupic::UInt32 getNNonZeros() const { return tensor_.getNNonZeros(); } nupic::UInt32 nNonZeros() const { return tensor_.getNNonZeros(); } - - PySparseTensor reshape(const TIV &dims) const - { + + PySparseTensor reshape(const TIV &dims) const { PySparseTensor t(dims); tensor_.reshape(t.tensor_); return t; } - void resize(const TIV &dims) - { tensor_.resize(PyTensorIndex(dims)); } + void resize(const TIV &dims) { tensor_.resize(PyTensorIndex(dims)); } - void resize(const PyTensorIndex &dims) - { tensor_.resize(dims); } + void resize(const PyTensorIndex &dims) { tensor_.resize(dims); } - PySparseTensor extract(nupic::UInt32 dim, const TIV &ind) const - { + PySparseTensor extract(nupic::UInt32 dim, const TIV &ind) const { std::set subset(ind.begin(), ind.end()); PySparseTensor t(tensor_.getBounds()); tensor_.extract(dim, subset, t.tensor_); return t; } - void reduce(nupic::UInt32 dim, const TIV &ind) - { + void reduce(nupic::UInt32 dim, const TIV &ind) { std::set subset(ind.begin(), ind.end()); tensor_.reduce(dim, subset); } - - PySparseTensor getSlice(const PyDomain &range) const - { + + PySparseTensor getSlice(const PyDomain &range) const { PyTensorIndex dims = range.getSliceBounds(); PySparseTensor t(dims); tensor_.getSlice(range, t.tensor_); return t; } - void setSlice(const PyDomain &range, const PySparseTensor &slice) - { + void setSlice(const PyDomain &range, const PySparseTensor &slice) { tensor_.setSlice(range, slice.tensor_); } - void setZero(const PyDomain& range) - { - tensor_.setZero(range); - } + void setZero(const PyDomain &range) { tensor_.setZero(range); } - void addSlice(nupic::UInt32 which, nupic::UInt32 src, nupic::UInt32 dst) - { + void addSlice(nupic::UInt32 which, nupic::UInt32 src, nupic::UInt32 dst) { tensor_.addSlice(which, src, dst); } - PySparseTensor factorMultiply(const TIV &dims, const PySparseTensor &B) const - { return factorMultiply(PyTensorIndex(dims), B); } + PySparseTensor factorMultiply(const TIV &dims, + const PySparseTensor &B) const { + return factorMultiply(PyTensorIndex(dims), B); + } - PySparseTensor factorMultiply(const PyTensorIndex &dims, const PySparseTensor &B) const - { + PySparseTensor factorMultiply(const PyTensorIndex &dims, + const PySparseTensor &B) const { PySparseTensor C(getBounds()); - tensor_.factor_apply_fast(dims, B.tensor_, C.tensor_, std::multiplies()); + tensor_.factor_apply_fast(dims, B.tensor_, C.tensor_, + std::multiplies()); return C; } - PySparseTensor outerProduct(const PySparseTensor& B) const - { + PySparseTensor outerProduct(const PySparseTensor &B) const { PySparseTensor C(PyTensorIndex(getBounds(), B.getBounds())); - tensor_.outer_product_nz(B.tensor_, C.tensor_, std::multiplies()); + tensor_.outer_product_nz(B.tensor_, C.tensor_, + std::multiplies()); return C; } - PySparseTensor innerProduct(const nupic::UInt32 dim1, const nupic::UInt32 dim2, const PySparseTensor& B) const - { + PySparseTensor innerProduct(const nupic::UInt32 dim1, + const nupic::UInt32 dim2, + const PySparseTensor &B) const { // Only works on rank 2 tensors right now - if((getRank() != 2) || (B.getRank() != 2)) - throw std::invalid_argument("innerProduct only works for rank 2 tensors."); - PySparseTensor C(PyTensorIndex(getBound(1-dim1),B.getBound(1-dim2))); - tensor_.inner_product_nz(dim1, dim2, B.tensor_, C.tensor_, std::multiplies(), std::plus()); + if ((getRank() != 2) || (B.getRank() != 2)) + throw std::invalid_argument( + "innerProduct only works for rank 2 tensors."); + PySparseTensor C(PyTensorIndex(getBound(1 - dim1), B.getBound(1 - dim2))); + tensor_.inner_product_nz(dim1, dim2, B.tensor_, C.tensor_, + std::multiplies(), + std::plus()); return C; } - PySparseTensor __add__(const PySparseTensor &B) const - { + PySparseTensor __add__(const PySparseTensor &B) const { PySparseTensor C(getBounds()); tensor_.axby(1.0, B.tensor_, 1.0, C.tensor_); return C; } - PySparseTensor __sub__(const PySparseTensor &B) const - { + PySparseTensor __sub__(const PySparseTensor &B) const { PySparseTensor C(getBounds()); tensor_.axby(1.0, B.tensor_, -1.0, C.tensor_); return C; } - PySparseTensor factorAdd(const TIV &dims, const PySparseTensor &B) const - { return factorAdd(PyTensorIndex(dims), B); } + PySparseTensor factorAdd(const TIV &dims, const PySparseTensor &B) const { + return factorAdd(PyTensorIndex(dims), B); + } - PySparseTensor factorAdd(const PyTensorIndex& dims, const PySparseTensor& B) const - { + PySparseTensor factorAdd(const PyTensorIndex &dims, + const PySparseTensor &B) const { PySparseTensor C(getBounds()); - tensor_.factor_apply_nz(dims, B.tensor_, C.tensor_, std::plus(), true); + tensor_.factor_apply_nz(dims, B.tensor_, C.tensor_, + std::plus(), true); return C; } - PySparseTensor getComplementBounds(const PyTensorIndex &dims) const - { + PySparseTensor getComplementBounds(const PyTensorIndex &dims) const { PyTensorIndex process(tensor_.getBounds()); nupic::UInt32 n = dims.size(); - for(nupic::UInt32 i=0; i(), nupic::Real(0.0f)); + tensor_.accumulate_nz(dims, B.tensor_, std::plus(), + nupic::Real(0.0f)); return B; } - + PyTensorIndex argmax() const; nupic::Real max() const; - PySparseTensor max(const TIV &dims) const - { return this->max(PyTensorIndex(dims)); } + PySparseTensor max(const TIV &dims) const { + return this->max(PyTensorIndex(dims)); + } - PySparseTensor max(const PyTensorIndex &dims) const - { + PySparseTensor max(const PyTensorIndex &dims) const { PySparseTensor B(getComplementBounds(dims)); tensor_.max(dims, B.tensor_); return B; } - PyObject* tolist() const - { + PyObject *tolist() const { const nupic::UInt32 rank = getRank(); const nupic::UInt32 nnz = getNNonZeros(); std::vector ind(nnz); nupic::NumpyVectorT val(nnz); tensor_.toList(ind.begin(), val.begin()); - PyObject* ind_list = PyTuple_New(nnz); + PyObject *ind_list = PyTuple_New(nnz); for (nupic::UInt32 i = 0; i != nnz; ++i) { - PyObject* idx = PyTuple_New(rank); + PyObject *idx = PyTuple_New(rank); for (nupic::UInt32 j = 0; j != rank; ++j) - PyTuple_SET_ITEM(idx, j, PyInt_FromLong(ind[i][j])); + PyTuple_SET_ITEM(idx, j, PyInt_FromLong(ind[i][j])); PyTuple_SET_ITEM(ind_list, i, idx); } - PyObject* toReturn = PyTuple_New(2); + PyObject *toReturn = PyTuple_New(2); PyTuple_SET_ITEM(toReturn, 0, ind_list); PyTuple_SET_ITEM(toReturn, 1, val.forPython()); return toReturn; } - bool __eq__(const PySparseTensor &B) const - { return nupic::operator==(tensor_, B.tensor_); } - bool __ne__(const PySparseTensor &B) const - { return nupic::operator==(tensor_, B.tensor_); } + bool __eq__(const PySparseTensor &B) const { + return nupic::operator==(tensor_, B.tensor_); + } + bool __ne__(const PySparseTensor &B) const { + return nupic::operator==(tensor_, B.tensor_); + } PyObject *toDense() const; @@ -574,4 +579,3 @@ class PySparseTensor //-------------------------------------------------------------------------------- #endif // __nta_PySparseTensor_hpp__ - diff --git a/src/nupic/encoders/ScalarEncoder.cpp b/src/nupic/encoders/ScalarEncoder.cpp index ae701e19be..f4b695a1f4 100644 --- a/src/nupic/encoders/ScalarEncoder.cpp +++ b/src/nupic/encoders/ScalarEncoder.cpp @@ -24,187 +24,148 @@ * Implementations of the ScalarEncoder and PeriodicScalarEncoder */ -#include // memset #include +#include // memset #include #include -namespace nupic -{ - ScalarEncoder::ScalarEncoder( - int w, double minValue, double maxValue, int n, double radius, - double resolution, bool clipInput) - :w_(w), - minValue_(minValue), - maxValue_(maxValue), - clipInput_(clipInput) - { - if ((n != 0 && (radius != 0 || resolution != 0)) || - (radius != 0 && (n != 0 || resolution != 0)) || - (resolution != 0 && (n != 0 || radius != 0))) - { - NTA_THROW << - "Only one of n/radius/resolution can be specified for a ScalarEncoder."; - } +namespace nupic { +ScalarEncoder::ScalarEncoder(int w, double minValue, double maxValue, int n, + double radius, double resolution, bool clipInput) + : w_(w), minValue_(minValue), maxValue_(maxValue), clipInput_(clipInput) { + if ((n != 0 && (radius != 0 || resolution != 0)) || + (radius != 0 && (n != 0 || resolution != 0)) || + (resolution != 0 && (n != 0 || radius != 0))) { + NTA_THROW << "Only one of n/radius/resolution can be specified for a " + "ScalarEncoder."; + } - const double extentWidth = maxValue - minValue; - if (extentWidth <= 0) - { - NTA_THROW << "minValue must be < maxValue. minValue=" << minValue << - " maxValue=" << maxValue; - } + const double extentWidth = maxValue - minValue; + if (extentWidth <= 0) { + NTA_THROW << "minValue must be < maxValue. minValue=" << minValue + << " maxValue=" << maxValue; + } + + if (n != 0) { + n_ = n; - if (n != 0) - { - n_ = n; - - if (w_ < 1 || w_ >= n_) - { - NTA_THROW << "w must be within the range [1, n). w=" << w_ << " n=" << n_; - } - - // Distribute nBuckets points along the domain [minValue, maxValue], - // including the endpoints. The resolution is the width of each band - // between the points. - const int nBuckets = n - (w - 1); - const int nBands = nBuckets - 1; - bucketWidth_ = extentWidth / nBands; + if (w_ < 1 || w_ >= n_) { + NTA_THROW << "w must be within the range [1, n). w=" << w_ << " n=" << n_; } - else - { - bucketWidth_ = resolution || radius / w; - if (bucketWidth_ == 0) - { - NTA_THROW << "One of n/radius/resolution must be nonzero."; - } - - const int neededBands = ceil(extentWidth / bucketWidth_); - const int neededBuckets = neededBands + 1; - n_ = neededBuckets + (w - 1); + + // Distribute nBuckets points along the domain [minValue, maxValue], + // including the endpoints. The resolution is the width of each band + // between the points. + const int nBuckets = n - (w - 1); + const int nBands = nBuckets - 1; + bucketWidth_ = extentWidth / nBands; + } else { + bucketWidth_ = resolution || radius / w; + if (bucketWidth_ == 0) { + NTA_THROW << "One of n/radius/resolution must be nonzero."; } - } - ScalarEncoder::~ScalarEncoder() - { + const int neededBands = ceil(extentWidth / bucketWidth_); + const int neededBuckets = neededBands + 1; + n_ = neededBuckets + (w - 1); } +} - int ScalarEncoder::encodeIntoArray(Real64 input, Real32 output[]) - { - if (input < minValue_) - { - if (clipInput_) - { - input = minValue_; - } - else - { - NTA_THROW << "input (" << input << ") less than range [" << minValue_ << - ", " << maxValue_ << "]"; - } - } else if (input > maxValue_) { - if (clipInput_) - { - input = maxValue_; - } - else - { - NTA_THROW << "input (" << input << ") greater than range [" << minValue_ << - ", " << maxValue_ << "]"; - } - } +ScalarEncoder::~ScalarEncoder() {} - const int iBucket = round((input - minValue_) / bucketWidth_); +int ScalarEncoder::encodeIntoArray(Real64 input, Real32 output[]) { + if (input < minValue_) { + if (clipInput_) { + input = minValue_; + } else { + NTA_THROW << "input (" << input << ") less than range [" << minValue_ + << ", " << maxValue_ << "]"; + } + } else if (input > maxValue_) { + if (clipInput_) { + input = maxValue_; + } else { + NTA_THROW << "input (" << input << ") greater than range [" << minValue_ + << ", " << maxValue_ << "]"; + } + } - const int firstBit = iBucket; + const int iBucket = round((input - minValue_) / bucketWidth_); - memset(output, 0, n_*sizeof(output[0])); - for (int i = 0; i < w_; i++) - { - output[firstBit + i] = 1; - } + const int firstBit = iBucket; - return iBucket; + memset(output, 0, n_ * sizeof(output[0])); + for (int i = 0; i < w_; i++) { + output[firstBit + i] = 1; } - PeriodicScalarEncoder::PeriodicScalarEncoder( - int w, double minValue, double maxValue, int n, double radius, double resolution) - :w_(w), - minValue_(minValue), - maxValue_(maxValue) - { - if ((n != 0 && (radius != 0 || resolution != 0)) || - (radius != 0 && (n != 0 || resolution != 0)) || - (resolution != 0 && (n != 0 || radius != 0))) - { - NTA_THROW << - "Only one of n/radius/resolution can be specified for a ScalarEncoder."; - } - - const double extentWidth = maxValue - minValue; - if (extentWidth <= 0) - { - NTA_THROW << "minValue must be < maxValue. minValue=" << minValue << - " maxValue=" << maxValue; - } + return iBucket; +} + +PeriodicScalarEncoder::PeriodicScalarEncoder(int w, double minValue, + double maxValue, int n, + double radius, double resolution) + : w_(w), minValue_(minValue), maxValue_(maxValue) { + if ((n != 0 && (radius != 0 || resolution != 0)) || + (radius != 0 && (n != 0 || resolution != 0)) || + (resolution != 0 && (n != 0 || radius != 0))) { + NTA_THROW << "Only one of n/radius/resolution can be specified for a " + "ScalarEncoder."; + } - if (n != 0) - { - n_ = n; + const double extentWidth = maxValue - minValue; + if (extentWidth <= 0) { + NTA_THROW << "minValue must be < maxValue. minValue=" << minValue + << " maxValue=" << maxValue; + } - if (w_ < 1 || w_ >= n_) - { - NTA_THROW << "w must be within the range [1, n). w=" << w_ << " n=" << n_; - } + if (n != 0) { + n_ = n; - // Distribute nBuckets equal-width bands within the domain [minValue, maxValue]. - // The resolution is the width of each band. - const int nBuckets = n; - bucketWidth_ = extentWidth / nBuckets; + if (w_ < 1 || w_ >= n_) { + NTA_THROW << "w must be within the range [1, n). w=" << w_ << " n=" << n_; } - else - { - bucketWidth_ = resolution || radius / w; - if (bucketWidth_ == 0) - { - NTA_THROW << "One of n/radius/resolution must be nonzero."; - } - - const int neededBuckets = ceil((maxValue - minValue) / bucketWidth_); - n_ = (neededBuckets > w_) ? neededBuckets : w_ + 1; + + // Distribute nBuckets equal-width bands within the domain [minValue, + // maxValue]. The resolution is the width of each band. + const int nBuckets = n; + bucketWidth_ = extentWidth / nBuckets; + } else { + bucketWidth_ = resolution || radius / w; + if (bucketWidth_ == 0) { + NTA_THROW << "One of n/radius/resolution must be nonzero."; } - } - PeriodicScalarEncoder::~PeriodicScalarEncoder() - { + const int neededBuckets = ceil((maxValue - minValue) / bucketWidth_); + n_ = (neededBuckets > w_) ? neededBuckets : w_ + 1; } +} - int PeriodicScalarEncoder::encodeIntoArray(Real64 input, Real32 output[]) - { - if (input < minValue_ || input >= maxValue_) - { - NTA_THROW << "input " << input << " not within range [" << minValue_ << - ", " << maxValue_ << ")"; - } +PeriodicScalarEncoder::~PeriodicScalarEncoder() {} - const int iBucket = (int)((input - minValue_) / bucketWidth_); +int PeriodicScalarEncoder::encodeIntoArray(Real64 input, Real32 output[]) { + if (input < minValue_ || input >= maxValue_) { + NTA_THROW << "input " << input << " not within range [" << minValue_ << ", " + << maxValue_ << ")"; + } - const int middleBit = iBucket; - const double reach = (w_ - 1) / 2.0; - const int left = floor(reach); - const int right = ceil(reach); + const int iBucket = (int)((input - minValue_) / bucketWidth_); - memset(output, 0, n_*sizeof(output[0])); - output[middleBit] = 1; - for (int i = 1; i <= left; i++) - { - const int index = middleBit - i; - output[(index < 0) ? index + n_ : index] = 1; - } - for (int i = 1; i <= right; i++) - { - output[(middleBit + i) % n_] = 1; - } + const int middleBit = iBucket; + const double reach = (w_ - 1) / 2.0; + const int left = floor(reach); + const int right = ceil(reach); - return iBucket; + memset(output, 0, n_ * sizeof(output[0])); + output[middleBit] = 1; + for (int i = 1; i <= left; i++) { + const int index = middleBit - i; + output[(index < 0) ? index + n_ : index] = 1; + } + for (int i = 1; i <= right; i++) { + output[(middleBit + i) % n_] = 1; } + + return iBucket; +} } // end namespace nupic diff --git a/src/nupic/encoders/ScalarEncoder.hpp b/src/nupic/encoders/ScalarEncoder.hpp index bea623a28a..51071a84ef 100644 --- a/src/nupic/encoders/ScalarEncoder.hpp +++ b/src/nupic/encoders/ScalarEncoder.hpp @@ -29,151 +29,148 @@ #include -namespace nupic -{ +namespace nupic { +/** + * @b Description + * Base class for ScalarEncoders + */ +class ScalarEncoderBase { +public: + virtual ~ScalarEncoderBase() {} + /** - * @b Description - * Base class for ScalarEncoders + * Encodes input, puts the encoded value into output, and returns the a + * bucket number for the encoding. + * + * The bucket number is essentially the input encoded into an integer rather + * than an array. A bucket number is easier to "decode" or to use inside a + * classifier. + * + * @param input The value to encode + * @param output Should have length of at least getOutputWidth() */ - class ScalarEncoderBase - { - public: - virtual ~ScalarEncoderBase() - {} - - /** - * Encodes input, puts the encoded value into output, and returns the a - * bucket number for the encoding. - * - * The bucket number is essentially the input encoded into an integer rather - * than an array. A bucket number is easier to "decode" or to use inside a - * classifier. - * - * @param input The value to encode - * @param output Should have length of at least getOutputWidth() - */ - virtual int encodeIntoArray(Real64 input, Real32 output[]) = 0; + virtual int encodeIntoArray(Real64 input, Real32 output[]) = 0; - /** - * Returns the output width, in bits. - */ - virtual int getOutputWidth() const = 0; - }; + /** + * Returns the output width, in bits. + */ + virtual int getOutputWidth() const = 0; +}; - /** Encodes a floating point number as a contiguous block of 1s. +/** Encodes a floating point number as a contiguous block of 1s. + * + * @b Description + * A ScalarEncoder encodes a numeric (floating point) value into an array + * of bits. The output is 0's except for a contiguous block of 1's. The + * location of this contiguous block varies continuously with the input value. + * + * Conceptually, the set of possible outputs is a set of "buckets". If there + * are m buckets, the ScalarEncoder distributes m points along the domain + * [minValue, maxValue], including the endpoints. To figure out the bucket + * index of an input, it rounds the input to the nearest of these points. + * + * This approach is different from the PeriodicScalarEncoder because two + * buckets, the first and last, are half as wide as the rest, since fewer + * numbers in the input domain will round to these endpoints. This behavior + * makes sense because, for example, with the input space [1, 10] and 10 + * buckets, 1.49 is in the first bucket and 1.51 is in the second. + */ +class ScalarEncoder : public ScalarEncoderBase { +public: + /** + * Constructs a ScalarEncoder * - * @b Description - * A ScalarEncoder encodes a numeric (floating point) value into an array - * of bits. The output is 0's except for a contiguous block of 1's. The - * location of this contiguous block varies continuously with the input value. + * @param w The number of bits that are set to encode a single value -- the + * "width" of the output signal + * @param minValue The minimum value of the input signal, inclusive. + * @param maxValue The maximum value of the input signal, inclusive. + * @param clipInput Whether to allow input values outside the [minValue, + * maxValue] range. If true, the input will be clipped to minValue or + * maxValue. * - * Conceptually, the set of possible outputs is a set of "buckets". If there - * are m buckets, the ScalarEncoder distributes m points along the domain - * [minValue, maxValue], including the endpoints. To figure out the bucket - * index of an input, it rounds the input to the nearest of these points. + * There are three mutually exclusive parameters that determine the overall + * size of of the output. Only one of these should be nonzero: * - * This approach is different from the PeriodicScalarEncoder because two - * buckets, the first and last, are half as wide as the rest, since fewer - * numbers in the input domain will round to these endpoints. This behavior - * makes sense because, for example, with the input space [1, 10] and 10 - * buckets, 1.49 is in the first bucket and 1.51 is in the second. + * @param n The number of bits in the output. Must be greater than or equal to + * w. + * @param radius Two inputs separated by more than the radius have + * non-overlapping representations. Two inputs separated by less than the + * radius will in general overlap in at least some of their bits. You can + * think of this as the radius of the input. + * @param resolution Two inputs separated by greater than, or equal to the + * resolution are guaranteed to have different representations. */ - class ScalarEncoder : public ScalarEncoderBase - { - public: - /** - * Constructs a ScalarEncoder - * - * @param w The number of bits that are set to encode a single value -- the - * "width" of the output signal - * @param minValue The minimum value of the input signal, inclusive. - * @param maxValue The maximum value of the input signal, inclusive. - * @param clipInput Whether to allow input values outside the [minValue, maxValue] - * range. If true, the input will be clipped to minValue or maxValue. - * - * There are three mutually exclusive parameters that determine the overall - * size of of the output. Only one of these should be nonzero: - * - * @param n The number of bits in the output. Must be greater than or equal to w. - * @param radius Two inputs separated by more than the radius have - * non-overlapping representations. Two inputs separated by less than the - * radius will in general overlap in at least some of their bits. You can - * think of this as the radius of the input. - * @param resolution Two inputs separated by greater than, or equal to the - * resolution are guaranteed to have different representations. - */ - ScalarEncoder(int w, double minValue, double maxValue, int n, double radius, - double resolution, bool clipInput); - ~ScalarEncoder() override; + ScalarEncoder(int w, double minValue, double maxValue, int n, double radius, + double resolution, bool clipInput); + ~ScalarEncoder() override; - virtual int encodeIntoArray(Real64 input, Real32 output[]) override; - virtual int getOutputWidth() const override { return n_; } + virtual int encodeIntoArray(Real64 input, Real32 output[]) override; + virtual int getOutputWidth() const override { return n_; } - private: - int w_; - int n_; - double minValue_; - double maxValue_; - double bucketWidth_; - bool clipInput_; - }; // end class ScalarEncoder +private: + int w_; + int n_; + double minValue_; + double maxValue_; + double bucketWidth_; + bool clipInput_; +}; // end class ScalarEncoder - /** Encodes a floating point number as a block of 1s that might wrap around. +/** Encodes a floating point number as a block of 1s that might wrap around. + * + * @b Description + * A PeriodicScalarEncoder encodes a numeric (floating point) value into an + * array of bits. The output is 0's except for a contiguous block of 1's that + * may wrap around the edge. The location of this contiguous block varies + * continuously with the input value. + * + * Conceptually, the set of possible outputs is a set of "buckets". If there + * are m buckets, the PeriodicScalarEncoder plots m equal-width bands along + * the domain [minValue, maxValue]. The bucket index of an input is simply its + * band index. + * + * Because of the equal-width buckets, the rounding differs from the + * ScalarEncoder. In cases where the ScalarEncoder would put 1.49 in the first + * bucket and 1.51 in the second, the PeriodicScalarEncoder will put 1.99 in + * the first bucket and 2.0 in the second. + */ +class PeriodicScalarEncoder : public ScalarEncoderBase { +public: + /** + * Constructs a PeriodicScalarEncoder * - * @b Description - * A PeriodicScalarEncoder encodes a numeric (floating point) value into an - * array of bits. The output is 0's except for a contiguous block of 1's that - * may wrap around the edge. The location of this contiguous block varies - * continuously with the input value. + * @param w The number of bits that are set to encode a single value -- the + * "width" of the output signal + * @param minValue The minimum value of the input signal, inclusive. + * @param maxValue The maximum value of the input signal, exclusive. All + * inputs will be strictly less than this value. * - * Conceptually, the set of possible outputs is a set of "buckets". If there - * are m buckets, the PeriodicScalarEncoder plots m equal-width bands along - * the domain [minValue, maxValue]. The bucket index of an input is simply its - * band index. + * There are three mutually exclusive parameters that determine the overall + * size of the output. Only one of these should be nonzero: * - * Because of the equal-width buckets, the rounding differs from the - * ScalarEncoder. In cases where the ScalarEncoder would put 1.49 in the first - * bucket and 1.51 in the second, the PeriodicScalarEncoder will put 1.99 in - * the first bucket and 2.0 in the second. + * @param n The number of bits in the output. Must be greater than or equal + * to w. + * @param radius Two inputs separated by more than the radius have + * non-overlapping representations. Two inputs separated by less than the + * radius will in general overlap in at least some of their bits. You can + * think of this as the radius of the input. + * @param resolution Two inputs separated by greater than, or equal to the + * resolution are guaranteed to have different representations. */ - class PeriodicScalarEncoder : public ScalarEncoderBase - { - public: - /** - * Constructs a PeriodicScalarEncoder - * - * @param w The number of bits that are set to encode a single value -- the - * "width" of the output signal - * @param minValue The minimum value of the input signal, inclusive. - * @param maxValue The maximum value of the input signal, exclusive. All - * inputs will be strictly less than this value. - * - * There are three mutually exclusive parameters that determine the overall - * size of the output. Only one of these should be nonzero: - * - * @param n The number of bits in the output. Must be greater than or equal - * to w. - * @param radius Two inputs separated by more than the radius have - * non-overlapping representations. Two inputs separated by less than the - * radius will in general overlap in at least some of their bits. You can - * think of this as the radius of the input. - * @param resolution Two inputs separated by greater than, or equal to the - * resolution are guaranteed to have different representations. - */ - PeriodicScalarEncoder(int w, double minValue, double maxValue, int n, - double radius, double resolution); - virtual ~PeriodicScalarEncoder() override; + PeriodicScalarEncoder(int w, double minValue, double maxValue, int n, + double radius, double resolution); + virtual ~PeriodicScalarEncoder() override; - virtual int encodeIntoArray(Real64 input, Real32 output[]) override; - virtual int getOutputWidth() const override { return n_; } + virtual int encodeIntoArray(Real64 input, Real32 output[]) override; + virtual int getOutputWidth() const override { return n_; } - private: - int w_; - int n_; - double minValue_; - double maxValue_; - double bucketWidth_; - }; // end class PeriodicScalarEncoder +private: + int w_; + int n_; + double minValue_; + double maxValue_; + double bucketWidth_; +}; // end class PeriodicScalarEncoder } // end namespace nupic #endif // NTA_ENCODERS_SCALAR diff --git a/src/nupic/encoders/ScalarSensor.cpp b/src/nupic/encoders/ScalarSensor.cpp index 2364ac5cc4..f3694dbd23 100644 --- a/src/nupic/encoders/ScalarSensor.cpp +++ b/src/nupic/encoders/ScalarSensor.cpp @@ -32,279 +32,207 @@ #include #include -#include +#include +#include #include -#include -#include // IWrite/ReadBuffer +#include #include -#include #include -#include -#include - +#include // IWrite/ReadBuffer +#include +#include using capnp::AnyPointer; -namespace nupic -{ - ScalarSensor::ScalarSensor(const ValueMap& params, Region *region) - : RegionImpl(region) - { - const UInt32 n = params.getScalarT("n"); - const UInt32 w = params.getScalarT("w"); - const Real64 resolution = params.getScalarT("resolution"); - const Real64 radius = params.getScalarT("radius"); - const Real64 minValue = params.getScalarT("minValue"); - const Real64 maxValue = params.getScalarT("maxValue"); - const bool periodic = params.getScalarT("periodic"); - const bool clipInput = params.getScalarT("clipInput"); - if (periodic) - { - encoder_ = new PeriodicScalarEncoder(w, minValue, maxValue, n, radius, - resolution); - } - else - { - encoder_ = new ScalarEncoder(w, minValue, maxValue, n, radius, resolution, - clipInput); - } - - sensedValue_ = params.getScalarT("sensedValue"); - } - - ScalarSensor::ScalarSensor(BundleIO& bundle, Region* region) : - RegionImpl(region) - { - deserialize(bundle); - } - - - ScalarSensor::ScalarSensor(AnyPointer::Reader& proto, Region* region) : - RegionImpl(region) - { - read(proto); +namespace nupic { +ScalarSensor::ScalarSensor(const ValueMap ¶ms, Region *region) + : RegionImpl(region) { + const UInt32 n = params.getScalarT("n"); + const UInt32 w = params.getScalarT("w"); + const Real64 resolution = params.getScalarT("resolution"); + const Real64 radius = params.getScalarT("radius"); + const Real64 minValue = params.getScalarT("minValue"); + const Real64 maxValue = params.getScalarT("maxValue"); + const bool periodic = params.getScalarT("periodic"); + const bool clipInput = params.getScalarT("clipInput"); + if (periodic) { + encoder_ = + new PeriodicScalarEncoder(w, minValue, maxValue, n, radius, resolution); + } else { + encoder_ = new ScalarEncoder(w, minValue, maxValue, n, radius, resolution, + clipInput); } + sensedValue_ = params.getScalarT("sensedValue"); +} - ScalarSensor::~ScalarSensor() - { - delete encoder_; - } +ScalarSensor::ScalarSensor(BundleIO &bundle, Region *region) + : RegionImpl(region) { + deserialize(bundle); +} - void ScalarSensor::compute() - { - Real32* array = (Real32*)encodedOutput_->getData().getBuffer(); - const Int32 iBucket = encoder_->encodeIntoArray(sensedValue_, array); - ((Int32*)bucketOutput_->getData().getBuffer())[0] = iBucket; - } +ScalarSensor::ScalarSensor(AnyPointer::Reader &proto, Region *region) + : RegionImpl(region) { + read(proto); +} - /* static */ Spec* - ScalarSensor::createSpec() - { - auto ns = new Spec; +ScalarSensor::~ScalarSensor() { delete encoder_; } - ns->singleNodeOnly = true; +void ScalarSensor::compute() { + Real32 *array = (Real32 *)encodedOutput_->getData().getBuffer(); + const Int32 iBucket = encoder_->encodeIntoArray(sensedValue_, array); + ((Int32 *)bucketOutput_->getData().getBuffer())[0] = iBucket; +} - /* ----- parameters ----- */ - ns->parameters.add( - "sensedValue", - ParameterSpec( - "Scalar input", - NTA_BasicType_Real64, - 1, // elementCount - "", // constraints - "-1", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "n", - ParameterSpec( - "The length of the encoding", - NTA_BasicType_UInt32, - 1, // elementCount - "", // constraints - "0", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "w", - ParameterSpec( - "The number of active bits in the encoding", - NTA_BasicType_UInt32, - 1, // elementCount - "", // constraints - "0", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "resolution", - ParameterSpec( - "The resolution for the encoder", - NTA_BasicType_Real64, - 1, // elementCount - "", // constraints - "0", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "radius", - ParameterSpec( - "The radius for the encoder", - NTA_BasicType_Real64, - 1, // elementCount - "", // constraints - "0", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "minValue", - ParameterSpec( - "The minimum value for the input", - NTA_BasicType_Real64, - 1, // elementCount - "", // constraints - "-1", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "maxValue", - ParameterSpec( - "The maximum value for the input", - NTA_BasicType_Real64, - 1, // elementCount - "", // constraints - "-1", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "periodic", - ParameterSpec( - "Whether the encoder is periodic", - NTA_BasicType_Bool, - 1, // elementCount - "", // constraints - "false", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( +/* static */ Spec *ScalarSensor::createSpec() { + auto ns = new Spec; + + ns->singleNodeOnly = true; + + /* ----- parameters ----- */ + ns->parameters.add("sensedValue", + ParameterSpec("Scalar input", NTA_BasicType_Real64, + 1, // elementCount + "", // constraints + "-1", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("n", ParameterSpec("The length of the encoding", + NTA_BasicType_UInt32, + 1, // elementCount + "", // constraints + "0", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("w", + ParameterSpec("The number of active bits in the encoding", + NTA_BasicType_UInt32, + 1, // elementCount + "", // constraints + "0", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("resolution", + ParameterSpec("The resolution for the encoder", + NTA_BasicType_Real64, + 1, // elementCount + "", // constraints + "0", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("radius", ParameterSpec("The radius for the encoder", + NTA_BasicType_Real64, + 1, // elementCount + "", // constraints + "0", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("minValue", + ParameterSpec("The minimum value for the input", + NTA_BasicType_Real64, + 1, // elementCount + "", // constraints + "-1", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("maxValue", + ParameterSpec("The maximum value for the input", + NTA_BasicType_Real64, + 1, // elementCount + "", // constraints + "-1", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("periodic", + ParameterSpec("Whether the encoder is periodic", + NTA_BasicType_Bool, + 1, // elementCount + "", // constraints + "false", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add( "clipInput", ParameterSpec( - "Whether to clip inputs if they're outside [minValue, maxValue]", - NTA_BasicType_Bool, - 1, // elementCount - "", // constraints - "false", // defaultValue - ParameterSpec::ReadWriteAccess)); - - /* ----- outputs ----- */ - - ns->outputs.add( - "encoded", - OutputSpec( - "Encoded value", - NTA_BasicType_Real32, - 0, // elementCount - true, // isRegionLevel - true // isDefaultOutput - )); - - ns->outputs.add( - "bucket", - OutputSpec( - "Bucket number for this sensedValue", - NTA_BasicType_Int32, - 0, // elementCount - true, // isRegionLevel - false // isDefaultOutput - )); - - return ns; - } - - void - ScalarSensor::getParameterFromBuffer(const std::string& name, - Int64 index, - IWriteBuffer& value) - { - if (name == "sensedValue") - { - value.write(sensedValue_); - } - else if (name == "n") - { - // Cast to UInt32 to avoid call resolution ambiguity on the write() method - value.write((UInt32)encoder_->getOutputWidth()); - } - else - { - NTA_THROW << "ScalarSensor::getParameter -- Unknown parameter " << name; - } - } - - void - ScalarSensor::setParameterFromBuffer(const std::string& name, - Int64 index, - IReadBuffer& value) - { - if (name == "sensedValue") - { - value.read(sensedValue_); - } - else - { - NTA_THROW << "ScalarSensor::setParameter -- Unknown parameter " << name; - } - } + "Whether to clip inputs if they're outside [minValue, maxValue]", + NTA_BasicType_Bool, + 1, // elementCount + "", // constraints + "false", // defaultValue + ParameterSpec::ReadWriteAccess)); + + /* ----- outputs ----- */ + + ns->outputs.add("encoded", OutputSpec("Encoded value", NTA_BasicType_Real32, + 0, // elementCount + true, // isRegionLevel + true // isDefaultOutput + )); + + ns->outputs.add("bucket", OutputSpec("Bucket number for this sensedValue", + NTA_BasicType_Int32, + 0, // elementCount + true, // isRegionLevel + false // isDefaultOutput + )); + + return ns; +} - void - ScalarSensor::initialize() - { - encodedOutput_ = getOutput("encoded"); - bucketOutput_ = getOutput("bucket"); +void ScalarSensor::getParameterFromBuffer(const std::string &name, Int64 index, + IWriteBuffer &value) { + if (name == "sensedValue") { + value.write(sensedValue_); + } else if (name == "n") { + // Cast to UInt32 to avoid call resolution ambiguity on the write() method + value.write((UInt32)encoder_->getOutputWidth()); + } else { + NTA_THROW << "ScalarSensor::getParameter -- Unknown parameter " << name; } +} - size_t - ScalarSensor::getNodeOutputElementCount(const std::string& outputName) - { - if (outputName == "encoded") - { - return encoder_->getOutputWidth(); - } - else if (outputName == "bucket") - { - return 1; - } - else - { - NTA_THROW << "ScalarSensor::getOutputSize -- unknown output " << outputName; - } +void ScalarSensor::setParameterFromBuffer(const std::string &name, Int64 index, + IReadBuffer &value) { + if (name == "sensedValue") { + value.read(sensedValue_); + } else { + NTA_THROW << "ScalarSensor::setParameter -- Unknown parameter " << name; } +} - std::string ScalarSensor::executeCommand(const std::vector& args, Int64 index) - { - NTA_THROW << "ScalarSensor::executeCommand -- commands not supported"; - } +void ScalarSensor::initialize() { + encodedOutput_ = getOutput("encoded"); + bucketOutput_ = getOutput("bucket"); +} - void ScalarSensor::serialize(BundleIO& bundle) - { - NTA_THROW << "ScalarSensor::serialize -- Not implemented"; +size_t ScalarSensor::getNodeOutputElementCount(const std::string &outputName) { + if (outputName == "encoded") { + return encoder_->getOutputWidth(); + } else if (outputName == "bucket") { + return 1; + } else { + NTA_THROW << "ScalarSensor::getOutputSize -- unknown output " << outputName; } +} +std::string ScalarSensor::executeCommand(const std::vector &args, + Int64 index) { + NTA_THROW << "ScalarSensor::executeCommand -- commands not supported"; +} - void ScalarSensor::deserialize(BundleIO& bundle) - { - NTA_THROW << "ScalarSensor::deserialize -- Not implemented"; - } - +void ScalarSensor::serialize(BundleIO &bundle) { + NTA_THROW << "ScalarSensor::serialize -- Not implemented"; +} - void ScalarSensor::write(AnyPointer::Builder& anyProto) const - { - NTA_THROW << "ScalarSensor::write -- Not implemented"; - } +void ScalarSensor::deserialize(BundleIO &bundle) { + NTA_THROW << "ScalarSensor::deserialize -- Not implemented"; +} +void ScalarSensor::write(AnyPointer::Builder &anyProto) const { + NTA_THROW << "ScalarSensor::write -- Not implemented"; +} - void ScalarSensor::read(AnyPointer::Reader& anyProto) - { - NTA_THROW << "ScalarSensor::read -- Not implemented"; - } +void ScalarSensor::read(AnyPointer::Reader &anyProto) { + NTA_THROW << "ScalarSensor::read -- Not implemented"; } +} // namespace nupic diff --git a/src/nupic/encoders/ScalarSensor.hpp b/src/nupic/encoders/ScalarSensor.hpp index ee9ed1dbb6..1547b0a8cb 100644 --- a/src/nupic/encoders/ScalarSensor.hpp +++ b/src/nupic/encoders/ScalarSensor.hpp @@ -39,55 +39,53 @@ #include #include -namespace nupic -{ - /** - * A network region that encapsulates the ScalarEncoder. - * - * @b Description - * A ScalarSensor encapsulates ScalarEncoders, connecting them to the Network - * API. As a network runs, the client will specify new encoder inputs by - * setting the "sensedValue" parameter. On each compute, the ScalarSensor will - * encode its "sensedValue" to output. - */ - class ScalarSensor : public RegionImpl - { - public: - ScalarSensor(const ValueMap& params, Region *region); - ScalarSensor(BundleIO& bundle, Region* region); - ScalarSensor(capnp::AnyPointer::Reader& proto, Region* region); - ScalarSensor(); - virtual ~ScalarSensor() override; +namespace nupic { +/** + * A network region that encapsulates the ScalarEncoder. + * + * @b Description + * A ScalarSensor encapsulates ScalarEncoders, connecting them to the Network + * API. As a network runs, the client will specify new encoder inputs by + * setting the "sensedValue" parameter. On each compute, the ScalarSensor will + * encode its "sensedValue" to output. + */ +class ScalarSensor : public RegionImpl { +public: + ScalarSensor(const ValueMap ¶ms, Region *region); + ScalarSensor(BundleIO &bundle, Region *region); + ScalarSensor(capnp::AnyPointer::Reader &proto, Region *region); + ScalarSensor(); + virtual ~ScalarSensor() override; + + static Spec *createSpec(); - static Spec* createSpec(); + virtual void getParameterFromBuffer(const std::string &name, Int64 index, + IWriteBuffer &value) override; + virtual void setParameterFromBuffer(const std::string &name, Int64 index, + IReadBuffer &value) override; + virtual void initialize() override; - virtual void getParameterFromBuffer(const std::string& name, - Int64 index, - IWriteBuffer& value) override; - virtual void setParameterFromBuffer(const std::string& name, - Int64 index, - IReadBuffer& value) override; - virtual void initialize() override; + virtual void serialize(BundleIO &bundle) override; + virtual void deserialize(BundleIO &bundle) override; - virtual void serialize(BundleIO& bundle) override; - virtual void deserialize(BundleIO& bundle) override; + using Serializable::write; + virtual void write(capnp::AnyPointer::Builder &anyProto) const override; + using Serializable::read; + virtual void read(capnp::AnyPointer::Reader &anyProto) override; - using Serializable::write; - virtual void write(capnp::AnyPointer::Builder& anyProto) const override; - using Serializable::read; - virtual void read(capnp::AnyPointer::Reader& anyProto) override; + void compute() override; + virtual std::string executeCommand(const std::vector &args, + Int64 index) override; - void compute() override; - virtual std::string executeCommand(const std::vector& args, - Int64 index) override; + virtual size_t + getNodeOutputElementCount(const std::string &outputName) override; - virtual size_t getNodeOutputElementCount(const std::string& outputName) override; - private: - Real64 sensedValue_; - ScalarEncoderBase* encoder_; - const Output* encodedOutput_; - const Output* bucketOutput_; - }; -} +private: + Real64 sensedValue_; + ScalarEncoderBase *encoder_; + const Output *encodedOutput_; + const Output *bucketOutput_; +}; +} // namespace nupic #endif // NTA_SCALAR_SENSOR_HPP diff --git a/src/nupic/engine/Collections.cpp b/src/nupic/engine/Collections.cpp index 4dded93f09..f50e4c652c 100644 --- a/src/nupic/engine/Collections.cpp +++ b/src/nupic/engine/Collections.cpp @@ -20,27 +20,25 @@ * --------------------------------------------------------------------- */ - #include /* - * We need to import the code from Collection.cpp + * We need to import the code from Collection.cpp * in order to instantiate all the methods in the classes - * instantiated below. + * instantiated below. */ -#include -#include -#include #include +#include +#include +#include using namespace nupic; - // Explicit instantiations of the collection classes used by Spec template class nupic::Collection; template class nupic::Collection; template class nupic::Collection; template class nupic::Collection; -template class nupic::Collection; -template class nupic::Collection; +template class nupic::Collection; +template class nupic::Collection; template class nupic::Collection; diff --git a/src/nupic/engine/Input.cpp b/src/nupic/engine/Input.cpp index c25063b024..82b17dce68 100644 --- a/src/nupic/engine/Input.cpp +++ b/src/nupic/engine/Input.cpp @@ -26,51 +26,41 @@ */ #include // memset -#include -#include #include -#include #include +#include #include +#include +#include #include -namespace nupic -{ +namespace nupic { -Input::Input(Region& region, NTA_BasicType dataType, bool isRegionLevel) : - region_(region), isRegionLevel_(isRegionLevel), - initialized_(false), data_(dataType), name_("Unnamed") -{ -} +Input::Input(Region ®ion, NTA_BasicType dataType, bool isRegionLevel) + : region_(region), isRegionLevel_(isRegionLevel), initialized_(false), + data_(dataType), name_("Unnamed") {} -Input::~Input() -{ +Input::~Input() { uninitialize(); - std::vector linkscopy = links_; - for (auto & elem : linkscopy) - { + std::vector linkscopy = links_; + for (auto &elem : linkscopy) { removeLink(elem); } } - -void -Input::addLink(Link* link, Output* srcOutput) -{ +void Input::addLink(Link *link, Output *srcOutput) { if (initialized_) - NTA_THROW << "Attempt to add link to input " << name_ - << " on region " << region_.getName() - << " when input is already initialized"; + NTA_THROW << "Attempt to add link to input " << name_ << " on region " + << region_.getName() << " when input is already initialized"; // Make sure we don't already have a link to the same output - for (std::vector::const_iterator link = links_.begin(); - link != links_.end(); link++) - { - if (srcOutput == &((*link)->getSrc())) - { - NTA_THROW << "addLink -- link from region " << srcOutput->getRegion().getName() - << " output " << srcOutput->getName() << " to region " - << region_.getName() << " input " << getName() << " already exists"; + for (std::vector::const_iterator link = links_.begin(); + link != links_.end(); link++) { + if (srcOutput == &((*link)->getSrc())) { + NTA_THROW << "addLink -- link from region " + << srcOutput->getRegion().getName() << " output " + << srcOutput->getName() << " to region " << region_.getName() + << " input " << getName() << " already exists"; } } @@ -81,16 +71,12 @@ Input::addLink(Link* link, Output* srcOutput) // is calculated at initialization time } - -void -Input::removeLink(Link*& link) -{ +void Input::removeLink(Link *&link) { // removeLink should only be called internally -- if it // does not exist, it is a logic error auto linkiter = links_.begin(); - for(; linkiter!= links_.end(); linkiter++) - { + for (; linkiter != links_.end(); linkiter++) { if (*linkiter == link) break; } @@ -111,16 +97,13 @@ Input::removeLink(Link*& link) link = nullptr; } -Link* Input::findLink(const std::string& srcRegionName, - const std::string& srcOutputName) -{ - std::vector::const_iterator linkiter = links_.begin(); - for (; linkiter != links_.end(); linkiter++) - { - Output& output = (*linkiter)->getSrc(); +Link *Input::findLink(const std::string &srcRegionName, + const std::string &srcOutputName) { + std::vector::const_iterator linkiter = links_.begin(); + for (; linkiter != links_.end(); linkiter++) { + Output &output = (*linkiter)->getSrc(); if (output.getName() == srcOutputName && - output.getRegion().getName() == srcRegionName) - { + output.getRegion().getName() == srcRegionName) { return *linkiter; } } @@ -128,47 +111,27 @@ Link* Input::findLink(const std::string& srcRegionName, return nullptr; } -void -Input::prepare() -{ +void Input::prepare() { // Each link copies data into its section of the overall input // TODO: initialization check? - for (auto & elem : links_) - { + for (auto &elem : links_) { (elem)->compute(); } } -const Array & -Input::getData() const -{ +const Array &Input::getData() const { NTA_CHECK(initialized_); return data_; } -Region& -Input::getRegion() -{ - return region_; -} +Region &Input::getRegion() { return region_; } -const std::vector& -Input::getLinks() -{ - return links_; -} - -bool -Input::isRegionLevel() -{ - return isRegionLevel_; -} +const std::vector &Input::getLinks() { return links_; } +bool Input::isRegionLevel() { return isRegionLevel_; } // See header file for documentation -size_t -Input::evaluateLinks() -{ +size_t Input::evaluateLinks() { /** * It is not an error to call evaluateLinks() on an initialized * input -- just report that no links remain to be evaluated. @@ -180,11 +143,10 @@ Input::evaluateLinks() return 0; size_t nIncompleteLinks = 0; - std::vector::iterator l; - for (l = links_.begin(); l != links_.end(); l++) - { - Region& srcRegion = (*l)->getSrc().getRegion(); - Region& destRegion = (*l)->getDest().getRegion(); + std::vector::iterator l; + for (l = links_.begin(); l != links_.end(); l++) { + Region &srcRegion = (*l)->getSrc().getRegion(); + Region &destRegion = (*l)->getDest().getRegion(); /** * The link and region need to be consistent at both @@ -202,14 +164,11 @@ Input::evaluateLinks() Dimensions srcLinkDims = (*l)->getSrcDimensions(); /* source region dimensions are unspecified */ - if (srcRegionDims.isUnspecified()) - { - if (srcLinkDims.isUnspecified()) - { + if (srcRegionDims.isUnspecified()) { + if (srcLinkDims.isUnspecified()) { // 1. link cares about src dimensions but they aren't set // link is incomplete; - } else if (srcLinkDims.isDontcare()) - { + } else if (srcLinkDims.isDontcare()) { // 2. Link doesn't care. We don't need to do anything. } else { // 3. Link specifies src dimensions but src region dimensions @@ -218,8 +177,7 @@ Input::evaluateLinks() // If source region is initialized, this is a logic error NTA_CHECK(!srcRegion.isInitialized()); - if(!((*l)->getSrc().isRegionLevel())) - { + if (!((*l)->getSrc().isRegionLevel())) { // 3.1 Only set the dimensions if the link source is not region // level @@ -228,39 +186,30 @@ Input::evaluateLinks() srcRegionDims = srcRegion.getDimensions(); std::stringstream ss; - ss << "Specified by source dimensions on link " - << (*l)->toString(); + ss << "Specified by source dimensions on link " << (*l)->toString(); srcRegion.setDimensionInfo(ss.str()); - } - else - { + } else { // 3.2 Link is incomplete } } } else { /* source region dimensions are specified */ - if (srcLinkDims.isDontcare()) - { + if (srcLinkDims.isDontcare()) { // 4. Link doesn't care. We don't need to do anything. - } else if (srcLinkDims.isUnspecified()) - { + } else if (srcLinkDims.isUnspecified()) { // 5. srcRegion dims set link dims - if((*l)->getSrc().isRegionLevel()) - { + if ((*l)->getSrc().isRegionLevel()) { // 5.1 link source is region level, so use dimensions of [1] Dimensions d; - for(size_t i = 0; i < srcRegionDims.size(); i++) - { + for (size_t i = 0; i < srcRegionDims.size(); i++) { d.push_back(1); } (*l)->setSrcDimensions(d); srcLinkDims = d; - } - else - { + } else { // 5.2 apply region dimensions to link (*l)->setSrcDimensions(srcRegionDims); @@ -270,56 +219,49 @@ Input::evaluateLinks() // 6. Both region dims and link dims are specified. // Verify that srcRegion dims are the same as // link dims - if (srcRegionDims != srcLinkDims) - { + if (srcRegionDims != srcLinkDims) { Dimensions oneD(1); bool inconsistentDimensions = false; - if((*l)->getSrc().isRegionLevel()) - { + if ((*l)->getSrc().isRegionLevel()) { Dimensions d; - for(size_t i = 0; i < srcRegionDims.size(); i++) - { + for (size_t i = 0; i < srcRegionDims.size(); i++) { d.push_back(1); } - if(srcLinkDims != d) - { + if (srcLinkDims != d) { NTA_THROW << "Internal error while processing Region " << srcRegion.getName() << ". The link " - << (*l)->toString() << " has a region level source " + << (*l)->toString() + << " has a region level source " "output, but the link dimensions are " << srcLinkDims.toString() << " instead of [1]"; } - } - else if(srcRegionDims == oneD) - { + } else if (srcRegionDims == oneD) { Dimensions d; - for(size_t i = 0; i < srcLinkDims.size(); i++) - { + for (size_t i = 0; i < srcLinkDims.size(); i++) { d.push_back(1); } - if(srcLinkDims != d) - { + if (srcLinkDims != d) { inconsistentDimensions = true; } - } - else - { + } else { inconsistentDimensions = true; } - if(inconsistentDimensions) - { - NTA_THROW << "Inconsistent dimension specification encountered. Region " - << srcRegion.getName() << " has dimensions " - << srcRegionDims.toString() << " but link " - << (*l)->toString() << " requires dimensions " - << srcLinkDims.toString() << ". Additional information on " - << "region dimensions: " - << (srcRegion.getDimensionInfo() == "" ? "(none)" : srcRegion.getDimensionInfo()); + if (inconsistentDimensions) { + NTA_THROW + << "Inconsistent dimension specification encountered. Region " + << srcRegion.getName() << " has dimensions " + << srcRegionDims.toString() << " but link " << (*l)->toString() + << " requires dimensions " << srcLinkDims.toString() + << ". Additional information on " + << "region dimensions: " + << (srcRegion.getDimensionInfo() == "" + ? "(none)" + : srcRegion.getDimensionInfo()); } } } @@ -330,18 +272,15 @@ Input::evaluateLinks() Dimensions destRegionDims = destRegion.getDimensions(); // The logic here is similar to the logic for the source side - // except for the case where the destination region dims are specified and the - // link dims are unspecified -- see comment below. + // except for the case where the destination region dims are specified and + // the link dims are unspecified -- see comment below. /* dest region dimensions are unspecified */ - if (destRegionDims.isUnspecified()) - { - if (destLinkDims.isUnspecified()) - { + if (destRegionDims.isUnspecified()) { + if (destLinkDims.isUnspecified()) { // 1. link cares about dest dimensions but they aren't set // link is incomplete; Nothing we can do. - } else if (destLinkDims.isDontcare()) - { + } else if (destLinkDims.isDontcare()) { // 2. Link doesn't care. We don't need to do anything. } else { // 3. Link specifies dest dimensions but region dimensions @@ -350,8 +289,7 @@ Input::evaluateLinks() // If dest region is initialized, this is a logic error NTA_CHECK(!destRegion.isInitialized()); - if(!((*l)->getDest().isRegionLevel())) - { + if (!((*l)->getDest().isRegionLevel())) { // 3.1 Only set the dimensions if the link destination is not region // level @@ -359,37 +297,31 @@ Input::evaluateLinks() destRegion.setDimensions(destLinkDims); destRegionDims = destRegion.getDimensions(); std::stringstream ss; - ss << "Specified by destination dimensions on link " << (*l)->toString(); + ss << "Specified by destination dimensions on link " + << (*l)->toString(); destRegion.setDimensionInfo(ss.str()); - } - else - { + } else { // 3.2 Link is incomplete } } } else { /* dest region dimensions are specified but src region dims are not */ - if (destLinkDims.isDontcare()) - { + if (destLinkDims.isDontcare()) { // 4. Link doesn't care. We don't need to do anything. - } else if (destLinkDims.isUnspecified()) { + } else if (destLinkDims.isUnspecified()) { // 5. Region has dimensions -- set them on the link. - if((*l)->getDest().isRegionLevel()) - { + if ((*l)->getDest().isRegionLevel()) { // 5.1 link source is region level, so use dimensions of [1] Dimensions d; - for(size_t i = 0; i < destRegionDims.size(); i++) - { + for (size_t i = 0; i < destRegionDims.size(); i++) { d.push_back(1); } (*l)->setDestDimensions(d); destLinkDims = d; - } - else - { + } else { // 5.2 apply region dimensions to link (*l)->setDestDimensions(destRegionDims); @@ -398,11 +330,9 @@ Input::evaluateLinks() // Setting the link dest dimensions may set the src // dimensions. Since we have already evaluated the source // side of the link, we need to re-evaluate here - if (srcRegionDims.isUnspecified()) - { + if (srcRegionDims.isUnspecified()) { srcLinkDims = (*l)->getSrcDimensions(); - if (!srcLinkDims.isUnspecified() && !srcLinkDims.isDontcare()) - { + if (!srcLinkDims.isUnspecified() && !srcLinkDims.isDontcare()) { // Induce. TODO: code is the same as on source side -- refactor? // If source region is initialized, this is a logic error NTA_CHECK(!srcRegion.isInitialized()); @@ -420,17 +350,18 @@ Input::evaluateLinks() } else { // src region dims were already specified. Make sure they // are compatible with the link dims. - if (srcLinkDims != srcRegionDims) - { - NTA_THROW << "Inconsistent dimension specification encountered. Region " - << srcRegion.getName() << " has dimensions " - << srcRegionDims.toString() << " but link " - << (*l)->toString() << " requires dimensions " - << srcLinkDims.toString() << ". Additional information on " - << "region dimensions: " - << (srcRegion.getDimensionInfo() == "" ? "(none)" : srcRegion.getDimensionInfo()); + if (srcLinkDims != srcRegionDims) { + NTA_THROW + << "Inconsistent dimension specification encountered. Region " + << srcRegion.getName() << " has dimensions " + << srcRegionDims.toString() << " but link " + << (*l)->toString() << " requires dimensions " + << srcLinkDims.toString() << ". Additional information on " + << "region dimensions: " + << (srcRegion.getDimensionInfo() == "" + ? "(none)" + : srcRegion.getDimensionInfo()); } - } } @@ -442,72 +373,65 @@ Input::evaluateLinks() bool inconsistentDimensions = false; - if (destRegionDims != destLinkDims) - { + if (destRegionDims != destLinkDims) { Dimensions oneD; oneD.push_back(1); - if((*l)->getDest().isRegionLevel()) - { - if (! destLinkDims.isOnes()) + if ((*l)->getDest().isRegionLevel()) { + if (!destLinkDims.isOnes()) NTA_THROW << "Internal error while processing Region " << destRegion.getName() << ". The link " - << (*l)->toString() << " has a region level destination " + << (*l)->toString() + << " has a region level destination " << "input, but the link dimensions are " << destLinkDims.toString() << " instead of [1]"; - } - else if(destRegionDims == oneD) - { + } else if (destRegionDims == oneD) { Dimensions d; - for(size_t i = 0; i < destLinkDims.size(); i++) - { + for (size_t i = 0; i < destLinkDims.size(); i++) { d.push_back(1); } - if(destLinkDims != d) - { + if (destLinkDims != d) { inconsistentDimensions = true; } - } - else - { + } else { inconsistentDimensions = true; } - if(inconsistentDimensions) - { - NTA_THROW << "Inconsistent dimension specification encountered. Region " - << destRegion.getName() << " has dimensions " - << destRegionDims.toString() << " but link " - << (*l)->toString() << " requires dimensions " - << destLinkDims.toString() << ". Additional information on " - << "region dimensions: " - << (destRegion.getDimensionInfo() == "" ? "(none)" : destRegion.getDimensionInfo()); + if (inconsistentDimensions) { + NTA_THROW + << "Inconsistent dimension specification encountered. Region " + << destRegion.getName() << " has dimensions " + << destRegionDims.toString() << " but link " << (*l)->toString() + << " requires dimensions " << destLinkDims.toString() + << ". Additional information on " + << "region dimensions: " + << (destRegion.getDimensionInfo() == "" + ? "(none)" + : destRegion.getDimensionInfo()); } } } } bool linkIsIncomplete = true; - if (srcRegionDims.isSpecified() && destRegionDims.isSpecified()) - { + if (srcRegionDims.isSpecified() && destRegionDims.isSpecified()) { linkIsIncomplete = false; // link dims may be specified or dontcare (!isUnspecified) NTA_CHECK(srcLinkDims.isSpecified() || srcLinkDims.isDontcare()) - << "link: " << (*l)->toString() - << " src: " << srcRegionDims.toString() - << " dest: " << destRegionDims.toString() - << " srclinkdims: " << srcLinkDims.toString(); + << "link: " << (*l)->toString() + << " src: " << srcRegionDims.toString() + << " dest: " << destRegionDims.toString() + << " srclinkdims: " << srcLinkDims.toString(); NTA_CHECK(destLinkDims.isSpecified() || destLinkDims.isDontcare()) - << "link: " << (*l)->toString() - << " src: " << srcRegionDims.toString() - << " dest: " << destRegionDims.toString() - << " destlinkdims: " << destLinkDims.toString(); + << "link: " << (*l)->toString() + << " src: " << srcRegionDims.toString() + << " dest: " << destRegionDims.toString() + << " destlinkdims: " << destLinkDims.toString(); } - if (linkIsIncomplete) - { + if (linkIsIncomplete) { nIncompleteLinks++; } @@ -521,21 +445,20 @@ Input::evaluateLinks() // our size and set up any data structures needed // for copying data over a link. -void Input::initialize() -{ +void Input::initialize() { if (initialized_) return; - if(region_.getDimensions().isUnspecified()) - { - NTA_THROW << "Input region's dimensions are unspecified when Input::initialize() " - << "was called. Region's dimensions must be specified."; + if (region_.getDimensions().isUnspecified()) { + NTA_THROW + << "Input region's dimensions are unspecified when Input::initialize() " + << "was called. Region's dimensions must be specified."; } // Calculate our size and the offset of each link size_t count = 0; - for (std::vector::const_iterator l = links_.begin(); l != links_.end(); l++) - { + for (std::vector::const_iterator l = links_.begin(); + l != links_.end(); l++) { linkOffsets_.push_back(count); // Setting the destination offset makes the link usable. // TODO: change @@ -548,40 +471,31 @@ void Input::initialize() data_.allocateBuffer(count); // Zero the inputs (required for inspectors) - if (count != 0) - { - void * buffer = data_.getBuffer(); + if (count != 0) { + void *buffer = data_.getBuffer(); size_t byteCount = count * BasicType::getSize(data_.getType()); ::memset(buffer, 0, byteCount); } - NTA_CHECK(splitterMap_.size() == 0); // create the splitter map by getting the contributions // from each link. - if(isRegionLevel_) - { + if (isRegionLevel_) { splitterMap_.resize(1); - } - else - { + } else { splitterMap_.resize(region_.getDimensions().getCount()); } - for (std::vector::const_iterator link = links_.begin(); - link != links_.end(); link++) - { + link != links_.end(); link++) { (*link)->buildSplitterMap(splitterMap_); } - initialized_ = true; } -void Input::uninitialize() -{ +void Input::uninitialize() { if (!initialized_) return; @@ -592,23 +506,13 @@ void Input::uninitialize() splitterMap_.clear(); } -bool Input::isInitialized() -{ - return(initialized_); -} +bool Input::isInitialized() { return (initialized_); } -void Input::setName(const std::string& name) -{ - name_ = name; -} +void Input::setName(const std::string &name) { name_ = name; } -const std::string& Input::getName() const -{ - return name_; -} +const std::string &Input::getName() const { return name_; } -const std::vector< std::vector >& Input::getSplitterMap() const -{ +const std::vector> &Input::getSplitterMap() const { NTA_CHECK(initialized_); // Originally the splitter map was created on demand in this method. // For now we have moved splitter map creation to initialize() because @@ -618,29 +522,34 @@ const std::vector< std::vector >& Input::getSplitterMap() const return splitterMap_; } - -template void Input::getInputForNode(size_t nodeIndex, std::vector& input) const -{ +template +void Input::getInputForNode(size_t nodeIndex, std::vector &input) const { NTA_CHECK(initialized_); - const SplitterMap& sm = getSplitterMap(); + const SplitterMap &sm = getSplitterMap(); NTA_CHECK(nodeIndex < sm.size()); - const std::vector& map = sm[nodeIndex]; - //NTA_CHECK(map.size() > 0); + const std::vector &map = sm[nodeIndex]; + // NTA_CHECK(map.size() > 0); input.resize(map.size()); - T* fullInput = (T*)(data_.getBuffer()); + T *fullInput = (T *)(data_.getBuffer()); for (size_t i = 0; i < map.size(); i++) input[i] = fullInput[map[i]]; } -template void Input::getInputForNode(size_t nodeIndex, std::vector& input) const; -template void Input::getInputForNode(size_t nodeIndex, std::vector& input) const; -template void Input::getInputForNode(size_t nodeIndex, std::vector& input) const; -template void Input::getInputForNode(size_t nodeIndex, std::vector& input) const; -template void Input::getInputForNode(size_t nodeIndex, std::vector& input) const; -template void Input::getInputForNode(size_t nodeIndex, std::vector& input) const; -template void Input::getInputForNode(size_t nodeIndex, std::vector& input) const; - -} - +template void Input::getInputForNode(size_t nodeIndex, + std::vector &input) const; +template void Input::getInputForNode(size_t nodeIndex, + std::vector &input) const; +template void Input::getInputForNode(size_t nodeIndex, + std::vector &input) const; +template void Input::getInputForNode(size_t nodeIndex, + std::vector &input) const; +template void Input::getInputForNode(size_t nodeIndex, + std::vector &input) const; +template void Input::getInputForNode(size_t nodeIndex, + std::vector &input) const; +template void Input::getInputForNode(size_t nodeIndex, + std::vector &input) const; + +} // namespace nupic diff --git a/src/nupic/engine/Input.hpp b/src/nupic/engine/Input.hpp index a10fd9d711..c661c82bb8 100644 --- a/src/nupic/engine/Input.hpp +++ b/src/nupic/engine/Input.hpp @@ -33,274 +33,254 @@ #error "Input class should not be wrapped" #endif - -#include -#include #include +#include +#include + +namespace nupic { +class Link; +class Region; +class Output; + +/** + * Represents a named input to a Region. (e.g. bottomUpIn) + * + * @note Input is not available in the public API, but is visible by + * the RegionImpl. + * + * @todo identify methods that may be called by RegionImpl -- this + * is the internal "public interface" + */ +class Input { +public: + /** + * Constructor. + * + * @param region + * The region that the input belongs to. + * @param type + * The type of the input, i.e. TODO + * @param isRegionLevel + * Whether the input is region level, i.e. TODO + */ + Input(Region ®ion, NTA_BasicType type, bool isRegionLevel); + + /** + * + * Destructor. + * + */ + ~Input(); + + /** + * + * Set the name for the input. + * + * Inputs need to know their own name for error messages. + * + * @param name + * The name of the input + * + */ + void setName(const std::string &name); + + /** + * Get the name of the input. + * + * @return + * The name of the input + */ + const std::string &getName() const; + + /** + * Add the given link to this input and to the list of links on the output + * + * @param link + * The link to add. + * @param srcOutput + * The output of previous Region, which is also the source of the input + */ + void addLink(Link *link, Output *srcOutput); + + /** + * Locate an existing Link to the input. + * + * It's called by Network.removeLink() and internally when adding a link + * + * @param srcRegionName + * The name of the source Region + * @param srcOutputName + * The name of the source Output + * + * @returns + * The link if found or @c NULL if no such link exists + */ + Link *findLink(const std::string &srcRegionName, + const std::string &srcOutputName); + + /** + * Removing an existing link from the input. + * + * It's called in four cases: + * + * 1. Network.removeLink() + * 2. Network.removeRegion() when given @a srcRegion + * 3. Network.removeRegion() when given @a destRegion + * 4. Network.~Network() + * + * It is an error to call this if our containing region + * is uninitialized. + * + * @note This method will set the Link pointer to NULL on return (to avoid + * a dangling reference) + * + * @param link + * The Link to remove, possibly retrieved by findLink(), note that + * it is a reference to the pointer, not the pointer itself. + */ + void removeLink(Link *&link); + + /** + * Make input data available. + * + * Called by Region.prepareInputs() + */ + void prepare(); -namespace nupic -{ - class Link; - class Region; - class Output; + /** + * + * Get the data of the input. + * + * @returns + * A mutable reference to the data of the input as an @c Array + */ + const Array &getData() const; + + /** + * + * Get the Region that the input belongs to. + * + * @returns + * The mutable reference to the Region that the input belongs to + */ + Region &getRegion(); + + /** + * + * Get all the Link objects added to the input. + * + * @returns + * All the Link objects added to the input + */ + const std::vector &getLinks(); /** - * Represents a named input to a Region. (e.g. bottomUpIn) * - * @note Input is not available in the public API, but is visible by - * the RegionImpl. + * Tells whether the input is region level. * - * @todo identify methods that may be called by RegionImpl -- this - * is the internal "public interface" + * @returns + * Whether the input is region level, i.e. TODO */ - class Input - { - public: - - /** - * Constructor. - * - * @param region - * The region that the input belongs to. - * @param type - * The type of the input, i.e. TODO - * @param isRegionLevel - * Whether the input is region level, i.e. TODO - */ - Input(Region& region, NTA_BasicType type, bool isRegionLevel); - - /** - * - * Destructor. - * - */ - ~Input(); - - /** - * - * Set the name for the input. - * - * Inputs need to know their own name for error messages. - * - * @param name - * The name of the input - * - */ - void setName(const std::string& name); - - /** - * Get the name of the input. - * - * @return - * The name of the input - */ - const std::string& getName() const; - - /** - * Add the given link to this input and to the list of links on the output - * - * @param link - * The link to add. - * @param srcOutput - * The output of previous Region, which is also the source of the input - */ - void - addLink(Link* link, Output* srcOutput); - - /** - * Locate an existing Link to the input. - * - * It's called by Network.removeLink() and internally when adding a link - * - * @param srcRegionName - * The name of the source Region - * @param srcOutputName - * The name of the source Output - * - * @returns - * The link if found or @c NULL if no such link exists - */ - Link* - findLink(const std::string& srcRegionName, - const std::string& srcOutputName); - - /** - * Removing an existing link from the input. - * - * It's called in four cases: - * - * 1. Network.removeLink() - * 2. Network.removeRegion() when given @a srcRegion - * 3. Network.removeRegion() when given @a destRegion - * 4. Network.~Network() - * - * It is an error to call this if our containing region - * is uninitialized. - * - * @note This method will set the Link pointer to NULL on return (to avoid - * a dangling reference) - * - * @param link - * The Link to remove, possibly retrieved by findLink(), note that - * it is a reference to the pointer, not the pointer itself. - */ - void - removeLink(Link*& link); - - /** - * Make input data available. - * - * Called by Region.prepareInputs() - */ - void - prepare(); - - /** - * - * Get the data of the input. - * - * @returns - * A mutable reference to the data of the input as an @c Array - */ - const Array & - getData() const; - - /** - * - * Get the Region that the input belongs to. - * - * @returns - * The mutable reference to the Region that the input belongs to - */ - Region& - getRegion(); - - /** - * - * Get all the Link objects added to the input. - * - * @returns - * All the Link objects added to the input - */ - const std::vector& - getLinks(); - - /** - * - * Tells whether the input is region level. - * - * @returns - * Whether the input is region level, i.e. TODO - */ - bool - isRegionLevel(); - - /** - * Called by Region.evaluateLinks() as part - * of network initialization. - * - * 1. Tries to make sure that dimensions at both ends - * of a link are specified by calling setSourceDimensions() - * if possible, and then calling getDestDimensions() - * 2. Ensures that region dimensions are consistent with - * either by setting destination region dimensions (this is - * where links "induce" dimensions) or by raising an exception - * if they are inconsistent. - * - * @returns - * Number of links that could not be fully evaluated, i.e. incomplete - */ - size_t - evaluateLinks(); - - /** - * Initialize the Input . - * - * After the input has all the information it needs, it is initialized by - * this method. Volatile data structures (e.g. the input buffer) are set up。 - */ - void - initialize(); - - /** - * Tells whether the Input is initialized. - * - * @returns - * Whether the Input is initialized - */ - bool - isInitialized(); - - /* ------------ Methods normally called by the RegionImpl ------------- */ - - /** - * - * @see Link.buildSplitterMap() - * - */ - typedef std::vector< std::vector > SplitterMap; - - /** - * - * Get splitter map from an initialized input - * - * @returns - * The splitter map - */ - const SplitterMap& getSplitterMap() const; - - /** explicitly instantiated for various types */ - template void getInputForNode(size_t nodeIndex, std::vector& input) const; - - private: - Region& region_; - // buffer is concatenation of input buffers (after prepare), or, - // if zeroCopyEnabled it points to the connected output - bool isRegionLevel_; - - - // Use a vector of links because order is important. - std::vector links_; - - // volatile (non-serialized) state - bool initialized_; - Array data_; - - /* - * cached splitter map -- only created if requested - * mutable because getSplitterMap() is const and logically - * getting the splitter map doesn't change the Input - */ - mutable SplitterMap splitterMap_; - - - /* - * Cache of information about link offsets so we can - * easily copy data from each link. - * the first link starts at offset 0 - * the next link starts at offset 0 + size(link[0]) - */ - std::vector linkOffsets_; - - // Useful for us to know our own name - std::string name_; - - // Internal methods - - /* - * uninitialize is called by removeLink - * and in our destructor. It is an error - * to call it if our region is initialized. - * It frees the input buffer and the splitter map - * but does not affect the links. - */ - void - uninitialize(); - - - - }; - -} + bool isRegionLevel(); + + /** + * Called by Region.evaluateLinks() as part + * of network initialization. + * + * 1. Tries to make sure that dimensions at both ends + * of a link are specified by calling setSourceDimensions() + * if possible, and then calling getDestDimensions() + * 2. Ensures that region dimensions are consistent with + * either by setting destination region dimensions (this is + * where links "induce" dimensions) or by raising an exception + * if they are inconsistent. + * + * @returns + * Number of links that could not be fully evaluated, i.e. incomplete + */ + size_t evaluateLinks(); + + /** + * Initialize the Input . + * + * After the input has all the information it needs, it is initialized by + * this method. Volatile data structures (e.g. the input buffer) are set up。 + */ + void initialize(); + + /** + * Tells whether the Input is initialized. + * + * @returns + * Whether the Input is initialized + */ + bool isInitialized(); + + /* ------------ Methods normally called by the RegionImpl ------------- */ + + /** + * + * @see Link.buildSplitterMap() + * + */ + typedef std::vector> SplitterMap; + + /** + * + * Get splitter map from an initialized input + * + * @returns + * The splitter map + */ + const SplitterMap &getSplitterMap() const; + + /** explicitly instantiated for various types */ + template + void getInputForNode(size_t nodeIndex, std::vector &input) const; + +private: + Region ®ion_; + // buffer is concatenation of input buffers (after prepare), or, + // if zeroCopyEnabled it points to the connected output + bool isRegionLevel_; + + // Use a vector of links because order is important. + std::vector links_; + + // volatile (non-serialized) state + bool initialized_; + Array data_; + + /* + * cached splitter map -- only created if requested + * mutable because getSplitterMap() is const and logically + * getting the splitter map doesn't change the Input + */ + mutable SplitterMap splitterMap_; + + /* + * Cache of information about link offsets so we can + * easily copy data from each link. + * the first link starts at offset 0 + * the next link starts at offset 0 + size(link[0]) + */ + std::vector linkOffsets_; + + // Useful for us to know our own name + std::string name_; + + // Internal methods + + /* + * uninitialize is called by removeLink + * and in our destructor. It is an error + * to call it if our region is initialized. + * It frees the input buffer and the splitter map + * but does not affect the links. + */ + void uninitialize(); +}; + +} // namespace nupic #endif // NTA_INPUT_HPP diff --git a/src/nupic/engine/Link.cpp b/src/nupic/engine/Link.cpp index 7b2109c1cd..8404feed4b 100644 --- a/src/nupic/engine/Link.cpp +++ b/src/nupic/engine/Link.cpp @@ -24,66 +24,53 @@ * Implementation of the Link class */ #include // memcpy,memset +#include #include -#include -#include -#include #include -#include -#include +#include #include +#include #include #include - +#include +#include // Set this to true when debugging to enable handy debug-level logging of data // moving through the links, including the delayed link transitions. #define _LINK_DEBUG false +namespace nupic { -namespace nupic -{ - -Link::Link(const std::string& linkType, const std::string& linkParams, - const std::string& srcRegionName, const std::string& destRegionName, - const std::string& srcOutputName, const std::string& destInputName, - const size_t propagationDelay): - srcBuffer_(0) -{ - commonConstructorInit_(linkType, linkParams, - srcRegionName, destRegionName, - srcOutputName, destInputName, - propagationDelay); - +Link::Link(const std::string &linkType, const std::string &linkParams, + const std::string &srcRegionName, const std::string &destRegionName, + const std::string &srcOutputName, const std::string &destInputName, + const size_t propagationDelay) + : srcBuffer_(0) { + commonConstructorInit_(linkType, linkParams, srcRegionName, destRegionName, + srcOutputName, destInputName, propagationDelay); } -Link::Link(const std::string& linkType, const std::string& linkParams, - Output* srcOutput, Input* destInput, const size_t propagationDelay): - srcBuffer_(0) -{ - commonConstructorInit_(linkType, linkParams, - srcOutput->getRegion().getName(), - destInput->getRegion().getName(), - srcOutput->getName(), - destInput->getName(), - propagationDelay); +Link::Link(const std::string &linkType, const std::string &linkParams, + Output *srcOutput, Input *destInput, const size_t propagationDelay) + : srcBuffer_(0) { + commonConstructorInit_(linkType, linkParams, srcOutput->getRegion().getName(), + destInput->getRegion().getName(), srcOutput->getName(), + destInput->getName(), propagationDelay); connectToNetwork(srcOutput, destInput); - // Note -- link is not usable until we set the destOffset, which happens at initialization time -} - - -Link::Link(): - srcBuffer_(0) -{ + // Note -- link is not usable until we set the destOffset, which happens at + // initialization time } +Link::Link() : srcBuffer_(0) {} -void Link::commonConstructorInit_(const std::string& linkType, const std::string& linkParams, - const std::string& srcRegionName, const std::string& destRegionName, - const std::string& srcOutputName, const std::string& destInputName, - const size_t propagationDelay) -{ +void Link::commonConstructorInit_(const std::string &linkType, + const std::string &linkParams, + const std::string &srcRegionName, + const std::string &destRegionName, + const std::string &srcOutputName, + const std::string &destInputName, + const size_t propagationDelay) { linkType_ = linkType; linkParams_ = linkParams; srcRegionName_ = srcRegionName; @@ -98,22 +85,15 @@ void Link::commonConstructorInit_(const std::string& linkType, const std::string dest_ = nullptr; initialized_ = false; - impl_ = LinkPolicyFactory().createLinkPolicy(linkType, linkParams, this); } -Link::~Link() -{ - delete impl_; -} - +Link::~Link() { delete impl_; } void Link::initPropagationDelayBuffer_(size_t propagationDelay, NTA_BasicType dataElementType, - size_t dataElementCount) -{ - if (srcBuffer_.capacity() != 0 || !propagationDelay) - { + size_t dataElementCount) { + if (srcBuffer_.capacity() != 0 || !propagationDelay) { // Already initialized(e.g., as result of deserialization); or a 0-delay // link, which doesn't use buffering. return; @@ -123,11 +103,10 @@ void Link::initPropagationDelayBuffer_(size_t propagationDelay, srcBuffer_.set_capacity(propagationDelay); // Initialize delay data elements - size_t dataBufferSize = dataElementCount * - BasicType::getSize(dataElementType); + size_t dataBufferSize = + dataElementCount * BasicType::getSize(dataElementType); - for(size_t i=0; i < propagationDelay; i++) - { + for (size_t i = 0; i < propagationDelay; i++) { Array arrayTemplate(dataElementType); srcBuffer_.push_back(arrayTemplate); @@ -138,9 +117,7 @@ void Link::initPropagationDelayBuffer_(size_t propagationDelay, } } - -void Link::initialize(size_t destinationOffset) -{ +void Link::initialize(size_t destinationOffset) { // Make sure all information is specified and // consistent. Unless there is a NuPIC implementation // error, all these checks are guaranteed to pass @@ -153,60 +130,47 @@ void Link::initialize(size_t destinationOffset) // Confirm that our dimensions are consistent with the // dimensions of the regions we're connecting. - const Dimensions& srcD = getSrcDimensions(); - const Dimensions& destD = getDestDimensions(); - NTA_CHECK(! srcD.isUnspecified()); - NTA_CHECK(! destD.isUnspecified()); + const Dimensions &srcD = getSrcDimensions(); + const Dimensions &destD = getDestDimensions(); + NTA_CHECK(!srcD.isUnspecified()); + NTA_CHECK(!destD.isUnspecified()); Dimensions oneD; oneD.push_back(1); - if(src_->isRegionLevel()) - { + if (src_->isRegionLevel()) { Dimensions d; - for(size_t i = 0; i < src_->getRegion().getDimensions().size(); i++) - { + for (size_t i = 0; i < src_->getRegion().getDimensions().size(); i++) { d.push_back(1); } NTA_CHECK(srcD.isDontcare() || srcD == d); - } - else if(src_->getRegion().getDimensions() == oneD) - { + } else if (src_->getRegion().getDimensions() == oneD) { Dimensions d; - for(size_t i = 0; i < srcD.size(); i++) - { + for (size_t i = 0; i < srcD.size(); i++) { d.push_back(1); } NTA_CHECK(srcD.isDontcare() || srcD == d); - } - else - { + } else { NTA_CHECK(srcD.isDontcare() || srcD == src_->getRegion().getDimensions()); } - if(dest_->isRegionLevel()) - { + if (dest_->isRegionLevel()) { Dimensions d; - for(size_t i = 0; i < dest_->getRegion().getDimensions().size(); i++) - { + for (size_t i = 0; i < dest_->getRegion().getDimensions().size(); i++) { d.push_back(1); } NTA_CHECK(destD.isDontcare() || destD.isOnes()); - } - else if(dest_->getRegion().getDimensions() == oneD) - { + } else if (dest_->getRegion().getDimensions() == oneD) { Dimensions d; - for(size_t i = 0; i < destD.size(); i++) - { + for (size_t i = 0; i < destD.size(); i++) { d.push_back(1); } NTA_CHECK(destD.isDontcare() || destD == d); - } - else - { - NTA_CHECK(destD.isDontcare() || destD == dest_->getRegion().getDimensions()); + } else { + NTA_CHECK(destD.isDontcare() || + destD == dest_->getRegion().getDimensions()); } destOffset_ = destinationOffset; @@ -215,114 +179,85 @@ void Link::initialize(size_t destinationOffset) // --- // Initialize the propagation delay buffer // --- - initPropagationDelayBuffer_(propagationDelay_, - src_->getData().getType(), + initPropagationDelayBuffer_(propagationDelay_, src_->getData().getType(), src_->getData().getCount()); initialized_ = true; - } -void Link::setSrcDimensions(Dimensions& dims) -{ +void Link::setSrcDimensions(Dimensions &dims) { NTA_CHECK(src_ != nullptr && dest_ != nullptr) - << "Link::setSrcDimensions() can only be called on a connected link"; + << "Link::setSrcDimensions() can only be called on a connected link"; size_t nodeElementCount = src_->getNodeOutputElementCount(); - if(nodeElementCount == 0) - { + if (nodeElementCount == 0) { nodeElementCount = - src_->getRegion().getNodeOutputElementCount(src_->getName()); + src_->getRegion().getNodeOutputElementCount(src_->getName()); } impl_->setNodeOutputElementCount(nodeElementCount); impl_->setSrcDimensions(dims); } -void Link::setDestDimensions(Dimensions& dims) -{ +void Link::setDestDimensions(Dimensions &dims) { NTA_CHECK(src_ != nullptr && dest_ != nullptr) - << "Link::setDestDimensions() can only be called on a connected link"; + << "Link::setDestDimensions() can only be called on a connected link"; size_t nodeElementCount = src_->getNodeOutputElementCount(); - if(nodeElementCount == 0) - { + if (nodeElementCount == 0) { nodeElementCount = - src_->getRegion().getNodeOutputElementCount(src_->getName()); + src_->getRegion().getNodeOutputElementCount(src_->getName()); } impl_->setNodeOutputElementCount(nodeElementCount); impl_->setDestDimensions(dims); } -const Dimensions& Link::getSrcDimensions() const -{ +const Dimensions &Link::getSrcDimensions() const { return impl_->getSrcDimensions(); }; -const Dimensions& Link::getDestDimensions() const -{ +const Dimensions &Link::getDestDimensions() const { return impl_->getDestDimensions(); }; // Return constructor params -const std::string& Link::getLinkType() const -{ - return linkType_; -} +const std::string &Link::getLinkType() const { return linkType_; } -const std::string& Link::getLinkParams() const -{ - return linkParams_; -} +const std::string &Link::getLinkParams() const { return linkParams_; } -const std::string& Link::getSrcRegionName() const -{ - return srcRegionName_; -} +const std::string &Link::getSrcRegionName() const { return srcRegionName_; } -const std::string& Link::getSrcOutputName() const -{ - return srcOutputName_; -} +const std::string &Link::getSrcOutputName() const { return srcOutputName_; } -const std::string& Link::getDestRegionName() const -{ - return destRegionName_; -} +const std::string &Link::getDestRegionName() const { return destRegionName_; } -const std::string& Link::getDestInputName() const -{ - return destInputName_; -} +const std::string &Link::getDestInputName() const { return destInputName_; } -std::string Link::getMoniker() const -{ +std::string Link::getMoniker() const { std::stringstream ss; ss << getSrcRegionName() << "." << getSrcOutputName() << "-->" << getDestRegionName() << "." << getDestInputName(); return ss.str(); } -const std::string Link::toString() const -{ +const std::string Link::toString() const { std::stringstream ss; ss << "[" << getSrcRegionName() << "." << getSrcOutputName(); - if (src_) - { - ss << " (region dims: " << src_->getRegion().getDimensions().toString() << ") "; + if (src_) { + ss << " (region dims: " << src_->getRegion().getDimensions().toString() + << ") "; } - ss << " to " << getDestRegionName() << "." << getDestInputName() ; - if (dest_) - { - ss << " (region dims: " << dest_->getRegion().getDimensions().toString() << ") "; + ss << " to " << getDestRegionName() << "." << getDestInputName(); + if (dest_) { + ss << " (region dims: " << dest_->getRegion().getDimensions().toString() + << ") "; } ss << " type: " << linkType_ << "]"; return ss.str(); } -void Link::connectToNetwork(Output *src, Input *dest) -{ +void Link::connectToNetwork(Output *src, Input *dest) { NTA_CHECK(src != nullptr); NTA_CHECK(dest != nullptr); @@ -330,26 +265,22 @@ void Link::connectToNetwork(Output *src, Input *dest) dest_ = dest; } - // The methods below only work on connected links. -Output& Link::getSrc() const +Output &Link::getSrc() const { NTA_CHECK(src_ != nullptr) - << "Link::getSrc() can only be called on a connected link"; + << "Link::getSrc() can only be called on a connected link"; return *src_; } -Input& Link::getDest() const -{ +Input &Link::getDest() const { NTA_CHECK(dest_ != nullptr) - << "Link::getDest() can only be called on a connected link"; + << "Link::getDest() can only be called on a connected link"; return *dest_; } -void -Link::buildSplitterMap(Input::SplitterMap& splitter) -{ +void Link::buildSplitterMap(Input::SplitterMap &splitter) { // The link policy generates a splitter map // at the element level. Here we convert it // to a full splitter map @@ -363,56 +294,46 @@ Link::buildSplitterMap(Input::SplitterMap& splitter) impl_->setNodeOutputElementCount(nodeElementCount); impl_->buildProtoSplitterMap(protoSplitter); - for (size_t destNode = 0; destNode < splitter.size(); destNode++) - { + for (size_t destNode = 0; destNode < splitter.size(); destNode++) { // convert proto-splitter values into real // splitter values; - for (auto & elem : protoSplitter[destNode]) - { + for (auto &elem : protoSplitter[destNode]) { size_t srcElement = elem; size_t elementOffset = srcElement + destOffset_; splitter[destNode].push_back(elementOffset); } - } } -void -Link::compute() -{ +void Link::compute() { NTA_CHECK(initialized_); - if (propagationDelay_) - { + if (propagationDelay_) { NTA_CHECK(!srcBuffer_.empty()); } // Copy data from source to destination. For delayed links, will copy from // head of circular queue; otherwise directly from source. - const Array & src = propagationDelay_ ? srcBuffer_[0] : src_->getData(); + const Array &src = propagationDelay_ ? srcBuffer_[0] : src_->getData(); - const Array & dest = dest_->getData(); + const Array &dest = dest_->getData(); size_t typeSize = BasicType::getSize(src.getType()); size_t srcSize = src.getCount() * typeSize; size_t destByteOffset = destOffset_ * typeSize; - if (_LINK_DEBUG) - { - NTA_DEBUG << "Link::compute: " << getMoniker() - << "; copying to dest input" - << "; delay=" << propagationDelay_ << "; " - << src.getCount() << " elements=" << src; + if (_LINK_DEBUG) { + NTA_DEBUG << "Link::compute: " << getMoniker() << "; copying to dest input" + << "; delay=" << propagationDelay_ << "; " << src.getCount() + << " elements=" << src; } - ::memcpy((char*)(dest.getBuffer()) + destByteOffset, src.getBuffer(), srcSize); + ::memcpy((char *)(dest.getBuffer()) + destByteOffset, src.getBuffer(), + srcSize); } - -void Link::shiftBufferedData() -{ - if (!propagationDelay_) - { +void Link::shiftBufferedData() { + if (!propagationDelay_) { // Source buffering is not used in 0-delay links return; } @@ -424,29 +345,24 @@ void Link::shiftBufferedData() // Pop head of circular queue - if (_LINK_DEBUG) - { + if (_LINK_DEBUG) { NTA_DEBUG << "Link::shiftBufferedData: " << getMoniker() - << "; popping head; " - << srcBuffer_[0].getCount() << " elements=" - << srcBuffer_[0]; + << "; popping head; " << srcBuffer_[0].getCount() + << " elements=" << srcBuffer_[0]; } srcBuffer_.pop_front(); - // Append the current src value to circular queue - const Array & srcArray = src_->getData(); + const Array &srcArray = src_->getData(); size_t elementCount = srcArray.getCount(); auto elementType = srcArray.getType(); - if (_LINK_DEBUG) - { + if (_LINK_DEBUG) { NTA_DEBUG << "Link::shiftBufferedData: " << getMoniker() - << "; appending src to circular buffer; " - << elementCount << " elements=" - << srcArray; + << "; appending src to circular buffer; " << elementCount + << " elements=" << srcArray; NTA_DEBUG << "Link::shiftBufferedData: " << getMoniker() << "; num arrays in circular buffer before append; " @@ -456,23 +372,19 @@ void Link::shiftBufferedData() Array array(elementType); srcBuffer_.push_back(array); - auto & lastElement = srcBuffer_.back(); + auto &lastElement = srcBuffer_.back(); lastElement.allocateBuffer(elementCount); ::memcpy(lastElement.getBuffer(), srcArray.getBuffer(), elementCount * BasicType::getSize(elementType)); - if (_LINK_DEBUG) - { + if (_LINK_DEBUG) { NTA_DEBUG << "Link::shiftBufferedData: " << getMoniker() << "; circular buffer head after append is: " - << srcBuffer_[0].getCount() << " elements=" - << srcBuffer_[0]; + << srcBuffer_[0].getCount() << " elements=" << srcBuffer_[0]; } } - -void Link::write(LinkProto::Builder& proto) const -{ +void Link::write(LinkProto::Builder &proto) const { proto.setType(linkType_.c_str()); proto.setParams(linkParams_.c_str()); proto.setSrcRegion(srcRegionName_.c_str()); @@ -482,60 +394,52 @@ void Link::write(LinkProto::Builder& proto) const // Save delayed outputs auto delayedOutputsBuilder = proto.initDelayedOutputs(propagationDelay_); - for (size_t i=0; i < propagationDelay_; ++i) - { + for (size_t i = 0; i < propagationDelay_; ++i) { ArrayProtoUtils::copyArrayToArrayProto(srcBuffer_[i], delayedOutputsBuilder[i]); } } - -void Link::read(LinkProto::Reader& proto) -{ +void Link::read(LinkProto::Reader &proto) { const auto delayedOutputsReader = proto.getDelayedOutputs(); commonConstructorInit_( proto.getType().cStr(), proto.getParams().cStr(), proto.getSrcRegion().cStr(), proto.getDestRegion().cStr(), proto.getSrcOutput().cStr(), proto.getDestInput().cStr(), - delayedOutputsReader.size()/*propagationDelay*/); + delayedOutputsReader.size() /*propagationDelay*/); - if (delayedOutputsReader.size()) - { + if (delayedOutputsReader.size()) { // Initialize the propagation delay buffer with delay arrays having 0 // elements that deserialization logic will replace with appropriately-sized // buffers. initPropagationDelayBuffer_( - propagationDelay_, - ArrayProtoUtils::getArrayTypeFromArrayProtoReader(delayedOutputsReader[0]), - 0); + propagationDelay_, + ArrayProtoUtils::getArrayTypeFromArrayProtoReader( + delayedOutputsReader[0]), + 0); // Populate delayed outputs - for (size_t i=0; i < propagationDelay_; ++i) - { - ArrayProtoUtils::copyArrayProtoToArray(delayedOutputsReader[i], - srcBuffer_[i], - true/*allocArrayBuffer*/); + for (size_t i = 0; i < propagationDelay_; ++i) { + ArrayProtoUtils::copyArrayProtoToArray( + delayedOutputsReader[i], srcBuffer_[i], true /*allocArrayBuffer*/); } } } - -namespace nupic -{ - std::ostream& operator<<(std::ostream& f, const Link& link) - { - f << "\n"; - f << " " << link.getLinkType() << "\n"; - f << " " << link.getLinkParams() << "\n"; - f << " " << link.getSrcRegionName() << "\n"; - f << " " << link.getDestRegionName() << "\n"; - f << " " << link.getSrcOutputName() << "\n"; - f << " " << link.getDestInputName() << "\n"; - f << "\n"; - return f; - } +namespace nupic { +std::ostream &operator<<(std::ostream &f, const Link &link) { + f << "\n"; + f << " " << link.getLinkType() << "\n"; + f << " " << link.getLinkParams() << "\n"; + f << " " << link.getSrcRegionName() << "\n"; + f << " " << link.getDestRegionName() << "\n"; + f << " " << link.getSrcOutputName() << "\n"; + f << " " << link.getDestInputName() << "\n"; + f << "\n"; + return f; } +} // namespace nupic -} +} // namespace nupic diff --git a/src/nupic/engine/Link.hpp b/src/nupic/engine/Link.hpp index db8b356bf4..d59144c9da 100644 --- a/src/nupic/engine/Link.hpp +++ b/src/nupic/engine/Link.hpp @@ -31,444 +31,436 @@ #include -#include #include // needed for splitter map +#include #include #include #include #include #include -namespace nupic -{ +namespace nupic { + +class Output; +class Input; - class Output; - class Input; +/** + * + * Represents a link between regions in a Network. + * + * @nosubgrouping + * + */ +class Link : public Serializable { +public: + /** + * @name Initialization + * + * @{ + * + * Links have four-phase initialization. + * + * 1. construct with link type, params, names of regions and inputs/outputs + * 2. wire in to network (setting src and dest Output/Input pointers) + * 3. set source and destination dimensions + * 4. initialize -- sets the offset in the destination Input (not known + * earlier) + * + * De-serializing is the same as phase 1. + * + * In phase 3, NuPIC will set and/or get source and/or destination + * dimensions until both are set. Normally we will only set the src + * dimensions, and the dest dimensions will be induced. It is possible to go + * the other way, though. + * + * The @a linkType and @a linkParams parameters are given to + * the LinkPolicyFactory to create a link policy + * + * @todo Should LinkPolicyFactory be documented? + * + */ /** + * Initialization Phase 1: setting parameters of the link. + * + * @param linkType + * The type of the link + * @param linkParams + * The parameters of the link + * @param srcRegionName + * The name of the source Region + * @param destRegionName + * The name of the destination Region + * @param srcOutputName + * The name of the source Output + * @param destInputName + * The name of the destination Input + * @param propagationDelay + * Propagation delay of the link as number of network run + * iterations involving the link as input; the delay vectors, if + * any, are initially populated with 0's. Defaults to 0=no delay. + * Per design, data on no-delay links is to become available to + * destination inputs within the same time step, while data on + * delayed links (propagationDelay > 0) is to be updated + * "atomically" between time steps. * - * Represents a link between regions in a Network. + * @internal * - * @nosubgrouping + * @todo It seems this constructor should be deprecated in favor of the other, + * which is less redundant. This constructor is being used for unit testing + * and unit testing links and for deserializing networks. + * + * See comments below commonConstructorInit_() + * + * @endinternal * */ - class Link : public Serializable - { - public: - - /** - * @name Initialization - * - * @{ - * - * Links have four-phase initialization. - * - * 1. construct with link type, params, names of regions and inputs/outputs - * 2. wire in to network (setting src and dest Output/Input pointers) - * 3. set source and destination dimensions - * 4. initialize -- sets the offset in the destination Input (not known earlier) - * - * De-serializing is the same as phase 1. - * - * In phase 3, NuPIC will set and/or get source and/or destination - * dimensions until both are set. Normally we will only set the src dimensions, - * and the dest dimensions will be induced. It is possible to go the other - * way, though. - * - * The @a linkType and @a linkParams parameters are given to - * the LinkPolicyFactory to create a link policy - * - * @todo Should LinkPolicyFactory be documented? - * - */ - - /** - * Initialization Phase 1: setting parameters of the link. - * - * @param linkType - * The type of the link - * @param linkParams - * The parameters of the link - * @param srcRegionName - * The name of the source Region - * @param destRegionName - * The name of the destination Region - * @param srcOutputName - * The name of the source Output - * @param destInputName - * The name of the destination Input - * @param propagationDelay - * Propagation delay of the link as number of network run - * iterations involving the link as input; the delay vectors, if - * any, are initially populated with 0's. Defaults to 0=no delay. - * Per design, data on no-delay links is to become available to - * destination inputs within the same time step, while data on - * delayed links (propagationDelay > 0) is to be updated - * "atomically" between time steps. - * - * @internal - * - * @todo It seems this constructor should be deprecated in favor of the other, - * which is less redundant. This constructor is being used for unit testing - * and unit testing links and for deserializing networks. - * - * See comments below commonConstructorInit_() - * - * @endinternal - * - */ - Link(const std::string& linkType, const std::string& linkParams, - const std::string& srcRegionName, const std::string& destRegionName, - const std::string& srcOutputName="", - const std::string& destInputName="", - const size_t propagationDelay=0); - - /** - * De-serialization use case. Creates a "blank" link. The caller must follow - * up with Link::read and Link::connectToNetwork - * - * @param proto - * LinkProto::Reader - */ - Link(); - - - /** - * Initialization Phase 2: connecting inputs/outputs to - * the Network. - * - * @param src - * The source Output of the link - * @param dest - * The destination Input of the link - */ - void connectToNetwork(Output* src, Input* dest); - - /* - * Initialization Phase 1 and 2. - * - * @param linkType - * The type of the link - * @param linkParams - * The parameters of the link - * @param srcOutput - * The source Output of the link - * @param destInput - * The destination Input of the link - * @param propagationDelay - * Propagation delay of the link as number of network run - * iterations involving the link as input; the delay vectors, if - * any, are initially populated with 0's. Defaults to 0=no delay - */ - Link(const std::string& linkType, const std::string& linkParams, - Output* srcOutput, Input* destInput, size_t propagationDelay=0); - - /** - * Initialization Phase 3: set the Dimensions for the source Output, and - * induce the Dimensions for the destination Input . - * - * - * @param dims - * The Dimensions for the source Output - */ - void setSrcDimensions(Dimensions& dims); - - /** - * Initialization Phase 3: Set the Dimensions for the destination Input, and - * induce the Dimensions for the source Output . - * - * @param dims - * The Dimensions for the destination Input - */ - void setDestDimensions(Dimensions& dims); - - /** - * Initialization Phase 4: sets the offset in the destination Input . - * - * @param destinationOffset - * The offset in the destination Input, i.e. TODO - * - */ - void initialize(size_t destinationOffset); - - /** - * Destructor - */ - ~Link(); - - /** - * @} - * - * @name Parameter getters of the link - * - * @{ - * - */ - - /** - * Get the Dimensions for the source Output . - * - * @returns - * The Dimensions for the source Output - */ - const Dimensions& getSrcDimensions() const; - - /** - * Get the Dimensions for the destination Input . - * - * @returns - * The Dimensions for the destination Input - */ - const Dimensions& getDestDimensions() const; - - /** - * Get the type of the link. - * - * @returns - * The type of the link - */ - const std::string& getLinkType() const; - - /** - * Get the parameters of the link. - * - * @returns - * The parameters of the link - */ - const std::string& getLinkParams() const; - - /** - * Get the name of the source Region - * - * @returns - * The name of the source Region - */ - const std::string& getSrcRegionName() const; - - /** - * Get the name of the source Output. - * - * @returns - * The name of the source Output - */ - const std::string& getSrcOutputName() const; - - /** - * Get the name of the destination Region. - * - * @returns - * The name of the destination Region - * - */ - const std::string& getDestRegionName() const; - - /** - * Get the name of the destination Input. - * - * @returns - * The name of the destination Input - */ - const std::string& getDestInputName() const; - - /** - * @} - * - * @name Misc - * - * @{ - */ - - // The methods below only work on connected links (after phase 2) - - /** - * - * Get a generated name of the link in the form - * RegName.outName --> RegName.inName for debug logging purposes only. - */ - std::string getMoniker() const; - - /** - * - * Get the source Output of the link. - * - * @returns - * The source Output of the link - */ - Output& getSrc() const; - - /** - * - * Get the destination Input of the link. - * - * @returns - * The destination Input of the link - */ - Input& getDest() const; - - /** - * Copy data from source to destination. - * - * Nodes request input data from their input objects. The input objects, - * in turn, request links to copy data into the inputs. - * - * @note This method must be called on a fully initialized link(all 4 phases). - * - */ - void - compute(); - - /** - * Build a splitter map from the link. - * - * @param[out] splitter - * The built SplitterMap - * - * A splitter map is a matrix that maps the full input - * of a region to the inputs of individual nodes within - * the region. - * A splitter map "sm" is declared as: - * - * vector< vector > sm; - * - * sm.length() == number of nodes - * - * `sm[i]` is a "sparse vector" used to gather the input - * for node i. `sm[i].size()` is the size (in elements) of - * the input for node i. - * - * `sm[i]` gathers the inputs as follows: - * - * T *regionInput; // input buffer for the whole region - * T *nodeInput; // pre-allocated - * for (size_t elem = 0; elem < sm[i].size; elem++) - * nodeInput[elem] = regionInput[sm[i][elem]]; - * - * The offset specified by `sm[i][j]` is in units of elements. - * To get byte offsets, you'd multiply by the size of an input/output - * element. - * - * An input to a region may come from several links. - * Each link contributes a contiguous block of the region input - * starting from a certain offset. The splitter map indices are - * with respect to the full region input, not the partial region - * input contributed by this link, so the destinationOffset for this - * link is included in each of the splitter map entries. - * - * Finally, the API is designed so that each link associated with - * an input can contribute its portion to a full splitter map. - * Thus the splitter map is an input-output parameter. This method - * appends data to each row of the splitter map, assuming that - * existing data in the splitter map comes from other links. - * - * For region-level inputs, a splitter map has just a single row. - * - * ### Splitter map ownership - * - * The splitter map is owned by the containing Input. Each Link - * in the input contributes a portion to the splitter map, through - * the buildSplitterMap method. - * - */ - void - buildSplitterMap(Input::SplitterMap& splitter); - - /* - * No-op for links without delay; for delayed links, remove head element of - * the propagation delay buffer and push back the current value from source. - * - * NOTE It's intended that this method be called exactly once on all links - * within a network at the end of every time step. Network::run calls it - * automatically on all links at the end of each time step. - */ - void shiftBufferedData(); - - /** - * Convert the Link to a human-readable string. - * - * @returns - * The human-readable string describing the Link - */ - const std::string toString() const; - - /** - * Serialize the link. - * - * @param f - * The output stream being serialized to - * @param link - * The Link being serialized - */ - friend std::ostream& operator<<(std::ostream& f, const Link& link); - - using Serializable::write; - void write(LinkProto::Builder& proto) const; - - using Serializable::read; - void read(LinkProto::Reader& proto); - - private: - // common initialization for the two constructors. - void commonConstructorInit_(const std::string& linkType, - const std::string& linkParams, - const std::string& srcRegionName, - const std::string& destRegionName, - const std::string& srcOutputName, - const std::string& destInputName, - const size_t propagationDelay); - - void initPropagationDelayBuffer_(size_t propagationDelay, - NTA_BasicType dataElementType, - size_t dataElementCount); - - // TODO: The strings with src/dest names are redundant with - // the src_ and dest_ objects. For unit testing links, - // and for deserializing networks, we need to be able to create - // a link object without a network. and for deserializing, we - // need to be able to instantiate a link before we have instantiated - // all the regions. (Maybe this isn't true? Re-evaluate when - // more infrastructure is in place). - - std::string srcRegionName_; - std::string destRegionName_; - std::string srcOutputName_; - std::string destInputName_; - - // We store the values given to use. Use these for - // serialization instead of serializing the LinkPolicy - // itself. - std::string linkType_; - std::string linkParams_; - - LinkPolicy *impl_; - - Output *src_; - Input *dest_; - - // Each link contributes a contiguous chunk of the destination - // input. The link needs to know its offset within the destination - // input. This value is set at initialization time. - size_t destOffset_; - - // TODO: These are currently unused. Situations where we need them - // are rare. Would they make more sense as link policy params? - // Will also need a link getDestinationSize method since - // the amount of data contributed by this link to the destination input - // may not equal the size of the source output. - size_t srcOffset_; - size_t srcSize_; - - // Circular buffer for delayed source data buffering - boost::circular_buffer srcBuffer_; - // Number of delay slots - size_t propagationDelay_; - - // link must be initialized before it can compute() - bool initialized_; - - }; + Link(const std::string &linkType, const std::string &linkParams, + const std::string &srcRegionName, const std::string &destRegionName, + const std::string &srcOutputName = "", + const std::string &destInputName = "", + const size_t propagationDelay = 0); + /** + * De-serialization use case. Creates a "blank" link. The caller must follow + * up with Link::read and Link::connectToNetwork + * + * @param proto + * LinkProto::Reader + */ + Link(); -} // namespace nupic + /** + * Initialization Phase 2: connecting inputs/outputs to + * the Network. + * + * @param src + * The source Output of the link + * @param dest + * The destination Input of the link + */ + void connectToNetwork(Output *src, Input *dest); + /* + * Initialization Phase 1 and 2. + * + * @param linkType + * The type of the link + * @param linkParams + * The parameters of the link + * @param srcOutput + * The source Output of the link + * @param destInput + * The destination Input of the link + * @param propagationDelay + * Propagation delay of the link as number of network run + * iterations involving the link as input; the delay vectors, if + * any, are initially populated with 0's. Defaults to 0=no delay + */ + Link(const std::string &linkType, const std::string &linkParams, + Output *srcOutput, Input *destInput, size_t propagationDelay = 0); + + /** + * Initialization Phase 3: set the Dimensions for the source Output, and + * induce the Dimensions for the destination Input . + * + * + * @param dims + * The Dimensions for the source Output + */ + void setSrcDimensions(Dimensions &dims); + + /** + * Initialization Phase 3: Set the Dimensions for the destination Input, and + * induce the Dimensions for the source Output . + * + * @param dims + * The Dimensions for the destination Input + */ + void setDestDimensions(Dimensions &dims); + + /** + * Initialization Phase 4: sets the offset in the destination Input . + * + * @param destinationOffset + * The offset in the destination Input, i.e. TODO + * + */ + void initialize(size_t destinationOffset); + + /** + * Destructor + */ + ~Link(); + + /** + * @} + * + * @name Parameter getters of the link + * + * @{ + * + */ + + /** + * Get the Dimensions for the source Output . + * + * @returns + * The Dimensions for the source Output + */ + const Dimensions &getSrcDimensions() const; + + /** + * Get the Dimensions for the destination Input . + * + * @returns + * The Dimensions for the destination Input + */ + const Dimensions &getDestDimensions() const; + + /** + * Get the type of the link. + * + * @returns + * The type of the link + */ + const std::string &getLinkType() const; + + /** + * Get the parameters of the link. + * + * @returns + * The parameters of the link + */ + const std::string &getLinkParams() const; + + /** + * Get the name of the source Region + * + * @returns + * The name of the source Region + */ + const std::string &getSrcRegionName() const; + + /** + * Get the name of the source Output. + * + * @returns + * The name of the source Output + */ + const std::string &getSrcOutputName() const; + + /** + * Get the name of the destination Region. + * + * @returns + * The name of the destination Region + * + */ + const std::string &getDestRegionName() const; + + /** + * Get the name of the destination Input. + * + * @returns + * The name of the destination Input + */ + const std::string &getDestInputName() const; + + /** + * @} + * + * @name Misc + * + * @{ + */ + + // The methods below only work on connected links (after phase 2) + + /** + * + * Get a generated name of the link in the form + * RegName.outName --> RegName.inName for debug logging purposes only. + */ + std::string getMoniker() const; + + /** + * + * Get the source Output of the link. + * + * @returns + * The source Output of the link + */ + Output &getSrc() const; + + /** + * + * Get the destination Input of the link. + * + * @returns + * The destination Input of the link + */ + Input &getDest() const; + + /** + * Copy data from source to destination. + * + * Nodes request input data from their input objects. The input objects, + * in turn, request links to copy data into the inputs. + * + * @note This method must be called on a fully initialized link(all 4 phases). + * + */ + void compute(); + + /** + * Build a splitter map from the link. + * + * @param[out] splitter + * The built SplitterMap + * + * A splitter map is a matrix that maps the full input + * of a region to the inputs of individual nodes within + * the region. + * A splitter map "sm" is declared as: + * + * vector< vector > sm; + * + * sm.length() == number of nodes + * + * `sm[i]` is a "sparse vector" used to gather the input + * for node i. `sm[i].size()` is the size (in elements) of + * the input for node i. + * + * `sm[i]` gathers the inputs as follows: + * + * T *regionInput; // input buffer for the whole region + * T *nodeInput; // pre-allocated + * for (size_t elem = 0; elem < sm[i].size; elem++) + * nodeInput[elem] = regionInput[sm[i][elem]]; + * + * The offset specified by `sm[i][j]` is in units of elements. + * To get byte offsets, you'd multiply by the size of an input/output + * element. + * + * An input to a region may come from several links. + * Each link contributes a contiguous block of the region input + * starting from a certain offset. The splitter map indices are + * with respect to the full region input, not the partial region + * input contributed by this link, so the destinationOffset for this + * link is included in each of the splitter map entries. + * + * Finally, the API is designed so that each link associated with + * an input can contribute its portion to a full splitter map. + * Thus the splitter map is an input-output parameter. This method + * appends data to each row of the splitter map, assuming that + * existing data in the splitter map comes from other links. + * + * For region-level inputs, a splitter map has just a single row. + * + * ### Splitter map ownership + * + * The splitter map is owned by the containing Input. Each Link + * in the input contributes a portion to the splitter map, through + * the buildSplitterMap method. + * + */ + void buildSplitterMap(Input::SplitterMap &splitter); + + /* + * No-op for links without delay; for delayed links, remove head element of + * the propagation delay buffer and push back the current value from source. + * + * NOTE It's intended that this method be called exactly once on all links + * within a network at the end of every time step. Network::run calls it + * automatically on all links at the end of each time step. + */ + void shiftBufferedData(); + + /** + * Convert the Link to a human-readable string. + * + * @returns + * The human-readable string describing the Link + */ + const std::string toString() const; + + /** + * Serialize the link. + * + * @param f + * The output stream being serialized to + * @param link + * The Link being serialized + */ + friend std::ostream &operator<<(std::ostream &f, const Link &link); + + using Serializable::write; + void write(LinkProto::Builder &proto) const; + + using Serializable::read; + void read(LinkProto::Reader &proto); + +private: + // common initialization for the two constructors. + void commonConstructorInit_(const std::string &linkType, + const std::string &linkParams, + const std::string &srcRegionName, + const std::string &destRegionName, + const std::string &srcOutputName, + const std::string &destInputName, + const size_t propagationDelay); + + void initPropagationDelayBuffer_(size_t propagationDelay, + NTA_BasicType dataElementType, + size_t dataElementCount); + + // TODO: The strings with src/dest names are redundant with + // the src_ and dest_ objects. For unit testing links, + // and for deserializing networks, we need to be able to create + // a link object without a network. and for deserializing, we + // need to be able to instantiate a link before we have instantiated + // all the regions. (Maybe this isn't true? Re-evaluate when + // more infrastructure is in place). + + std::string srcRegionName_; + std::string destRegionName_; + std::string srcOutputName_; + std::string destInputName_; + + // We store the values given to use. Use these for + // serialization instead of serializing the LinkPolicy + // itself. + std::string linkType_; + std::string linkParams_; + + LinkPolicy *impl_; + + Output *src_; + Input *dest_; + + // Each link contributes a contiguous chunk of the destination + // input. The link needs to know its offset within the destination + // input. This value is set at initialization time. + size_t destOffset_; + + // TODO: These are currently unused. Situations where we need them + // are rare. Would they make more sense as link policy params? + // Will also need a link getDestinationSize method since + // the amount of data contributed by this link to the destination input + // may not equal the size of the source output. + size_t srcOffset_; + size_t srcSize_; + + // Circular buffer for delayed source data buffering + boost::circular_buffer srcBuffer_; + // Number of delay slots + size_t propagationDelay_; + + // link must be initialized before it can compute() + bool initialized_; +}; + +} // namespace nupic #endif // NTA_LINK_HPP diff --git a/src/nupic/engine/LinkPolicy.hpp b/src/nupic/engine/LinkPolicy.hpp index a6f710b4c1..fe2e7ffd14 100644 --- a/src/nupic/engine/LinkPolicy.hpp +++ b/src/nupic/engine/LinkPolicy.hpp @@ -27,43 +27,38 @@ #ifndef NTA_LINKPOLICY_HPP #define NTA_LINKPOLICY_HPP -#include #include // SplitterMap definition +#include // LinkPolicy is an interface class subclassed by all link policies -namespace nupic -{ - - - class Dimensions; - - class LinkPolicy - { - // Subclasses implement this constructor: - // LinkPolicy(const std::string params, const Dimensions& srcDimensions, - // const Dimensions& destDimensions); +namespace nupic { - public: - virtual ~LinkPolicy() {}; - virtual void setSrcDimensions(Dimensions& dims) = 0; - virtual void setDestDimensions(Dimensions& dims) = 0; - virtual const Dimensions& getSrcDimensions() const = 0; - virtual const Dimensions& getDestDimensions() const = 0; - // initialization is probably unnecessary, but it lets - // us do a sanity check before generating the splitter map. - virtual void initialize() = 0; - virtual bool isInitialized() const = 0; - virtual void setNodeOutputElementCount(size_t elementCount) = 0; +class Dimensions; +class LinkPolicy { + // Subclasses implement this constructor: + // LinkPolicy(const std::string params, const Dimensions& srcDimensions, + // const Dimensions& destDimensions); - // A "protoSplitterMap" specifies which source output nodes send - // data to which dest input nodes. - // if protoSplitter[destNode][x] == srcNode for some x, then - // srcNode sends its output to destNode. - // - virtual void buildProtoSplitterMap(Input::SplitterMap& splitter) const = 0; - }; +public: + virtual ~LinkPolicy(){}; + virtual void setSrcDimensions(Dimensions &dims) = 0; + virtual void setDestDimensions(Dimensions &dims) = 0; + virtual const Dimensions &getSrcDimensions() const = 0; + virtual const Dimensions &getDestDimensions() const = 0; + // initialization is probably unnecessary, but it lets + // us do a sanity check before generating the splitter map. + virtual void initialize() = 0; + virtual bool isInitialized() const = 0; + virtual void setNodeOutputElementCount(size_t elementCount) = 0; + // A "protoSplitterMap" specifies which source output nodes send + // data to which dest input nodes. + // if protoSplitter[destNode][x] == srcNode for some x, then + // srcNode sends its output to destNode. + // + virtual void buildProtoSplitterMap(Input::SplitterMap &splitter) const = 0; +}; } // namespace nupic diff --git a/src/nupic/engine/LinkPolicyFactory.cpp b/src/nupic/engine/LinkPolicyFactory.cpp index 319dae9613..e0ffb95097 100644 --- a/src/nupic/engine/LinkPolicyFactory.cpp +++ b/src/nupic/engine/LinkPolicyFactory.cpp @@ -20,30 +20,23 @@ * --------------------------------------------------------------------- */ - #include #include #include #include #include -namespace nupic -{ - +namespace nupic { -LinkPolicy* LinkPolicyFactory::createLinkPolicy(const std::string policyType, - const std::string policyParams, - Link* link) -{ +LinkPolicy *LinkPolicyFactory::createLinkPolicy(const std::string policyType, + const std::string policyParams, + Link *link) { LinkPolicy *lp = nullptr; - if (policyType == "TestFanIn2") - { + if (policyType == "TestFanIn2") { lp = new TestFanIn2LinkPolicy(policyParams, link); - } else if (policyType == "UniformLink") - { + } else if (policyType == "UniformLink") { lp = new UniformLinkPolicy(policyParams, link); - } else if (policyType == "UnitTestLink") - { + } else if (policyType == "UnitTestLink") { // When unit testing a link policy, a valid Link* is required to be passed // to the link policy's constructor. If you pass NULL, other portions of // NuPIC may try to dereference it (e.g. operator<< from NTA_THROW). So we @@ -54,11 +47,9 @@ LinkPolicy* LinkPolicyFactory::createLinkPolicy(const std::string policyType, // // and pass this dummy link to the constructor of the real link policy // you wish to unit test. - } else if (policyType == "TestSplit") - { + } else if (policyType == "TestSplit") { NTA_THROW << "TestSplit not implemented yet"; - } else if (policyType == "TestOneToOne") - { + } else if (policyType == "TestOneToOne") { NTA_THROW << "TestOneToOne not implemented yet"; } else { NTA_THROW << "Unknown link policy '" << policyType << "'"; @@ -66,7 +57,4 @@ LinkPolicy* LinkPolicyFactory::createLinkPolicy(const std::string policyType, return lp; } - - -} - +} // namespace nupic diff --git a/src/nupic/engine/LinkPolicyFactory.hpp b/src/nupic/engine/LinkPolicyFactory.hpp index c60ac0c71f..d9cbb8b694 100644 --- a/src/nupic/engine/LinkPolicyFactory.hpp +++ b/src/nupic/engine/LinkPolicyFactory.hpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Definition of the LinkPolicyFactory API */ @@ -29,32 +29,25 @@ #include -namespace nupic -{ +namespace nupic { - class LinkPolicy; - class Link; - class Region; +class LinkPolicy; +class Link; +class Region; - class LinkPolicyFactory - { - public: +class LinkPolicyFactory { +public: + // LinkPolicyFactory is a lightweight object + LinkPolicyFactory(){}; + ~LinkPolicyFactory(){}; + // Create a LinkPolicy of a specific type; caller gets ownership. + LinkPolicy *createLinkPolicy(const std::string policyType, + const std::string policyParams, Link *link); - // LinkPolicyFactory is a lightweight object - LinkPolicyFactory() {}; - ~LinkPolicyFactory() {}; - - // Create a LinkPolicy of a specific type; caller gets ownership. - LinkPolicy* createLinkPolicy(const std::string policyType, - const std::string policyParams, - Link* link); - - private: - - }; +private: +}; } // namespace nupic - #endif // NTA_LINKPOLICY_FACTORY_HPP diff --git a/src/nupic/engine/Network.cpp b/src/nupic/engine/Network.cpp index 36e6babef7..08b79cf1ee 100644 --- a/src/nupic/engine/Network.cpp +++ b/src/nupic/engine/Network.cpp @@ -24,48 +24,43 @@ Implementation of the Network class */ -#include #include +#include #include #include +#include +#include #include +#include // for register/unregister #include #include -#include -#include +#include +#include +#include +#include #include #include #include #include -#include // for register/unregister -#include -#include -#include -#include #include - -namespace nupic -{ +namespace nupic { class GenericRegisteredRegionImpl; -Network::Network() -{ +Network::Network() { commonInit(); NuPIC::registerNetwork(this); } -Network::Network(const std::string& path) -{ +Network::Network(const std::string &path) { commonInit(); load(path); NuPIC::registerNetwork(this); } -void Network::commonInit() -{ +void Network::commonInit() { initialized_ = false; iteration_ = 0; minEnabledPhase_ = 0; @@ -75,46 +70,38 @@ void Network::commonInit() NuPIC::init(); } - - -Network::~Network() -{ +Network::~Network() { NuPIC::unregisterNetwork(this); /** * Teardown choreography: - * - unitialize all regions because otherwise we won't be able to disconnect them + * - unitialize all regions because otherwise we won't be able to disconnect + * them * - remove all links, because we can't delete connected regions * - delete the regions themselves. */ // 1. uninitialize - for (size_t i = 0; i < regions_.getCount(); i++) - { + for (size_t i = 0; i < regions_.getCount(); i++) { Region *r = regions_.getByIndex(i).second; r->uninitialize(); } // 2. remove all links - for (size_t i = 0; i < regions_.getCount(); i++) - { + for (size_t i = 0; i < regions_.getCount(); i++) { Region *r = regions_.getByIndex(i).second; r->removeAllIncomingLinks(); } // 3. delete the regions - for (size_t i = 0; i < regions_.getCount(); i++) - { - std::pair& item = regions_.getByIndex(i); + for (size_t i = 0; i < regions_.getCount(); i++) { + std::pair &item = regions_.getByIndex(i); delete item.second; item.second = nullptr; } } - -Region* Network::addRegion(const std::string& name, - const std::string& nodeType, - const std::string& nodeParams) -{ +Region *Network::addRegion(const std::string &name, const std::string &nodeType, + const std::string &nodeParams) { if (regions_.contains(name)) NTA_THROW << "Region with name '" << name << "' already exists in network"; @@ -126,31 +113,27 @@ Region* Network::addRegion(const std::string& name, return r; } -void Network::setDefaultPhase_(Region* region) -{ +void Network::setDefaultPhase_(Region *region) { UInt32 newphase = phaseInfo_.size(); std::set phases; phases.insert(newphase); setPhases_(region, phases); } - - - -Region* Network::addRegionFromBundle(const std::string& name, - const std::string& nodeType, - const Dimensions& dimensions, - const std::string& bundlePath, - const std::string& label) -{ +Region *Network::addRegionFromBundle(const std::string &name, + const std::string &nodeType, + const Dimensions &dimensions, + const std::string &bundlePath, + const std::string &label) { if (regions_.contains(name)) - NTA_THROW << "Invalid saved network: two or more instance of region '" << name << "'"; + NTA_THROW << "Invalid saved network: two or more instance of region '" + << name << "'"; - if (! Path::exists(bundlePath)) + if (!Path::exists(bundlePath)) NTA_THROW << "addRegionFromBundle -- bundle '" << bundlePath << " does not exist"; - BundleIO bundle(bundlePath, label, name, /* isInput: */ true ); + BundleIO bundle(bundlePath, label, name, /* isInput: */ true); auto r = new Region(name, nodeType, dimensions, bundle, this); regions_.add(name, r); initialized_ = false; @@ -163,12 +146,9 @@ Region* Network::addRegionFromBundle(const std::string& name, return r; } - -Region* Network::addRegionFromProto(const std::string& name, - RegionProto::Reader& proto) -{ - if (regions_.contains(name)) - { +Region *Network::addRegionFromProto(const std::string &name, + RegionProto::Reader &proto) { + if (regions_.contains(name)) { NTA_THROW << "Cannot add region with name '" << name << "' that is already in used."; } @@ -185,17 +165,13 @@ Region* Network::addRegionFromProto(const std::string& name, return region; } - -void -Network::setPhases_(Region *r, std::set& phases) -{ +void Network::setPhases_(Region *r, std::set &phases) { if (phases.empty()) NTA_THROW << "Attempt to set empty phase list for region " << r->getName(); UInt32 maxNewPhase = *(phases.rbegin()); UInt32 nextPhase = phaseInfo_.size(); - if (maxNewPhase >= nextPhase) - { + if (maxNewPhase >= nextPhase) { // It is very unlikely that someone would add a region // with a phase much greater than the phase of any other // region. This sanity check catches such problems, @@ -205,82 +181,68 @@ Network::setPhases_(Region *r, std::set& phases) << " when expected next phase is " << nextPhase << " -- this is probably an error."; - phaseInfo_.resize(maxNewPhase+1); + phaseInfo_.resize(maxNewPhase + 1); } - for (UInt i = 0; i < phaseInfo_.size(); i++) - { + for (UInt i = 0; i < phaseInfo_.size(); i++) { bool insertPhase = false; if (phases.find(i) != phases.end()) insertPhase = true; // remove previous settings for this region - std::set::iterator item; + std::set::iterator item; item = phaseInfo_[i].find(r); - if (item != phaseInfo_[i].end() && !insertPhase) - { + if (item != phaseInfo_[i].end() && !insertPhase) { phaseInfo_[i].erase(item); - } else if (insertPhase) - { + } else if (insertPhase) { phaseInfo_[i].insert(r); } } - // keep track (redundantly) of phases inside the Region also, for serialization + // keep track (redundantly) of phases inside the Region also, for + // serialization r->setPhases(phases); resetEnabledPhases_(); - } -void -Network::resetEnabledPhases_() -{ +void Network::resetEnabledPhases_() { // min/max enabled phases based on what is in the network minEnabledPhase_ = getMinPhase(); maxEnabledPhase_ = getMaxPhase(); } - -void -Network::setPhases(const std::string& name, std::set& phases) -{ - if (! regions_.contains(name)) +void Network::setPhases(const std::string &name, std::set &phases) { + if (!regions_.contains(name)) NTA_THROW << "setPhases -- no region exists with name '" << name << "'"; Region *r = regions_.getByName(name); setPhases_(r, phases); } -std::set -Network::getPhases(const std::string& name) const -{ - if (! regions_.contains(name)) +std::set Network::getPhases(const std::string &name) const { + if (!regions_.contains(name)) NTA_THROW << "setPhases -- no region exists with name '" << name << "'"; Region *r = regions_.getByName(name); std::set phases; // construct the set of phases enabled for this region - for (UInt32 i = 0; i < phaseInfo_.size(); i++) - { - if (phaseInfo_[i].find(r) != phaseInfo_[i].end()) - { + for (UInt32 i = 0; i < phaseInfo_.size(); i++) { + if (phaseInfo_[i].find(r) != phaseInfo_[i].end()) { phases.insert(i); } } return phases; } - -void -Network::removeRegion(const std::string& name) -{ - if (! regions_.contains(name)) +void Network::removeRegion(const std::string &name) { + if (!regions_.contains(name)) NTA_THROW << "removeRegion: no region named '" << name << "'"; Region *r = regions_.getByName(name); if (r->hasOutgoingLinks()) - NTA_THROW << "Unable to remove region '" << name << "' because it has one or more outgoing links"; + NTA_THROW << "Unable to remove region '" << name + << "' because it has one or more outgoing links"; // Network does not have to be uninitialized -- removing a region // has no effect on the network as long as it has no outgoing links, @@ -292,16 +254,14 @@ Network::removeRegion(const std::string& name) regions_.remove(name); auto phase = phaseInfo_.begin(); - for (; phase != phaseInfo_.end(); phase++) - { + for (; phase != phaseInfo_.end(); phase++) { auto toremove = phase->find(r); if (toremove != phase->end()) phase->erase(toremove); } // Trim phaseinfo as we may have no more regions at the highest phase(s) - for (size_t i = phaseInfo_.size() - 1; i > 0; i--) - { + for (size_t i = phaseInfo_.size() - 1; i > 0; i--) { if (phaseInfo_[i].empty()) phaseInfo_.resize(i); else @@ -315,35 +275,35 @@ Network::removeRegion(const std::string& name) return; } - -void -Network::link(const std::string& srcRegionName, const std::string& destRegionName, - const std::string& linkType, const std::string& linkParams, - const std::string& srcOutputName, const std::string& destInputName, - const size_t propagationDelay) -{ +void Network::link(const std::string &srcRegionName, + const std::string &destRegionName, + const std::string &linkType, const std::string &linkParams, + const std::string &srcOutputName, + const std::string &destInputName, + const size_t propagationDelay) { // Find the regions - if (! regions_.contains(srcRegionName)) - NTA_THROW << "Network::link -- source region '" << srcRegionName << "' does not exist"; - Region* srcRegion = regions_.getByName(srcRegionName); + if (!regions_.contains(srcRegionName)) + NTA_THROW << "Network::link -- source region '" << srcRegionName + << "' does not exist"; + Region *srcRegion = regions_.getByName(srcRegionName); - if (! regions_.contains(destRegionName)) - NTA_THROW << "Network::link -- dest region '" << destRegionName << "' does not exist"; - Region* destRegion = regions_.getByName(destRegionName); + if (!regions_.contains(destRegionName)) + NTA_THROW << "Network::link -- dest region '" << destRegionName + << "' does not exist"; + Region *destRegion = regions_.getByName(destRegionName); // Find the inputs/outputs - const Spec* srcSpec = srcRegion->getSpec(); + const Spec *srcSpec = srcRegion->getSpec(); std::string outputName = srcOutputName; if (outputName == "") outputName = srcSpec->getDefaultOutputName(); - Output* srcOutput = srcRegion->getOutput(outputName); + Output *srcOutput = srcRegion->getOutput(outputName); if (srcOutput == nullptr) NTA_THROW << "Network::link -- output " << outputName << " does not exist on region " << srcRegionName; - const Spec *destSpec = destRegion->getSpec(); std::string inputName; if (destInputName == "") @@ -351,33 +311,32 @@ Network::link(const std::string& srcRegionName, const std::string& destRegionNam else inputName = destInputName; - Input* destInput = destRegion->getInput(inputName); - if (destInput == nullptr) - { + Input *destInput = destRegion->getInput(inputName); + if (destInput == nullptr) { NTA_THROW << "Network::link -- input '" << inputName << " does not exist on region " << destRegionName; } // Create the link itself - auto link = new Link(linkType, linkParams, srcOutput, destInput, - propagationDelay); + auto link = + new Link(linkType, linkParams, srcOutput, destInput, propagationDelay); destInput->addLink(link, srcOutput); - } - -void -Network::removeLink(const std::string& srcRegionName, const std::string& destRegionName, - const std::string& srcOutputName, const std::string& destInputName) -{ +void Network::removeLink(const std::string &srcRegionName, + const std::string &destRegionName, + const std::string &srcOutputName, + const std::string &destInputName) { // Find the regions - if (! regions_.contains(srcRegionName)) - NTA_THROW << "Network::unlink -- source region '" << srcRegionName << "' does not exist"; - Region* srcRegion = regions_.getByName(srcRegionName); + if (!regions_.contains(srcRegionName)) + NTA_THROW << "Network::unlink -- source region '" << srcRegionName + << "' does not exist"; + Region *srcRegion = regions_.getByName(srcRegionName); - if (! regions_.contains(destRegionName)) - NTA_THROW << "Network::unlink -- dest region '" << destRegionName << "' does not exist"; - Region* destRegion = regions_.getByName(destRegionName); + if (!regions_.contains(destRegionName)) + NTA_THROW << "Network::unlink -- dest region '" << destRegionName + << "' does not exist"; + Region *destRegion = regions_.getByName(destRegionName); // Find the inputs const Spec *srcSpec = srcRegion->getSpec(); @@ -388,9 +347,8 @@ Network::removeLink(const std::string& srcRegionName, const std::string& destReg else inputName = destInputName; - Input* destInput = destRegion->getInput(inputName); - if (destInput == nullptr) - { + Input *destInput = destRegion->getInput(inputName); + if (destInput == nullptr) { NTA_THROW << "Network::unlink -- input '" << inputName << " does not exist on region " << destRegionName; } @@ -398,62 +356,52 @@ Network::removeLink(const std::string& srcRegionName, const std::string& destReg std::string outputName = srcOutputName; if (outputName == "") outputName = srcSpec->getDefaultOutputName(); - Link* link = destInput->findLink(srcRegionName, outputName); + Link *link = destInput->findLink(srcRegionName, outputName); if (link == nullptr) - NTA_THROW << "Network::unlink -- no link exists from region " << srcRegionName - << " output " << outputName << " to region " << destRegionName - << " input " << destInput->getName(); + NTA_THROW << "Network::unlink -- no link exists from region " + << srcRegionName << " output " << outputName << " to region " + << destRegionName << " input " << destInput->getName(); // Finally, remove the link destInput->removeLink(link); - } -void -Network::run(int n) -{ - if (!initialized_) - { +void Network::run(int n) { + if (!initialized_) { initialize(); } if (phaseInfo_.empty()) return; - NTA_CHECK(maxEnabledPhase_ < phaseInfo_.size()) << "maxphase: " << maxEnabledPhase_ << " size: " << phaseInfo_.size(); + NTA_CHECK(maxEnabledPhase_ < phaseInfo_.size()) + << "maxphase: " << maxEnabledPhase_ << " size: " << phaseInfo_.size(); - for(int iter = 0; iter < n; iter++) - { + for (int iter = 0; iter < n; iter++) { iteration_++; // compute on all enabled regions in phase order - for (UInt32 phase = minEnabledPhase_; phase <= maxEnabledPhase_; phase++) - { - for (auto r : phaseInfo_[phase]) - { + for (UInt32 phase = minEnabledPhase_; phase <= maxEnabledPhase_; phase++) { + for (auto r : phaseInfo_[phase]) { r->prepareInputs(); r->compute(); } } // invoke callbacks - for (UInt32 i = 0; i < callbacks_.getCount(); i++) - { - std::pair& callback = callbacks_.getByIndex(i); + for (UInt32 i = 0; i < callbacks_.getCount(); i++) { + std::pair &callback = callbacks_.getByIndex(i); callback.second.first(this, iteration_, callback.second.second); } // Refresh all links in the network at the end of every timestamp so that // data in delayed links appears to change atomically between iterations - for (size_t i = 0; i < regions_.getCount(); i++) - { + for (size_t i = 0; i < regions_.getCount(); i++) { const Region *r = regions_.getByIndex(i).second; - for (const auto & inputTuple: r->getInputs()) - { - for (const auto pLink: inputTuple.second->getLinks()) - { + for (const auto &inputTuple : r->getInputs()) { + for (const auto pLink : inputTuple.second->getLinks()) { pLink->shiftBufferedData(); } } @@ -464,10 +412,7 @@ Network::run(int n) return; } - -void -Network::initialize() -{ +void Network::initialize() { /* * Do not reinitialize if already initialized. @@ -485,7 +430,6 @@ Network::initialize() * region dimensions. */ - // Iterate until all regions have finished // evaluating their links. If network is // incompletely specified, we'll never finish, @@ -495,14 +439,12 @@ Network::initialize() size_t nLinksRemainingPrev = std::numeric_limits::max(); size_t nLinksRemaining = nLinksRemainingPrev - 1; - std::vector::iterator r; - while(nLinksRemaining > 0 && nLinksRemainingPrev > nLinksRemaining) - { + std::vector::iterator r; + while (nLinksRemaining > 0 && nLinksRemainingPrev > nLinksRemaining) { nLinksRemainingPrev = nLinksRemaining; nLinksRemaining = 0; - for (size_t i = 0; i < regions_.getCount(); i++) - { + for (size_t i = 0; i < regions_.getCount(); i++) { // evaluateLinks returns the number // of links which still need to be // evaluated. @@ -511,15 +453,13 @@ Network::initialize() } } - if (nLinksRemaining > 0) - { + if (nLinksRemaining > 0) { // Try to give complete information to the user std::stringstream ss; ss << "Network::initialize() -- unable to evaluate all links\n" << "The following links could not be evaluated:\n"; - for (size_t i = 0; i < regions_.getCount(); i++) - { - Region*r = regions_.getByIndex(i).second; + for (size_t i = 0; i < regions_.getCount(); i++) { + Region *r = regions_.getByIndex(i).second; std::string errors = r->getLinkErrors(); if (errors.size() == 0) continue; @@ -528,34 +468,28 @@ Network::initialize() NTA_THROW << ss.str(); } - // Make sure all regions now have dimensions - for (size_t i = 0; i < regions_.getCount(); i++) - { - Region* r = regions_.getByIndex(i).second; - const Dimensions& d = r->getDimensions(); - if (d.isUnspecified()) - { + for (size_t i = 0; i < regions_.getCount(); i++) { + Region *r = regions_.getByIndex(i).second; + const Dimensions &d = r->getDimensions(); + if (d.isUnspecified()) { NTA_THROW << "Network::initialize() -- unable to complete initialization " << "because region '" << r->getName() << "' has unspecified " << "dimensions. You must either specify dimensions directly or " - << "link to the region in a way that induces dimensions on the region."; + << "link to the region in a way that induces dimensions on the " + "region."; } - if (!d.isValid()) - { - NTA_THROW << "Network::initialize() -- invalid dimensions " << d.toString() - << " for Region " << r->getName(); + if (!d.isValid()) { + NTA_THROW << "Network::initialize() -- invalid dimensions " + << d.toString() << " for Region " << r->getName(); } - } - /* * 2. initialize outputs: * - . Delegated to regions */ - for (size_t i = 0; i < regions_.getCount(); i++) - { + for (size_t i = 0; i < regions_.getCount(); i++) { Region *r = regions_.getByIndex(i).second; r->initOutputs(); } @@ -564,8 +498,7 @@ Network::initialize() * 3. initialize inputs * - Delegated to regions */ - for (size_t i = 0; i < regions_.getCount(); i++) - { + for (size_t i = 0; i < regions_.getCount(); i++) { Region *r = regions_.getByIndex(i).second; r->initInputs(); } @@ -573,8 +506,7 @@ Network::initialize() /* * 4. initialize region/impl */ - for (size_t i = 0; i < regions_.getCount(); i++) - { + for (size_t i = 0; i < regions_.getCount(); i++) { Region *r = regions_.getByIndex(i).second; r->initialize(); } @@ -584,35 +516,21 @@ Network::initialize() */ resetEnabledPhases_(); - /* * Mark network as initialized. */ initialized_ = true; - -} - - -const Collection& -Network::getRegions() const -{ - return regions_; } +const Collection &Network::getRegions() const { return regions_; } -Collection -Network::getLinks() -{ +Collection Network::getLinks() { Collection links; - for (UInt32 phase = minEnabledPhase_; phase <= maxEnabledPhase_; phase++) - { - for (auto r : phaseInfo_[phase]) - { - for (auto & input : r->getInputs()) - { - for (auto & link: input.second->getLinks()) - { + for (UInt32 phase = minEnabledPhase_; phase <= maxEnabledPhase_; phase++) { + for (auto r : phaseInfo_[phase]) { + for (auto &input : r->getInputs()) { + for (auto &link : input.second->getLinks()) { links.add(link->toString(), link); } } @@ -622,28 +540,20 @@ Network::getLinks() return links; } -Collection& Network::getCallbacks() -{ +Collection &Network::getCallbacks() { return callbacks_; } - -UInt32 -Network::getMinPhase() const -{ +UInt32 Network::getMinPhase() const { UInt32 i = 0; - for (; i < phaseInfo_.size(); i++) - { + for (; i < phaseInfo_.size(); i++) { if (!phaseInfo_[i].empty()) break; } return i; } - -UInt32 -Network::getMaxPhase() const -{ +UInt32 Network::getMaxPhase() const { /* * phaseInfo_ is always trimmed, so the max phase is * phaseInfo_.size()-1 @@ -655,10 +565,7 @@ Network::getMaxPhase() const return phaseInfo_.size() - 1; } - -void -Network::setMinEnabledPhase(UInt32 minPhase) -{ +void Network::setMinEnabledPhase(UInt32 minPhase) { if (minPhase >= phaseInfo_.size()) NTA_THROW << "Attempt to set min enabled phase " << minPhase << " which is larger than the highest phase in the network - " @@ -666,9 +573,7 @@ Network::setMinEnabledPhase(UInt32 minPhase) minEnabledPhase_ = minPhase; } -void -Network::setMaxEnabledPhase(UInt32 maxPhase) -{ +void Network::setMaxEnabledPhase(UInt32 maxPhase) { if (maxPhase >= phaseInfo_.size()) NTA_THROW << "Attempt to set max enabled phase " << maxPhase << " which is larger than the highest phase in the network - " @@ -676,28 +581,15 @@ Network::setMaxEnabledPhase(UInt32 maxPhase) maxEnabledPhase_ = maxPhase; } -UInt32 -Network::getMinEnabledPhase() const -{ - return minEnabledPhase_; -} - - -UInt32 -Network::getMaxEnabledPhase() const -{ - return maxEnabledPhase_; -} +UInt32 Network::getMinEnabledPhase() const { return minEnabledPhase_; } +UInt32 Network::getMaxEnabledPhase() const { return maxEnabledPhase_; } -void Network::save(const std::string& name) -{ +void Network::save(const std::string &name) { - if (StringUtils::endsWith(name, ".tgz")) - { + if (StringUtils::endsWith(name, ".tgz")) { NTA_THROW << "Gzipped tar archives (" << name << ") not yet supported"; - } else if (StringUtils::endsWith(name, ".nta")) - { + } else if (StringUtils::endsWith(name, ".nta")) { saveToBundle(name); } else { NTA_THROW << "Network::save -- unknown file extension for '" << name @@ -709,26 +601,22 @@ void Network::save(const std::string& name) // This name may not be usable as part of a filesystem path, so // bundle files associated with a region use the region "label" // that can always be stored in the filesystem -static std::string getLabel(size_t index) -{ +static std::string getLabel(size_t index) { return std::string("R") + StringUtils::fromInt(index); } // save does the real work with saveToBundle -void Network::saveToBundle(const std::string& name) -{ - if (! StringUtils::endsWith(name, ".nta")) +void Network::saveToBundle(const std::string &name) { + if (!StringUtils::endsWith(name, ".nta")) NTA_THROW << "saveToBundle: bundle extension must be \".nta\""; std::string fullPath = Path::normalize(Path::makeAbsolute(name)); std::string networkStructureFilename = Path::join(fullPath, "network.yaml"); - // Only overwrite an existing path if it appears to be a network bundle - if (Path::exists(fullPath)) - { - if (! Path::isDirectory(fullPath) || ! Path::exists(networkStructureFilename)) - { + if (Path::exists(fullPath)) { + if (!Path::isDirectory(fullPath) || + !Path::exists(networkStructureFilename)) { NTA_THROW << "Existing filesystem entry " << fullPath << " is not a network bundle -- refusing to delete"; } @@ -743,12 +631,13 @@ void Network::saveToBundle(const std::string& name) out << YAML::BeginMap; out << YAML::Key << "Version" << YAML::Value << 2; out << YAML::Key << "Regions" << YAML::Value << YAML::BeginSeq; - for (size_t regionIndex = 0; regionIndex < regions_.getCount(); regionIndex++) - { - std::pair& info = regions_.getByIndex(regionIndex); + for (size_t regionIndex = 0; regionIndex < regions_.getCount(); + regionIndex++) { + std::pair &info = regions_.getByIndex(regionIndex); Region *r = info.second; // Network serializes the region directly because it is actually easier - // to do here than inside the region, and we don't have the RegionImpl data yet. + // to do here than inside the region, and we don't have the RegionImpl + // data yet. out << YAML::BeginMap; out << YAML::Key << "name" << YAML::Value << info.first; out << YAML::Key << "nodeType" << YAML::Value << r->getType(); @@ -758,8 +647,7 @@ void Network::saveToBundle(const std::string& name) // implement as a sequence by hand. out << YAML::Key << "phases" << YAML::Value << YAML::BeginSeq; std::set phases = r->getPhases(); - for (const auto & phases_phase : phases) - { + for (const auto &phases_phase : phases) { out << phases_phase; } out << YAML::EndSeq; @@ -772,26 +660,27 @@ void Network::saveToBundle(const std::string& name) out << YAML::Key << "Links" << YAML::Value << YAML::BeginSeq; - for (size_t regionIndex = 0; regionIndex < regions_.getCount(); regionIndex++) - { + for (size_t regionIndex = 0; regionIndex < regions_.getCount(); + regionIndex++) { Region *r = regions_.getByIndex(regionIndex).second; - const std::map inputs = r->getInputs(); - for (const auto & inputs_input : inputs) - { - const std::vector& links = inputs_input.second->getLinks(); - for (const auto & links_link : links) - { - Link& l = *(links_link); + const std::map inputs = r->getInputs(); + for (const auto &inputs_input : inputs) { + const std::vector &links = inputs_input.second->getLinks(); + for (const auto &links_link : links) { + Link &l = *(links_link); out << YAML::BeginMap; out << YAML::Key << "type" << YAML::Value << l.getLinkType(); out << YAML::Key << "params" << YAML::Value << l.getLinkParams(); - out << YAML::Key << "srcRegion" << YAML::Value << l.getSrcRegionName(); - out << YAML::Key << "srcOutput" << YAML::Value << l.getSrcOutputName(); - out << YAML::Key << "destRegion" << YAML::Value << l.getDestRegionName(); - out << YAML::Key << "destInput" << YAML::Value << l.getDestInputName(); + out << YAML::Key << "srcRegion" << YAML::Value + << l.getSrcRegionName(); + out << YAML::Key << "srcOutput" << YAML::Value + << l.getSrcOutputName(); + out << YAML::Key << "destRegion" << YAML::Value + << l.getDestRegionName(); + out << YAML::Key << "destInput" << YAML::Value + << l.getDestInputName(); out << YAML::EndMap; } - } } out << YAML::EndSeq; // end of links @@ -805,9 +694,9 @@ void Network::saveToBundle(const std::string& name) } // Now save RegionImpl data - for (size_t regionIndex = 0; regionIndex < regions_.getCount(); regionIndex++) - { - std::pair& info = regions_.getByIndex(regionIndex); + for (size_t regionIndex = 0; regionIndex < regions_.getCount(); + regionIndex++) { + std::pair &info = regions_.getByIndex(regionIndex); Region *r = info.second; std::string label = getLabel(regionIndex); BundleIO bundle(fullPath, label, info.first, /* isInput: */ false); @@ -815,29 +704,24 @@ void Network::saveToBundle(const std::string& name) } } -void Network::load(const std::string& path) -{ - if (StringUtils::endsWith(path, ".tgz")) - { +void Network::load(const std::string &path) { + if (StringUtils::endsWith(path, ".tgz")) { NTA_THROW << "Gzipped tar archives (" << path << ") not yet supported"; - } else if (StringUtils::endsWith(path, ".nta")) - { + } else if (StringUtils::endsWith(path, ".nta")) { loadFromBundle(path); } else { NTA_THROW << "Network::save -- unknown file extension for '" << path << "'. Supported extensions are .tgz and .nta"; } - } -void Network::loadFromBundle(const std::string& name) -{ - if (! StringUtils::endsWith(name, ".nta")) +void Network::loadFromBundle(const std::string &name) { + if (!StringUtils::endsWith(name, ".nta")) NTA_THROW << "loadFromBundle: bundle extension must be \".nta\""; std::string fullPath = Path::normalize(Path::makeAbsolute(name)); - if (! Path::exists(fullPath)) + if (!Path::exists(fullPath)) NTA_THROW << "Path " << fullPath << " does not exist"; std::string networkStructureFilename = Path::join(fullPath, "network.yaml"); @@ -854,8 +738,8 @@ void Network::loadFromBundle(const std::string& name) // Should contain Version, Regions, Links if (doc.size() != 3) - NTA_THROW << "Invalid network structure file -- contains " - << doc.size() << " elements"; + NTA_THROW << "Invalid network structure file -- contains " << doc.size() + << " elements"; // Extra version const YAML::Node *node = doc.FindValue("Version"); @@ -873,10 +757,11 @@ void Network::loadFromBundle(const std::string& name) NTA_THROW << "Invalid network structure file -- no regions"; if (regions->Type() != YAML::NodeType::Sequence) - NTA_THROW << "Invalid network structure file -- regions element is not a list"; + NTA_THROW + << "Invalid network structure file -- regions element is not a list"; - for (YAML::Iterator region = regions->begin(); region != regions->end(); region++) - { + for (YAML::Iterator region = regions->begin(); region != regions->end(); + region++) { // Each region is a map -- extract the 5 values in the map if ((*region).Type() != YAML::NodeType::Map) NTA_THROW << "Invalid network structure file -- bad region (not a map)"; @@ -894,22 +779,22 @@ void Network::loadFromBundle(const std::string& name) // 2. nodeType node = (*region).FindValue("nodeType"); if (node == nullptr) - NTA_THROW << "Invalid network structure file -- region " - << name << " has no node type"; + NTA_THROW << "Invalid network structure file -- region " << name + << " has no node type"; std::string nodeType; *node >> nodeType; // 3. dimensions node = (*region).FindValue("dimensions"); if (node == nullptr) - NTA_THROW << "Invalid network structure file -- region " - << name << " has no dimensions"; + NTA_THROW << "Invalid network structure file -- region " << name + << " has no dimensions"; if ((*node).Type() != YAML::NodeType::Sequence) - NTA_THROW << "Invalid network structure file -- region " - << name << " dimensions specified incorrectly"; + NTA_THROW << "Invalid network structure file -- region " << name + << " dimensions specified incorrectly"; Dimensions dimensions; - for (YAML::Iterator valiter = (*node).begin(); valiter != (*node).end(); valiter++) - { + for (YAML::Iterator valiter = (*node).begin(); valiter != (*node).end(); + valiter++) { size_t val; (*valiter) >> val; dimensions.push_back(val); @@ -918,15 +803,15 @@ void Network::loadFromBundle(const std::string& name) // 4. phases node = (*region).FindValue("phases"); if (node == nullptr) - NTA_THROW << "Invalid network structure file -- region" - << name << "has no phases"; + NTA_THROW << "Invalid network structure file -- region" << name + << "has no phases"; if ((*node).Type() != YAML::NodeType::Sequence) - NTA_THROW << "Invalid network structure file -- region " - << name << " phases specified incorrectly"; + NTA_THROW << "Invalid network structure file -- region " << name + << " phases specified incorrectly"; std::set phases; - for (YAML::Iterator valiter = (*node).begin(); valiter != (*node).end(); valiter++) - { + for (YAML::Iterator valiter = (*node).begin(); valiter != (*node).end(); + valiter++) { UInt32 val; (*valiter) >> val; phases.insert(val); @@ -935,15 +820,14 @@ void Network::loadFromBundle(const std::string& name) // 5. label node = (*region).FindValue("label"); if (node == nullptr) - NTA_THROW << "Invalid network structure file -- region" - << name << "has no label"; + NTA_THROW << "Invalid network structure file -- region" << name + << "has no label"; std::string label; *node >> label; - Region *r = addRegionFromBundle(name, nodeType, dimensions, fullPath, label); + Region *r = + addRegionFromBundle(name, nodeType, dimensions, fullPath, label); setPhases_(r, phases); - - } const YAML::Node *links = doc.FindValue("Links"); @@ -951,10 +835,10 @@ void Network::loadFromBundle(const std::string& name) NTA_THROW << "Invalid network structure file -- no links"; if (links->Type() != YAML::NodeType::Sequence) - NTA_THROW << "Invalid network structure file -- links element is not a list"; + NTA_THROW + << "Invalid network structure file -- links element is not a list"; - for (YAML::Iterator link = links->begin(); link != links->end(); link++) - { + for (YAML::Iterator link = links->begin(); link != links->end(); link++) { // Each link is a map -- extract the 5 values in the map if ((*link).Type() != YAML::NodeType::Map) NTA_THROW << "Invalid network structure file -- bad link (not a map)"; @@ -965,78 +849,88 @@ void Network::loadFromBundle(const std::string& name) // 1. type node = (*link).FindValue("type"); if (node == nullptr) - NTA_THROW << "Invalid network structure file -- link does not have a type"; + NTA_THROW + << "Invalid network structure file -- link does not have a type"; std::string linkType; *node >> linkType; // 2. params node = (*link).FindValue("params"); if (node == nullptr) - NTA_THROW << "Invalid network structure file -- link does not have params"; + NTA_THROW + << "Invalid network structure file -- link does not have params"; std::string params; *node >> params; // 3. srcRegion (name) node = (*link).FindValue("srcRegion"); if (node == nullptr) - NTA_THROW << "Invalid network structure file -- link does not have a srcRegion"; + NTA_THROW + << "Invalid network structure file -- link does not have a srcRegion"; std::string srcRegionName; *node >> srcRegionName; - // 4. srcOutput node = (*link).FindValue("srcOutput"); if (node == nullptr) - NTA_THROW << "Invalid network structure file -- link does not have a srcOutput"; + NTA_THROW + << "Invalid network structure file -- link does not have a srcOutput"; std::string srcOutputName; *node >> srcOutputName; // 5. destRegion node = (*link).FindValue("destRegion"); if (node == nullptr) - NTA_THROW << "Invalid network structure file -- link does not have a destRegion"; + NTA_THROW << "Invalid network structure file -- link does not have a " + "destRegion"; std::string destRegionName; *node >> destRegionName; // 6. destInput node = (*link).FindValue("destInput"); if (node == nullptr) - NTA_THROW << "Invalid network structure file -- link does not have a destInput"; + NTA_THROW + << "Invalid network structure file -- link does not have a destInput"; std::string destInputName; *node >> destInputName; if (!regions_.contains(srcRegionName)) - NTA_THROW << "Invalid network structure file -- link specifies source region '" << srcRegionName << "' but no such region exists"; - Region* srcRegion = regions_.getByName(srcRegionName); + NTA_THROW + << "Invalid network structure file -- link specifies source region '" + << srcRegionName << "' but no such region exists"; + Region *srcRegion = regions_.getByName(srcRegionName); if (!regions_.contains(destRegionName)) - NTA_THROW << "Invalid network structure file -- link specifies destination region '" << destRegionName << "' but no such region exists"; - Region* destRegion = regions_.getByName(destRegionName); + NTA_THROW << "Invalid network structure file -- link specifies " + "destination region '" + << destRegionName << "' but no such region exists"; + Region *destRegion = regions_.getByName(destRegionName); - Output* srcOutput = srcRegion->getOutput(srcOutputName); + Output *srcOutput = srcRegion->getOutput(srcOutputName); if (srcOutput == nullptr) - NTA_THROW << "Invalid network structure file -- link specifies source output '" << srcOutputName << "' but no such name exists"; + NTA_THROW + << "Invalid network structure file -- link specifies source output '" + << srcOutputName << "' but no such name exists"; - Input* destInput = destRegion->getInput(destInputName); + Input *destInput = destRegion->getInput(destInputName); if (destInput == nullptr) - NTA_THROW << "Invalid network structure file -- link specifies destination input '" << destInputName << "' but no such name exists"; + NTA_THROW << "Invalid network structure file -- link specifies " + "destination input '" + << destInputName << "' but no such name exists"; // Create the link itself auto newLink = new Link(linkType, params, srcOutput, destInput, - 0/*propagationDelay*/); + 0 /*propagationDelay*/); destInput->addLink(newLink, srcOutput); } // links - } -void Network::write(NetworkProto::Builder& proto) const -{ +void Network::write(NetworkProto::Builder &proto) const { // Aggregate links from all of the regions - std::vector links; + std::vector links; auto entriesProto = proto.initRegions().initEntries(regions_.getCount()); - for (UInt i = 0; i < regions_.getCount(); i++) - { + for (UInt i = 0; i < regions_.getCount(); i++) { auto entry = entriesProto[i]; auto regionPair = regions_.getByIndex(i); auto regionProto = entry.initValue(); @@ -1044,74 +938,63 @@ void Network::write(NetworkProto::Builder& proto) const regionPair.second->write(regionProto); // Aggregate this regions links in a vector to store at end - for (auto inputPair : regionPair.second->getInputs()) - { - auto& newLinks = inputPair.second->getLinks(); + for (auto inputPair : regionPair.second->getInputs()) { + auto &newLinks = inputPair.second->getLinks(); links.insert(links.end(), newLinks.begin(), newLinks.end()); } } // Store the aggregated links auto linksListProto = proto.initLinks(links.size()); - for (UInt i = 0; i < links.size(); ++i) - { + for (UInt i = 0; i < links.size(); ++i) { auto linkProto = linksListProto[i]; links[i]->write(linkProto); } } -void Network::read(NetworkProto::Reader& proto) -{ +void Network::read(NetworkProto::Reader &proto) { // Clear any previous regions - while (regions_.getCount() > 0) - { + while (regions_.getCount() > 0) { auto pair = regions_.getByIndex(0); delete pair.second; regions_.remove(pair.first); } // Add regions - for (auto entry : proto.getRegions().getEntries()) - { + for (auto entry : proto.getRegions().getEntries()) { auto regionProto = entry.getValue(); auto region = addRegionFromProto(entry.getKey().cStr(), regionProto); // Initialize the phases for the region std::set phases; - for (auto phase : regionProto.getPhases()) - { + for (auto phase : regionProto.getPhases()) { phases.insert(phase); } setPhases_(region, phases); } - for (auto linkProto : proto.getLinks()) - { + for (auto linkProto : proto.getLinks()) { auto link = new Link(); link->read(linkProto); - if (!regions_.contains(link->getSrcRegionName())) - { + if (!regions_.contains(link->getSrcRegionName())) { NTA_THROW << "Link references unknown region: " << link->getSrcRegionName(); } - Region* srcRegion = regions_.getByName(link->getSrcRegionName()); - Output* srcOutput = srcRegion->getOutput(link->getSrcOutputName()); - if (srcOutput == nullptr) - { + Region *srcRegion = regions_.getByName(link->getSrcRegionName()); + Output *srcOutput = srcRegion->getOutput(link->getSrcOutputName()); + if (srcOutput == nullptr) { NTA_THROW << "Link references unknown source output: " << link->getSrcOutputName(); } - if (!regions_.contains(link->getDestRegionName())) - { + if (!regions_.contains(link->getDestRegionName())) { NTA_THROW << "Link references unknown region: " << link->getDestRegionName(); } - Region* destRegion = regions_.getByName(link->getDestRegionName()); - Input* destInput = destRegion->getInput(link->getDestInputName()); - if (destInput == nullptr) - { + Region *destRegion = regions_.getByName(link->getDestRegionName()); + Input *destInput = destRegion->getInput(link->getDestInputName()); + if (destInput == nullptr) { NTA_THROW << "Link references unknown destination input: " << link->getDestInputName(); } @@ -1125,43 +1008,37 @@ void Network::read(NetworkProto::Reader& proto) initialized_ = false; } -void Network::enableProfiling() -{ +void Network::enableProfiling() { for (size_t i = 0; i < regions_.getCount(); i++) regions_.getByIndex(i).second->enableProfiling(); } -void Network::disableProfiling() -{ +void Network::disableProfiling() { for (size_t i = 0; i < regions_.getCount(); i++) regions_.getByIndex(i).second->disableProfiling(); } -void Network::resetProfiling() -{ +void Network::resetProfiling() { for (size_t i = 0; i < regions_.getCount(); i++) regions_.getByIndex(i).second->resetProfiling(); } -void Network::registerPyRegion(const std::string module, const std::string className) -{ +void Network::registerPyRegion(const std::string module, + const std::string className) { Region::registerPyRegion(module, className); } -void Network::registerCPPRegion(const std::string name, GenericRegisteredRegionImpl* wrapper) -{ +void Network::registerCPPRegion(const std::string name, + GenericRegisteredRegionImpl *wrapper) { Region::registerCPPRegion(name, wrapper); } -void Network::unregisterPyRegion(const std::string className) -{ +void Network::unregisterPyRegion(const std::string className) { Region::unregisterPyRegion(className); } -void Network::unregisterCPPRegion(const std::string name) -{ +void Network::unregisterCPPRegion(const std::string name) { Region::unregisterCPPRegion(name); } - } // namespace nupic diff --git a/src/nupic/engine/Network.hpp b/src/nupic/engine/Network.hpp index 5e0cd88b17..d99b2a7c0a 100644 --- a/src/nupic/engine/Network.hpp +++ b/src/nupic/engine/Network.hpp @@ -27,7 +27,6 @@ #ifndef NTA_NETWORK_HPP #define NTA_NETWORK_HPP - #include #include #include @@ -38,443 +37,418 @@ #include #include -#include #include +#include + +namespace nupic { + +class Region; +class Dimensions; +class GenericRegisteredRegionImpl; +class Link; + +/** + * Represents an HTM network. A network is a collection of regions. + * + * @nosubgrouping + */ +class Network : public Serializable { +public: + /** + * @name Construction and destruction + * @{ + */ + + /** + * + * Create an new Network and register it to NuPIC. + * + * @note Creating a Network will auto-initialize NuPIC. + */ + Network(); + + /** + * Create a Network by loading previously saved bundle, + * and register it to NuPIC. + * + * @param path The path to the previously saved bundle file, currently only + * support files with `.nta` extension. + * + * @note Creating a Network will auto-initialize NuPIC. + */ + Network(const std::string &path); + + /** + * Destructor. + * + * Destruct the network and unregister it from NuPIC: + * + * - Uninitialize all regions + * - Remove all links + * - Delete the regions themselves + * + * @todo Should we document the tear down steps above? + */ + ~Network(); + + /** + * Initialize all elements of a network so that it can run. + * + * @note This can be called after the Network structure has been set and + * before Network.run(). However, if you don't call it, Network.run() will + * call it for you. Also sets up various memory buffers etc. once the Network + * structure has been finalized. + */ + void initialize(); + + /** + * @} + * + * @name Serialization + * + * @{ + */ + + /** + * Save the network to a network bundle (extension `.nta`). + * + * @param name + * Name of the bundle + */ + void save(const std::string &name); + + /** + * @} + * + * @name Region and Link operations + * + * @{ + */ + + /** + * Create a new region in a network. + * + * @param name + * Name of the region, Must be unique in the network + * @param nodeType + * Type of node in the region, e.g. "FDRNode" + * @param nodeParams + * A JSON-encoded string specifying writable params + * + * @returns A pointer to the newly created Region + */ + Region *addRegion(const std::string &name, const std::string &nodeType, + const std::string &nodeParams); + + /** + * Create a new region from saved state. + * + * @param name + * Name of the region, Must be unique in the network + * @param nodeType + * Type of node in the region, e.g. "FDRNode" + * @param dimensions + * Dimensions of the region + * @param bundlePath + * The path to the bundle + * @param label + * The label of the bundle + * + * @todo @a label is the prefix of filename of the saved bundle, should this + * be documented? + * + * @returns A pointer to the newly created Region + */ + Region *addRegionFromBundle(const std::string &name, + const std::string &nodeType, + const Dimensions &dimensions, + const std::string &bundlePath, + const std::string &label); + + /** + * Create a new region from saved Cap'n Proto state. + * + * @param name + * Name of the region, Must be unique in the network + * @param proto + * The capnp proto reader + * + * @returns A pointer to the newly created Region + */ + Region *addRegionFromProto(const std::string &name, + RegionProto::Reader &proto); + + /** + * Removes an existing region from the network. + * + * @param name + * Name of the Region + */ + void removeRegion(const std::string &name); + + /** + * Create a link and add it to the network. + * + * @param srcName + * Name of the source region + * @param destName + * Name of the destination region + * @param linkType + * Type of the link + * @param linkParams + * Parameters of the link + * @param srcOutput + * Name of the source output + * @param destInput + * Name of the destination input + * @param propagationDelay + * Propagation delay of the link as number of network run + * iterations involving the link as input; the delay vectors, if + * any, are initially populated with 0's. Defaults to 0=no delay + */ + void link(const std::string &srcName, const std::string &destName, + const std::string &linkType, const std::string &linkParams, + const std::string &srcOutput = "", + const std::string &destInput = "", + const size_t propagationDelay = 0); + + /** + * Removes a link. + * + * @param srcName + * Name of the source region + * @param destName + * Name of the destination region + * @param srcOutputName + * Name of the source output + * @param destInputName + * Name of the destination input + */ + void removeLink(const std::string &srcName, const std::string &destName, + const std::string &srcOutputName = "", + const std::string &destInputName = ""); + + /** + * @} + * + * @name Access to components + * + * @{ + */ + + /** + * Get all regions. + * + * @returns A Collection of Region objects in the network + */ + const Collection &getRegions() const; + + /** + * Get all links between regions + * + * @returns A Collection of Link objects in the network + */ + Collection getLinks(); + + /** + * Set phases for a region. + * + * @param name + * Name of the region + * @param phases + * A tuple of phases (must be positive integers) + */ + void setPhases(const std::string &name, std::set &phases); + + /** + * Get phases for a region. + * + * @param name + * Name of the region + * + * @returns Set of phases for the region + */ + std::set getPhases(const std::string &name) const; + + /** + * Get minimum phase for regions in this network. If no regions, then min = 0. + * + * @returns Minimum phase + */ + UInt32 getMinPhase() const; + + /** + * Get maximum phase for regions in this network. If no regions, then max = 0. + * + * @returns Maximum phase + */ + UInt32 getMaxPhase() const; + + /** + * Set the minimum enabled phase for this network. + * + * @param minPhase Minimum enabled phase + */ + void setMinEnabledPhase(UInt32 minPhase); + + /** + * Set the maximum enabled phase for this network. + * + * @param minPhase Maximum enabled phase + */ + void setMaxEnabledPhase(UInt32 minPhase); + + /** + * Get the minimum enabled phase for this network. + * + * @returns Minimum enabled phase for this network + */ + UInt32 getMinEnabledPhase() const; + + /** + * Get the maximum enabled phase for this network. + * + * @returns Maximum enabled phase for this network + */ + UInt32 getMaxEnabledPhase() const; + + /** + * @} + * + * @name Running + * + * @{ + */ + + /** + * Run the network for the given number of iterations of compute for each + * Region in the correct order. + * + * For each iteration, Region.compute() is called. + * + * @param n Number of iterations + */ + void run(int n); + + /** + * The type of run callback function. + * + * You can attach a callback function to a network, and the callback function + * is called after every iteration of run(). + * + * To attach a callback, just get a reference to the callback + * collection with getCallbacks() , and add a callback. + */ + typedef void (*runCallbackFunction)(Network *, UInt64 iteration, void *); + + /** + * Type definition for a callback item, combines a @c runCallbackFunction and + * a `void*` pointer to the associated data. + */ + typedef std::pair callbackItem; + + /** + * Get reference to callback Collection. + * + * @returns Reference to callback Collection + */ + Collection &getCallbacks(); + + /** + * @} + * + * @name Profiling + * + * @{ + */ + + /** + * Start profiling for all regions of this network. + */ + void enableProfiling(); + + /** + * Stop profiling for all regions of this network. + */ + void disableProfiling(); + + /** + * Reset profiling timers for all regions of this network. + */ + void resetProfiling(); + + // Capnp serialization methods + using Serializable::write; + virtual void write(NetworkProto::Builder &proto) const override; + using Serializable::read; + virtual void read(NetworkProto::Reader &proto) override; + + /** + * @} + */ + + /* + * Adds user built region to list of regions + */ + static void registerPyRegion(const std::string module, + const std::string className); + + /* + * Adds a c++ region to the RegionImplFactory's packages + */ + static void registerCPPRegion(const std::string name, + GenericRegisteredRegionImpl *wrapper); + + /* + * Removes a region from RegionImplFactory's packages + */ + static void unregisterPyRegion(const std::string className); + + /* + * Removes a c++ region from RegionImplFactory's packages + */ + static void unregisterCPPRegion(const std::string name); + +private: + // Both constructors use this common initialization method + void commonInit(); + + // Used by the path-based constructor + void load(const std::string &path); + + void loadFromBundle(const std::string &path); + + // save() always calls this internal method, which creates + // a .nta bundle + void saveToBundle(const std::string &bundleName); + + // internal method using region pointer instead of name + void setPhases_(Region *r, std::set &phases); + + // default phase assignment for a new region + void setDefaultPhase_(Region *region); + + // whenever we modify a network or change phase + // information, we set enabled phases to min/max for + // the network + void resetEnabledPhases_(); + + bool initialized_; + Collection regions_; + + UInt32 minEnabledPhase_; + UInt32 maxEnabledPhase_; + + // This is main data structure used to choreograph + // network computation + std::vector> phaseInfo_; + + // we invoke these callbacks at every iteration + Collection callbacks_; -namespace nupic -{ - - class Region; - class Dimensions; - class GenericRegisteredRegionImpl; - class Link; - - /** - * Represents an HTM network. A network is a collection of regions. - * - * @nosubgrouping - */ - class Network : public Serializable - { - public: - - /** - * @name Construction and destruction - * @{ - */ - - /** - * - * Create an new Network and register it to NuPIC. - * - * @note Creating a Network will auto-initialize NuPIC. - */ - Network(); - - /** - * Create a Network by loading previously saved bundle, - * and register it to NuPIC. - * - * @param path The path to the previously saved bundle file, currently only - * support files with `.nta` extension. - * - * @note Creating a Network will auto-initialize NuPIC. - */ - Network(const std::string& path); - - /** - * Destructor. - * - * Destruct the network and unregister it from NuPIC: - * - * - Uninitialize all regions - * - Remove all links - * - Delete the regions themselves - * - * @todo Should we document the tear down steps above? - */ - ~Network(); - - /** - * Initialize all elements of a network so that it can run. - * - * @note This can be called after the Network structure has been set and - * before Network.run(). However, if you don't call it, Network.run() will - * call it for you. Also sets up various memory buffers etc. once the Network - * structure has been finalized. - */ - void - initialize(); - - /** - * @} - * - * @name Serialization - * - * @{ - */ - - /** - * Save the network to a network bundle (extension `.nta`). - * - * @param name - * Name of the bundle - */ - void save(const std::string& name); - - /** - * @} - * - * @name Region and Link operations - * - * @{ - */ - - /** - * Create a new region in a network. - * - * @param name - * Name of the region, Must be unique in the network - * @param nodeType - * Type of node in the region, e.g. "FDRNode" - * @param nodeParams - * A JSON-encoded string specifying writable params - * - * @returns A pointer to the newly created Region - */ - Region* - addRegion(const std::string& name, - const std::string& nodeType, - const std::string& nodeParams); - - /** - * Create a new region from saved state. - * - * @param name - * Name of the region, Must be unique in the network - * @param nodeType - * Type of node in the region, e.g. "FDRNode" - * @param dimensions - * Dimensions of the region - * @param bundlePath - * The path to the bundle - * @param label - * The label of the bundle - * - * @todo @a label is the prefix of filename of the saved bundle, should this - * be documented? - * - * @returns A pointer to the newly created Region - */ - Region* - addRegionFromBundle(const std::string& name, - const std::string& nodeType, - const Dimensions& dimensions, - const std::string& bundlePath, - const std::string& label); - - /** - * Create a new region from saved Cap'n Proto state. - * - * @param name - * Name of the region, Must be unique in the network - * @param proto - * The capnp proto reader - * - * @returns A pointer to the newly created Region - */ - Region* - addRegionFromProto(const std::string& name, - RegionProto::Reader& proto); - - /** - * Removes an existing region from the network. - * - * @param name - * Name of the Region - */ - void - removeRegion(const std::string& name); - - /** - * Create a link and add it to the network. - * - * @param srcName - * Name of the source region - * @param destName - * Name of the destination region - * @param linkType - * Type of the link - * @param linkParams - * Parameters of the link - * @param srcOutput - * Name of the source output - * @param destInput - * Name of the destination input - * @param propagationDelay - * Propagation delay of the link as number of network run - * iterations involving the link as input; the delay vectors, if - * any, are initially populated with 0's. Defaults to 0=no delay - */ - void - link(const std::string& srcName, const std::string& destName, - const std::string& linkType, const std::string& linkParams, - const std::string& srcOutput="", const std::string& destInput="", - const size_t propagationDelay=0); - - - /** - * Removes a link. - * - * @param srcName - * Name of the source region - * @param destName - * Name of the destination region - * @param srcOutputName - * Name of the source output - * @param destInputName - * Name of the destination input - */ - void - removeLink(const std::string& srcName, const std::string& destName, - const std::string& srcOutputName="", const std::string& destInputName=""); - - /** - * @} - * - * @name Access to components - * - * @{ - */ - - /** - * Get all regions. - * - * @returns A Collection of Region objects in the network - */ - const Collection& - getRegions() const; - - /** - * Get all links between regions - * - * @returns A Collection of Link objects in the network - */ - Collection - getLinks(); - - /** - * Set phases for a region. - * - * @param name - * Name of the region - * @param phases - * A tuple of phases (must be positive integers) - */ - void - setPhases(const std::string& name, std::set& phases); - - /** - * Get phases for a region. - * - * @param name - * Name of the region - * - * @returns Set of phases for the region - */ - std::set - getPhases(const std::string& name) const; - - /** - * Get minimum phase for regions in this network. If no regions, then min = 0. - * - * @returns Minimum phase - */ - UInt32 getMinPhase() const; - - /** - * Get maximum phase for regions in this network. If no regions, then max = 0. - * - * @returns Maximum phase - */ - UInt32 getMaxPhase() const; - - /** - * Set the minimum enabled phase for this network. - * - * @param minPhase Minimum enabled phase - */ - void - setMinEnabledPhase(UInt32 minPhase); - - /** - * Set the maximum enabled phase for this network. - * - * @param minPhase Maximum enabled phase - */ - void - setMaxEnabledPhase(UInt32 minPhase); - - /** - * Get the minimum enabled phase for this network. - * - * @returns Minimum enabled phase for this network - */ - UInt32 - getMinEnabledPhase() const; - - /** - * Get the maximum enabled phase for this network. - * - * @returns Maximum enabled phase for this network - */ - UInt32 - getMaxEnabledPhase() const; - - /** - * @} - * - * @name Running - * - * @{ - */ - - /** - * Run the network for the given number of iterations of compute for each - * Region in the correct order. - * - * For each iteration, Region.compute() is called. - * - * @param n Number of iterations - */ - void - run(int n); - - /** - * The type of run callback function. - * - * You can attach a callback function to a network, and the callback function - * is called after every iteration of run(). - * - * To attach a callback, just get a reference to the callback - * collection with getCallbacks() , and add a callback. - */ - typedef void (*runCallbackFunction)(Network*, UInt64 iteration, void*); - - /** - * Type definition for a callback item, combines a @c runCallbackFunction and - * a `void*` pointer to the associated data. - */ - typedef std::pair callbackItem; - - /** - * Get reference to callback Collection. - * - * @returns Reference to callback Collection - */ - Collection& getCallbacks(); - - /** - * @} - * - * @name Profiling - * - * @{ - */ - - /** - * Start profiling for all regions of this network. - */ - void - enableProfiling(); - - /** - * Stop profiling for all regions of this network. - */ - void - disableProfiling(); - - /** - * Reset profiling timers for all regions of this network. - */ - void - resetProfiling(); - - // Capnp serialization methods - using Serializable::write; - virtual void write(NetworkProto::Builder& proto) const override; - using Serializable::read; - virtual void read(NetworkProto::Reader& proto) override; - - /** - * @} - */ - - /* - * Adds user built region to list of regions - */ - static void registerPyRegion(const std::string module, - const std::string className); - - /* - * Adds a c++ region to the RegionImplFactory's packages - */ - static void registerCPPRegion(const std::string name, - GenericRegisteredRegionImpl* wrapper); - - /* - * Removes a region from RegionImplFactory's packages - */ - static void unregisterPyRegion(const std::string className); - - /* - * Removes a c++ region from RegionImplFactory's packages - */ - static void unregisterCPPRegion(const std::string name); - - private: - - - // Both constructors use this common initialization method - void commonInit(); - - // Used by the path-based constructor - void load(const std::string& path); - - void loadFromBundle(const std::string& path); - - // save() always calls this internal method, which creates - // a .nta bundle - void saveToBundle(const std::string& bundleName); - - // internal method using region pointer instead of name - void - setPhases_(Region *r, std::set& phases); - - // default phase assignment for a new region - void setDefaultPhase_(Region* region); - - // whenever we modify a network or change phase - // information, we set enabled phases to min/max for - // the network - void resetEnabledPhases_(); - - bool initialized_; - Collection regions_; - - UInt32 minEnabledPhase_; - UInt32 maxEnabledPhase_; - - // This is main data structure used to choreograph - // network computation - std::vector< std::set > phaseInfo_; - - // we invoke these callbacks at every iteration - Collection callbacks_; - - //number of elapsed iterations - UInt64 iteration_; - }; + // number of elapsed iterations + UInt64 iteration_; +}; } // namespace nupic diff --git a/src/nupic/engine/NuPIC.cpp b/src/nupic/engine/NuPIC.cpp index d71c7f8d6a..44dea6b6df 100644 --- a/src/nupic/engine/NuPIC.cpp +++ b/src/nupic/engine/NuPIC.cpp @@ -24,19 +24,17 @@ // TODO -- thread safety +#include #include #include #include -#include -namespace nupic -{ +namespace nupic { -std::set NuPIC::networks_; +std::set NuPIC::networks_; bool NuPIC::initialized_ = false; -void NuPIC::init() -{ +void NuPIC::init() { if (isInitialized()) return; @@ -52,16 +50,12 @@ void NuPIC::init() initialized_ = true; } - -void NuPIC::shutdown() -{ - if (!isInitialized()) - { +void NuPIC::shutdown() { + if (!isInitialized()) { NTA_THROW << "NuPIC::shutdown -- NuPIC has not been initialized"; } - if (!networks_.empty()) - { + if (!networks_.empty()) { NTA_THROW << "NuPIC::shutdown -- cannot shut down NuPIC because " << networks_.size() << " networks still exist."; } @@ -70,32 +64,26 @@ void NuPIC::shutdown() initialized_ = false; } - -void NuPIC::registerNetwork(Network* net) -{ - if (!isInitialized()) - { - NTA_THROW << "Attempt to create a network before NuPIC has been initialized -- call NuPIC::init() before creating any networks"; +void NuPIC::registerNetwork(Network *net) { + if (!isInitialized()) { + NTA_THROW + << "Attempt to create a network before NuPIC has been initialized -- " + "call NuPIC::init() before creating any networks"; } auto n = networks_.find(net); // This should not be possible - NTA_CHECK(n == networks_.end()) << "Internal error -- double registration of network"; + NTA_CHECK(n == networks_.end()) + << "Internal error -- double registration of network"; networks_.insert(net); - } -void NuPIC::unregisterNetwork(Network* net) -{ +void NuPIC::unregisterNetwork(Network *net) { auto n = networks_.find(net); NTA_CHECK(n != networks_.end()) << "Internal error -- network not registered"; networks_.erase(n); } -bool NuPIC::isInitialized() -{ - return initialized_; -} - -} +bool NuPIC::isInitialized() { return initialized_; } +} // namespace nupic diff --git a/src/nupic/engine/NuPIC.hpp b/src/nupic/engine/NuPIC.hpp index 893b8102ea..ffc8b3515a 100644 --- a/src/nupic/engine/NuPIC.hpp +++ b/src/nupic/engine/NuPIC.hpp @@ -31,50 +31,48 @@ * * Contains the primary NuPIC API. */ -namespace nupic -{ - class Network; +namespace nupic { +class Network; +/** + * Initialization and shutdown operations for NuPIC engine. + */ +class NuPIC { +public: /** - * Initialization and shutdown operations for NuPIC engine. + * Initialize NuPIC. + * + * @note It's safe to reinitialize an initialized NuPIC. + * @note Creating a Network will auto-initialize NuPIC. */ - class NuPIC - { - public: - /** - * Initialize NuPIC. - * - * @note It's safe to reinitialize an initialized NuPIC. - * @note Creating a Network will auto-initialize NuPIC. - */ - static void init(); + static void init(); - /** - * Shutdown NuPIC. - * - * @note As a safety measure, NuPIC with any Network still registered to it - * is not allowed to be shut down. - */ - static void shutdown(); + /** + * Shutdown NuPIC. + * + * @note As a safety measure, NuPIC with any Network still registered to it + * is not allowed to be shut down. + */ + static void shutdown(); - /** - * - * @return Whether NuPIC is initialized successfully. - */ - static bool isInitialized(); - private: + /** + * + * @return Whether NuPIC is initialized successfully. + */ + static bool isInitialized(); - /** - * Having Network as friend class to allow Networks register/unregister - * themselves at creation and destruction time by calling non-public methods of NuPIC. - * - */ - friend class Network; +private: + /** + * Having Network as friend class to allow Networks register/unregister + * themselves at creation and destruction time by calling non-public methods + * of NuPIC. + * + */ + friend class Network; - static void registerNetwork(Network* net); - static void unregisterNetwork(Network* net); - static std::set networks_; - static bool initialized_; - }; + static void registerNetwork(Network *net); + static void unregisterNetwork(Network *net); + static std::set networks_; + static bool initialized_; +}; } // namespace nupic - diff --git a/src/nupic/engine/Output.cpp b/src/nupic/engine/Output.cpp index 98e809fe24..e2536dd5ba 100644 --- a/src/nupic/engine/Output.cpp +++ b/src/nupic/engine/Output.cpp @@ -23,27 +23,24 @@ /** @file * Implementation of Output class * -*/ + */ -#include // memset -#include -#include +#include // memset +#include // temporary #include #include -#include // temporary - +#include +#include -namespace nupic -{ +namespace nupic { -Output::Output(Region& region, NTA_BasicType type, bool isRegionLevel) : - region_(region), isRegionLevel_(isRegionLevel), name_("Unnamed"), nodeOutputElementCount_(0) -{ +Output::Output(Region ®ion, NTA_BasicType type, bool isRegionLevel) + : region_(region), isRegionLevel_(isRegionLevel), name_("Unnamed"), + nodeOutputElementCount_(0) { data_ = new Array(type); } -Output::~Output() -{ +Output::~Output() { // If we have any outgoing links, then there has been an // error in the shutdown process. Not good to thow an exception // from a destructor, but we need to catch this error, and it @@ -53,9 +50,7 @@ Output::~Output() } // allocate buffer -void -Output::initialize(size_t count) -{ +void Output::initialize(size_t count) { // reinitialization is ok // might happen if initial initialization failed with an // exception (elsewhere) and was retried. @@ -68,8 +63,7 @@ Output::initialize(size_t count) dataCount = count; else dataCount = count * region_.getDimensions().getCount(); - if (dataCount != 0) - { + if (dataCount != 0) { data_->allocateBuffer(dataCount); // Zero the buffer because unitialized outputs can screw up inspectors, // which look at the output before compute(). NPC-60 @@ -79,9 +73,7 @@ Output::initialize(size_t count) } } -void -Output::addLink(Link* link) -{ +void Output::addLink(Link *link) { // Make sure we don't add the same link twice // It is a logic error if we add the same link twice here, since // this method should only be called from Input::addLink @@ -91,9 +83,7 @@ Output::addLink(Link* link) links_.insert(link); } -void -Output::removeLink(Link* link) -{ +void Output::removeLink(Link *link) { auto linkIter = links_.find(link); // Should only be called internally. Logic error if link not found NTA_CHECK(linkIter != links_.end()); @@ -102,48 +92,20 @@ Output::removeLink(Link* link) links_.erase(linkIter); } -const Array & -Output::getData() const -{ - return *data_; -} +const Array &Output::getData() const { return *data_; } -bool -Output::isRegionLevel() const -{ - return isRegionLevel_; -} +bool Output::isRegionLevel() const { return isRegionLevel_; } +Region &Output::getRegion() const { return region_; } -Region& -Output::getRegion() const -{ - return region_; -} - +void Output::setName(const std::string &name) { name_ = name; } -void Output::setName(const std::string& name) -{ - name_ = name; -} - -const std::string& Output::getName() const -{ - return name_; -} +const std::string &Output::getName() const { return name_; } - -size_t -Output::getNodeOutputElementCount() const -{ +size_t Output::getNodeOutputElementCount() const { return nodeOutputElementCount_; } -bool -Output::hasOutgoingLinks() -{ - return (!links_.empty()); -} - -} +bool Output::hasOutgoingLinks() { return (!links_.empty()); } +} // namespace nupic diff --git a/src/nupic/engine/Output.hpp b/src/nupic/engine/Output.hpp index 1f3cea32d0..0dd8c8e4c0 100644 --- a/src/nupic/engine/Output.hpp +++ b/src/nupic/engine/Output.hpp @@ -20,170 +20,158 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Interface for the internal Output class. */ #ifndef NTA_OUTPUT_HPP #define NTA_OUTPUT_HPP -#include #include #include // temporary, while impl is in this file -namespace nupic -{ +#include +namespace nupic { + +class Link; +class Region; +class Array; + +/** + * Represents a named output to a Region. + */ +class Output { +public: + /** + * Constructor. + * + * @param region + * The region that the output belongs to. + * @param type + * The type of the output, TODO + * @param isRegionLevel + * Whether the output is region level, i.e. TODO + */ + Output(Region ®ion, NTA_BasicType type, bool isRegionLevel); + + /** + * Destructor + */ + ~Output(); - class Link; - class Region; - class Array; + /** + * Set the name for the output. + * + * Output need to know their own name for error messages. + * + * @param name + * The name of the output + */ + void setName(const std::string &name); + + /** + * Get the name of the output. + * + * @return + * The name of the output + */ + const std::string &getName() const; + + /** + * Initialize the Output . + * + * @param size + * The count of node output element, i.e. TODO + * + * @note It's safe to reinitialize an initialized Output with the same + * parameters. + * + */ + void initialize(size_t size); /** - * Represents a named output to a Region. + * + * Add a Link to the Output . + * + * @note The Output does NOT take ownership of @a link, it's created and + * owned by an Input Object. + * + * Called by Input.addLink() + * + * @param link + * The Link to add */ - class Output - { - public: - - /** - * Constructor. - * - * @param region - * The region that the output belongs to. - * @param type - * The type of the output, TODO - * @param isRegionLevel - * Whether the output is region level, i.e. TODO - */ - Output(Region& region, NTA_BasicType type, bool isRegionLevel); - - /** - * Destructor - */ - ~Output(); - - /** - * Set the name for the output. - * - * Output need to know their own name for error messages. - * - * @param name - * The name of the output - */ - void setName(const std::string& name); - - /** - * Get the name of the output. - * - * @return - * The name of the output - */ - const std::string& getName() const; - - /** - * Initialize the Output . - * - * @param size - * The count of node output element, i.e. TODO - * - * @note It's safe to reinitialize an initialized Output with the same - * parameters. - * - */ - void initialize(size_t size); - - /** - * - * Add a Link to the Output . - * - * @note The Output does NOT take ownership of @a link, it's created and - * owned by an Input Object. - * - * Called by Input.addLink() - * - * @param link - * The Link to add - */ - void - addLink(Link* link); - - /** - * Removing an existing link from the output. - * - * @note Called only by Input.removeLink() even if triggered by - * Network.removeRegion() while removing the region that contains us. - * - * @param link - * The Link to remove - */ - void - removeLink(Link* link); - - /** - * Tells whether the output has outgoing links. - * - * @note We cannot delete a region if there are any outgoing links - * This allows us to check in Network.removeRegion() and Network.~Network(). - * @returns - * Whether the output has outgoing links - */ - bool - hasOutgoingLinks(); - - /** - * - * Get the data of the output. - * - * @returns - * A constant reference to the data of the output as an @c Array - * - * @note It's mportant to return a const array so caller can't - * reallocate the buffer. - */ - const Array & - getData() const; - - /** - * - * Tells whether the output is region level. - * - * @returns - * Whether the output is region level, i.e. TODO - */ - bool - isRegionLevel() const; - - /** - * - * Get the Region that the output belongs to. - * - * @returns - * The mutable reference to the Region that the output belongs to - */ - Region& - getRegion() const; - - /** - * Get the count of node output element. - * - * @returns - * The count of node output element, previously set by initialize(). - */ - size_t - getNodeOutputElementCount() const; - - private: - - Region& region_; // needed for number of nodes - Array * data_; - bool isRegionLevel_; - // order of links never matters, so store as a set - // this is different from Input, where they do matter - std::set links_; - std::string name_; - size_t nodeOutputElementCount_; - }; - -} + void addLink(Link *link); + /** + * Removing an existing link from the output. + * + * @note Called only by Input.removeLink() even if triggered by + * Network.removeRegion() while removing the region that contains us. + * + * @param link + * The Link to remove + */ + void removeLink(Link *link); + + /** + * Tells whether the output has outgoing links. + * + * @note We cannot delete a region if there are any outgoing links + * This allows us to check in Network.removeRegion() and Network.~Network(). + * @returns + * Whether the output has outgoing links + */ + bool hasOutgoingLinks(); + + /** + * + * Get the data of the output. + * + * @returns + * A constant reference to the data of the output as an @c Array + * + * @note It's mportant to return a const array so caller can't + * reallocate the buffer. + */ + const Array &getData() const; + + /** + * + * Tells whether the output is region level. + * + * @returns + * Whether the output is region level, i.e. TODO + */ + bool isRegionLevel() const; + + /** + * + * Get the Region that the output belongs to. + * + * @returns + * The mutable reference to the Region that the output belongs to + */ + Region &getRegion() const; + + /** + * Get the count of node output element. + * + * @returns + * The count of node output element, previously set by initialize(). + */ + size_t getNodeOutputElementCount() const; + +private: + Region ®ion_; // needed for number of nodes + Array *data_; + bool isRegionLevel_; + // order of links never matters, so store as a set + // this is different from Input, where they do matter + std::set links_; + std::string name_; + size_t nodeOutputElementCount_; +}; + +} // namespace nupic #endif // NTA_OUTPUT_HPP diff --git a/src/nupic/engine/Region.cpp b/src/nupic/engine/Region.cpp index f8f7eb0976..dbdd3f4c62 100644 --- a/src/nupic/engine/Region.cpp +++ b/src/nupic/engine/Region.cpp @@ -29,565 +29,409 @@ Methods related to inputs and outputs are in Region_io.cpp */ #include -#include -#include -#include +#include +#include +#include #include #include #include #include -#include -#include -#include -#include #include #include #include +#include +#include +#include +#include -namespace nupic -{ - - class GenericRegisteredRegionImpl; - - // Create region from parameter spec - Region::Region(std::string name, - const std::string& nodeType, - const std::string& nodeParams, - Network * network) : - name_(std::move(name)), - type_(nodeType), - initialized_(false), - enabledNodes_(nullptr), - network_(network), - profilingEnabled_(false) - { - // Set region info before creating the RegionImpl so that the - // Impl has access to the region info in its constructor. - RegionImplFactory & factory = RegionImplFactory::getInstance(); - spec_ = factory.getSpec(nodeType); - - // Dimensions start off as unspecified, but if - // the RegionImpl only supports a single node, we - // can immediately set the dimensions. - if (spec_->singleNodeOnly) - dims_.push_back(1); - // else dims_ = [] - - impl_ = factory.createRegionImpl(nodeType, nodeParams, this); - createInputsAndOutputs_(); - - } +namespace nupic { + +class GenericRegisteredRegionImpl; + +// Create region from parameter spec +Region::Region(std::string name, const std::string &nodeType, + const std::string &nodeParams, Network *network) + : name_(std::move(name)), type_(nodeType), initialized_(false), + enabledNodes_(nullptr), network_(network), profilingEnabled_(false) { + // Set region info before creating the RegionImpl so that the + // Impl has access to the region info in its constructor. + RegionImplFactory &factory = RegionImplFactory::getInstance(); + spec_ = factory.getSpec(nodeType); + + // Dimensions start off as unspecified, but if + // the RegionImpl only supports a single node, we + // can immediately set the dimensions. + if (spec_->singleNodeOnly) + dims_.push_back(1); + // else dims_ = [] + + impl_ = factory.createRegionImpl(nodeType, nodeParams, this); + createInputsAndOutputs_(); +} - // Deserialize region - Region::Region(std::string name, - const std::string& nodeType, - const Dimensions& dimensions, - BundleIO& bundle, - Network * network) : - name_(std::move(name)), - type_(nodeType), - initialized_(false), - enabledNodes_(nullptr), - network_(network), - profilingEnabled_(false) - { - // Set region info before creating the RegionImpl so that the - // Impl has access to the region info in its constructor. - RegionImplFactory & factory = RegionImplFactory::getInstance(); - spec_ = factory.getSpec(nodeType); - - // Dimensions start off as unspecified, but if - // the RegionImpl only supports a single node, we - // can immediately set the dimensions. - if (spec_->singleNodeOnly) - if (!dimensions.isDontcare() && !dimensions.isUnspecified() && - !dimensions.isOnes()) - NTA_THROW << "Attempt to deserialize region of type " << nodeType - << " with dimensions " << dimensions - << " but region supports exactly one node."; - - dims_ = dimensions; - - impl_ = factory.deserializeRegionImpl(nodeType, bundle, this); - createInputsAndOutputs_(); - } +// Deserialize region +Region::Region(std::string name, const std::string &nodeType, + const Dimensions &dimensions, BundleIO &bundle, Network *network) + : name_(std::move(name)), type_(nodeType), initialized_(false), + enabledNodes_(nullptr), network_(network), profilingEnabled_(false) { + // Set region info before creating the RegionImpl so that the + // Impl has access to the region info in its constructor. + RegionImplFactory &factory = RegionImplFactory::getInstance(); + spec_ = factory.getSpec(nodeType); + + // Dimensions start off as unspecified, but if + // the RegionImpl only supports a single node, we + // can immediately set the dimensions. + if (spec_->singleNodeOnly) + if (!dimensions.isDontcare() && !dimensions.isUnspecified() && + !dimensions.isOnes()) + NTA_THROW << "Attempt to deserialize region of type " << nodeType + << " with dimensions " << dimensions + << " but region supports exactly one node."; + + dims_ = dimensions; + + impl_ = factory.deserializeRegionImpl(nodeType, bundle, this); + createInputsAndOutputs_(); +} - Region::Region(std::string name, RegionProto::Reader& proto, - Network* network) : - name_(std::move(name)), - type_(proto.getNodeType().cStr()), - initialized_(false), - enabledNodes_(nullptr), - network_(network), - profilingEnabled_(false) - { - read(proto); - createInputsAndOutputs_(); - } +Region::Region(std::string name, RegionProto::Reader &proto, Network *network) + : name_(std::move(name)), type_(proto.getNodeType().cStr()), + initialized_(false), enabledNodes_(nullptr), network_(network), + profilingEnabled_(false) { + read(proto); + createInputsAndOutputs_(); +} - Network * Region::getNetwork() - { - return network_; - } +Network *Region::getNetwork() { return network_; } - void Region::createInputsAndOutputs_() - { - - // Create all the outputs for this node type. By default outputs are zero size - for (size_t i = 0; i < spec_->outputs.getCount(); ++i) - { - const std::pair & p = spec_->outputs.getByIndex(i); - std::string outputName = p.first; - const OutputSpec & os = p.second; - auto output = new Output(*this, os.dataType, os.regionLevel); - outputs_[outputName] = output; - // keep track of name in the output also -- see note in Region.hpp - output->setName(outputName); - } +void Region::createInputsAndOutputs_() { - // Create all the inputs for this node type. - for (size_t i = 0; i < spec_->inputs.getCount(); ++i) - { - const std::pair & p = spec_->inputs.getByIndex(i); - std::string inputName = p.first; - const InputSpec &is = p.second; - - auto input = new Input(*this, is.dataType, is.regionLevel); - inputs_[inputName] = input; - // keep track of name in the input also -- see note in Region.hpp - input->setName(inputName); - } + // Create all the outputs for this node type. By default outputs are zero size + for (size_t i = 0; i < spec_->outputs.getCount(); ++i) { + const std::pair &p = spec_->outputs.getByIndex(i); + std::string outputName = p.first; + const OutputSpec &os = p.second; + auto output = new Output(*this, os.dataType, os.regionLevel); + outputs_[outputName] = output; + // keep track of name in the output also -- see note in Region.hpp + output->setName(outputName); } + // Create all the inputs for this node type. + for (size_t i = 0; i < spec_->inputs.getCount(); ++i) { + const std::pair &p = spec_->inputs.getByIndex(i); + std::string inputName = p.first; + const InputSpec &is = p.second; - bool Region::hasOutgoingLinks() const - { - for (const auto & elem : outputs_) - { - if (elem.second->hasOutgoingLinks()) - { - return true; - } - } - return false; + auto input = new Input(*this, is.dataType, is.regionLevel); + inputs_[inputName] = input; + // keep track of name in the input also -- see note in Region.hpp + input->setName(inputName); } +} - Region::~Region() - { - // If there are any links connected to our outputs, this will fail. - // We should catch this error in the Network class and give the - // user a good error message (regions may be removed either in - // Network::removeRegion or Network::~Network()) - for (auto & elem : outputs_) - { - delete elem.second; - elem.second = nullptr; - } - - for (auto & elem : inputs_) - { - delete elem.second; - elem.second = nullptr; +bool Region::hasOutgoingLinks() const { + for (const auto &elem : outputs_) { + if (elem.second->hasOutgoingLinks()) { + return true; } + } + return false; +} - delete impl_; - delete enabledNodes_; - +Region::~Region() { + // If there are any links connected to our outputs, this will fail. + // We should catch this error in the Network class and give the + // user a good error message (regions may be removed either in + // Network::removeRegion or Network::~Network()) + for (auto &elem : outputs_) { + delete elem.second; + elem.second = nullptr; } + for (auto &elem : inputs_) { + delete elem.second; + elem.second = nullptr; + } + delete impl_; + delete enabledNodes_; +} - void - Region::initialize() - { +void Region::initialize() { - if (initialized_) - return; + if (initialized_) + return; - impl_->initialize(); - initialized_ = true; - } + impl_->initialize(); + initialized_ = true; +} - bool - Region::isInitialized() const - { - return initialized_; - } +bool Region::isInitialized() const { return initialized_; } - const std::string& - Region::getName() const - { - return name_; - } +const std::string &Region::getName() const { return name_; } - const std::string& - Region::getType() const - { - return type_; - } +const std::string &Region::getType() const { return type_; } - const Spec* - Region::getSpec() const - { - return spec_; - } +const Spec *Region::getSpec() const { return spec_; } - const Spec* - Region::getSpecFromType(const std::string& nodeType) - { - RegionImplFactory & factory = RegionImplFactory::getInstance(); - return factory.getSpec(nodeType); - } +const Spec *Region::getSpecFromType(const std::string &nodeType) { + RegionImplFactory &factory = RegionImplFactory::getInstance(); + return factory.getSpec(nodeType); +} - void - Region::registerPyRegion(const std::string module, const std::string className) - { - RegionImplFactory::registerPyRegion(module, className); - } +void Region::registerPyRegion(const std::string module, + const std::string className) { + RegionImplFactory::registerPyRegion(module, className); +} - void - Region::registerCPPRegion(const std::string name, GenericRegisteredRegionImpl* wrapper) - { - RegionImplFactory::registerCPPRegion(name, wrapper); - } +void Region::registerCPPRegion(const std::string name, + GenericRegisteredRegionImpl *wrapper) { + RegionImplFactory::registerCPPRegion(name, wrapper); +} - void - Region::unregisterPyRegion(const std::string className) - { - RegionImplFactory::unregisterPyRegion(className); - } +void Region::unregisterPyRegion(const std::string className) { + RegionImplFactory::unregisterPyRegion(className); +} - void - Region::unregisterCPPRegion(const std::string name) - { - RegionImplFactory::unregisterCPPRegion(name); - } +void Region::unregisterCPPRegion(const std::string name) { + RegionImplFactory::unregisterCPPRegion(name); +} - const Dimensions& - Region::getDimensions() const - { - return dims_; - } +const Dimensions &Region::getDimensions() const { return dims_; } - void - Region::enable() - { - NTA_THROW << "Region::enable not implemented (region name: " << getName() << ")"; - } +void Region::enable() { + NTA_THROW << "Region::enable not implemented (region name: " << getName() + << ")"; +} +void Region::disable() { + NTA_THROW << "Region::disable not implemented (region name: " << getName() + << ")"; +} - void - Region::disable() - { - NTA_THROW << "Region::disable not implemented (region name: " << getName() << ")"; +std::string Region::executeCommand(const std::vector &args) { + std::string retVal; + if (args.size() < 1) { + NTA_THROW << "Invalid empty command specified"; } - std::string - Region::executeCommand(const std::vector& args) - { - std::string retVal; - if (args.size() < 1) - { - NTA_THROW << "Invalid empty command specified"; - } + if (profilingEnabled_) + executeTimer_.start(); + retVal = impl_->executeCommand(args, (UInt64)(-1)); - if (profilingEnabled_) - executeTimer_.start(); + if (profilingEnabled_) + executeTimer_.stop(); - retVal = impl_->executeCommand(args, (UInt64)(-1)); + return retVal; +} - if (profilingEnabled_) - executeTimer_.stop(); +void Region::compute() { + if (!initialized_) + NTA_THROW << "Region " << getName() + << " unable to compute because not initialized"; - return retVal; - } + if (profilingEnabled_) + computeTimer_.start(); + impl_->compute(); - void - Region::compute() - { - if (!initialized_) - NTA_THROW << "Region " << getName() << " unable to compute because not initialized"; + if (profilingEnabled_) + computeTimer_.stop(); - if (profilingEnabled_) - computeTimer_.start(); - - impl_->compute(); + return; +} - if (profilingEnabled_) - computeTimer_.stop(); +/** + * These internal methods are called by Network as + * part of initialization. + */ - return; +size_t Region::evaluateLinks() { + int nIncompleteLinks = 0; + for (auto &elem : inputs_) { + nIncompleteLinks += (elem.second)->evaluateLinks(); } + return nIncompleteLinks; +} +std::string Region::getLinkErrors() const { - /** - * These internal methods are called by Network as - * part of initialization. - */ - - size_t - Region::evaluateLinks() - { - int nIncompleteLinks = 0; - for (auto & elem : inputs_) - { - nIncompleteLinks += (elem.second)->evaluateLinks(); - } - return nIncompleteLinks; - } - - std::string - Region::getLinkErrors() const - { - - std::stringstream ss; - for (const auto & elem : inputs_) - { - const std::vector& links = elem.second->getLinks(); - for (const auto & link : links) - { - if ( (link)->getSrcDimensions().isUnspecified() || - (link)->getDestDimensions().isUnspecified()) - { - ss << (link)->toString() << "\n"; - } + std::stringstream ss; + for (const auto &elem : inputs_) { + const std::vector &links = elem.second->getLinks(); + for (const auto &link : links) { + if ((link)->getSrcDimensions().isUnspecified() || + (link)->getDestDimensions().isUnspecified()) { + ss << (link)->toString() << "\n"; } } - - return ss.str(); } - size_t Region::getNodeOutputElementCount(const std::string& name) - { - // Use output count if specified in nodespec, otherwise - // ask the Impl - NTA_CHECK(spec_->outputs.contains(name)); - size_t count = spec_->outputs.getByName(name).count; - if(count == 0) - { - try - { - count = impl_->getNodeOutputElementCount(name); - } catch(Exception& e) { - NTA_THROW << "Internal error -- the size for the output " << name << - "is unknown. : " << e.what(); - } - } + return ss.str(); +} - return count; +size_t Region::getNodeOutputElementCount(const std::string &name) { + // Use output count if specified in nodespec, otherwise + // ask the Impl + NTA_CHECK(spec_->outputs.contains(name)); + size_t count = spec_->outputs.getByName(name).count; + if (count == 0) { + try { + count = impl_->getNodeOutputElementCount(name); + } catch (Exception &e) { + NTA_THROW << "Internal error -- the size for the output " << name + << "is unknown. : " << e.what(); + } } - void Region::initOutputs() - { - // Some outputs are optional. These outputs will have 0 elementCount in the node - // spec and also return 0 from impl->getNodeOutputElementCount(). These outputs still - // appear in the output map, but with an array size of 0. + return count; +} +void Region::initOutputs() { + // Some outputs are optional. These outputs will have 0 elementCount in the + // node spec and also return 0 from impl->getNodeOutputElementCount(). These + // outputs still appear in the output map, but with an array size of 0. - for (auto & elem : outputs_) - { - const std::string& name = elem.first; + for (auto &elem : outputs_) { + const std::string &name = elem.first; - size_t count = 0; - try - { - count = getNodeOutputElementCount(name); - } catch (nupic::Exception& e) { - NTA_THROW << "Internal error -- unable to get size of output " - << name << " : " << e.what(); - } - elem.second->initialize(count); + size_t count = 0; + try { + count = getNodeOutputElementCount(name); + } catch (nupic::Exception &e) { + NTA_THROW << "Internal error -- unable to get size of output " << name + << " : " << e.what(); } + elem.second->initialize(count); } +} - void Region::initInputs() const - { - auto i = inputs_.begin(); - for (; i != inputs_.end(); i++) - { - i->second->initialize(); - } +void Region::initInputs() const { + auto i = inputs_.begin(); + for (; i != inputs_.end(); i++) { + i->second->initialize(); } +} +void Region::setDimensions(Dimensions &newDims) { + // Can only set dimensions one time + if (dims_ == newDims) + return; - void - Region::setDimensions(Dimensions& newDims) - { - // Can only set dimensions one time - if (dims_ == newDims) - return; - - if (dims_.isUnspecified()) - { - if (newDims.isDontcare()) - { - NTA_THROW << "Invalid attempt to set region dimensions to dontcare value"; - } - - if (! newDims.isValid()) - { - NTA_THROW << "Attempt to set region dimensions to invalid value:" - << newDims.toString(); - } - - dims_ = newDims; - dimensionInfo_ = "Specified explicitly in setDimensions()"; - } else { - NTA_THROW << "Attempt to set dimensions of region " << getName() - << " to " << newDims.toString() - << " but region already has dimensions " << dims_.toString(); + if (dims_.isUnspecified()) { + if (newDims.isDontcare()) { + NTA_THROW << "Invalid attempt to set region dimensions to dontcare value"; } - // can only create the enabled node set after we know the number of dimensions - setupEnabledNodeSet(); + if (!newDims.isValid()) { + NTA_THROW << "Attempt to set region dimensions to invalid value:" + << newDims.toString(); + } + dims_ = newDims; + dimensionInfo_ = "Specified explicitly in setDimensions()"; + } else { + NTA_THROW << "Attempt to set dimensions of region " << getName() << " to " + << newDims.toString() << " but region already has dimensions " + << dims_.toString(); } - void Region::setupEnabledNodeSet() - { - NTA_CHECK(dims_.isValid()); - - if (enabledNodes_ != nullptr) - { - delete enabledNodes_; - } + // can only create the enabled node set after we know the number of dimensions + setupEnabledNodeSet(); +} - size_t nnodes = dims_.getCount(); - enabledNodes_ = new NodeSet(nnodes); +void Region::setupEnabledNodeSet() { + NTA_CHECK(dims_.isValid()); - enabledNodes_->allOn(); + if (enabledNodes_ != nullptr) { + delete enabledNodes_; } - const NodeSet& Region::getEnabledNodes() const - { - if (enabledNodes_ == nullptr) - { - NTA_THROW << "Attempt to access enabled nodes set before region has been initialized"; - } - return *enabledNodes_; - } + size_t nnodes = dims_.getCount(); + enabledNodes_ = new NodeSet(nnodes); + enabledNodes_->allOn(); +} - void - Region::setDimensionInfo(const std::string& info) - { - dimensionInfo_ = info; +const NodeSet &Region::getEnabledNodes() const { + if (enabledNodes_ == nullptr) { + NTA_THROW << "Attempt to access enabled nodes set before region has been " + "initialized"; } + return *enabledNodes_; +} - const std::string& - Region::getDimensionInfo() const - { - return dimensionInfo_; - } +void Region::setDimensionInfo(const std::string &info) { + dimensionInfo_ = info; +} - void - Region::removeAllIncomingLinks() - { - InputMap::const_iterator i = inputs_.begin(); - for (; i != inputs_.end(); i++) - { - std::vector links = i->second->getLinks(); - for (auto & links_link : links) - { - i->second->removeLink(links_link); +const std::string &Region::getDimensionInfo() const { return dimensionInfo_; } - } +void Region::removeAllIncomingLinks() { + InputMap::const_iterator i = inputs_.begin(); + for (; i != inputs_.end(); i++) { + std::vector links = i->second->getLinks(); + for (auto &links_link : links) { + i->second->removeLink(links_link); } - } +} - void - Region::uninitialize() - { - initialized_ = false; - } +void Region::uninitialize() { initialized_ = false; } - void - Region::setPhases(std::set& phases) - { - phases_ = phases; - } +void Region::setPhases(std::set &phases) { phases_ = phases; } - std::set& - Region::getPhases() - { - return phases_; - } +std::set &Region::getPhases() { return phases_; } - void - Region::serializeImpl(BundleIO& bundle) - { - impl_->serialize(bundle); - } +void Region::serializeImpl(BundleIO &bundle) { impl_->serialize(bundle); } - void Region::write(RegionProto::Builder& proto) const - { - auto dimensionsProto = proto.initDimensions(dims_.size()); - for (UInt i = 0; i < dims_.size(); ++i) - { - dimensionsProto.set(i, dims_[i]); - } - auto phasesProto = proto.initPhases(phases_.size()); - UInt i = 0; - for (auto elem : phases_) - { - phasesProto.set(i++, elem); - } - proto.setNodeType(type_.c_str()); - auto implProto = proto.getRegionImpl(); - impl_->write(implProto); +void Region::write(RegionProto::Builder &proto) const { + auto dimensionsProto = proto.initDimensions(dims_.size()); + for (UInt i = 0; i < dims_.size(); ++i) { + dimensionsProto.set(i, dims_[i]); } - - void Region::read(RegionProto::Reader& proto) - { - dims_.clear(); - for (auto elem : proto.getDimensions()) - { - dims_.push_back(elem); - } - - phases_.clear(); - for (auto elem : proto.getPhases()) - { - phases_.insert(elem); - } - - auto implProto = proto.getRegionImpl(); - RegionImplFactory& factory = RegionImplFactory::getInstance(); - spec_ = factory.getSpec(type_); - impl_ = factory.deserializeRegionImpl( - proto.getNodeType().cStr(), implProto, this); + auto phasesProto = proto.initPhases(phases_.size()); + UInt i = 0; + for (auto elem : phases_) { + phasesProto.set(i++, elem); } + proto.setNodeType(type_.c_str()); + auto implProto = proto.getRegionImpl(); + impl_->write(implProto); +} - void - Region::enableProfiling() - { - profilingEnabled_ = true; +void Region::read(RegionProto::Reader &proto) { + dims_.clear(); + for (auto elem : proto.getDimensions()) { + dims_.push_back(elem); } - void - Region::disableProfiling() - { - profilingEnabled_ = false; + phases_.clear(); + for (auto elem : proto.getPhases()) { + phases_.insert(elem); } - void - Region::resetProfiling() - { - computeTimer_.reset(); - executeTimer_.reset(); - } + auto implProto = proto.getRegionImpl(); + RegionImplFactory &factory = RegionImplFactory::getInstance(); + spec_ = factory.getSpec(type_); + impl_ = factory.deserializeRegionImpl(proto.getNodeType().cStr(), implProto, + this); +} - const Timer& Region::getComputeTimer() const - { - return computeTimer_; - } +void Region::enableProfiling() { profilingEnabled_ = true; } - const Timer& Region::getExecuteTimer() const - { - return executeTimer_; - } +void Region::disableProfiling() { profilingEnabled_ = false; } +void Region::resetProfiling() { + computeTimer_.reset(); + executeTimer_.reset(); } + +const Timer &Region::getComputeTimer() const { return computeTimer_; } + +const Timer &Region::getExecuteTimer() const { return executeTimer_; } + +} // namespace nupic diff --git a/src/nupic/engine/Region.hpp b/src/nupic/engine/Region.hpp index a8739652c9..59161b47cb 100644 --- a/src/nupic/engine/Region.hpp +++ b/src/nupic/engine/Region.hpp @@ -25,15 +25,15 @@ * * A region is a set of one or more "identical" nodes, implemented by a * RegionImpl"plugin". A region contains nodes. -*/ + */ #ifndef NTA_REGION_HPP #define NTA_REGION_HPP -#include -#include #include #include +#include +#include // We need the full definitions because these // objects are returned by value. @@ -43,773 +43,704 @@ #include #include -namespace nupic -{ - - class RegionImpl; - class Output; - class Input; - class ArrayRef; - class Array; - struct Spec; - class NodeSet; - class BundleIO; - class Timer; - class Network; - class GenericRegisteredRegionImpl; - - /** - * Represents a set of one or more "identical" nodes in a Network. - * - * @nosubgrouping - * - * ### Constructors - * - * @note Region constructors are not available in the public API. - * Internally regions are created and owned by Network. - * - */ - class Region : public Serializable - { - public: - - /** - * @name Region information - * - * @{ - */ - - /** - * Get the network containing this region. - * - * @returns The network containing this region - */ - Network * - getNetwork(); - - /** - * Get the name of the region. - * - * @returns The region's name - */ - const std::string& - getName() const; - - - /** - * Get the dimensions of the region. - * - * @returns The region's dimensions - */ - const Dimensions& - getDimensions() const; - - /** - * Assign width and height to the region. - * - * @param dimensions - * A Dimensions object that describes the width and height - */ - void - setDimensions(Dimensions & dimensions); - - /** - * @} - * - * @name Element interface methods - * - * @todo What does "Element interface methods" mean here? - * - * @{ - * - */ - - /** - * Get the type of the region. - * - * @returns The node type as a string - */ - const std::string& - getType() const; - - /** - * Get the spec of the region. - * - * @returns The spec that describes this region - */ - const Spec* - getSpec() const; - - /** - * Get the Spec of a region type without an instance. - * - * @param nodeType - * A region type as a string - * - * @returns The Spec that describes this region type - */ - static const Spec* - getSpecFromType(const std::string& nodeType); - - /* - * Adds a Python module and class to the RegionImplFactory's regions - */ - static void registerPyRegion(const std::string module, const std::string className); - - /* - * Adds a cpp region to the RegionImplFactory's packages - */ - static void registerCPPRegion(const std::string name, GenericRegisteredRegionImpl* wrapper); - - /* - * Removes a Python module and class from the RegionImplFactory's regions - */ - static void unregisterPyRegion(const std::string className); - - /* - * Removes a cpp region from the RegionImplFactory's packages - */ - static void unregisterCPPRegion(const std::string name); - - - /** - * @} - * - * @name Parameter getters and setters - * - * @{ - * - */ - - /** - * Get the parameter as an @c Int32 value. - * - * @param name - * The name of the parameter - * - * @returns The value of the parameter - */ - Int32 - getParameterInt32(const std::string& name) const; - - /** - * Get the parameter as an @c UInt32 value. - * - * @param name - * The name of the parameter - * - * @returns The value of the parameter - */ - UInt32 - getParameterUInt32(const std::string& name) const; - - /** - * Get the parameter as an @c Int64 value. - * - * @param name - * The name of the parameter - * - * @returns The value of the parameter - */ - Int64 - getParameterInt64(const std::string& name) const; - - /** - * Get the parameter as an @c UInt64 value. - * - * @param name - * The name of the parameter - * - * @returns The value of the parameter - */ - UInt64 - getParameterUInt64(const std::string& name) const; - - /** - * Get the parameter as an @c Real32 value. - * - * @param name - * The name of the parameter - * - * @returns The value of the parameter - */ - Real32 - getParameterReal32(const std::string& name) const; - - /** - * Get the parameter as an @c Real64 value. - * - * @param name - * The name of the parameter - * - * @returns The value of the parameter - */ - Real64 - getParameterReal64(const std::string& name) const; - - /** - * Get the parameter as an @c Handle value. - * - * @param name - * The name of the parameter - * - * @returns The value of the parameter - */ - Handle - getParameterHandle(const std::string& name) const; - - /** - * Get a bool parameter. - * - * @param name - * The name of the parameter - * - * @returns The value of the parameter - */ - bool - getParameterBool(const std::string& name) const; - - /** - * Set the parameter to an Int32 value. - * - * @param name - * The name of the parameter - * - * @param value - * The value of the parameter - */ - void - setParameterInt32(const std::string& name, Int32 value); - - /** - * Set the parameter to an UInt32 value. - * - * @param name - * The name of the parameter - * - * @param value - * The value of the parameter - */ - void - setParameterUInt32(const std::string& name, UInt32 value); - - /** - * Set the parameter to an Int64 value. - * - * @param name - * The name of the parameter - * - * @param value - * The value of the parameter - */ - void - setParameterInt64(const std::string& name, Int64 value); - - /** - * Set the parameter to an UInt64 value. - * - * @param name - * The name of the parameter - * - * @param value - * The value of the parameter - */ - void - setParameterUInt64(const std::string& name, UInt64 value); - - /** - * Set the parameter to a Real32 value. - * - * @param name - * The name of the parameter - * - * @param value - * The value of the parameter - */ - void - setParameterReal32(const std::string& name, Real32 value); - - /** - * Set the parameter to a Real64 value. - * - * @param name - * The name of the parameter - * - * @param value - * The value of the parameter - */ - void - setParameterReal64(const std::string& name, Real64 value); - - /** - * Set the parameter to a Handle value. - * - * @param name - * The name of the parameter - * - * @param value - * The value of the parameter - */ - void - setParameterHandle(const std::string& name, Handle value); - - /** - * Set the parameter to a bool value. - * - * @param name - * The name of the parameter - * - * @param value - * The value of the parameter - */ - void - setParameterBool(const std::string& name, bool value); - - /** - * Get the parameter as an @c Array value. - * - * @param name - * The name of the parameter - * - * @param[out] array - * The value of the parameter - * - * @a array is a memory buffer. If the buffer is allocated, - * the value is copied into the supplied buffer; otherwise - * @a array would be asked to allocate the buffer and copy into it. - * - * A typical use might be that the caller would supply an - * unallocated buffer on the first call and then reuse the memory - * buffer on subsequent calls, i.e. - * - * @code{.cpp} - * - * { - * // no buffer allocated - * Array buffer(NTA_BasicTypeInt64); - * - * // buffer is allocated, and owned by Array object - * getParameterArray("foo", buffer); - * - * // uses already-allocated buffer - * getParameterArray("foo", buffer); - * - * } // Array destructor called -- frees the buffer - * @endcode - * - * Throws an exception if the supplied @a array is not big enough. - * - */ - void - getParameterArray(const std::string& name, Array & array) const; - - /** - * Set the parameter to an @c Array value. - * - * @param name - * The name of the parameter - * - * @param array - * The value of the parameter - * - * - * @note @a array must be initialized before calling setParameterArray(). - * - */ - void - setParameterArray(const std::string& name, const Array & array); - - /** - * Set the parameter to a @c std::string value. - * - * @param name - * The name of the parameter - * - * @param s - * The value of the parameter - * - * Strings are handled internally as Byte Arrays, but this interface - * is clumsy. setParameterString() and getParameterString() internally use - * byte arrays but converts to/from strings. - * - * setParameterString() is implemented with one copy (from the string into - * the node) but getParameterString() requires a second copy so that there - * are temporarily three copies of the data in memory (in the node, - * in an internal Array object, and in the string returned to the user) - * - */ - void - setParameterString(const std::string& name, const std::string& s); - - /** - * Get the parameter as a @c std::string value. - * - * @param name - * The name of the parameter - * - * @returns - * The value of the parameter - * - * @see setParameterString() - */ - std::string - getParameterString(const std::string& name); - - /** - * Tells whether the parameter is shared. - * - * @param name - * The name of the parameter - * - * @returns - * Whether the parameter is shared - * - * @todo figure out what "shared" means here - * - * @note This method must be overridden by subclasses. - * - * Throws an exception if it's not overridden - */ - bool - isParameterShared(const std::string& name) const; - - /** - * @} - * - * @name Inputs and outputs - * - * @{ - * - */ - - /** - * Copies data into the inputs of this region, using - * the links that are attached to each input. - */ - void - prepareInputs(); - - /** - * Get the input data. - * - * - * @param inputName - * The name of the target input - * - * @returns An @c ArrayRef that contains the input data. - * - * @internal - * - * @note The data is either stored in the - * the @c ArrayRef or point to the internal stored data, - * the actual behavior is controlled by the 'copy' argument (see below). - * - * @todo what's the copy' argument mentioned here? - * - * @endinternal - * - */ - virtual ArrayRef - getInputData(const std::string& inputName) const; - - /** - * Get the output data. - * - * @param outputName - * The name of the target output - * - * @returns - * An @c ArrayRef that contains the output data. - * - * @internal - * - * @note The data is either stored in the - * the @c ArrayRef or point to the internal stored data, - * the actual behavior is controlled by the 'copy' argument (see below). - * - * @todo what's the copy' argument mentioned here? - * - * @endinternal - * - */ - virtual ArrayRef - getOutputData(const std::string& outputName) const; - - /** - * Get the count of input data. - * - * @param inputName - * The name of the target input - * - * @returns - * The count of input data - * - * @todo are getOutput/InputCount needed? count can be obtained from the array objects. - * - */ - virtual size_t - getInputCount(const std::string& inputName) const; - - /** - * Get the count of output data. - * - * @param outputName - * The name of the target output - * - * @returns - * The count of output data - * - * @todo are getOutput/InputCount needed? count can be obtained from the array objects. - * - */ - virtual size_t - getOutputCount(const std::string& outputName) const; - - /** - * @} - * - * @name Operations - * - * @{ - * - */ - - /** - * @todo Region::enable() not implemented, should it be part of API at all? - */ - virtual void - enable(); - - /** - * @todo Region::disable() not implemented, should it be part of API at all? - */ - virtual void - disable(); - - /** - * Request the underlying region to execute a command. - * - * @param args - * A list of strings that the actual region will interpret. - * The first string is the command name. The other arguments are optional. - * - * @returns - * The result value of command execution is a string determined - * by the underlying region. - */ - virtual std::string - executeCommand(const std::vector& args); - - /** - * Perform one step of the region computation. - */ - void - compute(); - - /** - * @} - * - * @name Profiling - * - * @{ - * - */ - - /** - * Enable profiling of the compute and execute operations - */ - void - enableProfiling(); - - /** - * Disable profiling of the compute and execute operations - */ - void - disableProfiling(); - - /** - * Reset the compute and execute timers - */ - void - resetProfiling(); - - /** - * Get the timer used to profile the compute operation. - * - * @returns - * The Timer object used to profile the compute operation - */ - const Timer& getComputeTimer() const; - - /** - * Get the timer used to profile the execute operation. - * - * @returns - * The Timer object used to profile the execute operation - */ - const Timer& getExecuteTimer() const; - - /** - * @} - */ +namespace nupic { + +class RegionImpl; +class Output; +class Input; +class ArrayRef; +class Array; +struct Spec; +class NodeSet; +class BundleIO; +class Timer; +class Network; +class GenericRegisteredRegionImpl; + +/** + * Represents a set of one or more "identical" nodes in a Network. + * + * @nosubgrouping + * + * ### Constructors + * + * @note Region constructors are not available in the public API. + * Internally regions are created and owned by Network. + * + */ +class Region : public Serializable { +public: + /** + * @name Region information + * + * @{ + */ -#ifdef NTA_INTERNAL - // Internal methods. + /** + * Get the network containing this region. + * + * @returns The network containing this region + */ + Network *getNetwork(); - // New region from parameter spec - Region(std::string name, - const std::string& type, - const std::string& nodeParams, - Network * network = nullptr); + /** + * Get the name of the region. + * + * @returns The region's name + */ + const std::string &getName() const; - // New region from serialized state - Region(std::string name, - const std::string& type, - const Dimensions& dimensions, - BundleIO& bundle, - Network * network = nullptr); + /** + * Get the dimensions of the region. + * + * @returns The region's dimensions + */ + const Dimensions &getDimensions() const; - // New region from capnp struct - Region(std::string name, RegionProto::Reader& proto, - Network* network=nullptr); + /** + * Assign width and height to the region. + * + * @param dimensions + * A Dimensions object that describes the width and height + */ + void setDimensions(Dimensions &dimensions); - virtual ~Region(); + /** + * @} + * + * @name Element interface methods + * + * @todo What does "Element interface methods" mean here? + * + * @{ + * + */ - void - initialize(); + /** + * Get the type of the region. + * + * @returns The node type as a string + */ + const std::string &getType() const; - bool - isInitialized() const; + /** + * Get the spec of the region. + * + * @returns The spec that describes this region + */ + const Spec *getSpec() const; + /** + * Get the Spec of a region type without an instance. + * + * @param nodeType + * A region type as a string + * + * @returns The Spec that describes this region type + */ + static const Spec *getSpecFromType(const std::string &nodeType); + /* + * Adds a Python module and class to the RegionImplFactory's regions + */ + static void registerPyRegion(const std::string module, + const std::string className); - // Used by RegionImpl to get inputs/outputs - Output* - getOutput(const std::string& name) const; + /* + * Adds a cpp region to the RegionImplFactory's packages + */ + static void registerCPPRegion(const std::string name, + GenericRegisteredRegionImpl *wrapper); - Input* - getInput(const std::string& name) const; + /* + * Removes a Python module and class from the RegionImplFactory's regions + */ + static void unregisterPyRegion(const std::string className); - // These are used only for serialization - const std::map& - getInputs() const; + /* + * Removes a cpp region from the RegionImplFactory's packages + */ + static void unregisterCPPRegion(const std::string name); - const std::map& - getOutputs() const; + /** + * @} + * + * @name Parameter getters and setters + * + * @{ + * + */ - // The following methods are called by Network in initialization + /** + * Get the parameter as an @c Int32 value. + * + * @param name + * The name of the parameter + * + * @returns The value of the parameter + */ + Int32 getParameterInt32(const std::string &name) const; - // Returns number of links that could not be fully evaluated - size_t - evaluateLinks(); + /** + * Get the parameter as an @c UInt32 value. + * + * @param name + * The name of the parameter + * + * @returns The value of the parameter + */ + UInt32 getParameterUInt32(const std::string &name) const; - std::string - getLinkErrors() const; + /** + * Get the parameter as an @c Int64 value. + * + * @param name + * The name of the parameter + * + * @returns The value of the parameter + */ + Int64 getParameterInt64(const std::string &name) const; - size_t - getNodeOutputElementCount(const std::string& name); + /** + * Get the parameter as an @c UInt64 value. + * + * @param name + * The name of the parameter + * + * @returns The value of the parameter + */ + UInt64 getParameterUInt64(const std::string &name) const; - void - initOutputs(); + /** + * Get the parameter as an @c Real32 value. + * + * @param name + * The name of the parameter + * + * @returns The value of the parameter + */ + Real32 getParameterReal32(const std::string &name) const; - void - initInputs() const; + /** + * Get the parameter as an @c Real64 value. + * + * @param name + * The name of the parameter + * + * @returns The value of the parameter + */ + Real64 getParameterReal64(const std::string &name) const; - void - intialize(); + /** + * Get the parameter as an @c Handle value. + * + * @param name + * The name of the parameter + * + * @returns The value of the parameter + */ + Handle getParameterHandle(const std::string &name) const; - // Internal -- for link debugging - void - setDimensionInfo(const std::string& info); + /** + * Get a bool parameter. + * + * @param name + * The name of the parameter + * + * @returns The value of the parameter + */ + bool getParameterBool(const std::string &name) const; - const std::string& - getDimensionInfo() const; + /** + * Set the parameter to an Int32 value. + * + * @param name + * The name of the parameter + * + * @param value + * The value of the parameter + */ + void setParameterInt32(const std::string &name, Int32 value); - bool - hasOutgoingLinks() const; + /** + * Set the parameter to an UInt32 value. + * + * @param name + * The name of the parameter + * + * @param value + * The value of the parameter + */ + void setParameterUInt32(const std::string &name, UInt32 value); - // These methods are needed for teardown choreography - // in Network::~Network() - // It is an error to call any region methods after uninitialize() - // except removeAllIncomingLinks and ~Region - void - uninitialize(); + /** + * Set the parameter to an Int64 value. + * + * @param name + * The name of the parameter + * + * @param value + * The value of the parameter + */ + void setParameterInt64(const std::string &name, Int64 value); - void - removeAllIncomingLinks(); + /** + * Set the parameter to an UInt64 value. + * + * @param name + * The name of the parameter + * + * @param value + * The value of the parameter + */ + void setParameterUInt64(const std::string &name, UInt64 value); - const NodeSet& - getEnabledNodes() const; + /** + * Set the parameter to a Real32 value. + * + * @param name + * The name of the parameter + * + * @param value + * The value of the parameter + */ + void setParameterReal32(const std::string &name, Real32 value); - // TODO: sort our phases api. Users should never call Region::setPhases - // and it is here for serialization only. - void - setPhases(std::set& phases); + /** + * Set the parameter to a Real64 value. + * + * @param name + * The name of the parameter + * + * @param value + * The value of the parameter + */ + void setParameterReal64(const std::string &name, Real64 value); - std::set& - getPhases(); + /** + * Set the parameter to a Handle value. + * + * @param name + * The name of the parameter + * + * @param value + * The value of the parameter + */ + void setParameterHandle(const std::string &name, Handle value); - // Called by Network for serialization - void - serializeImpl(BundleIO& bundle); + /** + * Set the parameter to a bool value. + * + * @param name + * The name of the parameter + * + * @param value + * The value of the parameter + */ + void setParameterBool(const std::string &name, bool value); - using Serializable::write; - void write(RegionProto::Builder& proto) const; + /** + * Get the parameter as an @c Array value. + * + * @param name + * The name of the parameter + * + * @param[out] array + * The value of the parameter + * + * @a array is a memory buffer. If the buffer is allocated, + * the value is copied into the supplied buffer; otherwise + * @a array would be asked to allocate the buffer and copy into it. + * + * A typical use might be that the caller would supply an + * unallocated buffer on the first call and then reuse the memory + * buffer on subsequent calls, i.e. + * + * @code{.cpp} + * + * { + * // no buffer allocated + * Array buffer(NTA_BasicTypeInt64); + * + * // buffer is allocated, and owned by Array object + * getParameterArray("foo", buffer); + * + * // uses already-allocated buffer + * getParameterArray("foo", buffer); + * + * } // Array destructor called -- frees the buffer + * @endcode + * + * Throws an exception if the supplied @a array is not big enough. + * + */ + void getParameterArray(const std::string &name, Array &array) const; - using Serializable::read; - void read(RegionProto::Reader& proto); + /** + * Set the parameter to an @c Array value. + * + * @param name + * The name of the parameter + * + * @param array + * The value of the parameter + * + * + * @note @a array must be initialized before calling setParameterArray(). + * + */ + void setParameterArray(const std::string &name, const Array &array); + /** + * Set the parameter to a @c std::string value. + * + * @param name + * The name of the parameter + * + * @param s + * The value of the parameter + * + * Strings are handled internally as Byte Arrays, but this interface + * is clumsy. setParameterString() and getParameterString() internally use + * byte arrays but converts to/from strings. + * + * setParameterString() is implemented with one copy (from the string into + * the node) but getParameterString() requires a second copy so that there + * are temporarily three copies of the data in memory (in the node, + * in an internal Array object, and in the string returned to the user) + * + */ + void setParameterString(const std::string &name, const std::string &s); -#endif // NTA_INTERNAL + /** + * Get the parameter as a @c std::string value. + * + * @param name + * The name of the parameter + * + * @returns + * The value of the parameter + * + * @see setParameterString() + */ + std::string getParameterString(const std::string &name); - private: - // verboten - Region(); - Region(Region&); + /** + * Tells whether the parameter is shared. + * + * @param name + * The name of the parameter + * + * @returns + * Whether the parameter is shared + * + * @todo figure out what "shared" means here + * + * @note This method must be overridden by subclasses. + * + * Throws an exception if it's not overridden + */ + bool isParameterShared(const std::string &name) const; - // common method used by both constructors - // Can be called after nodespec_ has been set. - void createInputsAndOutputs_(); + /** + * @} + * + * @name Inputs and outputs + * + * @{ + * + */ - const std::string name_; + /** + * Copies data into the inputs of this region, using + * the links that are attached to each input. + */ + void prepareInputs(); - // pointer to the "plugin"; owned by Region - RegionImpl* impl_; - const std::string type_; - Spec* spec_; + /** + * Get the input data. + * + * + * @param inputName + * The name of the target input + * + * @returns An @c ArrayRef that contains the input data. + * + * @internal + * + * @note The data is either stored in the + * the @c ArrayRef or point to the internal stored data, + * the actual behavior is controlled by the 'copy' argument (see below). + * + * @todo what's the copy' argument mentioned here? + * + * @endinternal + * + */ + virtual ArrayRef getInputData(const std::string &inputName) const; - typedef std::map OutputMap; - typedef std::map InputMap; + /** + * Get the output data. + * + * @param outputName + * The name of the target output + * + * @returns + * An @c ArrayRef that contains the output data. + * + * @internal + * + * @note The data is either stored in the + * the @c ArrayRef or point to the internal stored data, + * the actual behavior is controlled by the 'copy' argument (see below). + * + * @todo what's the copy' argument mentioned here? + * + * @endinternal + * + */ + virtual ArrayRef getOutputData(const std::string &outputName) const; - OutputMap outputs_; - InputMap inputs_; - // used for serialization only - std::set phases_; - Dimensions dims_; // topology of nodes; starts as [] - bool initialized_; + /** + * Get the count of input data. + * + * @param inputName + * The name of the target input + * + * @returns + * The count of input data + * + * @todo are getOutput/InputCount needed? count can be obtained from the array + * objects. + * + */ + virtual size_t getInputCount(const std::string &inputName) const; - NodeSet* enabledNodes_; + /** + * Get the count of output data. + * + * @param outputName + * The name of the target output + * + * @returns + * The count of output data + * + * @todo are getOutput/InputCount needed? count can be obtained from the array + * objects. + * + */ + virtual size_t getOutputCount(const std::string &outputName) const; - // Region contains a backpointer to network_ only to be able - // to retrieve the containing network via getNetwork() for inspectors. - // The implementation should not use network_ in any other methods. - Network* network_; + /** + * @} + * + * @name Operations + * + * @{ + * + */ - // Figuring out how a region's dimensions were set - // can be difficult because any link can induce - // dimensions. This field says how a region's dimensions - // were set. - std::string dimensionInfo_; + /** + * @todo Region::enable() not implemented, should it be part of API at all? + */ + virtual void enable(); - // private helper methods - void setupEnabledNodeSet(); + /** + * @todo Region::disable() not implemented, should it be part of API at all? + */ + virtual void disable(); + + /** + * Request the underlying region to execute a command. + * + * @param args + * A list of strings that the actual region will interpret. + * The first string is the command name. The other arguments are + * optional. + * + * @returns + * The result value of command execution is a string determined + * by the underlying region. + */ + virtual std::string executeCommand(const std::vector &args); + + /** + * Perform one step of the region computation. + */ + void compute(); + + /** + * @} + * + * @name Profiling + * + * @{ + * + */ + + /** + * Enable profiling of the compute and execute operations + */ + void enableProfiling(); + + /** + * Disable profiling of the compute and execute operations + */ + void disableProfiling(); + + /** + * Reset the compute and execute timers + */ + void resetProfiling(); + + /** + * Get the timer used to profile the compute operation. + * + * @returns + * The Timer object used to profile the compute operation + */ + const Timer &getComputeTimer() const; + + /** + * Get the timer used to profile the execute operation. + * + * @returns + * The Timer object used to profile the execute operation + */ + const Timer &getExecuteTimer() const; + + /** + * @} + */ + +#ifdef NTA_INTERNAL + // Internal methods. + + // New region from parameter spec + Region(std::string name, const std::string &type, + const std::string &nodeParams, Network *network = nullptr); + + // New region from serialized state + Region(std::string name, const std::string &type, + const Dimensions &dimensions, BundleIO &bundle, + Network *network = nullptr); + + // New region from capnp struct + Region(std::string name, RegionProto::Reader &proto, + Network *network = nullptr); + + virtual ~Region(); + + void initialize(); + + bool isInitialized() const; + + // Used by RegionImpl to get inputs/outputs + Output *getOutput(const std::string &name) const; + + Input *getInput(const std::string &name) const; + + // These are used only for serialization + const std::map &getInputs() const; + + const std::map &getOutputs() const; + + // The following methods are called by Network in initialization + + // Returns number of links that could not be fully evaluated + size_t evaluateLinks(); + + std::string getLinkErrors() const; + size_t getNodeOutputElementCount(const std::string &name); + + void initOutputs(); + + void initInputs() const; + + void intialize(); + + // Internal -- for link debugging + void setDimensionInfo(const std::string &info); + + const std::string &getDimensionInfo() const; + + bool hasOutgoingLinks() const; + + // These methods are needed for teardown choreography + // in Network::~Network() + // It is an error to call any region methods after uninitialize() + // except removeAllIncomingLinks and ~Region + void uninitialize(); + + void removeAllIncomingLinks(); + + const NodeSet &getEnabledNodes() const; + + // TODO: sort our phases api. Users should never call Region::setPhases + // and it is here for serialization only. + void setPhases(std::set &phases); + + std::set &getPhases(); + + // Called by Network for serialization + void serializeImpl(BundleIO &bundle); + + using Serializable::write; + void write(RegionProto::Builder &proto) const; + + using Serializable::read; + void read(RegionProto::Reader &proto); + +#endif // NTA_INTERNAL - // Profiling related methods and variables. - bool profilingEnabled_; - Timer computeTimer_; - Timer executeTimer_; - }; +private: + // verboten + Region(); + Region(Region &); + + // common method used by both constructors + // Can be called after nodespec_ has been set. + void createInputsAndOutputs_(); + + const std::string name_; + + // pointer to the "plugin"; owned by Region + RegionImpl *impl_; + const std::string type_; + Spec *spec_; + + typedef std::map OutputMap; + typedef std::map InputMap; + + OutputMap outputs_; + InputMap inputs_; + // used for serialization only + std::set phases_; + Dimensions dims_; // topology of nodes; starts as [] + bool initialized_; + + NodeSet *enabledNodes_; + + // Region contains a backpointer to network_ only to be able + // to retrieve the containing network via getNetwork() for inspectors. + // The implementation should not use network_ in any other methods. + Network *network_; + + // Figuring out how a region's dimensions were set + // can be difficult because any link can induce + // dimensions. This field says how a region's dimensions + // were set. + std::string dimensionInfo_; + + // private helper methods + void setupEnabledNodeSet(); + + // Profiling related methods and variables. + bool profilingEnabled_; + Timer computeTimer_; + Timer executeTimer_; +}; } // namespace nupic diff --git a/src/nupic/engine/RegionImpl.cpp b/src/nupic/engine/RegionImpl.cpp index b38057bd3d..7c6ae892f3 100644 --- a/src/nupic/engine/RegionImpl.cpp +++ b/src/nupic/engine/RegionImpl.cpp @@ -37,35 +37,21 @@ #include #include -namespace nupic -{ +namespace nupic { -RegionImpl::RegionImpl(Region *region) : - region_(region) -{ -} +RegionImpl::RegionImpl(Region *region) : region_(region) {} -RegionImpl::~RegionImpl() -{ -} +RegionImpl::~RegionImpl() {} // convenience method -const std::string& RegionImpl::getType() const -{ - return region_->getType(); -} +const std::string &RegionImpl::getType() const { return region_->getType(); } -const std::string& RegionImpl::getName() const -{ - return region_->getName(); -} +const std::string &RegionImpl::getName() const { return region_->getName(); } -const NodeSet& RegionImpl::getEnabledNodes() const -{ +const NodeSet &RegionImpl::getEnabledNodes() const { return region_->getEnabledNodes(); } - /* ------------- Parameter support --------------- */ // By default, all typed getParameter calls forward to the // untyped getParameter that serializes to a buffer @@ -75,165 +61,155 @@ const NodeSet& RegionImpl::getEnabledNodes() const // templated methods can't be virtual and thus can't be // overridden by subclasses. -#define getParameterInternalT(MethodT,Type) \ -Type RegionImpl::getParameter##MethodT(const std::string& name, Int64 index) \ -{\ - if (! region_->getSpec()->parameters.contains(name)) \ - NTA_THROW << "getParameter" #Type ": parameter " << name << " does not exist in nodespec"; \ - ParameterSpec p = region_->getSpec()->parameters.getByName(name); \ - if (p.dataType != NTA_BasicType_ ## MethodT) \ - NTA_THROW << "getParameter" #Type ": parameter " << name << " is of type " \ - << BasicType::getName(p.dataType) << " not " #Type; \ - WriteBuffer wb; \ - getParameterFromBuffer(name, index, wb); \ - ReadBuffer rb(wb.getData(), wb.getSize(), false /* copy */); \ - Type val; \ - int rc = rb.read(val); \ - if (rc != 0) \ - { \ - NTA_THROW << "getParameter" #Type " -- failure to get parameter '" \ - << name << "' on node of type " << getType(); \ - } \ - return val; \ -} +#define getParameterInternalT(MethodT, Type) \ + Type RegionImpl::getParameter##MethodT(const std::string &name, \ + Int64 index) { \ + if (!region_->getSpec()->parameters.contains(name)) \ + NTA_THROW << "getParameter" #Type ": parameter " << name \ + << " does not exist in nodespec"; \ + ParameterSpec p = region_->getSpec()->parameters.getByName(name); \ + if (p.dataType != NTA_BasicType_##MethodT) \ + NTA_THROW << "getParameter" #Type ": parameter " << name \ + << " is of type " << BasicType::getName(p.dataType) \ + << " not " #Type; \ + WriteBuffer wb; \ + getParameterFromBuffer(name, index, wb); \ + ReadBuffer rb(wb.getData(), wb.getSize(), false /* copy */); \ + Type val; \ + int rc = rb.read(val); \ + if (rc != 0) { \ + NTA_THROW << "getParameter" #Type " -- failure to get parameter '" \ + << name << "' on node of type " << getType(); \ + } \ + return val; \ + } -#define getParameterT(Type) getParameterInternalT(Type,Type) +#define getParameterT(Type) getParameterInternalT(Type, Type) getParameterT(Int32); getParameterT(UInt32); getParameterT(Int64); -getParameterT(UInt64) -getParameterT(Real32); +getParameterT(UInt64) getParameterT(Real32); getParameterT(Real64); getParameterInternalT(Bool, bool); +#define setParameterInternalT(MethodT, Type) \ + void RegionImpl::setParameter##MethodT(const std::string &name, Int64 index, \ + Type value) { \ + WriteBuffer wb; \ + wb.write((Type)value); \ + ReadBuffer rb(wb.getData(), wb.getSize(), false /* copy */); \ + setParameterFromBuffer(name, index, rb); \ + } - -#define setParameterInternalT(MethodT, Type) \ -void RegionImpl::setParameter##MethodT(const std::string& name, Int64 index, Type value) \ -{ \ - WriteBuffer wb; \ - wb.write((Type)value); \ - ReadBuffer rb(wb.getData(), wb.getSize(), false /* copy */); \ - setParameterFromBuffer(name, index, rb); \ -} - -#define setParameterT(Type) setParameterInternalT(Type,Type) +#define setParameterT(Type) setParameterInternalT(Type, Type) setParameterT(Int32); setParameterT(UInt32); setParameterT(Int64); -setParameterT(UInt64) -setParameterT(Real32); +setParameterT(UInt64) setParameterT(Real32); setParameterT(Real64); setParameterInternalT(Bool, bool); -// buffer mechanism can't handle Handles. RegionImpl must override these methods. -Handle RegionImpl::getParameterHandle(const std::string& name, Int64 index) -{ - NTA_THROW << "Unknown parameter '" << name << "' of type Handle."; +// buffer mechanism can't handle Handles. RegionImpl must override these +// methods. +Handle RegionImpl::getParameterHandle(const std::string &name, Int64 index) { + NTA_THROW << "Unknown parameter '" << name << "' of type Handle."; } -void RegionImpl::setParameterHandle(const std::string& name, Int64 index, Handle h) -{ - NTA_THROW << "Unknown parameter '" << name << "' of type Handle."; +void RegionImpl::setParameterHandle(const std::string &name, Int64 index, + Handle h) { + NTA_THROW << "Unknown parameter '" << name << "' of type Handle."; } - - -void RegionImpl::getParameterArray(const std::string& name, Int64 index, Array & array) -{ +void RegionImpl::getParameterArray(const std::string &name, Int64 index, + Array &array) { WriteBuffer wb; getParameterFromBuffer(name, index, wb); ReadBuffer rb(wb.getData(), wb.getSize(), false /* copy */); size_t count = array.getCount(); void *buffer = array.getBuffer(); - for (size_t i = 0; i < count; i++) - { + for (size_t i = 0; i < count; i++) { int rc; - switch (array.getType()) - { + switch (array.getType()) { case NTA_BasicType_Byte: - rc = rb.read(((Byte*)buffer)[i]); + rc = rb.read(((Byte *)buffer)[i]); break; case NTA_BasicType_Int32: - rc = rb.read(((Int32*)buffer)[i]); + rc = rb.read(((Int32 *)buffer)[i]); break; case NTA_BasicType_UInt32: - rc = rb.read(((UInt32*)buffer)[i]); + rc = rb.read(((UInt32 *)buffer)[i]); break; case NTA_BasicType_Int64: - rc = rb.read(((Int64*)buffer)[i]); + rc = rb.read(((Int64 *)buffer)[i]); break; case NTA_BasicType_UInt64: - rc = rb.read(((UInt64*)buffer)[i]); + rc = rb.read(((UInt64 *)buffer)[i]); break; case NTA_BasicType_Real32: - rc = rb.read(((Real32*)buffer)[i]); + rc = rb.read(((Real32 *)buffer)[i]); break; case NTA_BasicType_Real64: - rc = rb.read(((Real64*)buffer)[i]); + rc = rb.read(((Real64 *)buffer)[i]); break; default: - NTA_THROW << "Unsupported basic type " << BasicType::getName(array.getType()) + NTA_THROW << "Unsupported basic type " + << BasicType::getName(array.getType()) << " in getParameterArray for parameter " << name; break; } - if (rc != 0) - { - NTA_THROW << "getParameterArray -- failure to get parameter '" - << name << "' on node of type " << getType(); + if (rc != 0) { + NTA_THROW << "getParameterArray -- failure to get parameter '" << name + << "' on node of type " << getType(); } } return; } - -void RegionImpl::setParameterArray(const std::string& name, Int64 index,const Array & array) -{ +void RegionImpl::setParameterArray(const std::string &name, Int64 index, + const Array &array) { WriteBuffer wb; size_t count = array.getCount(); void *buffer = array.getBuffer(); - for (size_t i = 0; i < count; i++) - { + for (size_t i = 0; i < count; i++) { int rc; - switch (array.getType()) - { + switch (array.getType()) { case NTA_BasicType_Byte: - rc = wb.write(((Byte*)buffer)[i]); + rc = wb.write(((Byte *)buffer)[i]); break; case NTA_BasicType_Int32: - rc = wb.write(((Int32*)buffer)[i]); + rc = wb.write(((Int32 *)buffer)[i]); break; case NTA_BasicType_UInt32: - rc = wb.write(((UInt32*)buffer)[i]); + rc = wb.write(((UInt32 *)buffer)[i]); break; case NTA_BasicType_Int64: - rc = wb.write(((Int64*)buffer)[i]); + rc = wb.write(((Int64 *)buffer)[i]); break; case NTA_BasicType_UInt64: - rc = wb.write(((UInt64*)buffer)[i]); + rc = wb.write(((UInt64 *)buffer)[i]); break; case NTA_BasicType_Real32: - rc = wb.write(((Real32*)buffer)[i]); + rc = wb.write(((Real32 *)buffer)[i]); break; case NTA_BasicType_Real64: - rc = wb.write(((Real64*)buffer)[i]); + rc = wb.write(((Real64 *)buffer)[i]); break; default: - NTA_THROW << "Unsupported basic type " << BasicType::getName(array.getType()) + NTA_THROW << "Unsupported basic type " + << BasicType::getName(array.getType()) << " in setParameterArray for parameter " << name; break; } - if (rc != 0) - { - NTA_THROW << "setParameterArray - failure to set parameter '" << name << - "' on node of type " << getType(); + if (rc != 0) { + NTA_THROW << "setParameterArray - failure to set parameter '" << name + << "' on node of type " << getType(); } } @@ -241,59 +217,50 @@ void RegionImpl::setParameterArray(const std::string& name, Int64 index,const Ar setParameterFromBuffer(name, index, rb); } - -void RegionImpl::setParameterString(const std::string& name, Int64 index, const std::string& s) -{ +void RegionImpl::setParameterString(const std::string &name, Int64 index, + const std::string &s) { ReadBuffer rb(s.c_str(), s.size(), false); setParameterFromBuffer(name, index, rb); } -std::string RegionImpl::getParameterString(const std::string& name, Int64 index) -{ +std::string RegionImpl::getParameterString(const std::string &name, + Int64 index) { WriteBuffer wb; getParameterFromBuffer(name, index, wb); return std::string(wb.getData(), wb.getSize()); } - // Must be overridden by subclasses -bool RegionImpl::isParameterShared(const std::string& name) -{ - NTA_THROW << "RegionImpl::isParameterShared was not overridden in node type " << getType(); +bool RegionImpl::isParameterShared(const std::string &name) { + NTA_THROW << "RegionImpl::isParameterShared was not overridden in node type " + << getType(); } -void RegionImpl::getParameterFromBuffer(const std::string& name, - Int64 index, - IWriteBuffer& value) -{ - NTA_THROW << "RegionImpl::getParameterFromBuffer must be overridden by subclasses"; +void RegionImpl::getParameterFromBuffer(const std::string &name, Int64 index, + IWriteBuffer &value) { + NTA_THROW + << "RegionImpl::getParameterFromBuffer must be overridden by subclasses"; } -void RegionImpl::setParameterFromBuffer(const std::string& name, - Int64 index, - IReadBuffer& value) -{ - NTA_THROW << "RegionImpl::setParameterFromBuffer must be overridden by subclasses"; +void RegionImpl::setParameterFromBuffer(const std::string &name, Int64 index, + IReadBuffer &value) { + NTA_THROW + << "RegionImpl::setParameterFromBuffer must be overridden by subclasses"; } - - -size_t RegionImpl::getParameterArrayCount(const std::string& name, Int64 index) -{ +size_t RegionImpl::getParameterArrayCount(const std::string &name, + Int64 index) { // Default implementation for RegionImpls with no array parameters // that have a dynamic length. - //std::map::iterator i = nodespec_->parameters.find(name); - //if (i == nodespec_->parameters.end()) - + // std::map::iterator i = + // nodespec_->parameters.find(name); if (i == nodespec_->parameters.end()) - if (!region_->getSpec()->parameters.contains(name)) - { - NTA_THROW << "getParameterArrayCount -- no parameter named '" - << name << "' in node of type " << getType(); + if (!region_->getSpec()->parameters.contains(name)) { + NTA_THROW << "getParameterArrayCount -- no parameter named '" << name + << "' in node of type " << getType(); } UInt32 count = region_->getSpec()->parameters.getByName(name).count; - if (count == 0) - { + if (count == 0) { NTA_THROW << "Internal Error -- unknown element count for " << "node type " << getType() << ". The RegionImpl " << "implementation should override this method."; @@ -302,22 +269,18 @@ size_t RegionImpl::getParameterArrayCount(const std::string& name, Int64 index) return count; } - // Provide data access for subclasses -const Input* RegionImpl::getInput(const std::string& name) -{ +const Input *RegionImpl::getInput(const std::string &name) { return region_->getInput(name); } -const Output* RegionImpl::getOutput(const std::string& name) -{ +const Output *RegionImpl::getOutput(const std::string &name) { return region_->getOutput(name); } -const Dimensions& RegionImpl::getDimensions() -{ +const Dimensions &RegionImpl::getDimensions() { return region_->getDimensions(); } -} +} // namespace nupic diff --git a/src/nupic/engine/RegionImpl.hpp b/src/nupic/engine/RegionImpl.hpp index c684cc71e7..5a8f32a5bc 100644 --- a/src/nupic/engine/RegionImpl.hpp +++ b/src/nupic/engine/RegionImpl.hpp @@ -43,193 +43,187 @@ #include // IWriteBuffer #include -namespace nupic -{ - - struct Spec; - class Region; - class Dimensions; - class Input; - class Output; - class Array; - class ArrayRef; - class NodeSet; - class BundleIO; - - class RegionImpl : public Serializable - { - public: - - - // All subclasses must call this constructor from their regular constructor - RegionImpl(Region* region); - - virtual ~RegionImpl(); - - - - /* ------- Convenience methods that access region data -------- */ - - const std::string& getType() const; - - const std::string& getName() const; - - const NodeSet& getEnabledNodes() const; - - /* ------- Parameter support in the base class. ---------*/ - // The default implementation of all of these methods goes through - // set/getParameterFromBuffer, which is compatible with NuPIC 1. - // RegionImpl subclasses may override for higher performance. - - virtual Int32 getParameterInt32(const std::string& name, Int64 index); - virtual UInt32 getParameterUInt32(const std::string& name, Int64 index); - virtual Int64 getParameterInt64(const std::string& name, Int64 index); - virtual UInt64 getParameterUInt64(const std::string& name, Int64 index); - virtual Real32 getParameterReal32(const std::string& name, Int64 index); - virtual Real64 getParameterReal64(const std::string& name, Int64 index); - virtual Handle getParameterHandle(const std::string& name, Int64 index); - virtual bool getParameterBool(const std::string& name, Int64 index); - - virtual void setParameterInt32(const std::string& name, Int64 index, Int32 value); - virtual void setParameterUInt32(const std::string& name, Int64 index, UInt32 value); - virtual void setParameterInt64(const std::string& name, Int64 index, Int64 value); - virtual void setParameterUInt64(const std::string& name, Int64 index, UInt64 value); - virtual void setParameterReal32(const std::string& name, Int64 index, Real32 value); - virtual void setParameterReal64(const std::string& name, Int64 index, Real64 value); - virtual void setParameterHandle(const std::string& name, Int64 index, Handle value); - virtual void setParameterBool(const std::string& name, Int64 index, bool value); - - virtual void getParameterArray(const std::string& name, Int64 index, Array & array); - virtual void setParameterArray(const std::string& name, Int64 index, const Array & array); - - virtual void setParameterString(const std::string& name, Int64 index, const std::string& s); - virtual std::string getParameterString(const std::string& name, Int64 index); - - - /* -------- Methods that must be implemented by subclasses -------- */ - - /** - * Can't declare a static method in an interface. But RegionFactory - * expects to find this method. Caller gets ownership. - */ - - //static Spec* createSpec(); - - // Serialize state. - virtual void serialize(BundleIO& bundle) = 0; - - // De-serialize state. Must be called from deserializing constructor - virtual void deserialize(BundleIO& bundle) = 0; - - // Serialize state with capnp - using Serializable::write; - virtual void write(capnp::AnyPointer::Builder& anyProto) const = 0; - - // Deserialize state from capnp. Must be called from deserializing - // constructor. - using Serializable::read; - virtual void read(capnp::AnyPointer::Reader& anyProto) = 0; - - /** - * Inputs/Outputs are made available in initialize() - * It is always called after the constructor (or load from serialized state) - */ - virtual void initialize() = 0; - - // Compute outputs from inputs and internal state - virtual void compute() = 0; - - - - // Execute a command - virtual std::string executeCommand(const std::vector& args, Int64 index) = 0; - - - // Per-node size (in elements) of the given output. - // For per-region outputs, it is the total element count. - // This method is called only for outputs whose size is not - // specified in the nodespec. - virtual size_t getNodeOutputElementCount(const std::string& outputName) = 0; - - /** - * Get a parameter from a write buffer. - * This method is called only by the typed getParameter* - * methods in the RegionImpl base class - * - * Must be implemented by all subclasses. - * - * @param index A node index. (-1) indicates a region-level parameter - * - */ - virtual void getParameterFromBuffer(const std::string& name, - Int64 index, - IWriteBuffer& value) = 0; - - /** - * Set a parameter from a read buffer. - * This method is called only by the RegionImpl base class - * type-specific setParameter* methods - * Must be implemented by all subclasses. - * - * @param index A node index. (-1) indicates a region-level parameter - */ - virtual void setParameterFromBuffer(const std::string& name, - Int64 index, - IReadBuffer& value) = 0; - - - - /* -------- Methods that may be overridden by subclasses -------- */ - - - /** - * Array-valued parameters may have a size determined at runtime. - * This method returns the number of elements in the named parameter. - * If parameter is not an array type, may throw an exception or return 1. - * - * Must be implemented only if the node has one or more array - * parameters with a dynamically-determined length. - */ - virtual size_t getParameterArrayCount(const std::string& name, Int64 index); - - - /** - * isParameterShared must be available after construction - * Default implementation -- all parameters are shared - * Tests whether a parameter is node or region level - */ - virtual bool isParameterShared(const std::string& name); - - - - protected: - Region* region_; - - /* -------- Methods provided by the base class for use by subclasses -------- */ - - // --- - /// Callback for subclasses to get an output stream during serialize() - /// (for output) and the deserializing constructor (for input) - /// It is invalid to call this method except inside serialize() in a subclass. - /// - /// Only one serialization stream may be open at a time. Calling - /// getSerializationXStream a second time automatically closes the - /// first stream. Any open stream is closed when serialize() returns. - // --- - std::ostream& getSerializationOutputStream(const std::string& name); - std::istream& getSerializationInputStream(const std::string& name); - std::string getSerializationPath(const std::string& name); - - // These methods provide access to inputs and outputs - // They raise an exception if the named input or output is - // not found. - const Input* getInput(const std::string& name); - const Output* getOutput(const std::string& name); - - const Dimensions& getDimensions(); - - }; - -} +namespace nupic { + +struct Spec; +class Region; +class Dimensions; +class Input; +class Output; +class Array; +class ArrayRef; +class NodeSet; +class BundleIO; + +class RegionImpl : public Serializable { +public: + // All subclasses must call this constructor from their regular constructor + RegionImpl(Region *region); + + virtual ~RegionImpl(); + + /* ------- Convenience methods that access region data -------- */ + + const std::string &getType() const; + + const std::string &getName() const; + + const NodeSet &getEnabledNodes() const; + + /* ------- Parameter support in the base class. ---------*/ + // The default implementation of all of these methods goes through + // set/getParameterFromBuffer, which is compatible with NuPIC 1. + // RegionImpl subclasses may override for higher performance. + + virtual Int32 getParameterInt32(const std::string &name, Int64 index); + virtual UInt32 getParameterUInt32(const std::string &name, Int64 index); + virtual Int64 getParameterInt64(const std::string &name, Int64 index); + virtual UInt64 getParameterUInt64(const std::string &name, Int64 index); + virtual Real32 getParameterReal32(const std::string &name, Int64 index); + virtual Real64 getParameterReal64(const std::string &name, Int64 index); + virtual Handle getParameterHandle(const std::string &name, Int64 index); + virtual bool getParameterBool(const std::string &name, Int64 index); + + virtual void setParameterInt32(const std::string &name, Int64 index, + Int32 value); + virtual void setParameterUInt32(const std::string &name, Int64 index, + UInt32 value); + virtual void setParameterInt64(const std::string &name, Int64 index, + Int64 value); + virtual void setParameterUInt64(const std::string &name, Int64 index, + UInt64 value); + virtual void setParameterReal32(const std::string &name, Int64 index, + Real32 value); + virtual void setParameterReal64(const std::string &name, Int64 index, + Real64 value); + virtual void setParameterHandle(const std::string &name, Int64 index, + Handle value); + virtual void setParameterBool(const std::string &name, Int64 index, + bool value); + + virtual void getParameterArray(const std::string &name, Int64 index, + Array &array); + virtual void setParameterArray(const std::string &name, Int64 index, + const Array &array); + + virtual void setParameterString(const std::string &name, Int64 index, + const std::string &s); + virtual std::string getParameterString(const std::string &name, Int64 index); + + /* -------- Methods that must be implemented by subclasses -------- */ + + /** + * Can't declare a static method in an interface. But RegionFactory + * expects to find this method. Caller gets ownership. + */ + + // static Spec* createSpec(); + + // Serialize state. + virtual void serialize(BundleIO &bundle) = 0; + + // De-serialize state. Must be called from deserializing constructor + virtual void deserialize(BundleIO &bundle) = 0; + + // Serialize state with capnp + using Serializable::write; + virtual void write(capnp::AnyPointer::Builder &anyProto) const = 0; + + // Deserialize state from capnp. Must be called from deserializing + // constructor. + using Serializable::read; + virtual void read(capnp::AnyPointer::Reader &anyProto) = 0; + + /** + * Inputs/Outputs are made available in initialize() + * It is always called after the constructor (or load from serialized state) + */ + virtual void initialize() = 0; + + // Compute outputs from inputs and internal state + virtual void compute() = 0; + + // Execute a command + virtual std::string executeCommand(const std::vector &args, + Int64 index) = 0; + + // Per-node size (in elements) of the given output. + // For per-region outputs, it is the total element count. + // This method is called only for outputs whose size is not + // specified in the nodespec. + virtual size_t getNodeOutputElementCount(const std::string &outputName) = 0; + + /** + * Get a parameter from a write buffer. + * This method is called only by the typed getParameter* + * methods in the RegionImpl base class + * + * Must be implemented by all subclasses. + * + * @param index A node index. (-1) indicates a region-level parameter + * + */ + virtual void getParameterFromBuffer(const std::string &name, Int64 index, + IWriteBuffer &value) = 0; + + /** + * Set a parameter from a read buffer. + * This method is called only by the RegionImpl base class + * type-specific setParameter* methods + * Must be implemented by all subclasses. + * + * @param index A node index. (-1) indicates a region-level parameter + */ + virtual void setParameterFromBuffer(const std::string &name, Int64 index, + IReadBuffer &value) = 0; + + /* -------- Methods that may be overridden by subclasses -------- */ + + /** + * Array-valued parameters may have a size determined at runtime. + * This method returns the number of elements in the named parameter. + * If parameter is not an array type, may throw an exception or return 1. + * + * Must be implemented only if the node has one or more array + * parameters with a dynamically-determined length. + */ + virtual size_t getParameterArrayCount(const std::string &name, Int64 index); + + /** + * isParameterShared must be available after construction + * Default implementation -- all parameters are shared + * Tests whether a parameter is node or region level + */ + virtual bool isParameterShared(const std::string &name); + +protected: + Region *region_; + + /* -------- Methods provided by the base class for use by subclasses -------- + */ + + // --- + /// Callback for subclasses to get an output stream during serialize() + /// (for output) and the deserializing constructor (for input) + /// It is invalid to call this method except inside serialize() in a subclass. + /// + /// Only one serialization stream may be open at a time. Calling + /// getSerializationXStream a second time automatically closes the + /// first stream. Any open stream is closed when serialize() returns. + // --- + std::ostream &getSerializationOutputStream(const std::string &name); + std::istream &getSerializationInputStream(const std::string &name); + std::string getSerializationPath(const std::string &name); + + // These methods provide access to inputs and outputs + // They raise an exception if the named input or output is + // not found. + const Input *getInput(const std::string &name); + const Output *getOutput(const std::string &name); + + const Dimensions &getDimensions(); +}; + +} // namespace nupic #endif // NTA_REGION_IMPL_HPP diff --git a/src/nupic/engine/RegionImplFactory.cpp b/src/nupic/engine/RegionImplFactory.cpp index 6d1b104da5..87fd9b4c41 100644 --- a/src/nupic/engine/RegionImplFactory.cpp +++ b/src/nupic/engine/RegionImplFactory.cpp @@ -24,20 +24,20 @@ #include -#include -#include +#include #include +#include +#include #include #include +#include +#include +#include +#include #include -#include -#include #include -#include -#include -#include -#include -#include +#include +#include #include #include #include @@ -45,199 +45,168 @@ #include // from http://stackoverflow.com/a/9096509/1781435 -#define stringify(x) #x +#define stringify(x) #x #define expand_and_stringify(x) stringify(x) -namespace nupic -{ - // Keys are Python modules and the values are sets of class names for the - // regions that have been registered in the corresponding module. E.g. - // pyRegions["nupic.regions.sample_region"] = "SampleRegion" - static std::map> pyRegions; - - // Mappings for C++ regions - static std::map cppRegions; - - bool initializedRegions = false; - - // Allows the user to add custom regions - void RegionImplFactory::registerPyRegion(const std::string module, const std::string className) - { - // Verify that no regions exist with the same className in any module - for (auto pyRegion : pyRegions) - { - if (pyRegion.second.find(className) != pyRegion.second.end()) - { - if (pyRegion.first != module) - { - // New region class name conflicts with previously registered region - NTA_THROW << "A pyRegion with name '" << className << "' already exists. " - << "Unregister the existing region or register the new region using a " - << "different name."; - } else { - // Same region registered again, ignore - return; - } +namespace nupic { +// Keys are Python modules and the values are sets of class names for the +// regions that have been registered in the corresponding module. E.g. +// pyRegions["nupic.regions.sample_region"] = "SampleRegion" +static std::map> pyRegions; + +// Mappings for C++ regions +static std::map cppRegions; + +bool initializedRegions = false; + +// Allows the user to add custom regions +void RegionImplFactory::registerPyRegion(const std::string module, + const std::string className) { + // Verify that no regions exist with the same className in any module + for (auto pyRegion : pyRegions) { + if (pyRegion.second.find(className) != pyRegion.second.end()) { + if (pyRegion.first != module) { + // New region class name conflicts with previously registered region + NTA_THROW << "A pyRegion with name '" << className + << "' already exists. " + << "Unregister the existing region or register the new " + "region using a " + << "different name."; + } else { + // Same region registered again, ignore + return; } } - - // Module hasn't been added yet - if (pyRegions.find(module) == pyRegions.end()) - { - pyRegions[module] = std::set(); - } - - pyRegions[module].insert(className); } - void RegionImplFactory::registerCPPRegion(const std::string name, - GenericRegisteredRegionImpl * wrapper) - { - if (cppRegions.find(name) != cppRegions.end()) - { - NTA_WARN << "A CPPRegion already exists with the name '" - << name << "'. Overwriting it..."; - } - cppRegions[name] = wrapper; + // Module hasn't been added yet + if (pyRegions.find(module) == pyRegions.end()) { + pyRegions[module] = std::set(); } - void RegionImplFactory::unregisterPyRegion(const std::string className) - { - for (auto pyRegion : pyRegions) - { - if (pyRegion.second.find(className) != pyRegion.second.end()) - { - pyRegions.erase(pyRegion.first); - return; - } - } - NTA_WARN << "A pyRegion with name '" << className << - "' doesn't exist. Nothing to unregister..."; + pyRegions[module].insert(className); +} + +void RegionImplFactory::registerCPPRegion( + const std::string name, GenericRegisteredRegionImpl *wrapper) { + if (cppRegions.find(name) != cppRegions.end()) { + NTA_WARN << "A CPPRegion already exists with the name '" << name + << "'. Overwriting it..."; } + cppRegions[name] = wrapper; +} - void RegionImplFactory::unregisterCPPRegion(const std::string name) - { - if (cppRegions.find(name) != cppRegions.end()) - { - cppRegions.erase(name); +void RegionImplFactory::unregisterPyRegion(const std::string className) { + for (auto pyRegion : pyRegions) { + if (pyRegion.second.find(className) != pyRegion.second.end()) { + pyRegions.erase(pyRegion.first); return; } } + NTA_WARN << "A pyRegion with name '" << className + << "' doesn't exist. Nothing to unregister..."; +} - class DynamicPythonLibrary - { - typedef void (*initPythonFunc)(); - typedef void (*finalizePythonFunc)(); - typedef void * (*createSpecFunc)(const char *, void **, const char *); - typedef int (*destroySpecFunc)(const char *, const char *); - typedef void * (*createPyNodeFunc)(const char *, void *, void *, void **, const char *); - typedef void * (*deserializePyNodeFunc)(const char *, void *, void *, void *, const char *); - typedef void * (*deserializePyNodeProtoFunc)(const char *, void *, void *, void *, const char *); - public: - DynamicPythonLibrary() : - initPython_(nullptr), - finalizePython_(nullptr), - createSpec_(nullptr), - destroySpec_(nullptr), - createPyNode_(nullptr) - { - initPython_ = (initPythonFunc) PyRegion::NTA_initPython; - finalizePython_ = (finalizePythonFunc) PyRegion::NTA_finalizePython; - createPyNode_ = (createPyNodeFunc) PyRegion::NTA_createPyNode; - deserializePyNode_ = (deserializePyNodeFunc) PyRegion::NTA_deserializePyNode; - deserializePyNodeProto_ = (deserializePyNodeProtoFunc) PyRegion::NTA_deserializePyNodeProto; - createSpec_ = (createSpecFunc) PyRegion::NTA_createSpec; - destroySpec_ = (destroySpecFunc) PyRegion::NTA_destroySpec; - - (*initPython_)(); - } - - ~DynamicPythonLibrary() - { - if (finalizePython_) - finalizePython_(); - } +void RegionImplFactory::unregisterCPPRegion(const std::string name) { + if (cppRegions.find(name) != cppRegions.end()) { + cppRegions.erase(name); + return; + } +} - void * createSpec(std::string nodeType, void ** exception, std::string className) - { - return (*createSpec_)(nodeType.c_str(), exception, className.c_str()); - } +class DynamicPythonLibrary { + typedef void (*initPythonFunc)(); + typedef void (*finalizePythonFunc)(); + typedef void *(*createSpecFunc)(const char *, void **, const char *); + typedef int (*destroySpecFunc)(const char *, const char *); + typedef void *(*createPyNodeFunc)(const char *, void *, void *, void **, + const char *); + typedef void *(*deserializePyNodeFunc)(const char *, void *, void *, void *, + const char *); + typedef void *(*deserializePyNodeProtoFunc)(const char *, void *, void *, + void *, const char *); + +public: + DynamicPythonLibrary() + : initPython_(nullptr), finalizePython_(nullptr), createSpec_(nullptr), + destroySpec_(nullptr), createPyNode_(nullptr) { + initPython_ = (initPythonFunc)PyRegion::NTA_initPython; + finalizePython_ = (finalizePythonFunc)PyRegion::NTA_finalizePython; + createPyNode_ = (createPyNodeFunc)PyRegion::NTA_createPyNode; + deserializePyNode_ = (deserializePyNodeFunc)PyRegion::NTA_deserializePyNode; + deserializePyNodeProto_ = + (deserializePyNodeProtoFunc)PyRegion::NTA_deserializePyNodeProto; + createSpec_ = (createSpecFunc)PyRegion::NTA_createSpec; + destroySpec_ = (destroySpecFunc)PyRegion::NTA_destroySpec; + + (*initPython_)(); + } - int destroySpec(std::string nodeType, std::string& className) - { - NTA_INFO << "destroySpec(" << nodeType << ")"; - return (*destroySpec_)(nodeType.c_str(), className.c_str()); - } + ~DynamicPythonLibrary() { + if (finalizePython_) + finalizePython_(); + } - void * createPyNode(const std::string& nodeType, - ValueMap * nodeParams, - Region * region, - void ** exception, - const std::string& className) - { - return (*createPyNode_)(nodeType.c_str(), - reinterpret_cast(nodeParams), - reinterpret_cast(region), - exception, - className.c_str()); + void *createSpec(std::string nodeType, void **exception, + std::string className) { + return (*createSpec_)(nodeType.c_str(), exception, className.c_str()); + } - } + int destroySpec(std::string nodeType, std::string &className) { + NTA_INFO << "destroySpec(" << nodeType << ")"; + return (*destroySpec_)(nodeType.c_str(), className.c_str()); + } - void * deserializePyNode(const std::string& nodeType, - BundleIO* bundle, - Region * region, - void ** exception, - const std::string& className) - { - return (*deserializePyNode_)(nodeType.c_str(), - reinterpret_cast(bundle), - reinterpret_cast(region), - exception, - className.c_str()); - } + void *createPyNode(const std::string &nodeType, ValueMap *nodeParams, + Region *region, void **exception, + const std::string &className) { + return (*createPyNode_)( + nodeType.c_str(), reinterpret_cast(nodeParams), + reinterpret_cast(region), exception, className.c_str()); + } - void * deserializePyNodeProto(const std::string& nodeType, - capnp::AnyPointer::Reader* proto, - Region * region, - void ** exception, - const std::string& className) - { - return (*deserializePyNodeProto_)(nodeType.c_str(), - reinterpret_cast(proto), - reinterpret_cast(region), - exception, - className.c_str()); - } + void *deserializePyNode(const std::string &nodeType, BundleIO *bundle, + Region *region, void **exception, + const std::string &className) { + return (*deserializePyNode_)( + nodeType.c_str(), reinterpret_cast(bundle), + reinterpret_cast(region), exception, className.c_str()); + } - const std::string& getRootDir() const - { - return rootDir_; - } + void *deserializePyNodeProto(const std::string &nodeType, + capnp::AnyPointer::Reader *proto, Region *region, + void **exception, const std::string &className) { + return (*deserializePyNodeProto_)( + nodeType.c_str(), reinterpret_cast(proto), + reinterpret_cast(region), exception, className.c_str()); + } - private: - std::string rootDir_; - boost::shared_ptr pynodeLibrary_; - initPythonFunc initPython_; - finalizePythonFunc finalizePython_; - createSpecFunc createSpec_; - destroySpecFunc destroySpec_; - createPyNodeFunc createPyNode_; - deserializePyNodeFunc deserializePyNode_; - deserializePyNodeProtoFunc deserializePyNodeProto_; - }; - -RegionImplFactory & RegionImplFactory::getInstance() -{ + const std::string &getRootDir() const { return rootDir_; } + +private: + std::string rootDir_; + boost::shared_ptr pynodeLibrary_; + initPythonFunc initPython_; + finalizePythonFunc finalizePython_; + createSpecFunc createSpec_; + destroySpecFunc destroySpec_; + createPyNodeFunc createPyNode_; + deserializePyNodeFunc deserializePyNode_; + deserializePyNodeProtoFunc deserializePyNodeProto_; +}; + +RegionImplFactory &RegionImplFactory::getInstance() { static RegionImplFactory instance; // Initialize Regions - if (!initializedRegions) - { + if (!initializedRegions) { // Create C++ regions cppRegions["ScalarSensor"] = new RegisteredRegionImpl(); cppRegions["TestNode"] = new RegisteredRegionImpl(); - cppRegions["VectorFileEffector"] = new RegisteredRegionImpl(); - cppRegions["VectorFileSensor"] = new RegisteredRegionImpl(); + cppRegions["VectorFileEffector"] = + new RegisteredRegionImpl(); + cppRegions["VectorFileSensor"] = + new RegisteredRegionImpl(); initializedRegions = true; } @@ -246,196 +215,160 @@ RegionImplFactory & RegionImplFactory::getInstance() } // This function creates either a NuPIC 2 or NuPIC 1 Python node -static RegionImpl * createPyNode(DynamicPythonLibrary * pyLib, - const std::string & nodeType, - ValueMap * nodeParams, - Region * region) -{ +static RegionImpl *createPyNode(DynamicPythonLibrary *pyLib, + const std::string &nodeType, + ValueMap *nodeParams, Region *region) { std::string className(nodeType.c_str() + 3); - for (auto pyr=pyRegions.begin(); pyr!=pyRegions.end(); pyr++) - { + for (auto pyr = pyRegions.begin(); pyr != pyRegions.end(); pyr++) { const std::string module = pyr->first; std::set classes = pyr->second; // This module contains the class - if (classes.find(className) != classes.end()) - { - void * exception = nullptr; - void * node = pyLib->createPyNode(module, nodeParams, region, &exception, className); - if (node) - { - return static_cast(node); + if (classes.find(className) != classes.end()) { + void *exception = nullptr; + void *node = pyLib->createPyNode(module, nodeParams, region, &exception, + className); + if (node) { + return static_cast(node); } } } - NTA_THROW << "Unable to create region " << region->getName() << " of type " << className; + NTA_THROW << "Unable to create region " << region->getName() << " of type " + << className; return nullptr; } // This function deserializes either a NuPIC 2 or NuPIC 1 Python node -static RegionImpl * deserializePyNode(DynamicPythonLibrary * pyLib, - const std::string & nodeType, - BundleIO & bundle, - Region * region) -{ +static RegionImpl *deserializePyNode(DynamicPythonLibrary *pyLib, + const std::string &nodeType, + BundleIO &bundle, Region *region) { std::string className(nodeType.c_str() + 3); - for (auto pyr=pyRegions.begin(); pyr!=pyRegions.end(); pyr++) - { + for (auto pyr = pyRegions.begin(); pyr != pyRegions.end(); pyr++) { const std::string module = pyr->first; std::set classes = pyr->second; // This module contains the class - if (classes.find(className) != classes.end()) - { - void * exception = nullptr; - void * node = pyLib->deserializePyNode(module, &bundle, region, &exception, className); - if (node) - { - return static_cast(node); + if (classes.find(className) != classes.end()) { + void *exception = nullptr; + void *node = pyLib->deserializePyNode(module, &bundle, region, &exception, + className); + if (node) { + return static_cast(node); } } } - NTA_THROW << "Unable to deserialize region " << region->getName() << " of type " << className; + NTA_THROW << "Unable to deserialize region " << region->getName() + << " of type " << className; return nullptr; - - - } -static RegionImpl * deserializePyNode(DynamicPythonLibrary * pyLib, - const std::string & nodeType, - capnp::AnyPointer::Reader& proto, - Region * region) -{ +static RegionImpl *deserializePyNode(DynamicPythonLibrary *pyLib, + const std::string &nodeType, + capnp::AnyPointer::Reader &proto, + Region *region) { std::string className(nodeType.c_str() + 3); - for (auto pyr=pyRegions.begin(); pyr!=pyRegions.end(); pyr++) - { + for (auto pyr = pyRegions.begin(); pyr != pyRegions.end(); pyr++) { const std::string module = pyr->first; std::set classes = pyr->second; // This module contains the class - if (classes.find(className) != classes.end()) - { - void * exception = nullptr; - void * node = pyLib->deserializePyNodeProto(module, &proto, region, - &exception, className); - if (node) - { - return static_cast(node); + if (classes.find(className) != classes.end()) { + void *exception = nullptr; + void *node = pyLib->deserializePyNodeProto(module, &proto, region, + &exception, className); + if (node) { + return static_cast(node); } } } - NTA_THROW << "Unable to deserialize region " << region->getName() << - " of type " << className; + NTA_THROW << "Unable to deserialize region " << region->getName() + << " of type " << className; return nullptr; } -RegionImpl* RegionImplFactory::createRegionImpl(const std::string nodeType, +RegionImpl *RegionImplFactory::createRegionImpl(const std::string nodeType, const std::string nodeParams, - Region* region) -{ + Region *region) { RegionImpl *impl = nullptr; Spec *ns = getSpec(nodeType); - ValueMap vm = YAMLUtils::toValueMap( - nodeParams.c_str(), - ns->parameters, - nodeType, - region->getName()); - - if (cppRegions.find(nodeType) != cppRegions.end()) - { + ValueMap vm = YAMLUtils::toValueMap(nodeParams.c_str(), ns->parameters, + nodeType, region->getName()); + + if (cppRegions.find(nodeType) != cppRegions.end()) { impl = cppRegions[nodeType]->createRegionImpl(vm, region); - } - else if ((nodeType.find(std::string("py.")) == 0)) - { + } else if ((nodeType.find(std::string("py.")) == 0)) { if (!pyLib_) - pyLib_ = boost::shared_ptr(new DynamicPythonLibrary()); + pyLib_ = + boost::shared_ptr(new DynamicPythonLibrary()); impl = createPyNode(pyLib_.get(), nodeType, &vm, region); - } else - { + } else { NTA_THROW << "Unsupported node type '" << nodeType << "'"; } return impl; - } -RegionImpl* RegionImplFactory::deserializeRegionImpl(const std::string nodeType, - BundleIO& bundle, - Region* region) -{ +RegionImpl *RegionImplFactory::deserializeRegionImpl(const std::string nodeType, + BundleIO &bundle, + Region *region) { RegionImpl *impl = nullptr; - if (cppRegions.find(nodeType) != cppRegions.end()) - { + if (cppRegions.find(nodeType) != cppRegions.end()) { impl = cppRegions[nodeType]->deserializeRegionImpl(bundle, region); - } - else if (StringUtils::startsWith(nodeType, "py.")) - { + } else if (StringUtils::startsWith(nodeType, "py.")) { if (!pyLib_) - pyLib_ = boost::shared_ptr(new DynamicPythonLibrary()); + pyLib_ = + boost::shared_ptr(new DynamicPythonLibrary()); impl = deserializePyNode(pyLib_.get(), nodeType, bundle, region); - } else - { + } else { NTA_THROW << "Unsupported node type '" << nodeType << "'"; } return impl; - } -RegionImpl* RegionImplFactory::deserializeRegionImpl( - const std::string nodeType, - capnp::AnyPointer::Reader& proto, - Region* region) -{ +RegionImpl * +RegionImplFactory::deserializeRegionImpl(const std::string nodeType, + capnp::AnyPointer::Reader &proto, + Region *region) { RegionImpl *impl = nullptr; - if (cppRegions.find(nodeType) != cppRegions.end()) - { + if (cppRegions.find(nodeType) != cppRegions.end()) { impl = cppRegions[nodeType]->deserializeRegionImpl(proto, region); - } - else if (StringUtils::startsWith(nodeType, "py.")) - { + } else if (StringUtils::startsWith(nodeType, "py.")) { if (!pyLib_) - pyLib_ = boost::shared_ptr(new DynamicPythonLibrary()); + pyLib_ = + boost::shared_ptr(new DynamicPythonLibrary()); impl = deserializePyNode(pyLib_.get(), nodeType, proto, region); - } - else - { + } else { NTA_THROW << "Unsupported node type '" << nodeType << "'"; } return impl; } // This function returns the node spec of a NuPIC 2 or NuPIC 1 Python node -static Spec * getPySpec(DynamicPythonLibrary * pyLib, - const std::string & nodeType) -{ +static Spec *getPySpec(DynamicPythonLibrary *pyLib, + const std::string &nodeType) { std::string className(nodeType.c_str() + 3); - for (auto pyr=pyRegions.begin(); pyr!=pyRegions.end(); pyr++) - { + for (auto pyr = pyRegions.begin(); pyr != pyRegions.end(); pyr++) { const std::string module = pyr->first; std::set classes = pyr->second; // This module contains the class - if (classes.find(className) != classes.end()) - { - void * exception = nullptr; - void * ns = pyLib->createSpec(module, &exception, className); - if (exception != nullptr) - { - throw *((Exception*)exception); + if (classes.find(className) != classes.end()) { + void *exception = nullptr; + void *ns = pyLib->createSpec(module, &exception, className); + if (exception != nullptr) { + throw *((Exception *)exception); } - if (ns) - { + if (ns) { return (Spec *)ns; } } @@ -444,9 +377,8 @@ static Spec * getPySpec(DynamicPythonLibrary * pyLib, NTA_THROW << "Matching Python module for " << className << " not found."; } -Spec * RegionImplFactory::getSpec(const std::string nodeType) -{ - std::map::iterator it; +Spec *RegionImplFactory::getSpec(const std::string nodeType) { + std::map::iterator it; // return from cache if we already have it it = nodespecCache_.find(nodeType); if (it != nodespecCache_.end()) @@ -454,20 +386,16 @@ Spec * RegionImplFactory::getSpec(const std::string nodeType) // grab the nodespec and cache it // one entry per supported node type - Spec * ns = nullptr; - if (cppRegions.find(nodeType) != cppRegions.end()) - { + Spec *ns = nullptr; + if (cppRegions.find(nodeType) != cppRegions.end()) { ns = cppRegions[nodeType]->createSpec(); - } - else if (nodeType.find(std::string("py.")) == 0) - { + } else if (nodeType.find(std::string("py.")) == 0) { if (!pyLib_) - pyLib_ = boost::shared_ptr(new DynamicPythonLibrary()); + pyLib_ = + boost::shared_ptr(new DynamicPythonLibrary()); ns = getPySpec(pyLib_.get(), nodeType); - } - else - { + } else { NTA_THROW << "getSpec() -- Unsupported node type '" << nodeType << "'"; } @@ -478,21 +406,16 @@ Spec * RegionImplFactory::getSpec(const std::string nodeType) return ns; } -void RegionImplFactory::cleanup() -{ - std::map::iterator ns; +void RegionImplFactory::cleanup() { + std::map::iterator ns; // destroy all nodespecs - for (ns = nodespecCache_.begin(); ns != nodespecCache_.end(); ns++) - { + for (ns = nodespecCache_.begin(); ns != nodespecCache_.end(); ns++) { assert(ns->second != nullptr); // PyNode node specs are destroyed by the C++ PyNode - if (ns->first.substr(0, 3) == "py.") - { + if (ns->first.substr(0, 3) == "py.") { std::string noClass = ""; pyLib_->destroySpec(ns->first, noClass); - } - else - { + } else { delete ns->second; } @@ -502,8 +425,7 @@ void RegionImplFactory::cleanup() nodespecCache_.clear(); // destroy all RegisteredRegionImpls - for (auto rri = cppRegions.begin(); rri != cppRegions.end(); rri++) - { + for (auto rri = cppRegions.begin(); rri != cppRegions.end(); rri++) { NTA_ASSERT(rri->second != nullptr); delete rri->second; rri->second = nullptr; @@ -515,7 +437,7 @@ void RegionImplFactory::cleanup() // Never release the Python dynamic library! // This is due to cleanup issues of Python itself // See: http://docs.python.org/c-api/init.html#Py_Finalize - //pyLib_.reset(); + // pyLib_.reset(); } -} +} // namespace nupic diff --git a/src/nupic/engine/RegionImplFactory.hpp b/src/nupic/engine/RegionImplFactory.hpp index ae61cb5bf1..caf2764862 100644 --- a/src/nupic/engine/RegionImplFactory.hpp +++ b/src/nupic/engine/RegionImplFactory.hpp @@ -37,81 +37,83 @@ #include #include + +// Workaround windows.h collision: +// https://github.com/sandstorm-io/capnproto/issues/213 +#undef VOID #include -namespace nupic -{ +namespace nupic { - class RegionImpl; - class Region; - class DynamicPythonLibrary; - struct Spec; - class BundleIO; - class ValueMap; - class GenericRegisteredRegionImpl; +class RegionImpl; +class Region; +class DynamicPythonLibrary; +struct Spec; +class BundleIO; +class ValueMap; +class GenericRegisteredRegionImpl; - class RegionImplFactory - { - public: - static RegionImplFactory & getInstance(); +class RegionImplFactory { +public: + static RegionImplFactory &getInstance(); - // RegionImplFactory is a lightweight object - ~RegionImplFactory() {}; + // RegionImplFactory is a lightweight object + ~RegionImplFactory(){}; - // Create a RegionImpl of a specific type; caller gets ownership. - RegionImpl* createRegionImpl(const std::string nodeType, - const std::string nodeParams, - Region* region); + // Create a RegionImpl of a specific type; caller gets ownership. + RegionImpl *createRegionImpl(const std::string nodeType, + const std::string nodeParams, Region *region); - // Create a RegionImpl from serialized state; caller gets ownership. - RegionImpl* deserializeRegionImpl(const std::string nodeType, - BundleIO& bundle, - Region* region); + // Create a RegionImpl from serialized state; caller gets ownership. + RegionImpl *deserializeRegionImpl(const std::string nodeType, + BundleIO &bundle, Region *region); - // Create a RegionImpl from capnp proto; caller gets ownership. - RegionImpl* deserializeRegionImpl(const std::string nodeType, - capnp::AnyPointer::Reader& proto, - Region* region); + // Create a RegionImpl from capnp proto; caller gets ownership. + RegionImpl *deserializeRegionImpl(const std::string nodeType, + capnp::AnyPointer::Reader &proto, + Region *region); - // Returns nodespec for a specific node type; Factory retains ownership. - Spec* getSpec(const std::string nodeType); + // Returns nodespec for a specific node type; Factory retains ownership. + Spec *getSpec(const std::string nodeType); - // RegionImplFactory caches nodespecs and the dynamic library reference - // This frees up the cached information. - // Should be called only if there are no outstanding - // nodespec references (e.g. in NuPIC shutdown) or pynodes. - void cleanup(); + // RegionImplFactory caches nodespecs and the dynamic library reference + // This frees up the cached information. + // Should be called only if there are no outstanding + // nodespec references (e.g. in NuPIC shutdown) or pynodes. + void cleanup(); - static void registerPyRegionPackage(const char * path); + static void registerPyRegionPackage(const char *path); - // Allows the user to load custom Python regions - static void registerPyRegion(const std::string module, const std::string className); + // Allows the user to load custom Python regions + static void registerPyRegion(const std::string module, + const std::string className); - // Allows the user to load custom C++ regions - static void registerCPPRegion(const std::string name, GenericRegisteredRegionImpl * wrapper); + // Allows the user to load custom C++ regions + static void registerCPPRegion(const std::string name, + GenericRegisteredRegionImpl *wrapper); - // Allows the user to unregister Python regions - static void unregisterPyRegion(const std::string className); + // Allows the user to unregister Python regions + static void unregisterPyRegion(const std::string className); - // Allows the user to unregister C++ regions - static void unregisterCPPRegion(const std::string name); + // Allows the user to unregister C++ regions + static void unregisterCPPRegion(const std::string name); - private: - RegionImplFactory() {}; - RegionImplFactory(const RegionImplFactory &); +private: + RegionImplFactory(){}; + RegionImplFactory(const RegionImplFactory &); - // TODO: implement locking for thread safety for this global data structure - // TODO: implement cleanup + // TODO: implement locking for thread safety for this global data structure + // TODO: implement cleanup - // getSpec returns references to nodespecs in this cache. - // should not be cleaned up until those references have disappeared. - std::map nodespecCache_; + // getSpec returns references to nodespecs in this cache. + // should not be cleaned up until those references have disappeared. + std::map nodespecCache_; - // Using shared_ptr here to ensure the dynamic python library object - // is deleted when the factory goes away. Can't use scoped_ptr - // because it is not initialized in the constructor. - boost::shared_ptr pyLib_; - }; -} + // Using shared_ptr here to ensure the dynamic python library object + // is deleted when the factory goes away. Can't use scoped_ptr + // because it is not initialized in the constructor. + boost::shared_ptr pyLib_; +}; +} // namespace nupic #endif // NTA_REGION_IMPL_FACTORY_HPP diff --git a/src/nupic/engine/RegionIo.cpp b/src/nupic/engine/RegionIo.cpp index d48e836df3..92b5845ffb 100644 --- a/src/nupic/engine/RegionIo.cpp +++ b/src/nupic/engine/RegionIo.cpp @@ -20,117 +20,91 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Implementation of Region methods related to inputs and outputs */ +#include +#include #include #include -#include -#include -#include #include #include #include +#include -namespace nupic -{ - - - +namespace nupic { -// Internal methods called by RegionImpl. +// Internal methods called by RegionImpl. -Output* Region::getOutput(const std::string& name) const -{ +Output *Region::getOutput(const std::string &name) const { auto o = outputs_.find(name); if (o == outputs_.end()) return nullptr; return o->second; } - -Input* Region::getInput(const std::string& name) const -{ +Input *Region::getInput(const std::string &name) const { auto i = inputs_.find(name); if (i == inputs_.end()) return nullptr; return i->second; } - // Called by Network during serialization -const std::map& -Region::getInputs() const -{ +const std::map &Region::getInputs() const { return inputs_; } -const std::map& -Region::getOutputs() const -{ +const std::map &Region::getOutputs() const { return outputs_; } -size_t -Region::getOutputCount(const std::string& outputName) const -{ +size_t Region::getOutputCount(const std::string &outputName) const { auto oi = outputs_.find(outputName); if (oi == outputs_.end()) - NTA_THROW << "getOutputSize -- unknown output '" << outputName << "' on region " << getName(); + NTA_THROW << "getOutputSize -- unknown output '" << outputName + << "' on region " << getName(); return oi->second->getData().getCount(); } - -size_t -Region::getInputCount(const std::string& inputName) const -{ +size_t Region::getInputCount(const std::string &inputName) const { auto ii = inputs_.find(inputName); if (ii == inputs_.end()) - NTA_THROW << "getInputSize -- unknown input '" << inputName << "' on region " << getName(); + NTA_THROW << "getInputSize -- unknown input '" << inputName + << "' on region " << getName(); return ii->second->getData().getCount(); } - -ArrayRef -Region::getOutputData(const std::string& outputName) const -{ +ArrayRef Region::getOutputData(const std::string &outputName) const { auto oi = outputs_.find(outputName); if (oi == outputs_.end()) - NTA_THROW << "getOutputData -- unknown output '" << outputName << "' on region " << getName(); + NTA_THROW << "getOutputData -- unknown output '" << outputName + << "' on region " << getName(); - const Array & data = oi->second->getData(); + const Array &data = oi->second->getData(); ArrayRef a(data.getType()); a.setBuffer(data.getBuffer(), data.getCount()); return a; } -ArrayRef -Region::getInputData(const std::string& inputName) const -{ +ArrayRef Region::getInputData(const std::string &inputName) const { auto ii = inputs_.find(inputName); if (ii == inputs_.end()) - NTA_THROW << "getInput -- unknown input '" << inputName << "' on region " << getName(); + NTA_THROW << "getInput -- unknown input '" << inputName << "' on region " + << getName(); - const Array & data = ii->second->getData(); + const Array &data = ii->second->getData(); ArrayRef a(data.getType()); a.setBuffer(data.getBuffer(), data.getCount()); return a; } -void -Region::prepareInputs() -{ +void Region::prepareInputs() { // Ask each input to prepare itself - for (InputMap::const_iterator i = inputs_.begin(); - i != inputs_.end(); i++) - { + for (InputMap::const_iterator i = inputs_.begin(); i != inputs_.end(); i++) { i->second->prepare(); } - } - -} - - +} // namespace nupic diff --git a/src/nupic/engine/RegionParameters.cpp b/src/nupic/engine/RegionParameters.cpp index 3a3e2596a9..1a0f760f04 100644 --- a/src/nupic/engine/RegionParameters.cpp +++ b/src/nupic/engine/RegionParameters.cpp @@ -20,160 +20,123 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Implementation of Region methods related to parameters */ #include #include #include -#include #include #include +#include -namespace nupic -{ - +namespace nupic { // setParameter -void Region::setParameterInt32(const std::string& name, Int32 value) -{ +void Region::setParameterInt32(const std::string &name, Int32 value) { impl_->setParameterInt32(name, (Int64)-1, value); } -void Region::setParameterUInt32(const std::string& name, UInt32 value) -{ +void Region::setParameterUInt32(const std::string &name, UInt32 value) { impl_->setParameterUInt32(name, (Int64)-1, value); } -void Region::setParameterInt64(const std::string& name, Int64 value) -{ +void Region::setParameterInt64(const std::string &name, Int64 value) { impl_->setParameterInt64(name, (Int64)-1, value); } -void Region::setParameterUInt64(const std::string& name, UInt64 value) -{ +void Region::setParameterUInt64(const std::string &name, UInt64 value) { impl_->setParameterUInt64(name, (Int64)-1, value); } -void Region::setParameterReal32(const std::string& name, Real32 value) -{ +void Region::setParameterReal32(const std::string &name, Real32 value) { impl_->setParameterReal32(name, (Int64)-1, value); } -void Region::setParameterReal64(const std::string& name, Real64 value) -{ +void Region::setParameterReal64(const std::string &name, Real64 value) { impl_->setParameterReal64(name, (Int64)-1, value); } -void Region::setParameterHandle(const std::string& name, Handle value) -{ +void Region::setParameterHandle(const std::string &name, Handle value) { impl_->setParameterHandle(name, (Int64)-1, value); } -void Region::setParameterBool(const std::string& name, bool value) -{ +void Region::setParameterBool(const std::string &name, bool value) { impl_->setParameterBool(name, (Int64)-1, value); } - // getParameter -Int32 Region::getParameterInt32(const std::string& name) const -{ +Int32 Region::getParameterInt32(const std::string &name) const { return impl_->getParameterInt32(name, (Int64)-1); } -Int64 Region::getParameterInt64(const std::string& name) const -{ +Int64 Region::getParameterInt64(const std::string &name) const { return impl_->getParameterInt64(name, (Int64)-1); } -UInt32 Region::getParameterUInt32(const std::string& name) const -{ +UInt32 Region::getParameterUInt32(const std::string &name) const { return impl_->getParameterUInt32(name, (Int64)-1); } - -UInt64 Region::getParameterUInt64(const std::string& name) const -{ +UInt64 Region::getParameterUInt64(const std::string &name) const { return impl_->getParameterUInt64(name, (Int64)-1); } -Real32 Region::getParameterReal32(const std::string& name) const -{ +Real32 Region::getParameterReal32(const std::string &name) const { return impl_->getParameterReal32(name, (Int64)-1); } -Real64 Region::getParameterReal64(const std::string& name) const -{ +Real64 Region::getParameterReal64(const std::string &name) const { return impl_->getParameterReal64(name, (Int64)-1); } -Handle Region::getParameterHandle(const std::string& name) const -{ +Handle Region::getParameterHandle(const std::string &name) const { return impl_->getParameterHandle(name, (Int64)-1); } -bool Region::getParameterBool(const std::string& name) const -{ +bool Region::getParameterBool(const std::string &name) const { return impl_->getParameterBool(name, (Int64)-1); } - // array parameters - -void -Region::getParameterArray(const std::string& name, Array & array) const -{ +void Region::getParameterArray(const std::string &name, Array &array) const { size_t count = impl_->getParameterArrayCount(name, (Int64)(-1)); // Make sure we have a buffer to put the data in - if (array.getBuffer() != nullptr) - { + if (array.getBuffer() != nullptr) { // Buffer has already been allocated. Make sure it is big enough if (array.getCount() > count) NTA_THROW << "getParameterArray -- supplied buffer for parameter " << name - << " can hold " << array.getCount() - << " elements but parameter count is " - << count; + << " can hold " << array.getCount() + << " elements but parameter count is " << count; } else { array.allocateBuffer(count); } impl_->getParameterArray(name, (Int64)-1, array); - } - -void -Region::setParameterArray(const std::string& name, const Array & array) -{ +void Region::setParameterArray(const std::string &name, const Array &array) { // We do not check the array size here because it would be - // expensive -- involving a check against the nodespec, + // expensive -- involving a check against the nodespec, // and only usable in the rare case that the nodespec specified - // a fixed size. Instead, the implementation can check the size. + // a fixed size. Instead, the implementation can check the size. impl_->setParameterArray(name, (Int64)-1, array); } -void -Region::setParameterString(const std::string& name, const std::string& s) -{ +void Region::setParameterString(const std::string &name, const std::string &s) { impl_->setParameterString(name, (Int64)-1, s); } - -std::string -Region::getParameterString(const std::string& name) -{ + +std::string Region::getParameterString(const std::string &name) { return impl_->getParameterString(name, (Int64)-1); } -bool -Region::isParameterShared(const std::string& name) const -{ +bool Region::isParameterShared(const std::string &name) const { return impl_->isParameterShared(name); } -} - +} // namespace nupic diff --git a/src/nupic/engine/RegisteredRegionImpl.hpp b/src/nupic/engine/RegisteredRegionImpl.hpp index d24abf0362..31c3bb27fe 100644 --- a/src/nupic/engine/RegisteredRegionImpl.hpp +++ b/src/nupic/engine/RegisteredRegionImpl.hpp @@ -32,63 +32,56 @@ #include -namespace nupic -{ - struct Spec; - class BundleIO; - class RegionImpl; - class Region; - class ValueMap; - - class GenericRegisteredRegionImpl { - public: - GenericRegisteredRegionImpl() {} - - virtual ~GenericRegisteredRegionImpl() {} - - virtual RegionImpl* createRegionImpl( - const ValueMap& params, Region *region) = 0; - - virtual RegionImpl* deserializeRegionImpl( - BundleIO& params, Region *region) = 0; - - virtual RegionImpl* deserializeRegionImpl( - capnp::AnyPointer::Reader& proto, Region *region) = 0; - - virtual Spec* createSpec() = 0; - }; - - template - class RegisteredRegionImpl: public GenericRegisteredRegionImpl { - public: - RegisteredRegionImpl() {} - - ~RegisteredRegionImpl() {} - - virtual RegionImpl* createRegionImpl( - const ValueMap& params, Region *region) override - { - return new T(params, region); - } - - virtual RegionImpl* deserializeRegionImpl( - BundleIO& params, Region *region) override - { - return new T(params, region); - } - - virtual RegionImpl* deserializeRegionImpl( - capnp::AnyPointer::Reader& proto, Region *region) override - { - return new T(proto, region); - } - - virtual Spec* createSpec() override - { - return T::createSpec(); - } - }; - -} +namespace nupic { +struct Spec; +class BundleIO; +class RegionImpl; +class Region; +class ValueMap; + +class GenericRegisteredRegionImpl { +public: + GenericRegisteredRegionImpl() {} + + virtual ~GenericRegisteredRegionImpl() {} + + virtual RegionImpl *createRegionImpl(const ValueMap ¶ms, + Region *region) = 0; + + virtual RegionImpl *deserializeRegionImpl(BundleIO ¶ms, + Region *region) = 0; + + virtual RegionImpl *deserializeRegionImpl(capnp::AnyPointer::Reader &proto, + Region *region) = 0; + + virtual Spec *createSpec() = 0; +}; + +template +class RegisteredRegionImpl : public GenericRegisteredRegionImpl { +public: + RegisteredRegionImpl() {} + + ~RegisteredRegionImpl() {} + + virtual RegionImpl *createRegionImpl(const ValueMap ¶ms, + Region *region) override { + return new T(params, region); + } + + virtual RegionImpl *deserializeRegionImpl(BundleIO ¶ms, + Region *region) override { + return new T(params, region); + } + + virtual RegionImpl *deserializeRegionImpl(capnp::AnyPointer::Reader &proto, + Region *region) override { + return new T(proto, region); + } + + virtual Spec *createSpec() override { return T::createSpec(); } +}; + +} // namespace nupic #endif // NTA_REGISTERED_REGION_IMPL_HPP diff --git a/src/nupic/engine/Spec.cpp b/src/nupic/engine/Spec.cpp index 89499381a9..7c2d03127c 100644 --- a/src/nupic/engine/Spec.cpp +++ b/src/nupic/engine/Spec.cpp @@ -20,27 +20,19 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file Implementation of Spec API */ #include -#include #include +#include -namespace nupic -{ - - - +namespace nupic { -Spec::Spec() : singleNodeOnly(false), description("") -{ -} +Spec::Spec() : singleNodeOnly(false), description("") {} - -std::string Spec::getDefaultInputName() const -{ +std::string Spec::getDefaultInputName() const { if (inputs.getCount() == 0) return ""; if (inputs.getCount() == 1) @@ -50,22 +42,21 @@ std::string Spec::getDefaultInputName() const bool found = false; std::string name; - for (size_t i = 0; i < inputs.getCount(); ++i) - { - const std::pair & p = inputs.getByIndex(i); - if (p.second.isDefaultInput) - { - NTA_CHECK(!found) << "Internal error -- multiply-defined default inputs in Spec"; + for (size_t i = 0; i < inputs.getCount(); ++i) { + const std::pair &p = inputs.getByIndex(i); + if (p.second.isDefaultInput) { + NTA_CHECK(!found) + << "Internal error -- multiply-defined default inputs in Spec"; found = true; name = p.first; } } - NTA_CHECK(found) << "Internal error -- multiple inputs in Spec but no default"; + NTA_CHECK(found) + << "Internal error -- multiple inputs in Spec but no default"; return name; } -std::string Spec::getDefaultOutputName() const -{ +std::string Spec::getDefaultOutputName() const { if (outputs.getCount() == 0) return ""; if (outputs.getCount() == 1) @@ -75,112 +66,88 @@ std::string Spec::getDefaultOutputName() const bool found = false; std::string name; - for (size_t i = 0; i < outputs.getCount(); ++i) - { - const std::pair & p = outputs.getByIndex(i); - if (p.second.isDefaultOutput) - { - NTA_CHECK(!found) << "Internal error -- multiply-defined default outputs in Spec"; + for (size_t i = 0; i < outputs.getCount(); ++i) { + const std::pair &p = outputs.getByIndex(i); + if (p.second.isDefaultOutput) { + NTA_CHECK(!found) + << "Internal error -- multiply-defined default outputs in Spec"; found = true; name = p.first; } } - NTA_CHECK(found) << "Internal error -- multiple outputs in Spec but no default"; + NTA_CHECK(found) + << "Internal error -- multiple outputs in Spec but no default"; return name; } -InputSpec::InputSpec(std::string description, - NTA_BasicType dataType, - UInt32 count, - bool required, - bool regionLevel, - bool isDefaultInput, - bool requireSplitterMap) : - description(std::move(description)), - dataType(dataType), - count(count), - required(required), - regionLevel(regionLevel), - isDefaultInput(isDefaultInput), - requireSplitterMap(requireSplitterMap) -{ -} - -OutputSpec::OutputSpec(std::string description, - NTA_BasicType dataType, - size_t count, - bool regionLevel, - bool isDefaultOutput) : - description(std::move(description)), dataType(dataType), count(count), - regionLevel(regionLevel), isDefaultOutput(isDefaultOutput) -{ -} - - -CommandSpec::CommandSpec(std::string description) : - description(std::move(description)) -{ -} - - -ParameterSpec::ParameterSpec(std::string description, - NTA_BasicType dataType, - size_t count, - std::string constraints, - std::string defaultValue, - AccessMode accessMode) : - description(std::move(description)), dataType(dataType), count(count), - constraints(std::move(constraints)), defaultValue(std::move(defaultValue)), - accessMode(accessMode) -{ +InputSpec::InputSpec(std::string description, NTA_BasicType dataType, + UInt32 count, bool required, bool regionLevel, + bool isDefaultInput, bool requireSplitterMap) + : description(std::move(description)), dataType(dataType), count(count), + required(required), regionLevel(regionLevel), + isDefaultInput(isDefaultInput), requireSplitterMap(requireSplitterMap) {} + +OutputSpec::OutputSpec(std::string description, NTA_BasicType dataType, + size_t count, bool regionLevel, bool isDefaultOutput) + : description(std::move(description)), dataType(dataType), count(count), + regionLevel(regionLevel), isDefaultOutput(isDefaultOutput) {} + +CommandSpec::CommandSpec(std::string description) + : description(std::move(description)) {} + +ParameterSpec::ParameterSpec(std::string description, NTA_BasicType dataType, + size_t count, std::string constraints, + std::string defaultValue, AccessMode accessMode) + : description(std::move(description)), dataType(dataType), count(count), + constraints(std::move(constraints)), + defaultValue(std::move(defaultValue)), accessMode(accessMode) { // Parameter of type byte is not supported; // Strings are specified as type byte, length = 0 if (dataType == NTA_BasicType_Byte && count > 0) NTA_THROW << "Parameters of type 'byte' are not supported"; } - - -std::string Spec::toString() const -{ - // TODO -- minimal information here; fill out with the rest of +std::string Spec::toString() const { + // TODO -- minimal information here; fill out with the rest of // the parameter spec information std::stringstream ss; - ss << "Spec:" << "\n"; - ss << "Description:" << "\n" - << this->description << "\n" << "\n"; - - ss << "Parameters:" << "\n"; - for (size_t i = 0; i < parameters.getCount(); ++i) - { - const std::pair& item = parameters.getByIndex(i); + ss << "Spec:" + << "\n"; + ss << "Description:" + << "\n" + << this->description << "\n" + << "\n"; + + ss << "Parameters:" + << "\n"; + for (size_t i = 0; i < parameters.getCount(); ++i) { + const std::pair &item = + parameters.getByIndex(i); ss << " " << item.first << "\n" << " description: " << item.second.description << "\n" << " type: " << BasicType::getName(item.second.dataType) << "\n" - << " count: " << item.second.count << "\n"; + << " count: " << item.second.count << "\n"; } - ss << "Inputs:" << "\n"; - for (size_t i = 0; i < inputs.getCount(); ++i) - { + ss << "Inputs:" + << "\n"; + for (size_t i = 0; i < inputs.getCount(); ++i) { ss << " " << inputs.getByIndex(i).first << "\n"; } - ss << "Outputs:" << "\n"; - for (size_t i = 0; i < outputs.getCount(); ++i) - { + ss << "Outputs:" + << "\n"; + for (size_t i = 0; i < outputs.getCount(); ++i) { ss << " " << outputs.getByIndex(i).first << "\n"; } - ss << "Commands:" << "\n"; - for (size_t i = 0; i < commands.getCount(); ++i) - { + ss << "Commands:" + << "\n"; + for (size_t i = 0; i < commands.getCount(); ++i) { ss << " " << commands.getByIndex(i).first << "\n"; } - - return ss.str(); -} - + return ss.str(); } +} // namespace nupic diff --git a/src/nupic/engine/Spec.hpp b/src/nupic/engine/Spec.hpp index 4cdb6d7b92..8620c85a23 100644 --- a/src/nupic/engine/Spec.hpp +++ b/src/nupic/engine/Spec.hpp @@ -20,135 +20,119 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file Definition of Spec data structures */ #ifndef NTA_SPEC_HPP #define NTA_SPEC_HPP -#include +#include #include +#include #include -#include -namespace nupic -{ - class InputSpec - { - public: - InputSpec() {} - InputSpec( - std::string description, - NTA_BasicType dataType, - UInt32 count, - bool required, - bool regionLevel, - bool isDefaultInput, - bool requireSplitterMap = true); - - std::string description; - NTA_BasicType dataType; - // TBD: Omit? isn't it always of unknown size? - // 1 = scalar; > 1 = array of fixed sized; 0 = array of unknown size - UInt32 count; - // TBD. Omit? what is "required"? Is it ok to be zero length? - bool required; - bool regionLevel; - bool isDefaultInput; - bool requireSplitterMap; - }; - - class OutputSpec - { - public: - OutputSpec() {} - OutputSpec(std::string description, const - NTA_BasicType dataType, size_t count, bool regionLevel, bool isDefaultOutput); - - std::string description; - NTA_BasicType dataType; - // Size, in number of elements. If size is fixed, specify it here. - // Value of 0 means it is determined dynamically - size_t count; - bool regionLevel; - bool isDefaultOutput; - }; - - class CommandSpec - { - public: - CommandSpec() {} - CommandSpec(std::string description); - - std::string description; - - }; - - class ParameterSpec - { - public: - typedef enum { CreateAccess, ReadOnlyAccess, ReadWriteAccess } AccessMode; - - ParameterSpec() {} - /** - * @param defaultValue -- a JSON-encoded value - */ - ParameterSpec(std::string description, - NTA_BasicType dataType, size_t count, - std::string constraints, std::string defaultValue, - AccessMode accessMode); - - - std::string description; - - // [open: current basic types are bytes/{u}int16/32/64, real32/64, BytePtr. Is this - // the right list? Should we have std::string, jsonstd::string?] - NTA_BasicType dataType; - // 1 = scalar; > 1 = array o fixed sized; 0 = array of unknown size - // TODO: should be size_t? Serialization issues? - size_t count; - std::string constraints; - std::string defaultValue; // JSON representation; empty std::string means parameter is required - AccessMode accessMode; - - }; - - - struct Spec - { - // Return a printable string with Spec information - // TODO: should this be in the base API or layered? In the API right - // now since we do not build layered libraries. - std::string toString() const; - - // Some RegionImpls support only a single node in a region. - // Such regions always have dimension [1] - bool singleNodeOnly; - - // Description of the node as a whole - std::string description; - - Collection inputs; - Collection outputs; - Collection commands; - Collection parameters; +namespace nupic { +class InputSpec { +public: + InputSpec() {} + InputSpec(std::string description, NTA_BasicType dataType, UInt32 count, + bool required, bool regionLevel, bool isDefaultInput, + bool requireSplitterMap = true); + + std::string description; + NTA_BasicType dataType; + // TBD: Omit? isn't it always of unknown size? + // 1 = scalar; > 1 = array of fixed sized; 0 = array of unknown size + UInt32 count; + // TBD. Omit? what is "required"? Is it ok to be zero length? + bool required; + bool regionLevel; + bool isDefaultInput; + bool requireSplitterMap; +}; + +class OutputSpec { +public: + OutputSpec() {} + OutputSpec(std::string description, const NTA_BasicType dataType, + size_t count, bool regionLevel, bool isDefaultOutput); + + std::string description; + NTA_BasicType dataType; + // Size, in number of elements. If size is fixed, specify it here. + // Value of 0 means it is determined dynamically + size_t count; + bool regionLevel; + bool isDefaultOutput; +}; + +class CommandSpec { +public: + CommandSpec() {} + CommandSpec(std::string description); + + std::string description; +}; + +class ParameterSpec { +public: + typedef enum { CreateAccess, ReadOnlyAccess, ReadWriteAccess } AccessMode; + + ParameterSpec() {} + /** + * @param defaultValue -- a JSON-encoded value + */ + ParameterSpec(std::string description, NTA_BasicType dataType, size_t count, + std::string constraints, std::string defaultValue, + AccessMode accessMode); + + std::string description; + + // [open: current basic types are bytes/{u}int16/32/64, real32/64, BytePtr. Is + // this the right list? Should we have std::string, jsonstd::string?] + NTA_BasicType dataType; + // 1 = scalar; > 1 = array o fixed sized; 0 = array of unknown size + // TODO: should be size_t? Serialization issues? + size_t count; + std::string constraints; + std::string defaultValue; // JSON representation; empty std::string means + // parameter is required + AccessMode accessMode; +}; + +struct Spec { + // Return a printable string with Spec information + // TODO: should this be in the base API or layered? In the API right + // now since we do not build layered libraries. + std::string toString() const; + + // Some RegionImpls support only a single node in a region. + // Such regions always have dimension [1] + bool singleNodeOnly; + + // Description of the node as a whole + std::string description; + + Collection inputs; + Collection outputs; + Collection commands; + Collection parameters; #ifdef NTA_INTERNAL - Spec(); + Spec(); - // TODO: decide whether/how to wrap these - std::string getDefaultOutputName() const; - std::string getDefaultInputName() const; + // TODO: decide whether/how to wrap these + std::string getDefaultOutputName() const; + std::string getDefaultInputName() const; - // TODO: need Spec validation, to make sure - // that default input/output are defined - // Currently this is checked in getDefault*, above + // TODO: need Spec validation, to make sure + // that default input/output are defined + // Currently this is checked in getDefault*, above #endif // NTA_INTERNAL - - }; +}; } // namespace nupic diff --git a/src/nupic/engine/TestFanIn2LinkPolicy.cpp b/src/nupic/engine/TestFanIn2LinkPolicy.cpp index 4bc5dd2cc8..c681c7dc6d 100644 --- a/src/nupic/engine/TestFanIn2LinkPolicy.cpp +++ b/src/nupic/engine/TestFanIn2LinkPolicy.cpp @@ -20,137 +20,120 @@ * --------------------------------------------------------------------- */ - #include -#include #include +#include #include #include -namespace nupic -{ - TestFanIn2LinkPolicy::TestFanIn2LinkPolicy(const std::string params, Link* link) : link_(link), initialized_(false) - { +namespace nupic { +TestFanIn2LinkPolicy::TestFanIn2LinkPolicy(const std::string params, Link *link) + : link_(link), initialized_(false) {} + +TestFanIn2LinkPolicy::~TestFanIn2LinkPolicy() { + // We don't own link_ -- it is a reference to our parent. +} + +void TestFanIn2LinkPolicy::setSrcDimensions(Dimensions &dims) { + + // This method should never be called if we've already been set + NTA_CHECK(srcDimensions_.isUnspecified()) + << "Internal error on link " << link_->toString(); + NTA_CHECK(destDimensions_.isUnspecified()) + << "Internal error on link " << link_->toString(); + + if (dims.isUnspecified()) + NTA_THROW << "Invalid unspecified source dimensions for link " + << link_->toString(); + + if (dims.isDontcare()) + NTA_THROW << "Invalid dontcare source dimensions for link " + << link_->toString(); + + // Induce destination dimensions from src dimensions based on a fan-in of 2 + Dimensions destDims; + for (size_t i = 0; i < dims.size(); i++) { + destDims.push_back(dims[i] / 2); + if (destDims[i] * 2 != dims[i]) + NTA_THROW << "Invalid source dimensions " << dims.toString() + << " for link " << link_->toString() + << ". Dimensions must be multiples of 2"; } - TestFanIn2LinkPolicy::~TestFanIn2LinkPolicy() - { - // We don't own link_ -- it is a reference to our parent. + srcDimensions_ = dims; + destDimensions_ = destDims; +} + +void TestFanIn2LinkPolicy::setDestDimensions(Dimensions &dims) { + // This method should never be called if we've already been set + NTA_CHECK(srcDimensions_.isUnspecified()) + << "Internal error on link " << link_->toString(); + NTA_CHECK(destDimensions_.isUnspecified()) + << "Internal error on link " << link_->toString(); + + if (dims.isUnspecified()) + NTA_THROW << "Invalid unspecified dest dimensions for link " + << link_->toString(); + + if (dims.isDontcare()) + NTA_THROW << "Invalid dontcare dest dimensions for link " + << link_->toString(); + + Dimensions srcDims; + for (size_t i = 0; i < dims.size(); i++) { + // Induce src dimensions from destination dimensions based on a fan-in of 2 + // from src to dest which looks like fan-out of 2 from dest to src + srcDims.push_back(dims[i] * 2); } - void TestFanIn2LinkPolicy::setSrcDimensions(Dimensions& dims) - { - - // This method should never be called if we've already been set - NTA_CHECK(srcDimensions_.isUnspecified()) << "Internal error on link " << link_->toString(); - NTA_CHECK(destDimensions_.isUnspecified()) << "Internal error on link " << link_->toString(); - - if (dims.isUnspecified()) - NTA_THROW << "Invalid unspecified source dimensions for link " << link_->toString(); - - if (dims.isDontcare()) - NTA_THROW << "Invalid dontcare source dimensions for link " << link_->toString(); - - // Induce destination dimensions from src dimensions based on a fan-in of 2 - Dimensions destDims; - for (size_t i = 0; i < dims.size(); i++) - { - destDims.push_back(dims[i]/2); - if (destDims[i] * 2 != dims[i]) - NTA_THROW << "Invalid source dimensions " << dims.toString() << " for link " - << link_->toString() << ". Dimensions must be multiples of 2"; - } - - srcDimensions_ = dims; - destDimensions_ = destDims; - } + srcDimensions_ = srcDims; + destDimensions_ = dims; +} - void TestFanIn2LinkPolicy::setDestDimensions(Dimensions& dims) - { - // This method should never be called if we've already been set - NTA_CHECK(srcDimensions_.isUnspecified()) << "Internal error on link " << link_->toString(); - NTA_CHECK(destDimensions_.isUnspecified()) << "Internal error on link " << link_->toString(); - - if (dims.isUnspecified()) - NTA_THROW << "Invalid unspecified dest dimensions for link " << link_->toString(); - - if (dims.isDontcare()) - NTA_THROW << "Invalid dontcare dest dimensions for link " << link_->toString(); - - Dimensions srcDims; - for (size_t i = 0; i < dims.size(); i++) - { - // Induce src dimensions from destination dimensions based on a fan-in of 2 - // from src to dest which looks like fan-out of 2 from dest to src - srcDims.push_back(dims[i]*2); - } - - srcDimensions_ = srcDims; - destDimensions_ = dims; - - } - - const Dimensions& TestFanIn2LinkPolicy::getSrcDimensions() const - { - return srcDimensions_; - } +const Dimensions &TestFanIn2LinkPolicy::getSrcDimensions() const { + return srcDimensions_; +} - const Dimensions& TestFanIn2LinkPolicy::getDestDimensions() const - { - return destDimensions_; - } +const Dimensions &TestFanIn2LinkPolicy::getDestDimensions() const { + return destDimensions_; +} - void TestFanIn2LinkPolicy::setNodeOutputElementCount(size_t elementCount) - { - elementCount_ = elementCount; - } +void TestFanIn2LinkPolicy::setNodeOutputElementCount(size_t elementCount) { + elementCount_ = elementCount; +} - void TestFanIn2LinkPolicy::buildProtoSplitterMap(Input::SplitterMap& splitter) const - { - - NTA_CHECK(isInitialized()); - // node [i, j] in the source region sends data to node [i/2, j/2] in the dest region. - // For N dimensions, this is naturally done as N nested loops. Do just for N=1,2 for now - if (srcDimensions_.size() == 1) - { - for (size_t i = 0; i < srcDimensions_[0]; i++) - { - splitter[i/2].push_back(i); - } - } else if (srcDimensions_.size() == 2) - { - for (size_t y = 0; y < srcDimensions_[1]; y++) - { - for (size_t x = 0; x < srcDimensions_[0]; x++) - { - size_t srcIndex = srcDimensions_.getIndex(Dimensions(x, y)); - size_t destIndex = destDimensions_.getIndex(Dimensions(x/2, y/2)); - - size_t baseOffset = srcIndex*elementCount_; - for (size_t element = 0; element < elementCount_; element++) - { - splitter[destIndex].push_back(baseOffset + element); - } +void TestFanIn2LinkPolicy::buildProtoSplitterMap( + Input::SplitterMap &splitter) const { + + NTA_CHECK(isInitialized()); + // node [i, j] in the source region sends data to node [i/2, j/2] in the dest + // region. For N dimensions, this is naturally done as N nested loops. Do just + // for N=1,2 for now + if (srcDimensions_.size() == 1) { + for (size_t i = 0; i < srcDimensions_[0]; i++) { + splitter[i / 2].push_back(i); + } + } else if (srcDimensions_.size() == 2) { + for (size_t y = 0; y < srcDimensions_[1]; y++) { + for (size_t x = 0; x < srcDimensions_[0]; x++) { + size_t srcIndex = srcDimensions_.getIndex(Dimensions(x, y)); + size_t destIndex = destDimensions_.getIndex(Dimensions(x / 2, y / 2)); + + size_t baseOffset = srcIndex * elementCount_; + for (size_t element = 0; element < elementCount_; element++) { + splitter[destIndex].push_back(baseOffset + element); } } - } else { - NTA_THROW << "TestFanIn2 link policy does not support " << srcDimensions_.size() - << "-dimensional topologies. FIXME!"; } + } else { + NTA_THROW << "TestFanIn2 link policy does not support " + << srcDimensions_.size() << "-dimensional topologies. FIXME!"; } +} - void TestFanIn2LinkPolicy::initialize() - { - initialized_ = true; - } +void TestFanIn2LinkPolicy::initialize() { initialized_ = true; } - bool TestFanIn2LinkPolicy::isInitialized() const - { - return initialized_; - } +bool TestFanIn2LinkPolicy::isInitialized() const { return initialized_; } } // namespace nupic - - - diff --git a/src/nupic/engine/TestFanIn2LinkPolicy.hpp b/src/nupic/engine/TestFanIn2LinkPolicy.hpp index fe1bd0f8d6..114e26210d 100644 --- a/src/nupic/engine/TestFanIn2LinkPolicy.hpp +++ b/src/nupic/engine/TestFanIn2LinkPolicy.hpp @@ -20,56 +20,51 @@ * --------------------------------------------------------------------- */ - #ifndef NTA_TESTFANIN2LINKPOLICY_HPP #define NTA_TESTFANIN2LINKPOLICY_HPP -#include #include #include +#include -namespace nupic -{ +namespace nupic { - class Link; +class Link; - class TestFanIn2LinkPolicy : public LinkPolicy - { - public: - TestFanIn2LinkPolicy(const std::string params, Link* link); +class TestFanIn2LinkPolicy : public LinkPolicy { +public: + TestFanIn2LinkPolicy(const std::string params, Link *link); - ~TestFanIn2LinkPolicy(); + ~TestFanIn2LinkPolicy(); - void setSrcDimensions(Dimensions& dims) override; + void setSrcDimensions(Dimensions &dims) override; - void setDestDimensions(Dimensions& dims) override; - - const Dimensions& getSrcDimensions() const override; + void setDestDimensions(Dimensions &dims) override; - const Dimensions& getDestDimensions() const override; + const Dimensions &getSrcDimensions() const override; - void buildProtoSplitterMap(Input::SplitterMap& splitter) const override; + const Dimensions &getDestDimensions() const override; - void setNodeOutputElementCount(size_t elementCount) override; + void buildProtoSplitterMap(Input::SplitterMap &splitter) const override; - void initialize() override; + void setNodeOutputElementCount(size_t elementCount) override; - bool isInitialized() const override; + void initialize() override; + + bool isInitialized() const override; private: - Link* link_; - - Dimensions srcDimensions_; - Dimensions destDimensions_; + Link *link_; - size_t elementCount_; + Dimensions srcDimensions_; + Dimensions destDimensions_; - bool initialized_; + size_t elementCount_; + bool initialized_; - }; // TestFanIn2 +}; // TestFanIn2 } // namespace nupic - #endif // NTA_TESTFANIN2LINKPOLICY_HPP diff --git a/src/nupic/engine/TestNode.cpp b/src/nupic/engine/TestNode.cpp index 142fc7ed60..93030a44fa 100644 --- a/src/nupic/engine/TestNode.cpp +++ b/src/nupic/engine/TestNode.cpp @@ -25,10 +25,10 @@ #else #include #endif +#include #include #include #include // std::accumulate -#include #include // Workaround windows.h collision: @@ -36,890 +36,723 @@ #undef VOID #include -#include -#include +#include +#include #include -#include -#include // IWrite/ReadBuffer +#include +#include #include -#include #include -#include -#include -#include +#include // IWrite/ReadBuffer +#include #include +#include using capnp::AnyPointer; -namespace nupic -{ - - TestNode::TestNode(const ValueMap& params, Region *region) : - RegionImpl(region), - computeCallback_(nullptr), - nodeCount_(1) - - { - // params for get/setParameter testing - int32Param_ = params.getScalarT("int32Param", 32); - uint32Param_ = params.getScalarT("uint32Param", 33); - int64Param_ = params.getScalarT("int64Param", 64); - uint64Param_ = params.getScalarT("uint64Param", 65); - real32Param_ = params.getScalarT("real32Param", 32.1); - real64Param_ = params.getScalarT("real64Param", 64.1); - boolParam_ = params.getScalarT("boolParam", false); - - shouldCloneParam_ = params.getScalarT("shouldCloneParam", 1) != 0; - - stringParam_ = *params.getString("stringParam"); - - real32ArrayParam_.resize(8); - for (size_t i = 0; i < 8; i++) - { - real32ArrayParam_[i] = float(i * 32); - } - - int64ArrayParam_.resize(4); - for (size_t i = 0; i < 4; i++) - { - int64ArrayParam_[i] = i * 64; - } - - boolArrayParam_.resize(4); - for (size_t i = 0; i < 4; i++) - { - boolArrayParam_[i] = (i % 2) == 1; - } - +namespace nupic { - unclonedParam_.resize(nodeCount_); - unclonedParam_[0] = params.getScalarT("unclonedParam", 0); +TestNode::TestNode(const ValueMap ¶ms, Region *region) + : RegionImpl(region), computeCallback_(nullptr), nodeCount_(1) - possiblyUnclonedParam_.resize(nodeCount_); - possiblyUnclonedParam_[0] = params.getScalarT("possiblyUnclonedParam", 0); - - unclonedInt64ArrayParam_.resize(nodeCount_); - std::vector v(4, 0); //length 4 vector, each element == 0 - unclonedInt64ArrayParam_[0] = v; - - // params used for computation - outputElementCount_ = 2; - delta_ = 1; - iter_ = 0; +{ + // params for get/setParameter testing + int32Param_ = params.getScalarT("int32Param", 32); + uint32Param_ = params.getScalarT("uint32Param", 33); + int64Param_ = params.getScalarT("int64Param", 64); + uint64Param_ = params.getScalarT("uint64Param", 65); + real32Param_ = params.getScalarT("real32Param", 32.1); + real64Param_ = params.getScalarT("real64Param", 64.1); + boolParam_ = params.getScalarT("boolParam", false); + + shouldCloneParam_ = params.getScalarT("shouldCloneParam", 1) != 0; + + stringParam_ = *params.getString("stringParam"); + + real32ArrayParam_.resize(8); + for (size_t i = 0; i < 8; i++) { + real32ArrayParam_[i] = float(i * 32); + } + int64ArrayParam_.resize(4); + for (size_t i = 0; i < 4; i++) { + int64ArrayParam_[i] = i * 64; } - TestNode::TestNode(BundleIO& bundle, Region* region) : - RegionImpl(region) - { - deserialize(bundle); + boolArrayParam_.resize(4); + for (size_t i = 0; i < 4; i++) { + boolArrayParam_[i] = (i % 2) == 1; } + unclonedParam_.resize(nodeCount_); + unclonedParam_[0] = params.getScalarT("unclonedParam", 0); - TestNode::TestNode(AnyPointer::Reader& proto, Region* region) : - RegionImpl(region), - computeCallback_(nullptr) + possiblyUnclonedParam_.resize(nodeCount_); + possiblyUnclonedParam_[0] = + params.getScalarT("possiblyUnclonedParam", 0); - { - read(proto); - } + unclonedInt64ArrayParam_.resize(nodeCount_); + std::vector v(4, 0); // length 4 vector, each element == 0 + unclonedInt64ArrayParam_[0] = v; + // params used for computation + outputElementCount_ = 2; + delta_ = 1; + iter_ = 0; +} - TestNode::~TestNode() - { - } +TestNode::TestNode(BundleIO &bundle, Region *region) : RegionImpl(region) { + deserialize(bundle); +} +TestNode::TestNode(AnyPointer::Reader &proto, Region *region) + : RegionImpl(region), computeCallback_(nullptr) +{ + read(proto); +} - void - TestNode::compute() - { - if (computeCallback_ != nullptr) - computeCallback_(getName()); - - const Array & outputArray = bottomUpOut_->getData(); - NTA_CHECK(outputArray.getCount() == nodeCount_ * outputElementCount_); - NTA_CHECK(outputArray.getType() == NTA_BasicType_Real64); - Real64 *baseOutputBuffer = (Real64*) outputArray.getBuffer(); - - // See TestNode.hpp for description of the computation - std::vector nodeInput; - Real64* nodeOutputBuffer; - for (UInt32 node = 0; node < nodeCount_; node++) - { - nodeOutputBuffer = baseOutputBuffer + node * outputElementCount_; - bottomUpIn_->getInputForNode(node, nodeInput); - - // output[0] = number of inputs to this baby node + current iteration number - nodeOutputBuffer[0] = nupic::Real64(nodeInput.size() + iter_); - - // output[n] = node + sum(inputs) + (n-1) * delta - Real64 sum = std::accumulate(nodeInput.begin(), nodeInput.end(), 0.0); - for (size_t i = 1; i < outputElementCount_; i++) - nodeOutputBuffer[i] = node + sum + (i-1)*delta_; - } +TestNode::~TestNode() {} + +void TestNode::compute() { + if (computeCallback_ != nullptr) + computeCallback_(getName()); + + const Array &outputArray = bottomUpOut_->getData(); + NTA_CHECK(outputArray.getCount() == nodeCount_ * outputElementCount_); + NTA_CHECK(outputArray.getType() == NTA_BasicType_Real64); + Real64 *baseOutputBuffer = (Real64 *)outputArray.getBuffer(); - iter_++; + // See TestNode.hpp for description of the computation + std::vector nodeInput; + Real64 *nodeOutputBuffer; + for (UInt32 node = 0; node < nodeCount_; node++) { + nodeOutputBuffer = baseOutputBuffer + node * outputElementCount_; + bottomUpIn_->getInputForNode(node, nodeInput); + // output[0] = number of inputs to this baby node + current iteration number + nodeOutputBuffer[0] = nupic::Real64(nodeInput.size() + iter_); + // output[n] = node + sum(inputs) + (n-1) * delta + Real64 sum = std::accumulate(nodeInput.begin(), nodeInput.end(), 0.0); + for (size_t i = 1; i < outputElementCount_; i++) + nodeOutputBuffer[i] = node + sum + (i - 1) * delta_; } - Spec* - TestNode::createSpec() - { - auto ns = new Spec; - - /* ---- parameters ------ */ - - ns->parameters.add( - "int32Param", - ParameterSpec( - "Int32 scalar parameter", // description - NTA_BasicType_Int32, - 1, // elementCount - "", // constraints - "32", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "uint32Param", - ParameterSpec( - "UInt32 scalar parameter", // description - NTA_BasicType_UInt32, - 1, // elementCount - "", // constraints - "33", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "int64Param", - ParameterSpec( - "Int64 scalar parameter", // description - NTA_BasicType_Int64, - 1, // elementCount - "", // constraints - "64", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "uint64Param", - ParameterSpec( - "UInt64 scalar parameter", // description - NTA_BasicType_UInt64, - 1, // elementCount - "", // constraints - "65", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "real32Param", - ParameterSpec( - "Real32 scalar parameter", // description - NTA_BasicType_Real32, - 1, // elementCount - "", // constraints - "32.1", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "real64Param", - ParameterSpec( - "Real64 scalar parameter", // description - NTA_BasicType_Real64, - 1, // elementCount - "", // constraints - "64.1", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "boolParam", - ParameterSpec( - "bool scalar parameter", // description - NTA_BasicType_Bool, - 1, // elementCount - "", // constraints - "false", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "real32ArrayParam", - ParameterSpec( - "int32 array parameter", - NTA_BasicType_Real32, - 0, // array - "", - "", - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "int64ArrayParam", - ParameterSpec( - "int64 array parameter", - NTA_BasicType_Int64, - 0, // array - "", - "", - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "boolArrayParam", - ParameterSpec( - "bool array parameter", - NTA_BasicType_Bool, - 0, // array - "", - "", - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( + iter_++; +} + +Spec *TestNode::createSpec() { + auto ns = new Spec; + + /* ---- parameters ------ */ + + ns->parameters.add("int32Param", + ParameterSpec("Int32 scalar parameter", // description + NTA_BasicType_Int32, + 1, // elementCount + "", // constraints + "32", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("uint32Param", + ParameterSpec("UInt32 scalar parameter", // description + NTA_BasicType_UInt32, + 1, // elementCount + "", // constraints + "33", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("int64Param", + ParameterSpec("Int64 scalar parameter", // description + NTA_BasicType_Int64, + 1, // elementCount + "", // constraints + "64", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("uint64Param", + ParameterSpec("UInt64 scalar parameter", // description + NTA_BasicType_UInt64, + 1, // elementCount + "", // constraints + "65", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("real32Param", + ParameterSpec("Real32 scalar parameter", // description + NTA_BasicType_Real32, + 1, // elementCount + "", // constraints + "32.1", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("real64Param", + ParameterSpec("Real64 scalar parameter", // description + NTA_BasicType_Real64, + 1, // elementCount + "", // constraints + "64.1", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("boolParam", + ParameterSpec("bool scalar parameter", // description + NTA_BasicType_Bool, + 1, // elementCount + "", // constraints + "false", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("real32ArrayParam", + ParameterSpec("int32 array parameter", + NTA_BasicType_Real32, + 0, // array + "", "", ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("int64ArrayParam", + ParameterSpec("int64 array parameter", NTA_BasicType_Int64, + 0, // array + "", "", ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("boolArrayParam", + ParameterSpec("bool array parameter", NTA_BasicType_Bool, + 0, // array + "", "", ParameterSpec::ReadWriteAccess)); + + ns->parameters.add( "computeCallback", - ParameterSpec( - "address of a function that is called at every compute()", - NTA_BasicType_Handle, - 1, - "", - "", // handles must not have a default value - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( - "stringParam", - ParameterSpec( - "string parameter", - NTA_BasicType_Byte, - 0, // length=0 required for strings - "", - "nodespec value", - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( + ParameterSpec("address of a function that is called at every compute()", + NTA_BasicType_Handle, 1, "", + "", // handles must not have a default value + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("stringParam", + ParameterSpec("string parameter", NTA_BasicType_Byte, + 0, // length=0 required for strings + "", "nodespec value", + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add( "unclonedParam", - ParameterSpec( - "has a separate value for each node", //description - NTA_BasicType_UInt32, - 1, //elementCount - "", //constraints - "", //defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( + ParameterSpec("has a separate value for each node", // description + NTA_BasicType_UInt32, + 1, // elementCount + "", // constraints + "", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add( "shouldCloneParam", - ParameterSpec( - "whether possiblyUnclonedParam should clone", //description - NTA_BasicType_UInt32, - 1, //elementCount - "enum: 0, 1", //constraints - "1", //defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( + ParameterSpec("whether possiblyUnclonedParam should clone", // description + NTA_BasicType_UInt32, + 1, // elementCount + "enum: 0, 1", // constraints + "1", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add( "possiblyUnclonedParam", - ParameterSpec( - "cloned if shouldCloneParam is true", //description - NTA_BasicType_UInt32, - 1, //elementCount - "", //constraints - "", //defaultValue - ParameterSpec::ReadWriteAccess)); - - ns->parameters.add( + ParameterSpec("cloned if shouldCloneParam is true", // description + NTA_BasicType_UInt32, + 1, // elementCount + "", // constraints + "", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add( "unclonedInt64ArrayParam", - ParameterSpec( - "has a separate array for each node", //description - NTA_BasicType_Int64, - 0, //array //elementCount - "", //constraints - "", //defaultValue - ParameterSpec::ReadWriteAccess)); - - - /* ----- inputs ------- */ - ns->inputs.add( - "bottomUpIn", - InputSpec( - "Primary input for the node", - NTA_BasicType_Real64, - 0, // count. omit? - true, // required? - false, // isRegionLevel, - true // isDefaultInput - )); - - /* ----- outputs ------ */ - ns->outputs.add( - "bottomUpOut", - OutputSpec( - "Primary output for the node", - NTA_BasicType_Real64, - 0, // count is dynamic - false, // isRegionLevel - true // isDefaultOutput - )); - - /* ----- commands ------ */ - // commands TBD - - return ns; - } - + ParameterSpec("has a separate array for each node", // description + NTA_BasicType_Int64, + 0, // array //elementCount + "", // constraints + "", // defaultValue + ParameterSpec::ReadWriteAccess)); + + /* ----- inputs ------- */ + ns->inputs.add("bottomUpIn", + InputSpec("Primary input for the node", NTA_BasicType_Real64, + 0, // count. omit? + true, // required? + false, // isRegionLevel, + true // isDefaultInput + )); + + /* ----- outputs ------ */ + ns->outputs.add("bottomUpOut", OutputSpec("Primary output for the node", + NTA_BasicType_Real64, + 0, // count is dynamic + false, // isRegionLevel + true // isDefaultOutput + )); + + /* ----- commands ------ */ + // commands TBD + + return ns; +} - Real64 TestNode::getParameterReal64(const std::string& name, Int64 index) - { - if (name == "real64Param") - { - return real64Param_; - } - else - { - NTA_THROW << "TestNode::getParameter -- unknown parameter " << name; - } +Real64 TestNode::getParameterReal64(const std::string &name, Int64 index) { + if (name == "real64Param") { + return real64Param_; + } else { + NTA_THROW << "TestNode::getParameter -- unknown parameter " << name; } +} - - void TestNode::setParameterReal64(const std::string& name, Int64 index, Real64 value) - { - if (name == "real64Param") - { - real64Param_ = value; - } - else - { - NTA_THROW << "TestNode::setParameter -- unknown parameter " << name; - } +void TestNode::setParameterReal64(const std::string &name, Int64 index, + Real64 value) { + if (name == "real64Param") { + real64Param_ = value; + } else { + NTA_THROW << "TestNode::setParameter -- unknown parameter " << name; } +} - - void TestNode::getParameterFromBuffer(const std::string& name, - Int64 index, - IWriteBuffer& value) - { - if (name == "int32Param") { - value.write(int32Param_); - } else if (name == "uint32Param") { - value.write(uint32Param_); - } else if (name == "int64Param") { - value.write(int64Param_); - } else if (name == "uint64Param") { - value.write(uint64Param_); - } else if (name == "real32Param") { - value.write(real32Param_); - } else if (name == "real64Param") { - value.write(real64Param_); - } else if (name == "boolParam") { - value.write(boolParam_); - } else if (name == "stringParam") { - value.write(stringParam_.c_str(), stringParam_.size()); - } else if (name == "int64ArrayParam") { - for (auto & elem : int64ArrayParam_) - { - value.write(elem); - } - } else if (name == "real32ArrayParam") { - for (auto & elem : real32ArrayParam_) - { - value.write(elem); - } - } else if (name == "unclonedParam") { - if (index < 0) - { - NTA_THROW << "uncloned parameters cannot be accessed at region level"; - } - value.write(unclonedParam_[(UInt)index]); - } else if (name == "shouldCloneParam") { - value.write((UInt32)(shouldCloneParam_ ? 1 : 0)); - } else if (name == "possiblyUnclonedParam") { - if (shouldCloneParam_) - { - value.write(possiblyUnclonedParam_[0]); - } - else - { - if (index < 0) - { - NTA_THROW << "uncloned parameters cannot be accessed at region level"; - } - value.write(possiblyUnclonedParam_[(UInt)index]); - } - } else if (name == "unclonedInt64ArrayParam") { - if (index < 0) - { - NTA_THROW << "uncloned parameters cannot be accessed at region level"; - } - UInt nodeIndex = (UInt)index; - for (auto & elem : unclonedInt64ArrayParam_[nodeIndex]) - { - value.write(elem); - } +void TestNode::getParameterFromBuffer(const std::string &name, Int64 index, + IWriteBuffer &value) { + if (name == "int32Param") { + value.write(int32Param_); + } else if (name == "uint32Param") { + value.write(uint32Param_); + } else if (name == "int64Param") { + value.write(int64Param_); + } else if (name == "uint64Param") { + value.write(uint64Param_); + } else if (name == "real32Param") { + value.write(real32Param_); + } else if (name == "real64Param") { + value.write(real64Param_); + } else if (name == "boolParam") { + value.write(boolParam_); + } else if (name == "stringParam") { + value.write(stringParam_.c_str(), stringParam_.size()); + } else if (name == "int64ArrayParam") { + for (auto &elem : int64ArrayParam_) { + value.write(elem); + } + } else if (name == "real32ArrayParam") { + for (auto &elem : real32ArrayParam_) { + value.write(elem); + } + } else if (name == "unclonedParam") { + if (index < 0) { + NTA_THROW << "uncloned parameters cannot be accessed at region level"; + } + value.write(unclonedParam_[(UInt)index]); + } else if (name == "shouldCloneParam") { + value.write((UInt32)(shouldCloneParam_ ? 1 : 0)); + } else if (name == "possiblyUnclonedParam") { + if (shouldCloneParam_) { + value.write(possiblyUnclonedParam_[0]); } else { - NTA_THROW << "TestNode::getParameter -- Unknown parameter " << name; - } - } - - void TestNode::setParameterFromBuffer(const std::string& name, - Int64 index, - IReadBuffer& value) - { - if (name == "int32Param") { - value.read(int32Param_); - } else if (name == "uint32Param") { - value.read(uint32Param_); - } else if (name == "int64Param") { - value.read(int64Param_); - } else if (name == "uint64Param") { - value.read(uint64Param_); - } else if (name == "real32Param") { - value.read(real32Param_); - } else if (name == "real64Param") { - value.read(real64Param_); - } else if (name == "boolParam") { - value.read(boolParam_); - } else if (name == "stringParam") { - stringParam_ = std::string(value.getData(), value.getSize()); - } else if (name == "int64ArrayParam") { - for (auto & elem : int64ArrayParam_) - { - value.read(elem); - } - } else if (name == "real32ArrayParam") { - for (auto & elem : real32ArrayParam_) - { - value.read(elem); - } - } else if (name == "unclonedParam") { - if (index < 0) - { - NTA_THROW << "uncloned parameters cannot be accessed at region level"; - } - value.read(unclonedParam_[(UInt)index]); - } else if (name == "shouldCloneParam") { - UInt64 ival; - value.read(ival); - shouldCloneParam_ = (ival ? 1 : 0); - } else if (name == "possiblyUnclonedParam") { - if (shouldCloneParam_) - { - value.read(possiblyUnclonedParam_[0]); - } - else - { - if (index < 0) - { - NTA_THROW << "uncloned parameters cannot be accessed at region level"; - } - value.read(possiblyUnclonedParam_[(UInt)index]); - } - } else if (name == "unclonedInt64ArrayParam") { - if (index < 0) - { + if (index < 0) { NTA_THROW << "uncloned parameters cannot be accessed at region level"; } - UInt nodeIndex = (UInt)index; - for (auto & elem : unclonedInt64ArrayParam_[nodeIndex]) - { - value.read(elem); - } - } else if (name == "computeCallback") { - UInt64 ival; - value.read(ival); - computeCallback_ = (computeCallbackFunc)ival; - } else { - NTA_THROW << "TestNode::setParameter -- Unknown parameter " << name; + value.write(possiblyUnclonedParam_[(UInt)index]); } - - } - - size_t TestNode::getParameterArrayCount(const std::string& name, Int64 index) - { - if (name == "int64ArrayParam") - { - return int64ArrayParam_.size(); - } - else if (name == "real32ArrayParam") - { - return real32ArrayParam_.size(); + } else if (name == "unclonedInt64ArrayParam") { + if (index < 0) { + NTA_THROW << "uncloned parameters cannot be accessed at region level"; } - else if (name == "boolArrayParam") - { - return boolArrayParam_.size(); + UInt nodeIndex = (UInt)index; + for (auto &elem : unclonedInt64ArrayParam_[nodeIndex]) { + value.write(elem); } - else if (name == "unclonedInt64ArrayParam") - { - if (index < 0) - { + } else { + NTA_THROW << "TestNode::getParameter -- Unknown parameter " << name; + } +} + +void TestNode::setParameterFromBuffer(const std::string &name, Int64 index, + IReadBuffer &value) { + if (name == "int32Param") { + value.read(int32Param_); + } else if (name == "uint32Param") { + value.read(uint32Param_); + } else if (name == "int64Param") { + value.read(int64Param_); + } else if (name == "uint64Param") { + value.read(uint64Param_); + } else if (name == "real32Param") { + value.read(real32Param_); + } else if (name == "real64Param") { + value.read(real64Param_); + } else if (name == "boolParam") { + value.read(boolParam_); + } else if (name == "stringParam") { + stringParam_ = std::string(value.getData(), value.getSize()); + } else if (name == "int64ArrayParam") { + for (auto &elem : int64ArrayParam_) { + value.read(elem); + } + } else if (name == "real32ArrayParam") { + for (auto &elem : real32ArrayParam_) { + value.read(elem); + } + } else if (name == "unclonedParam") { + if (index < 0) { + NTA_THROW << "uncloned parameters cannot be accessed at region level"; + } + value.read(unclonedParam_[(UInt)index]); + } else if (name == "shouldCloneParam") { + UInt64 ival; + value.read(ival); + shouldCloneParam_ = (ival ? 1 : 0); + } else if (name == "possiblyUnclonedParam") { + if (shouldCloneParam_) { + value.read(possiblyUnclonedParam_[0]); + } else { + if (index < 0) { NTA_THROW << "uncloned parameters cannot be accessed at region level"; } - return unclonedInt64ArrayParam_[(UInt)index].size(); - } - else - { - NTA_THROW << "TestNode::getParameterArrayCount -- unknown parameter " << name; - } + value.read(possiblyUnclonedParam_[(UInt)index]); + } + } else if (name == "unclonedInt64ArrayParam") { + if (index < 0) { + NTA_THROW << "uncloned parameters cannot be accessed at region level"; + } + UInt nodeIndex = (UInt)index; + for (auto &elem : unclonedInt64ArrayParam_[nodeIndex]) { + value.read(elem); + } + } else if (name == "computeCallback") { + UInt64 ival; + value.read(ival); + computeCallback_ = (computeCallbackFunc)ival; + } else { + NTA_THROW << "TestNode::setParameter -- Unknown parameter " << name; } +} +size_t TestNode::getParameterArrayCount(const std::string &name, Int64 index) { + if (name == "int64ArrayParam") { + return int64ArrayParam_.size(); + } else if (name == "real32ArrayParam") { + return real32ArrayParam_.size(); + } else if (name == "boolArrayParam") { + return boolArrayParam_.size(); + } else if (name == "unclonedInt64ArrayParam") { + if (index < 0) { + NTA_THROW << "uncloned parameters cannot be accessed at region level"; + } + return unclonedInt64ArrayParam_[(UInt)index].size(); + } else { + NTA_THROW << "TestNode::getParameterArrayCount -- unknown parameter " + << name; + } +} - void TestNode::initialize() - { - nodeCount_ = getDimensions().getCount(); - bottomUpOut_ = getOutput("bottomUpOut"); - bottomUpIn_ = getInput("bottomUpIn"); - - unclonedParam_.resize(nodeCount_); - for (unsigned int i = 1; i < nodeCount_; i++) - { - unclonedParam_[i] = unclonedParam_[0]; - } +void TestNode::initialize() { + nodeCount_ = getDimensions().getCount(); + bottomUpOut_ = getOutput("bottomUpOut"); + bottomUpIn_ = getInput("bottomUpIn"); - if (! shouldCloneParam_) - { - possiblyUnclonedParam_.resize(nodeCount_); - for (unsigned int i = 1; i < nodeCount_; i++) - { - possiblyUnclonedParam_[i] = possiblyUnclonedParam_[0]; - } - } + unclonedParam_.resize(nodeCount_); + for (unsigned int i = 1; i < nodeCount_; i++) { + unclonedParam_[i] = unclonedParam_[0]; + } - unclonedInt64ArrayParam_.resize(nodeCount_); - std::vector v(4, 0); //length 4 vector, each element == 0 - for (unsigned int i = 1; i < nodeCount_; i++) - { - unclonedInt64ArrayParam_[i] = v; + if (!shouldCloneParam_) { + possiblyUnclonedParam_.resize(nodeCount_); + for (unsigned int i = 1; i < nodeCount_; i++) { + possiblyUnclonedParam_[i] = possiblyUnclonedParam_[0]; } } + unclonedInt64ArrayParam_.resize(nodeCount_); + std::vector v(4, 0); // length 4 vector, each element == 0 + for (unsigned int i = 1; i < nodeCount_; i++) { + unclonedInt64ArrayParam_[i] = v; + } +} // This is the per-node output size - size_t TestNode::getNodeOutputElementCount(const std::string& outputName) - { - if (outputName == "bottomUpOut") - { - return outputElementCount_; - } - NTA_THROW << "TestNode::getOutputSize -- unknown output " << outputName; +size_t TestNode::getNodeOutputElementCount(const std::string &outputName) { + if (outputName == "bottomUpOut") { + return outputElementCount_; } + NTA_THROW << "TestNode::getOutputSize -- unknown output " << outputName; +} - std::string TestNode::executeCommand(const std::vector& args, Int64 index) - { - return ""; +std::string TestNode::executeCommand(const std::vector &args, + Int64 index) { + return ""; +} + +bool TestNode::isParameterShared(const std::string &name) { + if ((name == "int32Param") || (name == "uint32Param") || + (name == "int64Param") || (name == "uint64Param") || + (name == "real32Param") || (name == "real64Param") || + (name == "boolParam") || (name == "stringParam") || + (name == "int64ArrayParam") || (name == "real32ArrayParam") || + (name == "boolArrayParam") || (name == "shouldCloneParam")) { + return true; + } else if ((name == "unclonedParam") || (name == "unclonedInt64ArrayParam")) { + return false; + } else if (name == "possiblyUnclonedParam") { + return shouldCloneParam_; + } else { + NTA_THROW << "TestNode::isParameterShared -- Unknown parameter " << name; } +} - bool TestNode::isParameterShared(const std::string& name) - { - if ((name == "int32Param") || - (name == "uint32Param") || - (name == "int64Param") || - (name == "uint64Param") || - (name == "real32Param") || - (name == "real64Param") || - (name == "boolParam") || - (name == "stringParam") || - (name == "int64ArrayParam") || - (name == "real32ArrayParam") || - (name == "boolArrayParam") || - (name == "shouldCloneParam")) { - return true; - } else if ((name == "unclonedParam") || - (name == "unclonedInt64ArrayParam")) { - return false; - } else if (name == "possiblyUnclonedParam") { - return shouldCloneParam_; - } else { - NTA_THROW << "TestNode::isParameterShared -- Unknown parameter " << name; - } +template +static void arrayOut(std::ostream &s, const std::vector &array, + const std::string &name) { + s << "ARRAY_" << name << " "; + s << array.size() << " "; + for (auto elem : array) { + s << elem << " "; } +} - template static void - arrayOut(std::ostream& s, const std::vector& array, const std::string& name) - { - s << "ARRAY_" << name << " "; - s << array.size() << " "; - for (auto elem : array) - { - s << elem << " "; - } +template +static void arrayIn(std::istream &s, std::vector &array, + const std::string &name) { + std::string expectedCookie = std::string("ARRAY_") + name; + std::string cookie; + s >> cookie; + if (cookie != expectedCookie) + NTA_THROW << "Bad cookie '" << cookie + << "' for serialized array. Expected '" << expectedCookie << "'"; + size_t sz; + s >> sz; + array.resize(sz); + for (size_t ix = 0; ix < sz; ix++) { + s >> array[ix]; } +} - template static void - arrayIn(std::istream& s, std::vector& array, const std::string& name) +void TestNode::serialize(BundleIO &bundle) { { - std::string expectedCookie = std::string("ARRAY_") + name; - std::string cookie; - s >> cookie; - if (cookie != expectedCookie) - NTA_THROW << "Bad cookie '" << cookie << "' for serialized array. Expected '" << expectedCookie << "'"; - size_t sz; - s >> sz; - array.resize(sz); - for (size_t ix = 0; ix < sz; ix++) - { - s >> array[ix]; - } + std::ofstream &f = bundle.getOutputStream("main"); + // There is more than one way to do this. We could serialize to YAML, which + // would make a readable format, or we could serialize directly to the + // stream Choose the easier one. + f << "TestNode-v2" + << " " << nodeCount_ << " " << int32Param_ << " " << uint32Param_ << " " + << int64Param_ << " " << uint64Param_ << " " << real32Param_ << " " + << real64Param_ << " " << boolParam_ << " " << outputElementCount_ << " " + << delta_ << " " << iter_ << " "; + + arrayOut(f, real32ArrayParam_, "real32ArrayParam_"); + arrayOut(f, int64ArrayParam_, "int64ArrayParam_"); + arrayOut(f, boolArrayParam_, "boolArrayParam_"); + arrayOut(f, unclonedParam_, "unclonedParam_"); + f << shouldCloneParam_ << " "; + + // outer vector needs to be done by hand. + f << "unclonedArray "; + f << unclonedInt64ArrayParam_.size() << " "; // number of nodes + for (size_t i = 0; i < unclonedInt64ArrayParam_.size(); i++) { + std::stringstream name; + name << "unclonedInt64ArrayParam[" << i << "]"; + arrayOut(f, unclonedInt64ArrayParam_[i], name.str()); + } + f.close(); + } // main file + + // auxilliary file using stream + { + std::ofstream &f = bundle.getOutputStream("aux"); + f << "This is an auxilliary file!\n"; + f.close(); } - - void TestNode::serialize(BundleIO& bundle) + // auxilliary file using path { - { - std::ofstream& f = bundle.getOutputStream("main"); - // There is more than one way to do this. We could serialize to YAML, which - // would make a readable format, or we could serialize directly to the stream - // Choose the easier one. - f << "TestNode-v2" << " " - << nodeCount_ << " " - << int32Param_ << " " - << uint32Param_ << " " - << int64Param_ << " " - << uint64Param_ << " " - << real32Param_ << " " - << real64Param_ << " " - << boolParam_ << " " - << outputElementCount_ << " " - << delta_ << " " - << iter_ << " "; - - arrayOut(f, real32ArrayParam_, "real32ArrayParam_"); - arrayOut(f, int64ArrayParam_, "int64ArrayParam_"); - arrayOut(f, boolArrayParam_, "boolArrayParam_"); - arrayOut(f, unclonedParam_, "unclonedParam_"); - f << shouldCloneParam_ << " "; - - // outer vector needs to be done by hand. - f << "unclonedArray "; - f << unclonedInt64ArrayParam_.size() << " "; // number of nodes - for (size_t i = 0; i < unclonedInt64ArrayParam_.size(); i++) - { - std::stringstream name; - name << "unclonedInt64ArrayParam[" << i << "]"; - arrayOut(f, unclonedInt64ArrayParam_[i], name.str()); - } - f.close(); - } // main file - + std::string path = bundle.getPath("aux2"); + std::ofstream f(path.c_str()); + f << "This is another auxilliary file!\n"; + f.close(); + } +} - // auxilliary file using stream - { - std::ofstream& f = bundle.getOutputStream("aux"); - f << "This is an auxilliary file!\n"; - f.close(); +void TestNode::deserialize(BundleIO &bundle) { + { + std::ifstream &f = bundle.getInputStream("main"); + // There is more than one way to do this. We could serialize to YAML, which + // would make a readable format, or we could serialize directly to the + // stream Choose the easier one. + std::string versionString; + f >> versionString; + if (versionString != "TestNode-v2") { + NTA_THROW << "Bad serialization for region '" << region_->getName() + << "' of type TestNode. Main serialization file must start " + << "with \"TestNode-v2\" but instead it starts with '" + << versionString << "'"; + } + f >> nodeCount_; + f >> int32Param_; + f >> uint32Param_; + f >> int64Param_; + f >> uint64Param_; + f >> real32Param_; + f >> real64Param_; + f >> boolParam_; + f >> outputElementCount_; + f >> delta_; + f >> iter_; + + arrayIn(f, real32ArrayParam_, "real32ArrayParam_"); + arrayIn(f, int64ArrayParam_, "int64ArrayParam_"); + arrayIn(f, int64ArrayParam_, "boolArrayParam_"); + arrayIn(f, unclonedParam_, "unclonedParam_"); + + f >> shouldCloneParam_; + + std::string label; + f >> label; + if (label != "unclonedArray") + NTA_THROW << "Missing label for uncloned array. Got '" << label << "'"; + size_t vecsize; + f >> vecsize; + unclonedInt64ArrayParam_.clear(); + unclonedInt64ArrayParam_.resize(vecsize); + for (size_t i = 0; i < vecsize; i++) { + std::stringstream name; + name << "unclonedInt64ArrayParam[" << i << "]"; + arrayIn(f, unclonedInt64ArrayParam_[i], name.str()); } + f.close(); + } // main file - // auxilliary file using path - { - std::string path = bundle.getPath("aux2"); - std::ofstream f(path.c_str()); - f << "This is another auxilliary file!\n"; - f.close(); - } + // auxilliary file using stream + { + std::ifstream &f = bundle.getInputStream("aux"); + char line1[100]; + f.read(line1, 100); + line1[f.gcount()] = '\0'; + if (std::string(line1) != "This is an auxilliary file!\n") { + NTA_THROW << "Invalid auxilliary serialization file for TestNode"; + } + f.close(); } - - void TestNode::deserialize(BundleIO& bundle) + // auxilliary file using path { - { - std::ifstream& f = bundle.getInputStream("main"); - // There is more than one way to do this. We could serialize to YAML, which - // would make a readable format, or we could serialize directly to the stream - // Choose the easier one. - std::string versionString; - f >> versionString; - if (versionString != "TestNode-v2") - { - NTA_THROW << "Bad serialization for region '" << region_->getName() - << "' of type TestNode. Main serialization file must start " - << "with \"TestNode-v2\" but instead it starts with '" - << versionString << "'"; - - } - f >> nodeCount_; - f >> int32Param_; - f >> uint32Param_; - f >> int64Param_; - f >> uint64Param_; - f >> real32Param_; - f >> real64Param_; - f >> boolParam_; - f >> outputElementCount_; - f >> delta_; - f >> iter_; - - arrayIn(f, real32ArrayParam_, "real32ArrayParam_"); - arrayIn(f, int64ArrayParam_, "int64ArrayParam_"); - arrayIn(f, int64ArrayParam_, "boolArrayParam_"); - arrayIn(f, unclonedParam_, "unclonedParam_"); - - f >> shouldCloneParam_; - - std::string label; - f >> label; - if (label != "unclonedArray") - NTA_THROW << "Missing label for uncloned array. Got '" << label << "'"; - size_t vecsize; - f >> vecsize; - unclonedInt64ArrayParam_.clear(); - unclonedInt64ArrayParam_.resize(vecsize); - for (size_t i = 0; i < vecsize; i++) - { - std::stringstream name; - name << "unclonedInt64ArrayParam[" << i << "]"; - arrayIn(f, unclonedInt64ArrayParam_[i], name.str()); - } - f.close(); - } // main file - - // auxilliary file using stream - { - std::ifstream& f = bundle.getInputStream("aux"); - char line1[100]; - f.read(line1, 100); - line1[f.gcount()] = '\0'; - if (std::string(line1) != "This is an auxilliary file!\n") - { - NTA_THROW << "Invalid auxilliary serialization file for TestNode"; - } - f.close(); + std::string path = bundle.getPath("aux2"); + std::ifstream f(path.c_str()); + char line1[100]; + f.read(line1, 100); + line1[f.gcount()] = '\0'; + if (std::string(line1) != "This is another auxilliary file!\n") { + NTA_THROW << "Invalid auxilliary2 serialization file for TestNode"; } - // auxilliary file using path - { - std::string path = bundle.getPath("aux2"); - std::ifstream f(path.c_str()); - char line1[100]; - f.read(line1, 100); - line1[f.gcount()] = '\0'; - if (std::string(line1) != "This is another auxilliary file!\n") - { - NTA_THROW << "Invalid auxilliary2 serialization file for TestNode"; - } - - f.close(); - } + f.close(); } +} +void TestNode::write(AnyPointer::Builder &anyProto) const { + TestNodeProto::Builder proto = anyProto.getAs(); + + proto.setInt32Param(int32Param_); + proto.setUint32Param(uint32Param_); + proto.setInt64Param(int64Param_); + proto.setUint64Param(uint64Param_); + proto.setReal32Param(real32Param_); + proto.setReal64Param(real64Param_); + proto.setBoolParam(boolParam_); + proto.setStringParam(stringParam_.c_str()); + + auto real32ArrayProto = proto.initReal32ArrayParam(real32ArrayParam_.size()); + for (UInt i = 0; i < real32ArrayParam_.size(); i++) { + real32ArrayProto.set(i, real32ArrayParam_[i]); + } - void TestNode::write(AnyPointer::Builder& anyProto) const - { - TestNodeProto::Builder proto = anyProto.getAs(); - - proto.setInt32Param(int32Param_); - proto.setUint32Param(uint32Param_); - proto.setInt64Param(int64Param_); - proto.setUint64Param(uint64Param_); - proto.setReal32Param(real32Param_); - proto.setReal64Param(real64Param_); - proto.setBoolParam(boolParam_); - proto.setStringParam(stringParam_.c_str()); - - auto real32ArrayProto = - proto.initReal32ArrayParam(real32ArrayParam_.size()); - for (UInt i = 0; i < real32ArrayParam_.size(); i++) - { - real32ArrayProto.set(i, real32ArrayParam_[i]); - } - - auto int64ArrayProto = proto.initInt64ArrayParam(int64ArrayParam_.size()); - for (UInt i = 0; i < int64ArrayParam_.size(); i++) - { - int64ArrayProto.set(i, int64ArrayParam_[i]); - } + auto int64ArrayProto = proto.initInt64ArrayParam(int64ArrayParam_.size()); + for (UInt i = 0; i < int64ArrayParam_.size(); i++) { + int64ArrayProto.set(i, int64ArrayParam_[i]); + } - auto boolArrayProto = proto.initBoolArrayParam(boolArrayParam_.size()); - for (UInt i = 0; i < boolArrayParam_.size(); i++) - { - boolArrayProto.set(i, boolArrayParam_[i]); - } + auto boolArrayProto = proto.initBoolArrayParam(boolArrayParam_.size()); + for (UInt i = 0; i < boolArrayParam_.size(); i++) { + boolArrayProto.set(i, boolArrayParam_[i]); + } - proto.setIterations(iter_); - proto.setOutputElementCount(outputElementCount_); - proto.setDelta(delta_); + proto.setIterations(iter_); + proto.setOutputElementCount(outputElementCount_); + proto.setDelta(delta_); - proto.setShouldCloneParam(shouldCloneParam_); + proto.setShouldCloneParam(shouldCloneParam_); - auto unclonedParamProto = proto.initUnclonedParam(unclonedParam_.size()); - for (UInt i = 0; i < unclonedParam_.size(); i++) - { - unclonedParamProto.set(i, unclonedParam_[i]); - } + auto unclonedParamProto = proto.initUnclonedParam(unclonedParam_.size()); + for (UInt i = 0; i < unclonedParam_.size(); i++) { + unclonedParamProto.set(i, unclonedParam_[i]); + } - auto unclonedInt64ArrayParamProto = - proto.initUnclonedInt64ArrayParam(unclonedInt64ArrayParam_.size()); - for (UInt i = 0; i < unclonedInt64ArrayParam_.size(); i++) - { - auto innerUnclonedParamProto = - unclonedInt64ArrayParamProto.init( - i, unclonedInt64ArrayParam_[i].size()); - for (UInt j = 0; j < unclonedInt64ArrayParam_[i].size(); j++) - { - innerUnclonedParamProto.set(j, unclonedInt64ArrayParam_[i][j]); - } + auto unclonedInt64ArrayParamProto = + proto.initUnclonedInt64ArrayParam(unclonedInt64ArrayParam_.size()); + for (UInt i = 0; i < unclonedInt64ArrayParam_.size(); i++) { + auto innerUnclonedParamProto = unclonedInt64ArrayParamProto.init( + i, unclonedInt64ArrayParam_[i].size()); + for (UInt j = 0; j < unclonedInt64ArrayParam_[i].size(); j++) { + innerUnclonedParamProto.set(j, unclonedInt64ArrayParam_[i][j]); } - - proto.setNodeCount(nodeCount_); } + proto.setNodeCount(nodeCount_); +} - void TestNode::read(AnyPointer::Reader& anyProto) - { - TestNodeProto::Reader proto = anyProto.getAs(); - - int32Param_ = proto.getInt32Param(); - uint32Param_ = proto.getUint32Param(); - int64Param_ = proto.getInt64Param(); - uint64Param_ = proto.getUint64Param(); - real32Param_ = proto.getReal32Param(); - real64Param_ = proto.getReal64Param(); - boolParam_ = proto.getBoolParam(); - stringParam_ = proto.getStringParam().cStr(); - - real32ArrayParam_.clear(); - auto real32ArrayParamProto = proto.getReal32ArrayParam(); - real32ArrayParam_.resize(real32ArrayParamProto.size()); - for (UInt i = 0; i < real32ArrayParamProto.size(); i++) - { - real32ArrayParam_[i] = real32ArrayParamProto[i]; - } +void TestNode::read(AnyPointer::Reader &anyProto) { + TestNodeProto::Reader proto = anyProto.getAs(); + + int32Param_ = proto.getInt32Param(); + uint32Param_ = proto.getUint32Param(); + int64Param_ = proto.getInt64Param(); + uint64Param_ = proto.getUint64Param(); + real32Param_ = proto.getReal32Param(); + real64Param_ = proto.getReal64Param(); + boolParam_ = proto.getBoolParam(); + stringParam_ = proto.getStringParam().cStr(); + + real32ArrayParam_.clear(); + auto real32ArrayParamProto = proto.getReal32ArrayParam(); + real32ArrayParam_.resize(real32ArrayParamProto.size()); + for (UInt i = 0; i < real32ArrayParamProto.size(); i++) { + real32ArrayParam_[i] = real32ArrayParamProto[i]; + } - int64ArrayParam_.clear(); - auto int64ArrayParamProto = proto.getInt64ArrayParam(); - int64ArrayParam_.resize(int64ArrayParamProto.size()); - for (UInt i = 0; i < int64ArrayParamProto.size(); i++) - { - int64ArrayParam_[i] = int64ArrayParamProto[i]; - } + int64ArrayParam_.clear(); + auto int64ArrayParamProto = proto.getInt64ArrayParam(); + int64ArrayParam_.resize(int64ArrayParamProto.size()); + for (UInt i = 0; i < int64ArrayParamProto.size(); i++) { + int64ArrayParam_[i] = int64ArrayParamProto[i]; + } - boolArrayParam_.clear(); - auto boolArrayParamProto = proto.getBoolArrayParam(); - boolArrayParam_.resize(boolArrayParamProto.size()); - for (UInt i = 0; i < boolArrayParamProto.size(); i++) - { - boolArrayParam_[i] = boolArrayParamProto[i]; - } + boolArrayParam_.clear(); + auto boolArrayParamProto = proto.getBoolArrayParam(); + boolArrayParam_.resize(boolArrayParamProto.size()); + for (UInt i = 0; i < boolArrayParamProto.size(); i++) { + boolArrayParam_[i] = boolArrayParamProto[i]; + } - iter_ = proto.getIterations(); - outputElementCount_ = proto.getOutputElementCount(); - delta_ = proto.getDelta(); + iter_ = proto.getIterations(); + outputElementCount_ = proto.getOutputElementCount(); + delta_ = proto.getDelta(); - shouldCloneParam_ = proto.getShouldCloneParam(); + shouldCloneParam_ = proto.getShouldCloneParam(); - unclonedParam_.clear(); - auto unclonedParamProto = proto.getUnclonedParam(); - unclonedParam_.resize(unclonedParamProto.size()); - for (UInt i = 0; i < unclonedParamProto.size(); i++) - { - unclonedParam_[i] = unclonedParamProto[i]; - } + unclonedParam_.clear(); + auto unclonedParamProto = proto.getUnclonedParam(); + unclonedParam_.resize(unclonedParamProto.size()); + for (UInt i = 0; i < unclonedParamProto.size(); i++) { + unclonedParam_[i] = unclonedParamProto[i]; + } - unclonedInt64ArrayParam_.clear(); - auto unclonedInt64ArrayProto = proto.getUnclonedInt64ArrayParam(); - unclonedInt64ArrayParam_.resize(unclonedInt64ArrayProto.size()); - for (UInt i = 0; i < unclonedInt64ArrayProto.size(); i++) - { - auto innerProto = unclonedInt64ArrayProto[i]; - unclonedInt64ArrayParam_[i].resize(innerProto.size()); - for (UInt j = 0; j < innerProto.size(); j++) - { - unclonedInt64ArrayParam_[i][j] = innerProto[j]; - } + unclonedInt64ArrayParam_.clear(); + auto unclonedInt64ArrayProto = proto.getUnclonedInt64ArrayParam(); + unclonedInt64ArrayParam_.resize(unclonedInt64ArrayProto.size()); + for (UInt i = 0; i < unclonedInt64ArrayProto.size(); i++) { + auto innerProto = unclonedInt64ArrayProto[i]; + unclonedInt64ArrayParam_[i].resize(innerProto.size()); + for (UInt j = 0; j < innerProto.size(); j++) { + unclonedInt64ArrayParam_[i][j] = innerProto[j]; } - - nodeCount_ = proto.getNodeCount(); } + nodeCount_ = proto.getNodeCount(); } + +} // namespace nupic diff --git a/src/nupic/engine/TestNode.hpp b/src/nupic/engine/TestNode.hpp index 38682f505b..ce91b61dab 100644 --- a/src/nupic/engine/TestNode.hpp +++ b/src/nupic/engine/TestNode.hpp @@ -20,11 +20,9 @@ * --------------------------------------------------------------------- */ - #ifndef NTA_TESTNODE_HPP #define NTA_TESTNODE_HPP - #include #include @@ -36,120 +34,123 @@ #include #include -namespace nupic -{ - - /* - * TestNode is does simple computations of inputs->outputs - * inputs and outputs are Real64 arrays - * - * delta is a parameter used for the computation. defaults to 1 - * - * Size of each node output is given by the outputSize parameter (cg) - * which defaults to 2 and cannot be less than 1. (parameter not yet implemented) - * - * Here is the totally lame "computation" - * output[0] = number of inputs to this baby node + current iteration number (0 for first compute) - * output[1] = baby node num + sum of inputs to this baby node - * output[2] = baby node num + sum of inputs + (delta) - * output[3] = baby node num + sum of inputs + (2*delta) - * ... - * output[n] = baby node num + sum of inputs + ((n-1) * delta) - - * It can act as a sensor if no inputs are connected (sum of inputs = 0) - */ - - class BundleIO; - - class TestNode : public RegionImpl - { - public: - typedef void (*computeCallbackFunc)(const std::string&); - TestNode(const ValueMap& params, Region *region); - TestNode(BundleIO& bundle, Region* region); - TestNode(capnp::AnyPointer::Reader& proto, Region* region); - virtual ~TestNode(); - - /* ----------- Required RegionImpl Interface methods ------- */ - - // Used by RegionImplFactory to create and cache - // a nodespec. Ownership is transferred to the caller. - static Spec* createSpec(); - - std::string getNodeType() { return "TestNode"; }; - void compute() override; - std::string executeCommand(const std::vector& args, Int64 index) override; - - size_t getNodeOutputElementCount(const std::string& outputName) override; - void getParameterFromBuffer(const std::string& name, Int64 index, IWriteBuffer& value) override; - void setParameterFromBuffer(const std::string& name, Int64 index, IReadBuffer& value) override; - - void initialize() override; - - void serialize(BundleIO& bundle) override; - void deserialize(BundleIO& bundle) override; - - using RegionImpl::write; - virtual void write(capnp::AnyPointer::Builder& anyProto) const override; - - using RegionImpl::read; - virtual void read(capnp::AnyPointer::Reader& anyProto) override; - - /* ----------- Optional RegionImpl Interface methods ------- */ - - size_t getParameterArrayCount(const std::string& name, Int64 index) override; - - // Override for Real64 only - // We choose Real64 in the test node to preserve precision. All other type - // go through read/write buffer serialization, and floating point values may get - // truncated in the conversion to/from ascii. - Real64 getParameterReal64(const std::string& name, Int64 index) override; - void setParameterReal64(const std::string& name, Int64 index, Real64 value) override; - - bool isParameterShared(const std::string& name) override; - - private: - TestNode(); - - // parameters - // cgs parameters for parameter testing - Int32 int32Param_; - UInt32 uint32Param_; - Int64 int64Param_; - UInt64 uint64Param_; - Real32 real32Param_; - Real64 real64Param_; - bool boolParam_; - std::string stringParam_; - computeCallbackFunc computeCallback_; - - std::vector real32ArrayParam_; - std::vector int64ArrayParam_; - std::vector boolArrayParam_; - - // read-only count of iterations since initialization - UInt64 iter_; - - // Constructor param specifying per-node output size - UInt32 outputElementCount_; - - // parameter used for computation - Int64 delta_; - - // cloning parameters - std::vector unclonedParam_; - bool shouldCloneParam_; - std::vector possiblyUnclonedParam_; - std::vector< std::vector > unclonedInt64ArrayParam_; - - /* ----- cached info from region ----- */ - size_t nodeCount_; - - // Input/output buffers for the whole region - const Input *bottomUpIn_; - const Output *bottomUpOut_; - }; -} +namespace nupic { -#endif // NTA_TESTNODE_HPP +/* + * TestNode is does simple computations of inputs->outputs + * inputs and outputs are Real64 arrays + * + * delta is a parameter used for the computation. defaults to 1 + * + * Size of each node output is given by the outputSize parameter (cg) + * which defaults to 2 and cannot be less than 1. (parameter not yet + implemented) + * + * Here is the totally lame "computation" + * output[0] = number of inputs to this baby node + current iteration number (0 + for first compute) + * output[1] = baby node num + sum of inputs to this baby node + * output[2] = baby node num + sum of inputs + (delta) + * output[3] = baby node num + sum of inputs + (2*delta) + * ... + * output[n] = baby node num + sum of inputs + ((n-1) * delta) + + * It can act as a sensor if no inputs are connected (sum of inputs = 0) + */ + +class BundleIO; + +class TestNode : public RegionImpl { +public: + typedef void (*computeCallbackFunc)(const std::string &); + TestNode(const ValueMap ¶ms, Region *region); + TestNode(BundleIO &bundle, Region *region); + TestNode(capnp::AnyPointer::Reader &proto, Region *region); + virtual ~TestNode(); + + /* ----------- Required RegionImpl Interface methods ------- */ + + // Used by RegionImplFactory to create and cache + // a nodespec. Ownership is transferred to the caller. + static Spec *createSpec(); + + std::string getNodeType() { return "TestNode"; }; + void compute() override; + std::string executeCommand(const std::vector &args, + Int64 index) override; + + size_t getNodeOutputElementCount(const std::string &outputName) override; + void getParameterFromBuffer(const std::string &name, Int64 index, + IWriteBuffer &value) override; + void setParameterFromBuffer(const std::string &name, Int64 index, + IReadBuffer &value) override; + + void initialize() override; + + void serialize(BundleIO &bundle) override; + void deserialize(BundleIO &bundle) override; + + using RegionImpl::write; + virtual void write(capnp::AnyPointer::Builder &anyProto) const override; + using RegionImpl::read; + virtual void read(capnp::AnyPointer::Reader &anyProto) override; + + /* ----------- Optional RegionImpl Interface methods ------- */ + + size_t getParameterArrayCount(const std::string &name, Int64 index) override; + + // Override for Real64 only + // We choose Real64 in the test node to preserve precision. All other type + // go through read/write buffer serialization, and floating point values may + // get truncated in the conversion to/from ascii. + Real64 getParameterReal64(const std::string &name, Int64 index) override; + void setParameterReal64(const std::string &name, Int64 index, + Real64 value) override; + + bool isParameterShared(const std::string &name) override; + +private: + TestNode(); + + // parameters + // cgs parameters for parameter testing + Int32 int32Param_; + UInt32 uint32Param_; + Int64 int64Param_; + UInt64 uint64Param_; + Real32 real32Param_; + Real64 real64Param_; + bool boolParam_; + std::string stringParam_; + computeCallbackFunc computeCallback_; + + std::vector real32ArrayParam_; + std::vector int64ArrayParam_; + std::vector boolArrayParam_; + + // read-only count of iterations since initialization + UInt64 iter_; + + // Constructor param specifying per-node output size + UInt32 outputElementCount_; + + // parameter used for computation + Int64 delta_; + + // cloning parameters + std::vector unclonedParam_; + bool shouldCloneParam_; + std::vector possiblyUnclonedParam_; + std::vector> unclonedInt64ArrayParam_; + + /* ----- cached info from region ----- */ + size_t nodeCount_; + + // Input/output buffers for the whole region + const Input *bottomUpIn_; + const Output *bottomUpOut_; +}; +} // namespace nupic + +#endif // NTA_TESTNODE_HPP diff --git a/src/nupic/engine/UniformLinkPolicy.cpp b/src/nupic/engine/UniformLinkPolicy.cpp index 1e3260ec6d..cb21cc5c99 100644 --- a/src/nupic/engine/UniformLinkPolicy.cpp +++ b/src/nupic/engine/UniformLinkPolicy.cpp @@ -20,34 +20,26 @@ * --------------------------------------------------------------------- */ - -#include #include +#include -#include #include +#include #include #include -#include #include +#include #include -namespace nupic -{ - +namespace nupic { // used to detect an uninitialized value static const size_t uninitializedElementCount = 987654321; - -UniformLinkPolicy::UniformLinkPolicy(const std::string params, - Link* link) : - link_(link), - elementCount_(uninitializedElementCount), - parameterDimensionality_(0), - initialized_(false) -{ +UniformLinkPolicy::UniformLinkPolicy(const std::string params, Link *link) + : link_(link), elementCount_(uninitializedElementCount), + parameterDimensionality_(0), initialized_(false) { setValidParameters(); readParameters(params); validateParameterDimensionality(); @@ -55,87 +47,57 @@ UniformLinkPolicy::UniformLinkPolicy(const std::string params, validateParameterConsistency(); } -UniformLinkPolicy::~UniformLinkPolicy() -{ -} +UniformLinkPolicy::~UniformLinkPolicy() {} -void UniformLinkPolicy::readParameters(const std::string& params) -{ +void UniformLinkPolicy::readParameters(const std::string ¶ms) { ValueMap paramMap = YAMLUtils::toValueMap(params.c_str(), parameters_); boost::shared_ptr mappingStr = paramMap.getString("mapping"); - if(*mappingStr == "in") - { + if (*mappingStr == "in") { mapping_ = inMapping; - } - else if(*mappingStr == "out") - { + } else if (*mappingStr == "out") { mapping_ = outMapping; - } - else if(*mappingStr == "full") - { + } else if (*mappingStr == "full") { mapping_ = fullMapping; - } - else - { + } else { NTA_THROW << "Internal error: ParameterSpec constraint not enforced, " - "Invalid mapping type utilized with UniformLinkPolicy."; + "Invalid mapping type utilized with UniformLinkPolicy."; } - populateArrayParamVector(rfSize_, - paramMap, - "rfSize"); + populateArrayParamVector(rfSize_, paramMap, "rfSize"); - populateArrayParamVector(rfOverlap_, - paramMap, - "rfOverlap"); + populateArrayParamVector(rfOverlap_, paramMap, "rfOverlap"); boost::shared_ptr rfGranularityStr = - paramMap.getString("rfGranularity"); + paramMap.getString("rfGranularity"); - if(*rfGranularityStr == "nodes") - { + if (*rfGranularityStr == "nodes") { rfGranularity_ = nodesGranularity; - } - else if (*rfGranularityStr == "elements") - { + } else if (*rfGranularityStr == "elements") { rfGranularity_ = elementsGranularity; - } - else - { + } else { NTA_THROW << "Internal error: ParameterSpec constraint not enforced, " - "Invalid rfGranularity type utilized with " - "UniformLinkPolicy."; + "Invalid rfGranularity type utilized with " + "UniformLinkPolicy."; } - populateArrayParamVector(overhang_, - paramMap, - "overhang"); + populateArrayParamVector(overhang_, paramMap, "overhang"); - populateArrayParamVector(overhangType_, - paramMap, + populateArrayParamVector(overhangType_, paramMap, "overhangType"); - populateArrayParamVector(span_, - paramMap, - "span"); + populateArrayParamVector(span_, paramMap, "span"); - boost::shared_ptr strictStr = - paramMap.getString("strict"); + boost::shared_ptr strictStr = paramMap.getString("strict"); - if(*strictStr == "true") - { + if (*strictStr == "true") { strict_ = true; - } - else if(*strictStr == "false") - { + } else if (*strictStr == "false") { strict_ = false; - } - else - { + } else { NTA_THROW << "Internal error: ParameterSpec constraint not enforced, " - "Invalid strict setting utilized with UniformLinkPolicy."; + "Invalid strict setting utilized with UniformLinkPolicy."; } } @@ -144,9 +106,8 @@ void UniformLinkPolicy::readParameters(const std::string& params) // them here. See the declaration of parameterDimensionality_ in the hpp for // more details. // --- -void UniformLinkPolicy::validateParameterDimensionality() -{ - std::map dimensionalityMap; +void UniformLinkPolicy::validateParameterDimensionality() { + std::map dimensionalityMap; dimensionalityMap["rfSize"] = rfSize_.size(); dimensionalityMap["rfOverlap"] = rfOverlap_.size(); @@ -157,26 +118,20 @@ void UniformLinkPolicy::validateParameterDimensionality() std::stringstream parameterDimensionalityMsg; bool parametersAreInconsistent = false; - for(auto & elem : dimensionalityMap) - { + for (auto &elem : dimensionalityMap) { parameterDimensionalityMsg << elem.first << ": "; - elem.second == 1 ? (parameterDimensionalityMsg << "*") : - (parameterDimensionalityMsg << elem.second); + elem.second == 1 ? (parameterDimensionalityMsg << "*") + : (parameterDimensionalityMsg << elem.second); - if(elem.second != parameterDimensionality_) - { - switch(parameterDimensionality_) - { + if (elem.second != parameterDimensionality_) { + switch (parameterDimensionality_) { case 0: - case 1: - { + case 1: { parameterDimensionality_ = elem.second; break; } - default: - { - if(elem.second != 1) - { + default: { + if (elem.second != 1) { parametersAreInconsistent = true; parameterDimensionalityMsg << " <-- Inconsistent"; } @@ -188,10 +143,10 @@ void UniformLinkPolicy::validateParameterDimensionality() parameterDimensionalityMsg << "\n"; } - if(parametersAreInconsistent) - { + if (parametersAreInconsistent) { NTA_THROW << "The dimensionality of the parameters are inconsistent:" - << "\n\n" << parameterDimensionalityMsg.str(); + << "\n\n" + << parameterDimensionalityMsg.str(); } } @@ -199,16 +154,12 @@ void UniformLinkPolicy::validateParameterDimensionality() // Certain combinations of parameters are not valid when used in combination, // so we check to ensure our parameters are mutually consistent here // --- -void UniformLinkPolicy::validateParameterConsistency() -{ - for(size_t i = 0; i < parameterDimensionality_; i++) - { - if(strict_) - { - if(!(workingParams_.span[i].isNaturalNumber())) - { +void UniformLinkPolicy::validateParameterConsistency() { + for (size_t i = 0; i < parameterDimensionality_; i++) { + if (strict_) { + if (!(workingParams_.span[i].isNaturalNumber())) { NTA_THROW << "When using a granularity of nodes in combination with " - "strict, the specified span must be a natural number"; + "strict, the specified span must be a natural number"; } } @@ -216,86 +167,72 @@ void UniformLinkPolicy::validateParameterConsistency() // We don't yet know the size of the source dimensions, so we can't // perform this check now. We'll do it at initialization instead. // --- - //if(workingParams_.overhang[i] > srcDimensions_[i]) + // if(workingParams_.overhang[i] > srcDimensions_[i]) //{ // NTA_THROW << "The overhang can't exceed the size of the source " // "dimensions"; //} - if(workingParams_.rfOverlap[i] == workingParams_.rfSize[i]) - { + if (workingParams_.rfOverlap[i] == workingParams_.rfSize[i]) { NTA_THROW << "100% overlap is not permitted; use a mapping of \"full\"" - " instead"; + " instead"; } - if(workingParams_.rfOverlap[i] > workingParams_.rfSize[i]) - { + if (workingParams_.rfOverlap[i] > workingParams_.rfSize[i]) { NTA_THROW << "An overlap greater than the rfSize is not valid"; } } } -void UniformLinkPolicy::populateWorkingParameters() -{ +void UniformLinkPolicy::populateWorkingParameters() { // --- // First, convert our vectors of real values to vectors of fractions // // This is necessary to remove floating point precision issues when // calculating strict uniformity using non integer values // --- - copyRealVecToFractionVec(rfSize_, - workingParams_.rfSize); + copyRealVecToFractionVec(rfSize_, workingParams_.rfSize); - copyRealVecToFractionVec(rfOverlap_, - workingParams_.rfOverlap); + copyRealVecToFractionVec(rfOverlap_, workingParams_.rfOverlap); - copyRealVecToFractionVec(overhang_, - workingParams_.overhang); + copyRealVecToFractionVec(overhang_, workingParams_.overhang); - copyRealVecToFractionVec(span_, - workingParams_.span); + copyRealVecToFractionVec(span_, workingParams_.span); NTA_CHECK(workingParams_.overhangType.size() == 0); - for(auto & elem : overhangType_) - { + for (auto &elem : overhangType_) { workingParams_.overhangType.push_back(elem); } } template -void UniformLinkPolicy::populateArrayParamVector( - std::vector& vec, - const ValueMap& paramMap, - const std::string& paramName) -{ +void UniformLinkPolicy::populateArrayParamVector(std::vector &vec, + const ValueMap ¶mMap, + const std::string ¶mName) { NTA_CHECK(vec.size() == 0); boost::shared_ptr arrayVal = paramMap.getArray(paramName); - T* buf = (T*) arrayVal->getBuffer(); + T *buf = (T *)arrayVal->getBuffer(); vec.reserve(arrayVal->getCount()); - for(size_t i = 0; i < arrayVal->getCount(); i++) - { + for (size_t i = 0; i < arrayVal->getCount(); i++) { vec.push_back(buf[i]); } } void UniformLinkPolicy::copyRealVecToFractionVec( - const std::vector& sourceVec, - DefaultValuedVector& destVec) -{ + const std::vector &sourceVec, + DefaultValuedVector &destVec) { NTA_CHECK(destVec.size() == 0); - for(auto & elem : sourceVec) - { + for (auto &elem : sourceVec) { destVec.push_back(Fraction::fromDouble(elem)); } } -void UniformLinkPolicy::setValidParameters() -{ +void UniformLinkPolicy::setValidParameters() { // --- // The Network::link() method specifies the direction of the link (i.e. // source and destination regions), and this parameter specifies the @@ -323,14 +260,11 @@ void UniformLinkPolicy::setValidParameters() // parameter (see rfGranularity), the mapping may operate on finer // structure than at the Node level. // --- - parameters_.add("mapping", - ParameterSpec("Source to Destination Mapping " - "(\"in\", \"out\", \"full\")", - NTA_BasicType_Byte, - 0, - "enumeration:in, out, full", - "in", - ParameterSpec::ReadWriteAccess)); + parameters_.add("mapping", ParameterSpec("Source to Destination Mapping " + "(\"in\", \"out\", \"full\")", + NTA_BasicType_Byte, 0, + "enumeration:in, out, full", "in", + ParameterSpec::ReadWriteAccess)); // --- // Specifies the size of the receptive field topology. @@ -367,11 +301,8 @@ void UniformLinkPolicy::setValidParameters() // parameter. // --- parameters_.add("rfSize", - ParameterSpec("Receptive Field Size", - NTA_BasicType_Real64, - 0, - "interval:[0,...)", - "[1]", + ParameterSpec("Receptive Field Size", NTA_BasicType_Real64, 0, + "interval:[0,...)", "[1]", ParameterSpec::ReadWriteAccess)); // --- @@ -387,11 +318,8 @@ void UniformLinkPolicy::setValidParameters() // one. In this case, the given number is used for all dimensions. // --- parameters_.add("rfOverlap", - ParameterSpec("Receptive Field Overlap", - NTA_BasicType_Real64, - 0, - "interval:[0,...)", - "[0]", + ParameterSpec("Receptive Field Overlap", NTA_BasicType_Real64, + 0, "interval:[0,...)", "[0]", ParameterSpec::ReadWriteAccess)); // --- @@ -402,10 +330,8 @@ void UniformLinkPolicy::setValidParameters() parameters_.add("rfGranularity", ParameterSpec("Receptive Field Granularity " "(\"nodes\", \"elements\")", - NTA_BasicType_Byte, - 0, - "enumeration:nodes, elements", - "nodes", + NTA_BasicType_Byte, 0, + "enumeration:nodes, elements", "nodes", ParameterSpec::ReadWriteAccess)); // --- @@ -427,11 +353,8 @@ void UniformLinkPolicy::setValidParameters() // This parameter is invalid for a mapping of "full". // --- parameters_.add("overhang", - ParameterSpec("Region Overhang", - NTA_BasicType_Real64, - 0, - "interval:[0,...)", - "[0]", + ParameterSpec("Region Overhang", NTA_BasicType_Real64, 0, + "interval:[0,...)", "[0]", ParameterSpec::ReadWriteAccess)); // --- @@ -460,11 +383,8 @@ void UniformLinkPolicy::setValidParameters() parameters_.add("overhangType", ParameterSpec("Receptive Field Overhang Type " "(null=0, wrap=1)", - NTA_BasicType_UInt32, - 0, - "enumeration:0, 1", - "[0]", - ParameterSpec::ReadWriteAccess)); + NTA_BasicType_UInt32, 0, "enumeration:0, 1", + "[0]", ParameterSpec::ReadWriteAccess)); // --- // Specifies the length, in Nodes, of a span. A span represents an atomic @@ -488,26 +408,19 @@ void UniformLinkPolicy::setValidParameters() // 2) As an array of real numbers; the length of the array being equal to // one. In this case, the given number is used for all dimensions. // --- - parameters_.add("span", - ParameterSpec("Span group size", - NTA_BasicType_Real64, - 0, - "interval:[0,...)", - "[0]", - ParameterSpec::ReadWriteAccess)); + parameters_.add("span", ParameterSpec("Span group size", NTA_BasicType_Real64, + 0, "interval:[0,...)", "[0]", + ParameterSpec::ReadWriteAccess)); // --- // Specifies if strict uniformity is required. If this is set to false, // then the linkage is built "as close to uniform as possible". // --- - parameters_.add("strict", - ParameterSpec("Require Strict Uniformity " - "(\"true\", \"false\")", - NTA_BasicType_Byte, - 0, - "enumeration:true, false", - "true", - ParameterSpec::ReadWriteAccess)); + parameters_.add("strict", ParameterSpec("Require Strict Uniformity " + "(\"true\", \"false\")", + NTA_BasicType_Byte, 0, + "enumeration:true, false", "true", + ParameterSpec::ReadWriteAccess)); } // --- @@ -524,41 +437,40 @@ void UniformLinkPolicy::setValidParameters() // V_i = Overlap, in Nodes, for dimension i // --- -void UniformLinkPolicy::setSrcDimensions(Dimensions& specifiedDims) -{ +void UniformLinkPolicy::setSrcDimensions(Dimensions &specifiedDims) { if (elementCount_ == uninitializedElementCount) - NTA_THROW << "Internal error: output element count not initialized on link " << link_->toString(); + NTA_THROW << "Internal error: output element count not initialized on link " + << link_->toString(); Dimensions dims = specifiedDims; if (dims.isOnes() && dims.size() != parameterDimensionality_) dims.promote(parameterDimensionality_); // This method should never be called if we've already been set - NTA_CHECK(srcDimensions_.isUnspecified()) << "Internal error on link " << - link_->toString(); - NTA_CHECK(destDimensions_.isUnspecified()) << "Internal error on link " << - link_->toString(); + NTA_CHECK(srcDimensions_.isUnspecified()) + << "Internal error on link " << link_->toString(); + NTA_CHECK(destDimensions_.isUnspecified()) + << "Internal error on link " << link_->toString(); - if(dims.isUnspecified()) - NTA_THROW << "Invalid unspecified source dimensions for link " << - link_->toString(); + if (dims.isUnspecified()) + NTA_THROW << "Invalid unspecified source dimensions for link " + << link_->toString(); - if(dims.isDontcare()) - NTA_THROW << "Invalid dontcare source dimensions for link " << - link_->toString(); + if (dims.isDontcare()) + NTA_THROW << "Invalid dontcare source dimensions for link " + << link_->toString(); // --- // validate that the parameter dimensionality matches the requested // dimensions // --- - if(parameterDimensionality_ != 1) - { - if(parameterDimensionality_ != dims.size() && !dims.isOnes()) - { + if (parameterDimensionality_ != 1) { + if (parameterDimensionality_ != dims.size() && !dims.isOnes()) { NTA_THROW << "Invalid parameter dimensionality; the parameters " - "have dimensionality " << parameterDimensionality_ << - " but the source dimensions supplied have dimensionality " + "have dimensionality " + << parameterDimensionality_ + << " but the source dimensions supplied have dimensionality " << dims.size(); } } @@ -566,12 +478,9 @@ void UniformLinkPolicy::setSrcDimensions(Dimensions& specifiedDims) Dimensions inducedDims; // Induce destination dimensions from source dimensions - switch(mapping_) - { - case inMapping: - { - if(strict_) - { + switch (mapping_) { + case inMapping: { + if (strict_) { // --- // if we are set to strict uniformity, we need to validate that the // requested dimensions are valid @@ -583,8 +492,7 @@ void UniformLinkPolicy::setSrcDimensions(Dimensions& specifiedDims) // Then: R_d,i = (S_i - V_i)/(F_s,i - V_i) * (R_s,i + 2 * H_i)/S_i // --- - for(size_t i = 0; i < dims.size(); i++) - { + for (size_t i = 0; i < dims.size(); i++) { // --- // If the span for this dimension is zero (indicating no atomic // groups of overlapping nodes), then S_i = R_s,i + 2 * H_i @@ -594,123 +502,144 @@ void UniformLinkPolicy::setSrcDimensions(Dimensions& specifiedDims) // R_d,i = (R_s,i + 2 * H_i - V_i)/(F_s,i - V_i) // --- - if(workingParams_.span[i].getNumerator() == 0) - { + if (workingParams_.span[i].getNumerator() == 0) { Fraction validityCheck = - (Fraction(dims[i]) + - workingParams_.overhang[i] * 2 - - workingParams_.rfSize[i]) % (workingParams_.rfSize[i] - - workingParams_.rfOverlap[i]); + (Fraction(dims[i]) + workingParams_.overhang[i] * 2 - + workingParams_.rfSize[i]) % + (workingParams_.rfSize[i] - workingParams_.rfOverlap[i]); - if(validityCheck.getNumerator() != 0) - { + if (validityCheck.getNumerator() != 0) { NTA_THROW << "Invalid source dimensions " << dims.toString() - << " for link " << link_->toString() << ".\n\n" - "For dimension " << i+1 << ", given the specified" - " overlap of " << workingParams_.rfOverlap[i] << - ", each successive receptive field of size " << - workingParams_.rfSize[i] << " as requested will " - "add " << workingParams_.rfSize[i] - - workingParams_.rfOverlap[i] << " required nodes. " - "Since no span was provided, the source region " - "size (" << dims[i] << " for this dimension) + 2 " - "* the overhang (" << workingParams_.overhang[i] + << " for link " << link_->toString() + << ".\n\n" + "For dimension " + << i + 1 + << ", given the specified" + " overlap of " + << workingParams_.rfOverlap[i] + << ", each successive receptive field of size " + << workingParams_.rfSize[i] + << " as requested will " + "add " + << workingParams_.rfSize[i] - workingParams_.rfOverlap[i] + << " required nodes. " + "Since no span was provided, the source region " + "size (" + << dims[i] + << " for this dimension) + 2 " + "* the overhang (" + << workingParams_.overhang[i] << " for this dimension) must equal the receptive" - " field size plus an integer multiple of the " - "amount added by successive receptive fields."; + " field size plus an integer multiple of the " + "amount added by successive receptive fields."; } validityCheck = workingParams_.rfSize[i] * elementCount_; - if(!validityCheck.isNaturalNumber()) - { + if (!validityCheck.isNaturalNumber()) { NTA_THROW << "Invalid source dimensions " << dims.toString() - << " for link " << link_->toString() << ".\n\n" - "For dimension " << i+1 << ", the specified " - "receptive field size of " - << workingParams_.rfSize[i] << "is invalid since it " - "would require " << validityCheck << " elements " - "(given the source region's " << elementCount_ << - " elements per node). Elements cannot be " - "subdivided, therefore a strict mapping with this" - " configuration is not possible."; + << " for link " << link_->toString() + << ".\n\n" + "For dimension " + << i + 1 + << ", the specified " + "receptive field size of " + << workingParams_.rfSize[i] + << "is invalid since it " + "would require " + << validityCheck + << " elements " + "(given the source region's " + << elementCount_ + << " elements per node). Elements cannot be " + "subdivided, therefore a strict mapping with this" + " configuration is not possible."; } // R_d,i = (R_s,i + 2 * H_i - V_i)/(F_s,i - V_i) - Fraction inducedDim = (Fraction(dims[i]) + - workingParams_.overhang[i] * 2 - - workingParams_.rfOverlap[i]) / - (workingParams_.rfSize[i] - - workingParams_.rfOverlap[i]); + Fraction inducedDim = + (Fraction(dims[i]) + workingParams_.overhang[i] * 2 - + workingParams_.rfOverlap[i]) / + (workingParams_.rfSize[i] - workingParams_.rfOverlap[i]); NTA_CHECK(inducedDim.isNaturalNumber()); inducedDim.reduce(); inducedDims.push_back(inducedDim.getNumerator()); - } - else - { - Fraction validityCheck = ((Fraction(dims[i])) + - workingParams_.overhang[i] * 2) % - workingParams_.span[i]; - - if(validityCheck.getNumerator() != 0) - { + } else { + Fraction validityCheck = + ((Fraction(dims[i])) + workingParams_.overhang[i] * 2) % + workingParams_.span[i]; + + if (validityCheck.getNumerator() != 0) { NTA_THROW << "Invalid source dimensions " << dims.toString() - << " for link " << link_->toString() << ".\n\n" - "For dimension " << i+1 << ", the source size (" - << dims[i] << "plus 2 times the overhang (" << - workingParams_.overhang[i] << " per side) must be" - " an integer multiple of the specified span (" << - workingParams_.span[i] << ")."; + << " for link " << link_->toString() + << ".\n\n" + "For dimension " + << i + 1 << ", the source size (" << dims[i] + << "plus 2 times the overhang (" + << workingParams_.overhang[i] + << " per side) must be" + " an integer multiple of the specified span (" + << workingParams_.span[i] << ")."; } - validityCheck = (workingParams_.span[i] - - workingParams_.rfSize[i]) % - (workingParams_.rfSize[i] - - workingParams_.rfOverlap[i]); + validityCheck = + (workingParams_.span[i] - workingParams_.rfSize[i]) % + (workingParams_.rfSize[i] - workingParams_.rfOverlap[i]); - if(validityCheck.getNumerator() != 0) - { + if (validityCheck.getNumerator() != 0) { NTA_THROW << "Invalid source dimensions " << dims.toString() - << " for link " << link_->toString() << ".\n\n" - "For dimension " << i+1 << ", given the specified" - " overlap of " << workingParams_.rfOverlap[i] << - ", each successive receptive field of size " << - workingParams_.rfSize[i] << " as requested will " - "add " << workingParams_.rfSize[i] - - workingParams_.rfOverlap[i] << " required nodes. " - "Each span in this dimension (having specified " - "size: " << workingParams_.span[i] << ") must " - "equal the receptive field size plus an integer " - "multiple of the amount added by successive " - "receptive fields."; + << " for link " << link_->toString() + << ".\n\n" + "For dimension " + << i + 1 + << ", given the specified" + " overlap of " + << workingParams_.rfOverlap[i] + << ", each successive receptive field of size " + << workingParams_.rfSize[i] + << " as requested will " + "add " + << workingParams_.rfSize[i] - workingParams_.rfOverlap[i] + << " required nodes. " + "Each span in this dimension (having specified " + "size: " + << workingParams_.span[i] + << ") must " + "equal the receptive field size plus an integer " + "multiple of the amount added by successive " + "receptive fields."; } validityCheck = workingParams_.rfSize[i] * elementCount_; - if(!validityCheck.isNaturalNumber()) - { + if (!validityCheck.isNaturalNumber()) { NTA_THROW << "Invalid source dimensions " << dims.toString() - << " for link " << link_->toString() << ".\n\n" - "For dimension " << i+1 << ", the specified " - "receptive field size of " - << workingParams_.rfSize[i] << "is invalid since it " - "would require " << validityCheck << " elements " - "(given the source region's " << elementCount_ << - " elements per node). Elements cannot be " - "subdivided, therefore a strict mapping with this" - " configuration is not possible."; + << " for link " << link_->toString() + << ".\n\n" + "For dimension " + << i + 1 + << ", the specified " + "receptive field size of " + << workingParams_.rfSize[i] + << "is invalid since it " + "would require " + << validityCheck + << " elements " + "(given the source region's " + << elementCount_ + << " elements per node). Elements cannot be " + "subdivided, therefore a strict mapping with this" + " configuration is not possible."; } // R_d,i = (S_i - V_i)/(F_s,i - V_i) * (R_s,i + 2 * H_i)/S_i - Fraction inducedDim = (workingParams_.span[i] - - workingParams_.rfOverlap[i]) / - (workingParams_.rfSize[i] - - workingParams_.rfOverlap[i]) * - (Fraction(dims[i]) + - workingParams_.overhang[i] * 2) / - workingParams_.span[i]; + Fraction inducedDim = + (workingParams_.span[i] - workingParams_.rfOverlap[i]) / + (workingParams_.rfSize[i] - workingParams_.rfOverlap[i]) * + (Fraction(dims[i]) + workingParams_.overhang[i] * 2) / + workingParams_.span[i]; NTA_CHECK(inducedDim.isNaturalNumber()); @@ -718,9 +647,7 @@ void UniformLinkPolicy::setSrcDimensions(Dimensions& specifiedDims) inducedDims.push_back(inducedDim.getNumerator()); } } - } - else - { + } else { // --- // Since we are set to non-strict uniformity, we don't need to // validate dimensions. So we'll just calculate the ideal fit using @@ -733,31 +660,25 @@ void UniformLinkPolicy::setSrcDimensions(Dimensions& specifiedDims) // n - delta source nodes. This implies rounding down. // --- - for(size_t i = 0; i < dims.size(); i++) - { + for (size_t i = 0; i < dims.size(); i++) { Fraction inducedDim; - if(workingParams_.span[i].getNumerator() == 0) - { + if (workingParams_.span[i].getNumerator() == 0) { // R_d,i = (R_s,i + 2 * H_i - V_i)/(F_s,i - V_i) - inducedDim = (Fraction(dims[i]) + - workingParams_.overhang[i] * 2) / - (workingParams_.rfSize[i] - - workingParams_.rfOverlap[i]); - } - else - { - Fraction numSpans = (Fraction(dims[i]) + - workingParams_.overhang[i] * 2) / - workingParams_.span[i]; + inducedDim = (Fraction(dims[i]) + workingParams_.overhang[i] * 2) / + (workingParams_.rfSize[i] - workingParams_.rfOverlap[i]); + } else { + Fraction numSpans = + (Fraction(dims[i]) + workingParams_.overhang[i] * 2) / + workingParams_.span[i]; - Fraction nodesPerSpan = Fraction(1) + - (workingParams_.span[i] - - workingParams_.rfSize[i]) / - (workingParams_.rfSize[i] - - workingParams_.rfOverlap[i]); + Fraction nodesPerSpan = + Fraction(1) + + (workingParams_.span[i] - workingParams_.rfSize[i]) / + (workingParams_.rfSize[i] - workingParams_.rfOverlap[i]); - int numWholeSpans = numSpans.getNumerator() / numSpans.getDenominator(); + int numWholeSpans = + numSpans.getNumerator() / numSpans.getDenominator(); inducedDim = Fraction(numWholeSpans) * nodesPerSpan; } @@ -769,10 +690,9 @@ void UniformLinkPolicy::setSrcDimensions(Dimensions& specifiedDims) break; } - default: - { + default: { NTA_THROW << "UniformLinkPolicy mappings other than 'in' are not yet " - "implemented."; + "implemented."; break; } @@ -782,51 +702,47 @@ void UniformLinkPolicy::setSrcDimensions(Dimensions& specifiedDims) destDimensions_ = inducedDims; } -void UniformLinkPolicy::setDestDimensions(Dimensions& specifiedDims) -{ +void UniformLinkPolicy::setDestDimensions(Dimensions &specifiedDims) { Dimensions dims = specifiedDims; if (dims.isOnes() && dims.size() != parameterDimensionality_) dims.promote(parameterDimensionality_); // This method should never be called if we've already been set - NTA_CHECK(srcDimensions_.isUnspecified()) << "Internal error on link " << - link_->toString(); - NTA_CHECK(destDimensions_.isUnspecified()) << "Internal error on link " << - link_->toString(); + NTA_CHECK(srcDimensions_.isUnspecified()) + << "Internal error on link " << link_->toString(); + NTA_CHECK(destDimensions_.isUnspecified()) + << "Internal error on link " << link_->toString(); - if(dims.isUnspecified()) - NTA_THROW << "Invalid unspecified destination dimensions for link " << - link_->toString(); + if (dims.isUnspecified()) + NTA_THROW << "Invalid unspecified destination dimensions for link " + << link_->toString(); - if(dims.isDontcare()) - NTA_THROW << "Invalid dontcare destination dimensions for link " << - link_->toString(); + if (dims.isDontcare()) + NTA_THROW << "Invalid dontcare destination dimensions for link " + << link_->toString(); // --- // validate that the parameter dimensionality matches the requested // dimensions // --- - if(parameterDimensionality_ != 1) - { - if(parameterDimensionality_ != dims.size()) - { + if (parameterDimensionality_ != 1) { + if (parameterDimensionality_ != dims.size()) { NTA_THROW << "Invalid parameter dimensionality; the parameters " - "have dimensionality " << parameterDimensionality_ << - " but the destination dimensions supplied have " - " dimensionality " << dims.size(); + "have dimensionality " + << parameterDimensionality_ + << " but the destination dimensions supplied have " + " dimensionality " + << dims.size(); } } Dimensions inducedDims; // Induce destination dimensions from source dimensions - switch(mapping_) - { - case inMapping: - { - if(strict_) - { + switch (mapping_) { + case inMapping: { + if (strict_) { // --- // Since the requested mapping is of type "in" and we are provided // destination dimensions, we can always calculate valid source @@ -842,38 +758,44 @@ void UniformLinkPolicy::setDestDimensions(Dimensions& specifiedDims) // Then: R_s,i = (R_d,i * S_i * (F_s,i - V_i))/(S_i - V_i) - 2 * H_i // --- - for(size_t i = 0; i < dims.size(); i++) - { - if(!workingParams_.rfSize[i].isNaturalNumber()) - { - if(rfGranularity_ != elementsGranularity) - { + for (size_t i = 0; i < dims.size(); i++) { + if (!workingParams_.rfSize[i].isNaturalNumber()) { + if (rfGranularity_ != elementsGranularity) { NTA_THROW << "Invalid dest dimensions " << dims.toString() - << " for link " << link_->toString() << ".\n\n" - "For dimension " << i+1 << ", a fractional " - "receptive field size of " << - workingParams_.rfSize[i] << " was specified in " - "combination with a strict mapping with a " - "granularity of nodes. Fractional receptive " - "fields are only valid with strict mappings when " - "rfGranularity is set to elements."; + << " for link " << link_->toString() + << ".\n\n" + "For dimension " + << i + 1 + << ", a fractional " + "receptive field size of " + << workingParams_.rfSize[i] + << " was specified in " + "combination with a strict mapping with a " + "granularity of nodes. Fractional receptive " + "fields are only valid with strict mappings when " + "rfGranularity is set to elements."; } - Fraction validityCheck = - workingParams_.rfSize[i] * elementCount_; + Fraction validityCheck = workingParams_.rfSize[i] * elementCount_; - if(!validityCheck.isNaturalNumber()) - { + if (!validityCheck.isNaturalNumber()) { NTA_THROW << "Invalid dest dimensions " << dims.toString() - << " for link " << link_->toString() << ".\n\n" - "For dimension " << i+1 << ", the specified " - "receptive field size of " - << workingParams_.rfSize[i] << "is invalid since it " - "would require " << validityCheck << " elements " - "(given the source region's " << elementCount_ << - " elements per node). Elements cannot be " - "subdivided, therefore a strict mapping with this" - " configuration is not possible."; + << " for link " << link_->toString() + << ".\n\n" + "For dimension " + << i + 1 + << ", the specified " + "receptive field size of " + << workingParams_.rfSize[i] + << "is invalid since it " + "would require " + << validityCheck + << " elements " + "(given the source region's " + << elementCount_ + << " elements per node). Elements cannot be " + "subdivided, therefore a strict mapping with this" + " configuration is not possible."; } } @@ -885,30 +807,24 @@ void UniformLinkPolicy::setDestDimensions(Dimensions& specifiedDims) // R_s,i = R_d,i * (F_s,i - V_i) + V_i - 2 * H_i // --- - if(workingParams_.span[i].getNumerator() == 0) - { + if (workingParams_.span[i].getNumerator() == 0) { // R_s,i = R_d,i * (F_s,i - V_i) + V_i - 2 * H_i - Fraction inducedDim = Fraction(dims[i]) * - (workingParams_.rfSize[i] - - workingParams_.rfOverlap[i]) + - workingParams_.rfOverlap[i] - - workingParams_.overhang[i] * 2; + Fraction inducedDim = + Fraction(dims[i]) * + (workingParams_.rfSize[i] - workingParams_.rfOverlap[i]) + + workingParams_.rfOverlap[i] - workingParams_.overhang[i] * 2; NTA_CHECK(inducedDim.isNaturalNumber()); inducedDim.reduce(); inducedDims.push_back(inducedDim.getNumerator()); - } - else - { + } else { // R_s,i = (R_d,i * S_i * (F_s,i - V_i))/(S_i - V_i) - 2 * H_i - Fraction inducedDim = (Fraction(dims[i]) * - workingParams_.span[i] * - (workingParams_.rfSize[i] - - workingParams_.rfOverlap[i])) / - (workingParams_.span[i] - - workingParams_.rfOverlap[i]) - - workingParams_.overhang[i] * 2; + Fraction inducedDim = + (Fraction(dims[i]) * workingParams_.span[i] * + (workingParams_.rfSize[i] - workingParams_.rfOverlap[i])) / + (workingParams_.span[i] - workingParams_.rfOverlap[i]) - + workingParams_.overhang[i] * 2; NTA_CHECK(inducedDim.isNaturalNumber()); @@ -916,9 +832,7 @@ void UniformLinkPolicy::setDestDimensions(Dimensions& specifiedDims) inducedDims.push_back(inducedDim.getNumerator()); } } - } - else - { + } else { // --- // Since we are set to non-strict uniformity, we don't need to // validate dimensions. So we'll just calculate the ideal fit using @@ -932,82 +846,70 @@ void UniformLinkPolicy::setDestDimensions(Dimensions& specifiedDims) // implies rounding up. // --- - for(size_t i = 0; i < dims.size(); i++) - { + for (size_t i = 0; i < dims.size(); i++) { Fraction inducedDim; - if(workingParams_.span[i].getNumerator() == 0) - { + if (workingParams_.span[i].getNumerator() == 0) { // R_s,i = R_d,i * (F_s,i - V_i) + V_i - 2 * H_i - inducedDim = Fraction(dims[i]) * - (workingParams_.rfSize[i] - - workingParams_.rfOverlap[i]) + - workingParams_.rfOverlap[i] - - workingParams_.overhang[i] * 2; - } - else - { + inducedDim = + Fraction(dims[i]) * + (workingParams_.rfSize[i] - workingParams_.rfOverlap[i]) + + workingParams_.rfOverlap[i] - workingParams_.overhang[i] * 2; + } else { // R_s,i = (R_d,i * S_i * (F_s,i - V_i))/(S_i - V_i) - 2 * H_i - inducedDim = (Fraction(dims[i]) * - workingParams_.span[i] * - (workingParams_.rfSize[i] - - workingParams_.rfOverlap[i])) / - (workingParams_.span[i] - - workingParams_.rfOverlap[i]) - - workingParams_.overhang[i] * 2; - - Fraction numSpans = (inducedDim + - workingParams_.overhang[i] * 2) / - workingParams_.span[i]; - - Fraction nodesPerSpan = Fraction(1) + - (workingParams_.span[i] - - workingParams_.rfSize[i]) / - (workingParams_.rfSize[i] - - workingParams_.rfOverlap[i]); - - int numWholeSpans = numSpans.getNumerator() / numSpans.getDenominator(); + inducedDim = + (Fraction(dims[i]) * workingParams_.span[i] * + (workingParams_.rfSize[i] - workingParams_.rfOverlap[i])) / + (workingParams_.span[i] - workingParams_.rfOverlap[i]) - + workingParams_.overhang[i] * 2; + + Fraction numSpans = (inducedDim + workingParams_.overhang[i] * 2) / + workingParams_.span[i]; + + Fraction nodesPerSpan = + Fraction(1) + + (workingParams_.span[i] - workingParams_.rfSize[i]) / + (workingParams_.rfSize[i] - workingParams_.rfOverlap[i]); + + int numWholeSpans = + numSpans.getNumerator() / numSpans.getDenominator(); Fraction properDestDim = Fraction(numWholeSpans) * nodesPerSpan; - unsigned int properWholeDestDim = properDestDim.getNumerator() / - properDestDim.getDenominator(); + unsigned int properWholeDestDim = + properDestDim.getNumerator() / properDestDim.getDenominator(); - if(properWholeDestDim != dims[i]) - { + if (properWholeDestDim != dims[i]) { NTA_WARN << "Since a span was specified, the destination " - "dimensions are treated such that they are " - "compatible with the requested span. In " - "non-strict mappings, extra source nodes are " - "divided amongst spans and then distributed as " - "evenly as possible. Given the specified " - "parameters, the destination dimensions being set " - "will result in " << dims[i] - properWholeDestDim + "dimensions are treated such that they are " + "compatible with the requested span. In " + "non-strict mappings, extra source nodes are " + "divided amongst spans and then distributed as " + "evenly as possible. Given the specified " + "parameters, the destination dimensions being set " + "will result in " + << dims[i] - properWholeDestDim << " destination nodes receiving no input for " - "dimension " << i+1 << "."; + "dimension " + << i + 1 << "."; } } - if(inducedDim.isNaturalNumber()) - { + if (inducedDim.isNaturalNumber()) { inducedDims.push_back(inducedDim.getNumerator() / inducedDim.getDenominator()); - } - else - { - inducedDims.push_back((inducedDim.getNumerator() / - inducedDim.getDenominator()) + - 1); + } else { + inducedDims.push_back( + (inducedDim.getNumerator() / inducedDim.getDenominator()) + 1); } } } break; } - default: - { + default: { NTA_THROW << "UniformLinkPolicy mappings other than 'in' are not yet " - "implemented."; + "implemented."; break; } @@ -1017,25 +919,21 @@ void UniformLinkPolicy::setDestDimensions(Dimensions& specifiedDims) srcDimensions_ = inducedDims; } -const Dimensions& UniformLinkPolicy::getSrcDimensions() const -{ +const Dimensions &UniformLinkPolicy::getSrcDimensions() const { return srcDimensions_; } -const Dimensions& UniformLinkPolicy::getDestDimensions() const -{ +const Dimensions &UniformLinkPolicy::getDestDimensions() const { return destDimensions_; } -void UniformLinkPolicy::setNodeOutputElementCount(size_t elementCount) -{ +void UniformLinkPolicy::setNodeOutputElementCount(size_t elementCount) { elementCount_ = elementCount; } std::pair UniformLinkPolicy::getInputBoundsForNode(Coordinate nodeCoordinate, - size_t dimension) const -{ + size_t dimension) const { NTA_CHECK(isInitialized()); Fraction lowerIndex(0), upperIndex(0); @@ -1059,37 +957,29 @@ UniformLinkPolicy::getInputBoundsForNode(Coordinate nodeCoordinate, // to // J_i*S_i + (F_s,i - 1) + (K_i - J_i*T_i)*(F_s,i - V_i) - H_i // --- - switch(mapping_) - { - case inMapping: - { - if(strict_) - { + switch (mapping_) { + case inMapping: { + if (strict_) { Fraction destNodesPerSpan = (workingParams_.span[dimension] - workingParams_.rfOverlap[dimension]) / - (workingParams_.rfSize[dimension] - - workingParams_.rfOverlap[dimension]); - - Fraction nodeInSpanFrac = Fraction(nodeCoordinate[dimension]) / - destNodesPerSpan; - - size_t nodeInSpan = nodeInSpanFrac.getNumerator() / - nodeInSpanFrac.getDenominator(); - - lowerIndex = workingParams_.span[dimension] * - nodeInSpan + - (Fraction(nodeCoordinate[dimension]) - - destNodesPerSpan * - nodeInSpan) * - (workingParams_.rfSize[dimension] - - workingParams_.rfOverlap[dimension]) - - workingParams_.overhang[dimension]; - - upperIndex = lowerIndex + - workingParams_.rfSize[dimension] - Fraction(1); - } - else - { + (workingParams_.rfSize[dimension] - + workingParams_.rfOverlap[dimension]); + + Fraction nodeInSpanFrac = + Fraction(nodeCoordinate[dimension]) / destNodesPerSpan; + + size_t nodeInSpan = + nodeInSpanFrac.getNumerator() / nodeInSpanFrac.getDenominator(); + + lowerIndex = workingParams_.span[dimension] * nodeInSpan + + (Fraction(nodeCoordinate[dimension]) - + destNodesPerSpan * nodeInSpan) * + (workingParams_.rfSize[dimension] - + workingParams_.rfOverlap[dimension]) - + workingParams_.overhang[dimension]; + + upperIndex = lowerIndex + workingParams_.rfSize[dimension] - Fraction(1); + } else { // --- // Since we're not strict, we will determine our bounds in several // steps. First, we need to calculate the overage over an ideal @@ -1104,96 +994,82 @@ UniformLinkPolicy::getInputBoundsForNode(Coordinate nodeCoordinate, Fraction srcNodeOverage = (Fraction(srcDimensions_[dimension]) + workingParams_.overhang[dimension] * 2) % - workingParams_.span[dimension]; + workingParams_.span[dimension]; - Fraction numberOfSpans = (Fraction(srcDimensions_[dimension]) + - workingParams_.overhang[dimension] * 2 - - srcNodeOverage) / - workingParams_.span[dimension]; + Fraction numberOfSpans = + (Fraction(srcDimensions_[dimension]) + + workingParams_.overhang[dimension] * 2 - srcNodeOverage) / + workingParams_.span[dimension]; NTA_CHECK(numberOfSpans.isNaturalNumber()); Fraction overagePerSpan = srcNodeOverage / numberOfSpans; - Fraction numRfsPerSpan = (workingParams_.span[dimension] - - workingParams_.rfSize[dimension]) / - (workingParams_.rfSize[dimension] - - workingParams_.rfOverlap[dimension]) - + 1; + Fraction numRfsPerSpan = + (workingParams_.span[dimension] - workingParams_.rfSize[dimension]) / + (workingParams_.rfSize[dimension] - + workingParams_.rfOverlap[dimension]) + + 1; - Fraction effectiveRfSize = workingParams_.rfSize[dimension] + - (overagePerSpan / - numRfsPerSpan); + Fraction effectiveRfSize = + workingParams_.rfSize[dimension] + (overagePerSpan / numRfsPerSpan); - Fraction effectiveSpan = workingParams_.span[dimension] + - overagePerSpan; + Fraction effectiveSpan = workingParams_.span[dimension] + overagePerSpan; Fraction destNodesPerSpan = (workingParams_.span[dimension] - workingParams_.rfOverlap[dimension]) / - (workingParams_.rfSize[dimension] - - workingParams_.rfOverlap[dimension]); - - Fraction nodeInSpanFrac = Fraction(nodeCoordinate[dimension]) / - destNodesPerSpan; - - size_t nodeInSpan = nodeInSpanFrac.getNumerator() / - nodeInSpanFrac.getDenominator(); - - lowerIndex = effectiveSpan * - nodeInSpan + - (Fraction(nodeCoordinate[dimension]) - - destNodesPerSpan * - nodeInSpan) * - (effectiveRfSize - - workingParams_.rfOverlap[dimension]) - - workingParams_.overhang[dimension]; - - upperIndex = lowerIndex + - effectiveRfSize - Fraction(1); - - if(rfGranularity_ == nodesGranularity) - { - if(!lowerIndex.isNaturalNumber()) - { - lowerIndex = Fraction(lowerIndex.getNumerator() / - lowerIndex.getDenominator()); + (workingParams_.rfSize[dimension] - + workingParams_.rfOverlap[dimension]); + + Fraction nodeInSpanFrac = + Fraction(nodeCoordinate[dimension]) / destNodesPerSpan; + + size_t nodeInSpan = + nodeInSpanFrac.getNumerator() / nodeInSpanFrac.getDenominator(); + + lowerIndex = effectiveSpan * nodeInSpan + + (Fraction(nodeCoordinate[dimension]) - + destNodesPerSpan * nodeInSpan) * + (effectiveRfSize - workingParams_.rfOverlap[dimension]) - + workingParams_.overhang[dimension]; + + upperIndex = lowerIndex + effectiveRfSize - Fraction(1); + + if (rfGranularity_ == nodesGranularity) { + if (!lowerIndex.isNaturalNumber()) { + lowerIndex = + Fraction(lowerIndex.getNumerator() / lowerIndex.getDenominator()); } - if(!upperIndex.isNaturalNumber()) - { - upperIndex = Fraction(upperIndex.getNumerator() / - upperIndex.getDenominator()); + if (!upperIndex.isNaturalNumber()) { + upperIndex = + Fraction(upperIndex.getNumerator() / upperIndex.getDenominator()); } - } - else - { + } else { Fraction wholeElementCheck = lowerIndex * elementCount_; - if(!wholeElementCheck.isNaturalNumber()) - { + if (!wholeElementCheck.isNaturalNumber()) { lowerIndex = Fraction((wholeElementCheck.getNumerator() / - wholeElementCheck.getDenominator()) - + 1) / - Fraction(elementCount_); + wholeElementCheck.getDenominator()) + + 1) / + Fraction(elementCount_); } wholeElementCheck = upperIndex * elementCount_; - if(!wholeElementCheck.isNaturalNumber()) - { + if (!wholeElementCheck.isNaturalNumber()) { upperIndex = Fraction((wholeElementCheck.getNumerator() / wholeElementCheck.getDenominator())) / - Fraction(elementCount_); + Fraction(elementCount_); } } } break; } - default: - { + default: { NTA_THROW << "UniformLinkPolicy mappings other than 'in' are not yet " - "implemented."; + "implemented."; break; } @@ -1204,8 +1080,7 @@ UniformLinkPolicy::getInputBoundsForNode(Coordinate nodeCoordinate, std::pair UniformLinkPolicy::getInputBoundsForNode(size_t nodeIndex, - size_t dimension) const -{ + size_t dimension) const { NTA_CHECK(isInitialized()); return getInputBoundsForNode(destDimensions_.getCoordinate(nodeIndex), @@ -1213,21 +1088,19 @@ UniformLinkPolicy::getInputBoundsForNode(size_t nodeIndex, } void UniformLinkPolicy::getInputForNode(Coordinate nodeCoordinate, - std::vector& input) const -{ + std::vector &input) const { // --- // We need to get the input bounds for our node in each dimension. // The bounds correspond to edges of an orthotope, and the elements // contained in the orthotope are the input for the node. // --- - std::vector > orthotopeBounds; + std::vector> orthotopeBounds; orthotopeBounds.reserve(destDimensions_.size()); - for(size_t d = 0; d < destDimensions_.size(); d++) - { + for (size_t d = 0; d < destDimensions_.size(); d++) { // get the bounds (inclusive) in Nodes for this dimension std::pair dimensionBounds = - getInputBoundsForNode(nodeCoordinate,d); + getInputBoundsForNode(nodeCoordinate, d); // convert to an exclusive upper bound dimensionBounds.second = dimensionBounds.second + 1; @@ -1241,21 +1114,19 @@ void UniformLinkPolicy::getInputForNode(Coordinate nodeCoordinate, // create an empty subCoordinate to pass in to the recursive routine // --- std::vector subCoordinate; - populateInputElements(input,orthotopeBounds,subCoordinate); + populateInputElements(input, orthotopeBounds, subCoordinate); } void UniformLinkPolicy::getInputForNode(size_t nodeIndex, - std::vector& input) const -{ + std::vector &input) const { getInputForNode(destDimensions_.getCoordinate(nodeIndex), input); } void UniformLinkPolicy::populateInputElements( - std::vector& input, - std::vector > orthotopeBounds, - std::vector& subCoordinate) const -{ - size_t dimension = orthotopeBounds.size() - subCoordinate.size()- 1; + std::vector &input, + std::vector> orthotopeBounds, + std::vector &subCoordinate) const { + size_t dimension = orthotopeBounds.size() - subCoordinate.size() - 1; // --- // When handling element level linking, a Node's elements are treated @@ -1293,27 +1164,20 @@ void UniformLinkPolicy::populateInputElements( // the second from [2, 2]. // --- - for(Fraction i = orthotopeBounds[dimension].first; - i < orthotopeBounds[dimension].second; - i = i+1) - { + for (Fraction i = orthotopeBounds[dimension].first; + i < orthotopeBounds[dimension].second; i = i + 1) { subCoordinate.insert(subCoordinate.begin(), i); - if(dimension != 0) - { + if (dimension != 0) { populateInputElements(input, orthotopeBounds, subCoordinate); - } - else - { + } else { Coordinate nodeCoordinate; std::pair elementOffset = - std::pair(std::numeric_limits::max(), - std::numeric_limits::min()); + std::pair(std::numeric_limits::max(), + std::numeric_limits::min()); - for(size_t x = 0; x < subCoordinate.size(); x++) - { - if(subCoordinate[x].getNumerator() < 0) - { + for (size_t x = 0; x < subCoordinate.size(); x++) { + if (subCoordinate[x].getNumerator() < 0) { // --- // We got a negative number which implies we're in overhang for // this dimension. If our overhang type is null then we won't @@ -1322,23 +1186,21 @@ void UniformLinkPolicy::populateInputElements( // dimension. // --- - switch(workingParams_.overhangType[x]) - { - case wrapOverhang: - { - Fraction effectiveSubCoordinate = Fraction(srcDimensions_[x]) + - subCoordinate[x]; + switch (workingParams_.overhangType[x]) { + case wrapOverhang: { + Fraction effectiveSubCoordinate = + Fraction(srcDimensions_[x]) + subCoordinate[x]; nodeCoordinate.push_back(effectiveSubCoordinate.getNumerator() / effectiveSubCoordinate.getDenominator()); Fraction fractionalComponent = - (effectiveSubCoordinate - nodeCoordinate[x]) * elementCount_; + (effectiveSubCoordinate - nodeCoordinate[x]) * elementCount_; NTA_CHECK(fractionalComponent.isNaturalNumber()); size_t fractionalOffset = fractionalComponent.getNumerator() / - fractionalComponent.getDenominator(); + fractionalComponent.getDenominator(); // --- // If a subCoordinate component is at the lower bound for that @@ -1359,26 +1221,19 @@ void UniformLinkPolicy::populateInputElements( // dimension will overlap all of the elements and thus they // should all be included. // --- - if(subCoordinate[x] == orthotopeBounds[x].first) - { - if(fractionalOffset < elementOffset.first) - { + if (subCoordinate[x] == orthotopeBounds[x].first) { + if (fractionalOffset < elementOffset.first) { elementOffset.first = fractionalOffset; } elementOffset.second = elementCount_; - } - else if(subCoordinate[x] == orthotopeBounds[x].second) - { + } else if (subCoordinate[x] == orthotopeBounds[x].second) { elementOffset.first = 0; - if(fractionalOffset > elementOffset.second) - { + if (fractionalOffset > elementOffset.second) { elementOffset.second = fractionalOffset; } - } - else - { + } else { elementOffset.first = 0; elementOffset.second = elementCount_; } @@ -1386,19 +1241,16 @@ void UniformLinkPolicy::populateInputElements( break; } case nullOverhang: - default: - { + default: { nodeCoordinate.push_back(0); - elementOffset = std::pair(0,0); + elementOffset = std::pair(0, 0); break; } } - } - else if(((size_t) - (subCoordinate[x].getNumerator() / - subCoordinate[x].getDenominator())) > srcDimensions_[x]) - { + } else if (((size_t)(subCoordinate[x].getNumerator() / + subCoordinate[x].getDenominator())) > + srcDimensions_[x]) { // --- // We got a number larger than our source dimension which implies // we're in overhang for this dimension. If our overhang type is @@ -1407,23 +1259,21 @@ void UniformLinkPolicy::populateInputElements( // the applicable dimension. // --- - switch(workingParams_.overhangType[x]) - { - case wrapOverhang: - { - Fraction effectiveSubCoordinate = subCoordinate[x] - - Fraction(srcDimensions_[x]); + switch (workingParams_.overhangType[x]) { + case wrapOverhang: { + Fraction effectiveSubCoordinate = + subCoordinate[x] - Fraction(srcDimensions_[x]); nodeCoordinate.push_back(effectiveSubCoordinate.getNumerator() / effectiveSubCoordinate.getDenominator()); Fraction fractionalComponent = - (effectiveSubCoordinate - nodeCoordinate[x]) * elementCount_; + (effectiveSubCoordinate - nodeCoordinate[x]) * elementCount_; NTA_CHECK(fractionalComponent.isNaturalNumber()); size_t fractionalOffset = fractionalComponent.getNumerator() / - fractionalComponent.getDenominator(); + fractionalComponent.getDenominator(); // --- // If a subCoordinate component is at the lower bound for that @@ -1444,26 +1294,19 @@ void UniformLinkPolicy::populateInputElements( // dimension will overlap all of the elements and thus they // should all be included. // --- - if(subCoordinate[x] == orthotopeBounds[x].first) - { - if(fractionalOffset < elementOffset.first) - { + if (subCoordinate[x] == orthotopeBounds[x].first) { + if (fractionalOffset < elementOffset.first) { elementOffset.first = fractionalOffset; } elementOffset.second = elementCount_; - } - else if(subCoordinate[x] == orthotopeBounds[x].second) - { + } else if (subCoordinate[x] == orthotopeBounds[x].second) { elementOffset.first = 0; - if(fractionalOffset > elementOffset.second) - { + if (fractionalOffset > elementOffset.second) { elementOffset.second = fractionalOffset; } - } - else - { + } else { elementOffset.first = 0; elementOffset.second = elementCount_; } @@ -1471,27 +1314,24 @@ void UniformLinkPolicy::populateInputElements( break; } case nullOverhang: - default: - { + default: { nodeCoordinate.push_back(0); - elementOffset = std::pair(0,0); + elementOffset = std::pair(0, 0); break; } } - } - else - { + } else { nodeCoordinate.push_back(subCoordinate[x].getNumerator() / subCoordinate[x].getDenominator()); Fraction fractionalComponent = - (subCoordinate[x] - nodeCoordinate[x]) * elementCount_; + (subCoordinate[x] - nodeCoordinate[x]) * elementCount_; NTA_CHECK(fractionalComponent.isNaturalNumber()); size_t fractionalOffset = fractionalComponent.getNumerator() / - fractionalComponent.getDenominator(); + fractionalComponent.getDenominator(); // --- // If a subCoordinate component is at the lower bound for that @@ -1512,26 +1352,19 @@ void UniformLinkPolicy::populateInputElements( // overlap all of the elements and thus they should all be // included. // --- - if(subCoordinate[x] == orthotopeBounds[x].first) - { - if(fractionalOffset < elementOffset.first) - { + if (subCoordinate[x] == orthotopeBounds[x].first) { + if (fractionalOffset < elementOffset.first) { elementOffset.first = fractionalOffset; } elementOffset.second = elementCount_; - } - else if(subCoordinate[x] == orthotopeBounds[x].second) - { + } else if (subCoordinate[x] == orthotopeBounds[x].second) { elementOffset.first = 0; - if(fractionalOffset > elementOffset.second) - { + if (fractionalOffset > elementOffset.second) { elementOffset.second = fractionalOffset; } - } - else - { + } else { elementOffset.first = 0; elementOffset.second = elementCount_; } @@ -1540,8 +1373,7 @@ void UniformLinkPolicy::populateInputElements( size_t elementIndex = srcDimensions_.getIndex(nodeCoordinate); - for(size_t x = elementOffset.first; x < elementOffset.second; x++) - { + for (size_t x = elementOffset.first; x < elementOffset.second; x++) { input.push_back(elementIndex * elementCount_ + x); } } @@ -1551,8 +1383,7 @@ void UniformLinkPolicy::populateInputElements( } void UniformLinkPolicy::buildProtoSplitterMap( - Input::SplitterMap& splitter) const -{ + Input::SplitterMap &splitter) const { NTA_CHECK(isInitialized()); // --- @@ -1560,21 +1391,18 @@ void UniformLinkPolicy::buildProtoSplitterMap( // splitter map entries // --- size_t numDestNodes = 1; - for(size_t i = 0; i < destDimensions_.size(); i++) - { + for (size_t i = 0; i < destDimensions_.size(); i++) { numDestNodes *= destDimensions_[i]; } NTA_CHECK(splitter.size() == numDestNodes); - for(size_t i = 0; i < numDestNodes; i++) - { + for (size_t i = 0; i < numDestNodes; i++) { getInputForNode(i, splitter[i]); } } -void UniformLinkPolicy::initialize() -{ +void UniformLinkPolicy::initialize() { // --- // Both Regions now have dimensions, so we will convert spans specified as // being equal to zero to the appropriate size. This simplifies the @@ -1587,30 +1415,21 @@ void UniformLinkPolicy::initialize() // to have the full dimensionality since individual dimensions may vary in // size. // --- - if(workingParams_.span.size() == 1 && - workingParams_.span[0].getNumerator() == 0) - { - for(size_t i = 1; i < srcDimensions_.size(); i++) - { + if (workingParams_.span.size() == 1 && + workingParams_.span[0].getNumerator() == 0) { + for (size_t i = 1; i < srcDimensions_.size(); i++) { workingParams_.span.push_back(Fraction(0)); } } - for(size_t i = 0; i < workingParams_.span.size(); i++) - { - if(workingParams_.span[i].getNumerator() == 0) - { - switch(mapping_) - { - case inMapping: - { - if(strict_) - { - workingParams_.span[i] = Fraction(srcDimensions_[i]) + - workingParams_.overhang[i] * 2; - } - else - { + for (size_t i = 0; i < workingParams_.span.size(); i++) { + if (workingParams_.span[i].getNumerator() == 0) { + switch (mapping_) { + case inMapping: { + if (strict_) { + workingParams_.span[i] = + Fraction(srcDimensions_[i]) + workingParams_.overhang[i] * 2; + } else { // --- // We aren't strict, so we want our span to be the ideal span as // if we had qualified as strict (the overage of elements/nodes @@ -1618,27 +1437,22 @@ void UniformLinkPolicy::initialize() // bounds). // --- - workingParams_.span[i] = Fraction(srcDimensions_[i]) - - ((Fraction(srcDimensions_[i]) + - workingParams_.overhang[i] * 2 - - workingParams_.rfSize[i]) % - (workingParams_.rfSize[i] - - workingParams_.rfOverlap[i])); + workingParams_.span[i] = + Fraction(srcDimensions_[i]) - + ((Fraction(srcDimensions_[i]) + workingParams_.overhang[i] * 2 - + workingParams_.rfSize[i]) % + (workingParams_.rfSize[i] - workingParams_.rfOverlap[i])); } break; } - case outMapping: - { - workingParams_.span[i] = Fraction(destDimensions_[i]) + - workingParams_.overhang[i] * 2; + case outMapping: { + workingParams_.span[i] = + Fraction(destDimensions_[i]) + workingParams_.overhang[i] * 2; break; } - default: - { - break; - } + default: { break; } } } } @@ -1648,70 +1462,53 @@ void UniformLinkPolicy::initialize() // we couldn't perform this check in validateParameterConsistency(). Now // that we know our dimensions, we'll perform the check. // --- - for(size_t i = 0; i < parameterDimensionality_; i++) - { - if(workingParams_.overhang[i] > srcDimensions_[i]) - { + for (size_t i = 0; i < parameterDimensionality_; i++) { + if (workingParams_.overhang[i] > srcDimensions_[i]) { NTA_THROW << "The overhang can't exceed the size of the source " - "dimensions"; + "dimensions"; } } initialized_ = true; } -bool UniformLinkPolicy::isInitialized() const -{ - return initialized_; -} - +bool UniformLinkPolicy::isInitialized() const { return initialized_; } template -UniformLinkPolicy::DefaultValuedVector::DefaultValuedVector() -{}; - - +UniformLinkPolicy::DefaultValuedVector::DefaultValuedVector(){}; template -T UniformLinkPolicy::DefaultValuedVector::operator[](const size_type index) const -{ +T UniformLinkPolicy::DefaultValuedVector:: +operator[](const size_type index) const { return at(index); } template -T& UniformLinkPolicy::DefaultValuedVector::operator[](const size_type index) -{ +T &UniformLinkPolicy::DefaultValuedVector:: +operator[](const size_type index) { return at(index); } template -T UniformLinkPolicy::DefaultValuedVector::at(const size_type index) const -{ - if(std::vector::size()==1) - { +T UniformLinkPolicy::DefaultValuedVector::at(const size_type index) const { + if (std::vector::size() == 1) { return std::vector::at(0); - } - else - { + } else { return std::vector::at(index); } } template -T& UniformLinkPolicy::DefaultValuedVector::at(const size_type index) -{ - if(std::vector::size()==1) - { +T &UniformLinkPolicy::DefaultValuedVector::at(const size_type index) { + if (std::vector::size() == 1) { return std::vector::at(0); - } - else - { + } else { return std::vector::at(index); } } template struct nupic::UniformLinkPolicy::DefaultValuedVector; -template struct nupic::UniformLinkPolicy::DefaultValuedVector; - -} +template struct nupic::UniformLinkPolicy::DefaultValuedVector< + nupic::UniformLinkPolicy::OverhangType>; +} // namespace nupic diff --git a/src/nupic/engine/UniformLinkPolicy.hpp b/src/nupic/engine/UniformLinkPolicy.hpp index bc973de15c..8380399e2c 100644 --- a/src/nupic/engine/UniformLinkPolicy.hpp +++ b/src/nupic/engine/UniformLinkPolicy.hpp @@ -20,7 +20,6 @@ * --------------------------------------------------------------------- */ - #ifndef NTA_UNIFORMLINKPOLICY_HPP #define NTA_UNIFORMLINKPOLICY_HPP @@ -28,215 +27,189 @@ #include #include -#include #include -#include +#include #include +#include // We use the ParameterSpec which is defined in the Spec header #include #include -namespace nupic -{ - class Link; - class ValueMap; +namespace nupic { +class Link; +class ValueMap; + +// --- +// The UniformLinkPolicy implements a linkage structure between two Regions +// wherein the topology of the receptive fields are uniform*. +// +// * To be precise, this should say more-or-less uniform since we allow +// strict uniformity to be disabled via parameter (in which case we build +// a linkage "as close to uniform as possible"). +// +// (Refer to GradedLinkPolicy and SparseLinkPolicy for examples of link +// policies with non-uniform receptive field topologies.) +// +// In the simplest case, this is a direct one-to-one mapping (and +// consequently this allows for linkage of Region level inputs and outputs +// without specifying any parameters). However, this can also take the form +// of more complex receptive field mappings as configured via parameters. +// --- + +class UniformLinkPolicy : public LinkPolicy { + // --- + // We make our unit test class a friend so that we can test dimension + // calculations and splitter map generation without requiring the rest of + // the NuPIC infrastructure. + // --- + friend class UniformLinkPolicyInspector; + +public: + UniformLinkPolicy(const std::string params, Link *link); + + ~UniformLinkPolicy(); + + // LinkPolicy Interface + void setSrcDimensions(Dimensions &dims) override; + void setDestDimensions(Dimensions &dims) override; + const Dimensions &getSrcDimensions() const override; + const Dimensions &getDestDimensions() const override; + void setNodeOutputElementCount(size_t elementCount) override; + void buildProtoSplitterMap(Input::SplitterMap &splitter) const override; + void initialize() override; + bool isInitialized() const override; + +private: + Link *link_; + + enum MappingType { inMapping, outMapping, fullMapping }; + + enum GranularityType { nodesGranularity, elementsGranularity }; + + enum OverhangType { nullOverhang = 0, wrapOverhang }; + + MappingType mapping_; + std::vector rfSize_; + std::vector rfOverlap_; + GranularityType rfGranularity_; + std::vector overhang_; + std::vector overhangType_; + std::vector span_; + bool strict_; + + template struct DefaultValuedVector : public std::vector { + typedef typename std::vector::size_type size_type; + DefaultValuedVector(); + T operator[](const size_type index) const; + T &operator[](const size_type index); + T at(const size_type index) const; + T &at(const size_type index); + }; + + struct WorkingParameters { + DefaultValuedVector rfSize; + DefaultValuedVector rfOverlap; + DefaultValuedVector overhang; + DefaultValuedVector overhangType; + DefaultValuedVector span; + }; + + WorkingParameters workingParams_; + + void setValidParameters(); + void readParameters(const std::string ¶ms); + void validateParameterDimensionality(); + void validateParameterConsistency(); + void populateWorkingParameters(); + + void copyRealVecToFractionVec(const std::vector &sourceVec, + DefaultValuedVector &destVec); + + template + void populateArrayParamVector(std::vector &vec, const ValueMap ¶mMap, + const std::string ¶mName); // --- - // The UniformLinkPolicy implements a linkage structure between two Regions - // wherein the topology of the receptive fields are uniform*. - // - // * To be precise, this should say more-or-less uniform since we allow - // strict uniformity to be disabled via parameter (in which case we build - // a linkage "as close to uniform as possible"). - // - // (Refer to GradedLinkPolicy and SparseLinkPolicy for examples of link - // policies with non-uniform receptive field topologies.) + // Returns a pair of fractions denoting the inclusive lower and upper + // bounds for a destination node's receptive field in the specified + // dimension. This is used when calculating the splitter map (via + // getInputForNode(). This will also be utilized when calculating the + // getIncomingConnections() API for use by inspectors. + // --- + std::pair getInputBoundsForNode(size_t nodeIndex, + size_t dimension) const; + + std::pair getInputBoundsForNode(Coordinate nodeCoordinate, + size_t dimension) const; + + // --- + // Calculates the entire set of bounds for a destination node's + // receptive field, and then utilizes populateInputElements() to fill + // in the splitter map. + // --- + void getInputForNode(size_t nodeIndex, std::vector &input) const; + + void getInputForNode(Coordinate nodeCoordinate, + std::vector &input) const; + + // --- + // Recursive method which walks the entire set of bounds and populates the + // vector "input" (the splitter map) accordingly. // - // In the simplest case, this is a direct one-to-one mapping (and - // consequently this allows for linkage of Region level inputs and outputs - // without specifying any parameters). However, this can also take the form - // of more complex receptive field mappings as configured via parameters. - // --- - - class UniformLinkPolicy : public LinkPolicy - { - // --- - // We make our unit test class a friend so that we can test dimension - // calculations and splitter map generation without requiring the rest of - // the NuPIC infrastructure. - // --- - friend class UniformLinkPolicyInspector; - - public: - UniformLinkPolicy(const std::string params, Link* link); - - ~UniformLinkPolicy(); - - // LinkPolicy Interface - void setSrcDimensions(Dimensions& dims) override; - void setDestDimensions(Dimensions& dims) override; - const Dimensions& getSrcDimensions() const override; - const Dimensions& getDestDimensions() const override; - void setNodeOutputElementCount(size_t elementCount) override; - void buildProtoSplitterMap(Input::SplitterMap& splitter) const override; - void initialize() override; - bool isInitialized() const override; - - private: - - Link* link_; - - enum MappingType - { - inMapping, - outMapping, - fullMapping - }; - - enum GranularityType - { - nodesGranularity, - elementsGranularity - }; - - enum OverhangType - { - nullOverhang = 0, - wrapOverhang - }; - - MappingType mapping_; - std::vector rfSize_; - std::vector rfOverlap_; - GranularityType rfGranularity_; - std::vector overhang_; - std::vector overhangType_; - std::vector span_; - bool strict_; - - template - struct DefaultValuedVector : public std::vector - { - typedef typename std::vector::size_type size_type; - DefaultValuedVector(); - T operator[](const size_type index) const; - T& operator[](const size_type index); - T at(const size_type index) const; - T& at(const size_type index); - }; - - - - struct WorkingParameters - { - DefaultValuedVector rfSize; - DefaultValuedVector rfOverlap; - DefaultValuedVector overhang; - DefaultValuedVector overhangType; - DefaultValuedVector span; - }; - - WorkingParameters workingParams_; - - void setValidParameters(); - void readParameters(const std::string& params); - void validateParameterDimensionality(); - void validateParameterConsistency(); - void populateWorkingParameters(); - - void copyRealVecToFractionVec(const std::vector& sourceVec, - DefaultValuedVector& destVec); - - template void populateArrayParamVector( - std::vector& vec, - const ValueMap& paramMap, - const std::string& paramName); - - // --- - // Returns a pair of fractions denoting the inclusive lower and upper - // bounds for a destination node's receptive field in the specified - // dimension. This is used when calculating the splitter map (via - // getInputForNode(). This will also be utilized when calculating the - // getIncomingConnections() API for use by inspectors. - // --- - std::pair getInputBoundsForNode( - size_t nodeIndex, - size_t dimension) const; - - std::pair getInputBoundsForNode( - Coordinate nodeCoordinate, - size_t dimension) const; - - // --- - // Calculates the entire set of bounds for a destination node's - // receptive field, and then utilizes populateInputElements() to fill - // in the splitter map. - // --- - void getInputForNode(size_t nodeIndex, - std::vector& input) const; - - void getInputForNode(Coordinate nodeCoordinate, - std::vector& input) const; - - // --- - // Recursive method which walks the entire set of bounds and populates the - // vector "input" (the splitter map) accordingly. - // - // For a uniform linkage, the set of bounds defines an "orthotope" - - // the generalization of a rectangle to n-dimensions. That is, the - // orthotope bounds is a collection of inclusive bounds, one for each - // dimension, which correspond to the edges of an n-dimensional box. - // --- - void populateInputElements( - std::vector& input, - std::vector > orthotopeBounds, - std::vector& subCoordinate) const; - - // --- - // The dimensions of the source Region, as specified by a call to - // setSrcDimensions() or induced by a call to setDestDimensions(). - // --- - Dimensions srcDimensions_; - - // --- - // The dimensions of the destination Region, as specified by a call to - // setDestDimensions() or induced by a call to setSrcDimensions(). - // --- - Dimensions destDimensions_; - - // --- - // The amount of elements per Node as specified by a call to - // setNodeOutputElementCount() - // --- - size_t elementCount_; - - // --- - // Parameters passed into the link policy can have varying dimensionality - // (i.e. quantity of dimensions). Since parameters with a dimensionality - // of 1 can be wildcards for any number of dimensions, it is necessary to - // calculate the true dimensionality of the parameters so as to validate - // a requested linkage topology. validateParameterDimensionality() checks - // that parameter dimensionality is consistent, and sets - // parameterDimensionality_ to the maximum dimensionality. - // --- - size_t parameterDimensionality_; - - // --- - // Set after a call to initialize whereupon the working parameters are - // valid for splitter map calculation - // --- - bool initialized_; - - // --- - // A collection of parameters valid for this link policy. Populated by - // setValidParameters() - // --- - Collection parameters_; - }; // UniformLinkPolicy + // For a uniform linkage, the set of bounds defines an "orthotope" - + // the generalization of a rectangle to n-dimensions. That is, the + // orthotope bounds is a collection of inclusive bounds, one for each + // dimension, which correspond to the edges of an n-dimensional box. + // --- + void populateInputElements( + std::vector &input, + std::vector> orthotopeBounds, + std::vector &subCoordinate) const; -} // namespace nupic + // --- + // The dimensions of the source Region, as specified by a call to + // setSrcDimensions() or induced by a call to setDestDimensions(). + // --- + Dimensions srcDimensions_; + + // --- + // The dimensions of the destination Region, as specified by a call to + // setDestDimensions() or induced by a call to setSrcDimensions(). + // --- + Dimensions destDimensions_; + + // --- + // The amount of elements per Node as specified by a call to + // setNodeOutputElementCount() + // --- + size_t elementCount_; + + // --- + // Parameters passed into the link policy can have varying dimensionality + // (i.e. quantity of dimensions). Since parameters with a dimensionality + // of 1 can be wildcards for any number of dimensions, it is necessary to + // calculate the true dimensionality of the parameters so as to validate + // a requested linkage topology. validateParameterDimensionality() checks + // that parameter dimensionality is consistent, and sets + // parameterDimensionality_ to the maximum dimensionality. + // --- + size_t parameterDimensionality_; + + // --- + // Set after a call to initialize whereupon the working parameters are + // valid for splitter map calculation + // --- + bool initialized_; + // --- + // A collection of parameters valid for this link policy. Populated by + // setValidParameters() + // --- + Collection parameters_; +}; // UniformLinkPolicy + +} // namespace nupic #endif // NTA_UNIFORMLINKPOLICY_HPP diff --git a/src/nupic/engine/YAMLUtils.cpp b/src/nupic/engine/YAMLUtils.cpp index 5adb46dcc7..ee382523fd 100644 --- a/src/nupic/engine/YAMLUtils.cpp +++ b/src/nupic/engine/YAMLUtils.cpp @@ -20,38 +20,34 @@ * --------------------------------------------------------------------- */ -#include +#include #include -#include #include #include -#include +#include +#include #include // strlen #include #include -namespace nupic -{ -namespace YAMLUtils -{ +namespace nupic { +namespace YAMLUtils { /* * These functions are used internally by toValue and toValueMap */ -static void _toScalar(const YAML::Node& node, boost::shared_ptr& s); -static void _toArray(const YAML::Node& node, boost::shared_ptr& a); -static Value toValue(const YAML::Node& node, NTA_BasicType dataType); - +static void _toScalar(const YAML::Node &node, boost::shared_ptr &s); +static void _toArray(const YAML::Node &node, boost::shared_ptr &a); +static Value toValue(const YAML::Node &node, NTA_BasicType dataType); -static void _toScalar(const YAML::Node& node, boost::shared_ptr& s) -{ +static void _toScalar(const YAML::Node &node, boost::shared_ptr &s) { NTA_CHECK(node.Type() == YAML::NodeType::Scalar); - switch(s->getType()) - { + switch (s->getType()) { case NTA_BasicType_Byte: // We should have already detected this and gone down the string path - NTA_THROW << "Internal error: attempting to convert YAML string to scalar of type Byte"; + NTA_THROW << "Internal error: attempting to convert YAML string to scalar " + "of type Byte"; break; case NTA_BasicType_UInt16: node >> s->value.uint16; @@ -87,53 +83,52 @@ static void _toScalar(const YAML::Node& node, boost::shared_ptr& s) // should not happen std::string val; node >> val; - NTA_THROW << "Unknown data type " << s->getType() << " for yaml node '" << val << "'"; + NTA_THROW << "Unknown data type " << s->getType() << " for yaml node '" + << val << "'"; } } - -static void _toArray(const YAML::Node& node, boost::shared_ptr& a) -{ + +static void _toArray(const YAML::Node &node, boost::shared_ptr &a) { NTA_CHECK(node.Type() == YAML::NodeType::Sequence); - + a->allocateBuffer(node.size()); - void* buffer = a->getBuffer(); - - for (size_t i = 0; i < node.size(); i++) - { - const YAML::Node& item = node[i]; + void *buffer = a->getBuffer(); + + for (size_t i = 0; i < node.size(); i++) { + const YAML::Node &item = node[i]; NTA_CHECK(item.Type() == YAML::NodeType::Scalar); - switch(a->getType()) - { + switch (a->getType()) { case NTA_BasicType_Byte: // We should have already detected this and gone down the string path - NTA_THROW << "Internal error: attempting to convert YAML string to array of type Byte"; + NTA_THROW << "Internal error: attempting to convert YAML string to array " + "of type Byte"; break; case NTA_BasicType_UInt16: - item.Read(((UInt16*)buffer)[i]); + item.Read(((UInt16 *)buffer)[i]); break; case NTA_BasicType_Int16: - item.Read(((Int16*)buffer)[i]); + item.Read(((Int16 *)buffer)[i]); break; case NTA_BasicType_UInt32: - item.Read(((UInt32*)buffer)[i]); + item.Read(((UInt32 *)buffer)[i]); break; case NTA_BasicType_Int32: - item.Read(((Int32*)buffer)[i]); + item.Read(((Int32 *)buffer)[i]); break; case NTA_BasicType_UInt64: - item.Read(((UInt64*)buffer)[i]); + item.Read(((UInt64 *)buffer)[i]); break; case NTA_BasicType_Int64: - item.Read(((Int64*)buffer)[i]); + item.Read(((Int64 *)buffer)[i]); break; case NTA_BasicType_Real32: - item.Read(((Real32*)buffer)[i]); + item.Read(((Real32 *)buffer)[i]); break; case NTA_BasicType_Real64: - item.Read(((Real64*)buffer)[i]); + item.Read(((Real64 *)buffer)[i]); break; case NTA_BasicType_Bool: - item.Read(((bool*)buffer)[i]); + item.Read(((bool *)buffer)[i]); break; default: // should not happen @@ -142,16 +137,13 @@ static void _toArray(const YAML::Node& node, boost::shared_ptr& a) } } -static Value toValue(const YAML::Node& node, NTA_BasicType dataType) -{ - if (node.Type() == YAML::NodeType::Map || node.Type() == YAML::NodeType::Null) - { +static Value toValue(const YAML::Node &node, NTA_BasicType dataType) { + if (node.Type() == YAML::NodeType::Map || + node.Type() == YAML::NodeType::Null) { NTA_THROW << "YAML string does not not represent a value."; } - if (node.Type() == YAML::NodeType::Scalar) - { - if (dataType == NTA_BasicType_Byte) - { + if (node.Type() == YAML::NodeType::Scalar) { + if (dataType == NTA_BasicType_Byte) { // node >> *str; std::string val; node.Read(val); @@ -173,16 +165,14 @@ static Value toValue(const YAML::Node& node, NTA_BasicType dataType) } } - -/* +/* * For converting default values specified in nodespec */ -Value toValue(const std::string& yamlstring, NTA_BasicType dataType) -{ +Value toValue(const std::string &yamlstring, NTA_BasicType dataType) { // IMemStream s(yamlstring, ::strlen(yamlstring)); // yaml-cpp bug: append a space if it is only one character - // This is very inefficient, but should be ok since it is + // This is very inefficient, but should be ok since it is // just used at construction time for short strings std::string paddedstring(yamlstring); if (paddedstring.size() < 2) @@ -192,40 +182,37 @@ Value toValue(const std::string& yamlstring, NTA_BasicType dataType) // TODO -- return value? exceptions? bool success = false; YAML::Node doc; - try - { + try { YAML::Parser parser(s); success = parser.GetNextDocument(doc); // } catch(YAML::ParserException& e) { - } catch(...) { + } catch (...) { success = false; } - if (!success) - { + if (!success) { std::string ys(paddedstring); - if (ys.size() > 30) - { + if (ys.size() > 30) { ys = ys.substr(0, 30) + "..."; } - NTA_THROW << "Unable to parse YAML string '" << ys << "' for a scalar value"; + NTA_THROW << "Unable to parse YAML string '" << ys + << "' for a scalar value"; } Value v = toValue(doc, dataType); return v; } -/* +/* * For converting param specs for Regions and LinkPolicies */ -ValueMap toValueMap(const char* yamlstring, - Collection& parameters, - const std::string & nodeType, - const std::string & regionName) -{ - +ValueMap toValueMap(const char *yamlstring, + Collection ¶meters, + const std::string &nodeType, + const std::string ®ionName) { + ValueMap vm; // yaml-cpp bug: append a space if it is only one character - // This is very inefficient, but should be ok since it is + // This is very inefficient, but should be ok since it is // just used at construction time for short strings std::string paddedstring(yamlstring); // TODO: strip white space to determine if empty @@ -238,8 +225,7 @@ ValueMap toValueMap(const char* yamlstring, // TODO: utf-8 compatible? YAML::Node doc; - if (!empty) - { + if (!empty) { YAML::Parser parser(s); bool success = parser.GetNextDocument(doc); @@ -247,95 +233,86 @@ ValueMap toValueMap(const char* yamlstring, NTA_THROW << "Unable to find document in YAML string"; // A ValueMap is specified as a dictionary - if (doc.Type() != YAML::NodeType::Map) - { + if (doc.Type() != YAML::NodeType::Map) { std::string ys(yamlstring); - if (ys.size() > 30) - { + if (ys.size() > 30) { ys = ys.substr(0, 30) + "..."; } - NTA_THROW << "YAML string '" << ys - << "' does not not specify a dictionary of key-value pairs. " - << "Region and Link parameters must be specified at a dictionary"; + NTA_THROW + << "YAML string '" << ys + << "' does not not specify a dictionary of key-value pairs. " + << "Region and Link parameters must be specified at a dictionary"; } } // Grab each value out of the YAML dictionary and put into the ValueMap // if it is allowed by the nodespec. YAML::Iterator i; - for (i = doc.begin(); i != doc.end(); i++) - { + for (i = doc.begin(); i != doc.end(); i++) { const std::string key = i.first().to(); - if (!parameters.contains(key)) - { + if (!parameters.contains(key)) { std::stringstream ss; - for (UInt j = 0; j < parameters.getCount(); j++) - { + for (UInt j = 0; j < parameters.getCount(); j++) { ss << " " << parameters.getByIndex(j).first << "\n"; } - - if (nodeType == std::string("")) - { - NTA_THROW << "Unknown parameter '" << key << "'\n" - << "Valid parameters are:\n" << ss.str(); - } - else - { + + if (nodeType == std::string("")) { + NTA_THROW << "Unknown parameter '" << key << "'\n" + << "Valid parameters are:\n" + << ss.str(); + } else { NTA_CHECK(regionName != std::string("")); NTA_THROW << "Unknown parameter '" << key << "' for region '" - << regionName << "' of type '" << nodeType << "'\n" - << "Valid parameters are:\n" << ss.str(); + << regionName << "' of type '" << nodeType << "'\n" + << "Valid parameters are:\n" + << ss.str(); } } if (vm.contains(key)) - NTA_THROW << "Parameter '" << key << "' specified more than once in YAML document"; + NTA_THROW << "Parameter '" << key + << "' specified more than once in YAML document"; ParameterSpec spec = parameters.getByName(key); - try - { + try { Value v = toValue(i.second(), spec.dataType); - if (v.isScalar() && spec.count != 1) - { + if (v.isScalar() && spec.count != 1) { throw std::runtime_error("Expected array value but got scalar value"); } - if (!v.isScalar() && spec.count == 1) - { + if (!v.isScalar() && spec.count == 1) { throw std::runtime_error("Expected scalar value but got array value"); } vm.add(key, v); - } catch (std::runtime_error& e) { + } catch (std::runtime_error &e) { NTA_THROW << "Unable to set parameter '" << key << "'. " << e.what(); } } - // Populate ValueMap with default values if they were not specified in the YAML dictionary. - for (size_t i = 0; i < parameters.getCount(); i++) - { - std::pair& item = parameters.getByIndex(i); - if (!vm.contains(item.first)) - { - ParameterSpec & ps = item.second; - if (ps.defaultValue != "") - { - // TODO: This check should be uncommented after dropping NuPIC 1.x nodes (which don't comply) - // if (ps.accessMode != ParameterSpec::CreateAccess) + // Populate ValueMap with default values if they were not specified in the + // YAML dictionary. + for (size_t i = 0; i < parameters.getCount(); i++) { + std::pair &item = parameters.getByIndex(i); + if (!vm.contains(item.first)) { + ParameterSpec &ps = item.second; + if (ps.defaultValue != "") { + // TODO: This check should be uncommented after dropping NuPIC 1.x nodes + // (which don't comply) if (ps.accessMode != + // ParameterSpec::CreateAccess) // { - // NTA_THROW << "Default value for non-create parameter: " << item.first; + // NTA_THROW << "Default value for non-create parameter: " << + // item.first; // } - + try { #ifdef YAMLDEBUG - NTA_DEBUG << "Adding default value '" << ps.defaultValue - << "' to parameter " << item.first - << " of type " << BasicType::getName(ps.dataType) - << " count " << ps.count; + NTA_DEBUG << "Adding default value '" << ps.defaultValue + << "' to parameter " << item.first << " of type " + << BasicType::getName(ps.dataType) << " count " << ps.count; #endif Value v = toValue(ps.defaultValue, ps.dataType); vm.add(item.first, v); } catch (...) { - NTA_THROW << "Unable to set default value for item '" - << item.first << "' of datatype " - << BasicType::getName(ps.dataType) - <<" with value '" << ps.defaultValue << "'"; + NTA_THROW << "Unable to set default value for item '" << item.first + << "' of datatype " << BasicType::getName(ps.dataType) + << " with value '" << ps.defaultValue << "'"; } } } @@ -344,7 +321,6 @@ ValueMap toValueMap(const char* yamlstring, return vm; } -} // end of YAMLUtils namespace +} // namespace YAMLUtils } // end of namespace nupic - diff --git a/src/nupic/engine/YAMLUtils.hpp b/src/nupic/engine/YAMLUtils.hpp index cf7b24c07e..fb59ff1fbd 100644 --- a/src/nupic/engine/YAMLUtils.hpp +++ b/src/nupic/engine/YAMLUtils.hpp @@ -22,35 +22,28 @@ #ifndef NTA_YAML_HPP #define NTA_YAML_HPP - -#include -#include -#include #include +#include +#include +#include -namespace nupic -{ - - namespace YAMLUtils - { - /* - * For converting default values - */ - Value toValue(const std::string& yamlstring, NTA_BasicType dataType); +namespace nupic { - /* - * For converting param specs for Regions and LinkPolicies - */ - ValueMap toValueMap( - const char* yamlstring, - Collection& parameters, - const std::string & nodeType = "", - const std::string & regionName = "" - ); +namespace YAMLUtils { +/* + * For converting default values + */ +Value toValue(const std::string &yamlstring, NTA_BasicType dataType); +/* + * For converting param specs for Regions and LinkPolicies + */ +ValueMap toValueMap(const char *yamlstring, + Collection ¶meters, + const std::string &nodeType = "", + const std::string ®ionName = ""); - } // namespace YAMLUtils +} // namespace YAMLUtils } // namespace nupic #endif // NTA_YAML_HPP - diff --git a/src/nupic/math/Array2D.hpp b/src/nupic/math/Array2D.hpp index dcdbc766ee..2c6d5ce49f 100644 --- a/src/nupic/math/Array2D.hpp +++ b/src/nupic/math/Array2D.hpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * A dense matrix with contiguous storage. */ @@ -40,73 +40,60 @@ /** * A fixed size matrix, allocated as a single contiguous chunk of memory. */ -template -class array2D -{ +template class array2D { public: typedef S size_type; typedef T value_type; - typedef value_type* iterator; - typedef const value_type* const_iterator; - + typedef value_type *iterator; + typedef const value_type *const_iterator; + size_type nrows_, ncols_; - iterator d; // __attribute__ ((aligned (16))); + iterator d; // __attribute__ ((aligned (16))); iterator d_end; // __attribute__ ((aligned (16))); - - inline array2D() - : nrows_(0), ncols_(0), d(0), d_end(0) - {} + + inline array2D() : nrows_(0), ncols_(0), d(0), d_end(0) {} inline array2D(size_type m_, size_type n_) - : nrows_(m_), ncols_(n_), - d(new value_type[m_ * n_]), d_end(d + m_ * n_) - {} - - inline array2D(size_type m_, size_type n_, const value_type& init_val) - : nrows_(m_), ncols_(n_), - d(new value_type[m_ * n_]), d_end(d + m_ * n_) - { + : nrows_(m_), ncols_(n_), d(new value_type[m_ * n_]), d_end(d + m_ * n_) { + } + + inline array2D(size_type m_, size_type n_, const value_type &init_val) + : nrows_(m_), ncols_(n_), d(new value_type[m_ * n_]), d_end(d + m_ * n_) { for (iterator i = begin(); i != end(); ++i) *i = init_val; } - inline array2D(size_type m_, size_type n_, value_type* array) - : nrows_(m_), ncols_(n_), - d(new value_type[m_ * n_]), d_end(d + m_ * n_) - { - size_type n = size(); - for (size_type i = 0; i != n; ++i) - *(d+i) = *(array + i); - } - - inline array2D(const array2D& b) - : nrows_(b.nrows_), ncols_(b.ncols_), - d(new value_type [b.nelts()]), d_end(d + b.nelts()) - { + inline array2D(size_type m_, size_type n_, value_type *array) + : nrows_(m_), ncols_(n_), d(new value_type[m_ * n_]), d_end(d + m_ * n_) { + size_type n = size(); + for (size_type i = 0; i != n; ++i) + *(d + i) = *(array + i); + } + + inline array2D(const array2D &b) + : nrows_(b.nrows_), ncols_(b.ncols_), d(new value_type[b.nelts()]), + d_end(d + b.nelts()) { this->copy(b); } - inline ~array2D() - { - delete [] d; + inline ~array2D() { + delete[] d; d = d_end = nullptr; } - inline array2D& operator=(const array2D& b) - { - if (this != &b) + inline array2D &operator=(const array2D &b) { + if (this != &b) this->copy(b); return *this; } - inline void copy(const array2D& b) - { + inline void copy(const array2D &b) { if (nelts() != b.nelts()) { - delete [] d; + delete[] d; nrows_ = b.nrows_; ncols_ = b.ncols_; size_type n = nrows_ * ncols_; - d = new value_type [n]; + d = new value_type[n]; d_end = d + n; } @@ -117,13 +104,11 @@ class array2D *this_it = *b_it; } - template - inline void copy(It array_it) - { - size_type n = size(); - for (size_type i = 0; i != n; ++i, ++array_it) - *(d+i) = *array_it; - } + template inline void copy(It array_it) { + size_type n = size(); + for (size_type i = 0; i != n; ++i, ++array_it) + *(d + i) = *array_it; + } inline size_type nrows() const { return nrows_; } inline size_type ncols() const { return ncols_; } @@ -134,94 +119,78 @@ class array2D inline iterator end() { return d_end; } inline const_iterator begin() const { return d; } inline const_iterator end() const { return d_end; } - inline iterator begin(size_type i) { return d + i*ncols_; } - inline iterator end(size_type i) { return d + (i+1)*ncols_; } - inline const_iterator begin(size_type i) const { return d + i*ncols_; } - inline const_iterator end(size_type i) const { return d + (i+1)*ncols_; } - - inline const value_type operator()(size_type i, size_type j) const - { - return d[i*ncols_+j]; - } + inline iterator begin(size_type i) { return d + i * ncols_; } + inline iterator end(size_type i) { return d + (i + 1) * ncols_; } + inline const_iterator begin(size_type i) const { return d + i * ncols_; } + inline const_iterator end(size_type i) const { return d + (i + 1) * ncols_; } - inline value_type& operator()(size_type i, size_type j) - { - return d[i*ncols_+j]; + inline const value_type operator()(size_type i, size_type j) const { + return d[i * ncols_ + j]; } - inline const value_type at(size_type i, size_type j) const - { - return d[i*ncols_+j]; + inline value_type &operator()(size_type i, size_type j) { + return d[i * ncols_ + j]; } - inline value_type at(size_type i, size_type j) - { - return d[i*ncols_+j]; + inline const value_type at(size_type i, size_type j) const { + return d[i * ncols_ + j]; } - template - inline void getRow(size_type row, It v_it) const - { + inline value_type at(size_type i, size_type j) { return d[i * ncols_ + j]; } + + template inline void getRow(size_type row, It v_it) const { const_iterator it = begin(row), it_end = end(row); while (it != it_end) { *v_it = *it; - ++v_it; ++it; + ++v_it; + ++it; } } - template - inline void setRow(size_type row, It v_it) - { + template inline void setRow(size_type row, It v_it) { iterator it = begin(row), it_end = end(row); while (it != it_end) { *it = *v_it; - ++it; ++v_it; + ++it; + ++v_it; } } - template - inline void getColumn(size_type col, It v_it) const - { - for (size_type i = 0; i < nrows(); ++i, ++v_it) + template inline void getColumn(size_type col, It v_it) const { + for (size_type i = 0; i < nrows(); ++i, ++v_it) *v_it = this->operator()(i, col); } - template - inline void setColumn(size_type col, It v_it) - { - for (size_type i = 0; i < nrows(); ++i, ++v_it) + template inline void setColumn(size_type col, It v_it) { + for (size_type i = 0; i < nrows(); ++i, ++v_it) this->operator()(i, col) = *v_it; } - inline void operator+=(const value_type& val) - { + inline void operator+=(const value_type &val) { for (iterator it = begin(); it != end(); ++it) *it += val; } - inline void operator-=(const value_type& val) - { + inline void operator-=(const value_type &val) { for (iterator it = begin(); it != end(); ++it) *it -= val; } - inline void operator*=(const value_type& val) - { + inline void operator*=(const value_type &val) { for (iterator it = begin(); it != end(); ++it) *it *= val; } - inline void operator/=(const value_type& val) - { + inline void operator/=(const value_type &val) { for (iterator it = begin(); it != end(); ++it) *it /= val; } - inline T trace() const - { + inline T trace() const { size_type step = ncols_ + 1; const_iterator it = begin(), it_end = end() + step; - T t = *it; ++it; + T t = *it; + ++it; for (; it != it_end; it += step) t += *it; return t; @@ -230,39 +199,37 @@ class array2D /** * Multiply row r by vector x. */ - template - inline T row_mult(size_type r, It x_it) const - { + template inline T row_mult(size_type r, It x_it) const { const_iterator it = begin(r), it_end = end(r); - value_type val = *it * *x_it; ++x_it; ++it; + value_type val = *it * *x_it; + ++x_it; + ++it; while (it != it_end) { val = *it * *x_it; - ++x_it; ++it; + ++x_it; + ++it; } return val; } template - inline void save(stream_type& outStream) const - { + inline void save(stream_type &outStream) const { outStream << nrows() << ' ' << ncols() << ' '; const_iterator it = begin(), it_end = end(); while (it != it_end) { outStream << *it << ' '; - ++it; + ++it; } } - template - inline void load(stream_type& inStream) - { + template inline void load(stream_type &inStream) { inStream >> nrows_ >> ncols_; assert(nrows_ >= 0); assert(ncols_ >= 0); size_type n = nrows_ * ncols_; - delete [] d; + delete[] d; d = new value_type[n]; d_end = d + n; iterator it = d; @@ -275,9 +242,8 @@ class array2D //-------------------------------------------------------------------------------- template -inline stream_type& operator<<(stream_type& out, const array2D& m) -{ - typedef typename array2D::const_iterator const_iterator; +inline stream_type &operator<<(stream_type &out, const array2D &m) { + typedef typename array2D::const_iterator const_iterator; const_iterator it = m.begin(), end = m.end(), row_end = it; while (it != end) { @@ -294,9 +260,8 @@ inline stream_type& operator<<(stream_type& out, const array2D& m) //-------------------------------------------------------------------------------- template -inline stream_type& operator<<(stream_type& out, const array2D& m) -{ - typedef typename array2D::const_iterator const_iterator; +inline stream_type &operator<<(stream_type &out, const array2D &m) { + typedef typename array2D::const_iterator const_iterator; const_iterator it = m.begin(), end = m.end(), row_end = it; while (it != end) { @@ -313,8 +278,8 @@ inline stream_type& operator<<(stream_type& out, const array2D& m) //-------------------------------------------------------------------------------- template -inline void print(stream_type& out, const array2D& v, S m1, S n1, S m2, S n2) -{ +inline void print(stream_type &out, const array2D &v, S m1, S n1, S m2, + S n2) { if (m2 > v.nrows()) m2 = v.nrows(); @@ -323,13 +288,13 @@ inline void print(stream_type& out, const array2D& v, S m1, S n1, S m2, S n if (n2 > v.ncols()) n2 = v.ncols(); - + if (n1 >= n2) return; for (S i = m1; i < m2; ++i) { for (S j = n1; j < n2; ++j) - out << v(i,j) << ' '; + out << v(i, j) << ' '; out << " ... \n"; } out << " ..."; diff --git a/src/nupic/math/ArrayAlgo.hpp b/src/nupic/math/ArrayAlgo.hpp index da9e3c9c5a..bf96a03298 100644 --- a/src/nupic/math/ArrayAlgo.hpp +++ b/src/nupic/math/ArrayAlgo.hpp @@ -29,358 +29,336 @@ #ifndef NTA_ARRAY_ALGO_HPP #define NTA_ARRAY_ALGO_HPP -#include -#include #include +#include +#include #if defined(NTA_OS_WINDOWS) && defined(NTA_COMPILER_MSVC) - #include - #include +#include +#include #endif -#include // For the official Numenta RNG #include #include +#include // For the official Numenta RNG namespace nupic { - //-------------------------------------------------------------------------------- - // Checks whether the SSE supports the operations we need, i.e. SSE3 and SSE4. - // Returns highest SSE level supported by the CPU: 1, 2, 3 or 41 or 42. It also - // returns -1 if SSE is not present at all. - // - // Refer to Intel manuals for details. Basically, after call to cpuid, the - // interesting bits are set to 1 in either ecx or edx: - // If 25th bit of edx is 1, we have sse: 2^25 = 33554432 = 1<<25 - // If 26th bit of edx is 1, we have sse2. - // If 0th bit of ecx is 1, we have sse3. - // If 19th bit of ecx is 1, we have sse4.1. - // If 20th bit of ecx is 1, we have sse4.2. - //-------------------------------------------------------------------------------- - static int checkSSE() - { - unsigned int c = 0, d = 0; - const unsigned int SSE = 1<<25, - SSE2 = 1<<26, - SSE3 = 1<<0, - SSE41 = 1<<19, - SSE42 = 1<<20; +//-------------------------------------------------------------------------------- +// Checks whether the SSE supports the operations we need, i.e. SSE3 and SSE4. +// Returns highest SSE level supported by the CPU: 1, 2, 3 or 41 or 42. It also +// returns -1 if SSE is not present at all. +// +// Refer to Intel manuals for details. Basically, after call to cpuid, the +// interesting bits are set to 1 in either ecx or edx: +// If 25th bit of edx is 1, we have sse: 2^25 = 33554432 = 1<<25 +// If 26th bit of edx is 1, we have sse2. +// If 0th bit of ecx is 1, we have sse3. +// If 19th bit of ecx is 1, we have sse4.1. +// If 20th bit of ecx is 1, we have sse4.2. +//-------------------------------------------------------------------------------- +static int checkSSE() { + unsigned int c = 0, d = 0; + const unsigned int SSE = 1 << 25, SSE2 = 1 << 26, SSE3 = 1 << 0, + SSE41 = 1 << 19, SSE42 = 1 << 20; #ifdef NTA_ASM - #if defined(NTA_ARCH_32) - #if defined(NTA_OS_WINDOWS) && defined(NTA_COMPILER_MSVC) +#if defined(NTA_ARCH_32) +#if defined(NTA_OS_WINDOWS) && defined(NTA_COMPILER_MSVC) - // VC asm + // VC asm - unsigned int f = 1; - __asm { + unsigned int f = 1; + __asm { mov eax, f cpuid mov c, ecx mov d, edx - } - - // TODO: add asm code for gcc/clang/... on Windows - - #elif defined(NTA_OS_LINUX) || defined(NTA_OS_DARWIN) - - unsigned int a = 0,b = 0, f = 1; - - // PIC-compliant asm - __asm__ __volatile__( - "pushl %%ebx\n\t" - "cpuid\n\t" - "movl %%ebx, %1\n\t" - "popl %%ebx\n\t" - : "=a" (a), "=r" (b), "=c" (c), "=d" (d) - : "a" (f) - : "cc" - ); - #endif - #elif defined(NTA_ARCH_64) - #if defined(NTA_OS_LINUX) || defined(NTA_OS_DARWIN) - - __asm__ __volatile__ ( - "pushq %%rbx\n\t" - - "movl $1, %%eax\n\t" - "cpuid\n\t" - "movl %%ecx, %0\n\t" - "movl %%edx, %1\n\t" - - "popq %%rbx\n\t" - : "=c" (c), "=d" (d) - : - : - ); - #elif defined(NTA_OS_WINDOWS) && defined(NTA_COMPILER_MSVC) - - std::array cpui; - __cpuid(cpui.data(), 1); - c = cpui[2]; - d = cpui[3]; - #endif - #endif -#endif //NTA_ASM - - int ret = -1; - if (d & SSE) ret = 1; - if (d & SSE2) ret = 2; - if (c & SSE3) ret = 3; - if (c & SSE41) ret = 41; - if (c & SSE42) ret = 42; - - return ret; } - //-------------------------------------------------------------------------------- - // Highest SSE level supported by the CPU: 1, 2, 3 or 41 or 42. - // Note that the asm routines are written for gcc only so far, so we turn them - // off for all platforms except darwin86. Also, they won't work properly on 64 bits - // platforms for now. - //-------------------------------------------------------------------------------- - static const int SSE_LEVEL = checkSSE(); - - //-------------------------------------------------------------------------------- - // TESTS - // - // TODO: nearly zero for positive numbers - // TODO: is C++ trying to use that for all types?? - //-------------------------------------------------------------------------------- - template - inline bool - nearlyZeroRange(It begin, It end, - const typename std::iterator_traits::value_type epsilon =nupic::Epsilon) - { - { - NTA_ASSERT(begin <= end) - << "nearlyZeroRange: Invalid input range"; - } + // TODO: add asm code for gcc/clang/... on Windows - while (begin != end) - if (!nearlyZero(*begin++, epsilon)) - return false; - return true; - } +#elif defined(NTA_OS_LINUX) || defined(NTA_OS_DARWIN) - //-------------------------------------------------------------------------------- - template - inline bool - nearlyEqualRange(It1 begin1, It1 end1, It2 begin2, It2 end2, - const typename std::iterator_traits::value_type epsilon =nupic::Epsilon) - { - { - NTA_ASSERT(begin1 <= end1) - << "nearlyZeroRange: Invalid first input range"; - NTA_ASSERT(begin2 <= end2) - << "nearlyZeroRange: Invalid second input range"; - NTA_ASSERT(end1 - begin1 <= end2 - begin2) - << "nearlyZeroRange: Incompatible ranges"; - } + unsigned int a = 0, b = 0, f = 1; - while (begin1 != end1) - if (!nearlyEqual(*begin1++, *begin2++, epsilon)) - return false; - return true; + // PIC-compliant asm + __asm__ __volatile__("pushl %%ebx\n\t" + "cpuid\n\t" + "movl %%ebx, %1\n\t" + "popl %%ebx\n\t" + : "=a"(a), "=r"(b), "=c"(c), "=d"(d) + : "a"(f) + : "cc"); +#endif +#elif defined(NTA_ARCH_64) +#if defined(NTA_OS_LINUX) || defined(NTA_OS_DARWIN) + + __asm__ __volatile__("pushq %%rbx\n\t" + + "movl $1, %%eax\n\t" + "cpuid\n\t" + "movl %%ecx, %0\n\t" + "movl %%edx, %1\n\t" + + "popq %%rbx\n\t" + : "=c"(c), "=d"(d) + : + :); +#elif defined(NTA_OS_WINDOWS) && defined(NTA_COMPILER_MSVC) + + std::array cpui; + __cpuid(cpui.data(), 1); + c = cpui[2]; + d = cpui[3]; +#endif +#endif +#endif // NTA_ASM + + int ret = -1; + if (d & SSE) + ret = 1; + if (d & SSE2) + ret = 2; + if (c & SSE3) + ret = 3; + if (c & SSE41) + ret = 41; + if (c & SSE42) + ret = 42; + + return ret; +} + +//-------------------------------------------------------------------------------- +// Highest SSE level supported by the CPU: 1, 2, 3 or 41 or 42. +// Note that the asm routines are written for gcc only so far, so we turn them +// off for all platforms except darwin86. Also, they won't work properly on 64 +// bits platforms for now. +//-------------------------------------------------------------------------------- +static const int SSE_LEVEL = checkSSE(); + +//-------------------------------------------------------------------------------- +// TESTS +// +// TODO: nearly zero for positive numbers +// TODO: is C++ trying to use that for all types?? +//-------------------------------------------------------------------------------- +template +inline bool +nearlyZeroRange(It begin, It end, + const typename std::iterator_traits::value_type epsilon = + nupic::Epsilon) { + { NTA_ASSERT(begin <= end) << "nearlyZeroRange: Invalid input range"; } + + while (begin != end) + if (!nearlyZero(*begin++, epsilon)) + return false; + return true; +} + +//-------------------------------------------------------------------------------- +template +inline bool +nearlyEqualRange(It1 begin1, It1 end1, It2 begin2, It2 end2, + const typename std::iterator_traits::value_type epsilon = + nupic::Epsilon) { + { + NTA_ASSERT(begin1 <= end1) << "nearlyZeroRange: Invalid first input range"; + NTA_ASSERT(begin2 <= end2) << "nearlyZeroRange: Invalid second input range"; + NTA_ASSERT(end1 - begin1 <= end2 - begin2) + << "nearlyZeroRange: Incompatible ranges"; } - //-------------------------------------------------------------------------------- - template - inline bool - nearlyEqualVector(const Container1& c1, const Container2& c2, - const typename Container1::value_type& epsilon =nupic::Epsilon) - { - if (c1.size() != c2.size()) + while (begin1 != end1) + if (!nearlyEqual(*begin1++, *begin2++, epsilon)) return false; + return true; +} + +//-------------------------------------------------------------------------------- +template +inline bool nearlyEqualVector( + const Container1 &c1, const Container2 &c2, + const typename Container1::value_type &epsilon = nupic::Epsilon) { + if (c1.size() != c2.size()) + return false; + + return nearlyEqualRange(c1.begin(), c1.end(), c2.begin(), c2.end()); +} + +//-------------------------------------------------------------------------------- +// IS ZERO +//-------------------------------------------------------------------------------- +template inline bool is_zero(const T &x) { return x == 0; } + +//-------------------------------------------------------------------------------- +template +inline bool is_zero(const std::pair &x) { + return x.first == 0 && x.second == 0; +} + +//-------------------------------------------------------------------------------- +template inline bool is_zero(const std::vector &x) { + for (size_t i = 0; i != x.size(); ++i) + if (!is_zero(x[i])) + return false; + return true; +} + +//-------------------------------------------------------------------------------- +// DENSE isZero +//-------------------------------------------------------------------------------- +/** + * Scans a binary 0/1 vector to decide whether it is uniformly zero, + * or if it contains non-zeros (4X faster than C++ loop). + * + * If vector x is not aligned on a 16 bytes boundary, the function + * reverts to slow C++. This can happen when using it with slices of numpy + * arrays. + * + * TODO: find 16 bytes aligned block that can be sent to SSE. + * TODO: support win32/win64 for the fast path. + * TODO: can we go faster if working on ints rather than floats? + */ +template +inline bool isZero_01(InputIterator x, InputIterator x_end) { + { NTA_ASSERT(x <= x_end); } - return nearlyEqualRange(c1.begin(), c1.end(), c2.begin(), c2.end()); - } + // This test can be moved to compile time using a template with an int + // parameter, and partial specializations that will match the static + // const int SSE_LEVEL. + if (SSE_LEVEL >= 41) { // ptest is a SSE 4.1 instruction - //-------------------------------------------------------------------------------- - // IS ZERO - //-------------------------------------------------------------------------------- - template - inline bool is_zero(const T& x) - { - return x == 0; - } + // On win32, the asm syntax is not correct. +#if defined(NTA_ASM) && defined(NTA_ARCH_32) && !defined(NTA_OS_WINDOWS) - //-------------------------------------------------------------------------------- - template - inline bool is_zero(const std::pair& x) - { - return x.first == 0 && x.second == 0; - } + // n is the total number of floats to process. + // n1 is the number of floats we can process in parallel using SSE. + // If x is not aligned on a 4 bytes boundary, we eschew all asm. + int result = 0; + int n = (int)(x_end - x); + int n1 = 0; + if (((long)x) % 16 == 0) + n1 = 8 * (n / 8); // we are going to process 2x4 floats at a time - //-------------------------------------------------------------------------------- - template - inline bool is_zero(const std::vector& x) - { - for (size_t i = 0; i != x.size(); ++i) - if (!is_zero(x[i])) + if (n1 > 0) { + + __asm__ __volatile__( + "pusha\n\t" // save all registers + + // fill xmm4 with all 1's, + // our mask to detect if there are on bits + // in the vector or not + "subl $16, %%esp\n\t" // allocate 4 floats on the stack + "movl $0xffffffff, (%%esp)\n\t" // copy mask 4 times, + "movl $0xffffffff, 4(%%esp)\n\t" // then move 16 bytes at once + "movl $0xffffffff, 8(%%esp)\n\t" // using movaps + "movl $0xffffffff, 12(%%esp)\n\t" + "movaps (%%esp), %%xmm4\n\t" + "addl $16, %%esp\n\t" // deallocate 4 floats on the stack + + "0:\n\t" + // esi and edi point to the same x, but staggered, so + // that we can load 2x4 bytes into xmm0 and xmm1 + "movaps (%%edi), %%xmm0\n\t" // move 4 floats from x + "movaps (%%esi), %%xmm1\n\t" // move another 4 floats from same x + "ptest %%xmm4, %%xmm0\n\t" // ptest first 4 floats, in xmm0 + "jne 1f\n\t" // jump if ZF = 0, some bit is not zero + "ptest %%xmm4, %%xmm1\n\t" // ptest second 4 floats, in xmm1 + "jne 1f\n\t" // jump if ZF = 0, some bit is not zero + + "addl $32, %%edi\n\t" // jump over 4 floats + "addl $32, %%esi\n\t" // and another 4 floats here + "subl $8, %%ecx\n\t" // processed 8 floats + "ja 0b\n\t" + + "movl $0, %0\n\t" // didn't find anything, result = 0 (int) + "jmp 2f\n\t" // exit + + "1:\n\t" // found something + "movl $0x1, %0\n\t" // result = 1 (int) + + "2:\n\t" // exit + "popa\n\t" // restore all registers + + : "=m"(result), "=D"(x) + : "D"(x), "S"(x + 4), "c"(n1) + :); + + if (result == 1) return false; - return true; - } + } // n1>0 end - //-------------------------------------------------------------------------------- - // DENSE isZero - //-------------------------------------------------------------------------------- - /** - * Scans a binary 0/1 vector to decide whether it is uniformly zero, - * or if it contains non-zeros (4X faster than C++ loop). - * - * If vector x is not aligned on a 16 bytes boundary, the function - * reverts to slow C++. This can happen when using it with slices of numpy - * arrays. - * - * TODO: find 16 bytes aligned block that can be sent to SSE. - * TODO: support win32/win64 for the fast path. - * TODO: can we go faster if working on ints rather than floats? - */ - template - inline bool isZero_01(InputIterator x, InputIterator x_end) - { - { - NTA_ASSERT(x <= x_end); - } + // Complete computation by iterating over "stragglers" one by one. + for (int i = n1; i != n; ++i) + if (*(x + i) > 0) + return false; + return true; +#elif defined(NTA_ASM) && defined(NTA_ARCH_64) && !defined(NTA_OS_WINDOWS) - // This test can be moved to compile time using a template with an int - // parameter, and partial specializations that will match the static - // const int SSE_LEVEL. - if (SSE_LEVEL >= 41) { // ptest is a SSE 4.1 instruction + // n is the total number of floats to process. + // n1 is the number of floats we can process in parallel using SSE. + // If x is not aligned on a 4 bytes boundary, we eschew all asm. + int result = 0; + int n = (int)(x_end - x); + int n1 = 0; - // On win32, the asm syntax is not correct. -#if defined(NTA_ASM) && defined(NTA_ARCH_32) && !defined(NTA_OS_WINDOWS) + if (((long)x) % 16 == 0) + n1 = 8 * (n / 8); // we are going to process 2x4 floats at a time - // n is the total number of floats to process. - // n1 is the number of floats we can process in parallel using SSE. - // If x is not aligned on a 4 bytes boundary, we eschew all asm. - int result = 0; - int n = (int)(x_end - x); - int n1 = 0; - if (((long)x) % 16 == 0) - n1 = 8 * (n / 8); // we are going to process 2x4 floats at a time - - if (n1 > 0) { - - __asm__ __volatile__( - "pusha\n\t" // save all registers - - // fill xmm4 with all 1's, - // our mask to detect if there are on bits - // in the vector or not - "subl $16, %%esp\n\t" // allocate 4 floats on the stack - "movl $0xffffffff, (%%esp)\n\t" // copy mask 4 times, - "movl $0xffffffff, 4(%%esp)\n\t" // then move 16 bytes at once - "movl $0xffffffff, 8(%%esp)\n\t" // using movaps - "movl $0xffffffff, 12(%%esp)\n\t" - "movaps (%%esp), %%xmm4\n\t" - "addl $16, %%esp\n\t" // deallocate 4 floats on the stack - - "0:\n\t" - // esi and edi point to the same x, but staggered, so - // that we can load 2x4 bytes into xmm0 and xmm1 - "movaps (%%edi), %%xmm0\n\t" // move 4 floats from x - "movaps (%%esi), %%xmm1\n\t" // move another 4 floats from same x - "ptest %%xmm4, %%xmm0\n\t" // ptest first 4 floats, in xmm0 - "jne 1f\n\t" // jump if ZF = 0, some bit is not zero - "ptest %%xmm4, %%xmm1\n\t" // ptest second 4 floats, in xmm1 - "jne 1f\n\t" // jump if ZF = 0, some bit is not zero - - "addl $32, %%edi\n\t" // jump over 4 floats - "addl $32, %%esi\n\t" // and another 4 floats here - "subl $8, %%ecx\n\t" // processed 8 floats - "ja 0b\n\t" - - "movl $0, %0\n\t" // didn't find anything, result = 0 (int) - "jmp 2f\n\t" // exit - - "1:\n\t" // found something - "movl $0x1, %0\n\t" // result = 1 (int) - - "2:\n\t" // exit - "popa\n\t" // restore all registers - - : "=m" (result), "=D" (x) - : "D" (x), "S" (x + 4), "c" (n1) - : - ); - - if (result == 1) - return false; - } // n1>0 end - - // Complete computation by iterating over "stragglers" one by one. - for (int i = n1; i != n; ++i) - if (*(x+i) > 0) - return false; - return true; + if (n1 > 0) { -#elif defined(NTA_ASM) && defined(NTA_ARCH_64) && !defined(NTA_OS_WINDOWS) + __asm__ __volatile__( + // fill xmm4 with all 1's, + // our mask to detect if there are on bits + // in the vector or not + "subq $16, %%rsp\n\t" // allocate 4 floats on the stack + "movl $0xffffffff, (%%rsp)\n\t" // copy mask 4 times, + "movl $0xffffffff, 4(%%rsp)\n\t" // then move 16 bytes at once + "movl $0xffffffff, 8(%%rsp)\n\t" // using movaps + "movl $0xffffffff, 12(%%rsp)\n\t" + "movaps (%%rsp), %%xmm4\n\t" + "addq $16, %%rsp\n\t" // deallocate 4 floats on the stack + + "0:\n\t" + // rsi and rdi point to the same x, but staggered, so + // that we can load 2x4 bytes into xmm0 and xmm1 + "movaps (%%rdi), %%xmm0\n\t" // move 4 floats from x + "movaps (%%rsi), %%xmm1\n\t" // move another 4 floats from same x + "ptest %%xmm4, %%xmm0\n\t" // ptest first 4 floats, in xmm0 + "jne 1f\n\t" // jump if ZF = 0, some bit is not zero + "ptest %%xmm4, %%xmm1\n\t" // ptest second 4 floats, in xmm1 + "jne 1f\n\t" // jump if ZF = 0, some bit is not zero + + "addq $32, %%rdi\n\t" // jump over 4 floats + "addq $32, %%rsi\n\t" // and another 4 floats here + "subq $8, %%rcx\n\t" // processed 8 floats + "ja 0b\n\t" + + "movl $0, %0\n\t" // didn't find anything, result = 0 (int) + "jmp 2f\n\t" // exit + + "1:\n\t" // found something + "movl $0x1, %0\n\t" // result = 1 (int) + + "2:\n\t" // exit + + : "=m"(result), "=D"(x) + : "D"(x), "S"(x + 4), "c"(n1) + :); + + if (result == 1) + return false; + else + return true; + } // n1>0 end - // n is the total number of floats to process. - // n1 is the number of floats we can process in parallel using SSE. - // If x is not aligned on a 4 bytes boundary, we eschew all asm. - int result = 0; - int n = (int)(x_end - x); - int n1 = 0; - - if (((long)x) % 16 == 0) - n1 = 8 * (n / 8); // we are going to process 2x4 floats at a time - - if (n1 > 0) { - - __asm__ __volatile__( - // fill xmm4 with all 1's, - // our mask to detect if there are on bits - // in the vector or not - "subq $16, %%rsp\n\t" // allocate 4 floats on the stack - "movl $0xffffffff, (%%rsp)\n\t" // copy mask 4 times, - "movl $0xffffffff, 4(%%rsp)\n\t" // then move 16 bytes at once - "movl $0xffffffff, 8(%%rsp)\n\t" // using movaps - "movl $0xffffffff, 12(%%rsp)\n\t" - "movaps (%%rsp), %%xmm4\n\t" - "addq $16, %%rsp\n\t" // deallocate 4 floats on the stack - - "0:\n\t" - // rsi and rdi point to the same x, but staggered, so - // that we can load 2x4 bytes into xmm0 and xmm1 - "movaps (%%rdi), %%xmm0\n\t" // move 4 floats from x - "movaps (%%rsi), %%xmm1\n\t" // move another 4 floats from same x - "ptest %%xmm4, %%xmm0\n\t" // ptest first 4 floats, in xmm0 - "jne 1f\n\t" // jump if ZF = 0, some bit is not zero - "ptest %%xmm4, %%xmm1\n\t" // ptest second 4 floats, in xmm1 - "jne 1f\n\t" // jump if ZF = 0, some bit is not zero - - "addq $32, %%rdi\n\t" // jump over 4 floats - "addq $32, %%rsi\n\t" // and another 4 floats here - "subq $8, %%rcx\n\t" // processed 8 floats - "ja 0b\n\t" - - "movl $0, %0\n\t" // didn't find anything, result = 0 (int) - "jmp 2f\n\t" // exit - - "1:\n\t" // found something - "movl $0x1, %0\n\t" // result = 1 (int) - - "2:\n\t" // exit - - : "=m" (result), "=D" (x) - : "D" (x), "S" (x+4), "c" (n1) - : - ); - - if (result == 1) - return false; - else - return true; - } // n1>0 end - - // Revert to slow c++ version if array is not on 16 byte boundary - for (; x != x_end; ++x) - if (*x > 0) - return false; - return true; + // Revert to slow c++ version if array is not on 16 byte boundary + for (; x != x_end; ++x) + if (*x > 0) + return false; + return true; #else for (; x != x_end; ++x) @@ -389,5627 +367,5108 @@ namespace nupic { return true; #endif - } else { // not SSE4.2 + } else { // not SSE4.2 - for (; x != x_end; ++x) - if (*x > 0) - return false; - return true; - } - } //end method - - //-------------------------------------------------------------------------------- - /** - * 10X faster than function just above. - */ - inline bool - is_zero_01(const ByteVector& x, size_t begin, size_t end) - { - const Byte* x_beg = &x[begin]; - const Byte* x_end = &x[end]; + for (; x != x_end; ++x) + if (*x > 0) + return false; + return true; + } +} // end method - // On win32, the asm syntax is not correct. +//-------------------------------------------------------------------------------- +/** + * 10X faster than function just above. + */ +inline bool is_zero_01(const ByteVector &x, size_t begin, size_t end) { + const Byte *x_beg = &x[begin]; + const Byte *x_end = &x[end]; - // This test can be moved to compile time using a template with an int - // parameter, and partial specializations that will match the static - // const int SSE_LEVEL. - if (SSE_LEVEL >= 41) { // ptest is a SSE 4.1 instruction + // On win32, the asm syntax is not correct. + + // This test can be moved to compile time using a template with an int + // parameter, and partial specializations that will match the static + // const int SSE_LEVEL. + if (SSE_LEVEL >= 41) { // ptest is a SSE 4.1 instruction #if defined(NTA_ASM) && defined(NTA_ARCH_32) && !defined(NTA_OS_WINDOWS) - // n is the total number of floats to process. - // n1 is the number of floats we can process in parallel using SSE. - // If x is not aligned on a 4 bytes boundary, we eschew all asm. - int result = 0; - int n = (int)(x_end - x_beg); - int n1 = 0; - if (((long)x_beg) % 16 == 0) - n1 = 32 * (n / 32); // we are going to process 32 bytes at a time - - if (n1 > 0) { - - __asm__ __volatile__( - "pusha\n\t" // save all registers - - // fill xmm4 with all 1's, - // our mask to detect if there are on bits - // in the vector or not - "subl $16, %%esp\n\t" // allocate 4 floats on the stack - "movl $0xffffffff, (%%esp)\n\t" // copy mask 4 times, - "movl $0xffffffff, 4(%%esp)\n\t" // then move 16 bytes at once - "movl $0xffffffff, 8(%%esp)\n\t" // using movaps - "movl $0xffffffff, 12(%%esp)\n\t" - "movaps (%%esp), %%xmm4\n\t" - "addl $16, %%esp\n\t" // deallocate 4 floats on the stack - - "0:\n\t" - // esi and edi point to the same x, but staggered, so - // that we can load 2x4 bytes into xmm0 and xmm1 - "movaps (%%edi), %%xmm0\n\t" // move 4 floats from x - "movaps (%%esi), %%xmm1\n\t" // move another 4 floats from same x - "ptest %%xmm4, %%xmm0\n\t" // ptest first 4 floats, in xmm0 - "jne 1f\n\t" // jump if ZF = 0, some bit is not zero - "ptest %%xmm4, %%xmm1\n\t" // ptest second 4 floats, in xmm1 - "jne 1f\n\t" // jump if ZF = 0, some bit is not zero - - "addl $32, %%edi\n\t" // jump 32 bytes (16 in xmm0 + 16 in xmm1) - "addl $32, %%esi\n\t" // and another 32 bytes - "subl $32, %%ecx\n\t" // processed 32 bytes - "ja 0b\n\t" - - "movl $0, %0\n\t" // didn't find anything, result = 0 (int) - "jmp 2f\n\t" // exit - - "1:\n\t" // found something - "movl $0x1, %0\n\t" // result = 1 (int) - - "2:\n\t" // exit - "popa\n\t" // restore all registers - - : "=m" (result), "=D" (x_beg) - : "D" (x_beg), "S" (x_beg + 16), "c" (n1) - : - ); - - if (result == 1) - return false; - } + // n is the total number of floats to process. + // n1 is the number of floats we can process in parallel using SSE. + // If x is not aligned on a 4 bytes boundary, we eschew all asm. + int result = 0; + int n = (int)(x_end - x_beg); + int n1 = 0; + if (((long)x_beg) % 16 == 0) + n1 = 32 * (n / 32); // we are going to process 32 bytes at a time - // Complete computation by iterating over "stragglers" one by one. - for (int i = n1; i != n; ++i) - if (*(x_beg+i) > 0) - return false; - return true; + if (n1 > 0) { + + __asm__ __volatile__( + "pusha\n\t" // save all registers + + // fill xmm4 with all 1's, + // our mask to detect if there are on bits + // in the vector or not + "subl $16, %%esp\n\t" // allocate 4 floats on the stack + "movl $0xffffffff, (%%esp)\n\t" // copy mask 4 times, + "movl $0xffffffff, 4(%%esp)\n\t" // then move 16 bytes at once + "movl $0xffffffff, 8(%%esp)\n\t" // using movaps + "movl $0xffffffff, 12(%%esp)\n\t" + "movaps (%%esp), %%xmm4\n\t" + "addl $16, %%esp\n\t" // deallocate 4 floats on the stack + + "0:\n\t" + // esi and edi point to the same x, but staggered, so + // that we can load 2x4 bytes into xmm0 and xmm1 + "movaps (%%edi), %%xmm0\n\t" // move 4 floats from x + "movaps (%%esi), %%xmm1\n\t" // move another 4 floats from same x + "ptest %%xmm4, %%xmm0\n\t" // ptest first 4 floats, in xmm0 + "jne 1f\n\t" // jump if ZF = 0, some bit is not zero + "ptest %%xmm4, %%xmm1\n\t" // ptest second 4 floats, in xmm1 + "jne 1f\n\t" // jump if ZF = 0, some bit is not zero + + "addl $32, %%edi\n\t" // jump 32 bytes (16 in xmm0 + 16 in xmm1) + "addl $32, %%esi\n\t" // and another 32 bytes + "subl $32, %%ecx\n\t" // processed 32 bytes + "ja 0b\n\t" + + "movl $0, %0\n\t" // didn't find anything, result = 0 (int) + "jmp 2f\n\t" // exit + + "1:\n\t" // found something + "movl $0x1, %0\n\t" // result = 1 (int) + + "2:\n\t" // exit + "popa\n\t" // restore all registers + + : "=m"(result), "=D"(x_beg) + : "D"(x_beg), "S"(x_beg + 16), "c"(n1) + :); + + if (result == 1) + return false; + } + + // Complete computation by iterating over "stragglers" one by one. + for (int i = n1; i != n; ++i) + if (*(x_beg + i) > 0) + return false; + return true; #elif defined(NTA_ASM) && !defined(NTA_OS_WINDOWS) - // n is the total number of floats to process. - // n1 is the number of floats we can process in parallel using SSE. - // If x is not aligned on a 4 bytes boundary, we eschew all asm. - int result = 0; - int n = (int)(x_end - x_beg); - int n1 = 0; - if (((long)x_beg) % 16 == 0) - n1 = 32 * (n / 32); // we are going to process 32 bytes at a time - - if (n1 > 0) { - - __asm__ __volatile__( - // fill xmm4 with all 1's, - // our mask to detect if there are on bits - // in the vector or not - "subq $16, %%rsp\n\t" // allocate 4 floats on the stack - "movq $0xffffffff, (%%rsp)\n\t" // copy mask 4 times, - "movq $0xffffffff, 4(%%rsp)\n\t" // then move 16 bytes at once - "movq $0xffffffff, 8(%%rsp)\n\t" // using movaps - "movq $0xffffffff, 12(%%rsp)\n\t" - "movaps (%%rsp), %%xmm4\n\t" - "addq $16, %%rsp\n\t" // deallocate 4 floats on the stack - - "0:\n\t" - // rsi and rdi point to the same x, but staggered, so - // that we can load 2x4 bytes into xmm0 and xmm1 - "movaps (%%rdi), %%xmm0\n\t" // move 4 floats from x - "movaps (%%rsi), %%xmm1\n\t" // move another 4 floats from same x - "ptest %%xmm4, %%xmm0\n\t" // ptest first 4 floats, in xmm0 - "jne 1f\n\t" // jump if ZF = 0, some bit is not zero - "ptest %%xmm4, %%xmm1\n\t" // ptest second 4 floats, in xmm1 - "jne 1f\n\t" // jump if ZF = 0, some bit is not zero - - "addq $32, %%rdi\n\t" // jump 32 bytes (16 in xmm0 + 16 in xmm1) - "addq $32, %%rsi\n\t" // and another 32 bytes - "subq $32, %%rcx\n\t" // processed 32 bytes - "ja 0b\n\t" - - "movl $0, %0\n\t" // didn't find anything, result = 0 (int) - "jmp 2f\n\t" // exit - - "1:\n\t" // found something - "movl $0x1, %0\n\t" // result = 1 (int) - - "2:\n\t" // exit - "popa\n\t" // restore all registers - - : "=m" (result), "=D" (x_beg) - : "D" (x_beg), "S" (x_beg + 16), "c" (n1) - : - ); - - if (result == 1) - return false; - else - return true; - } + // n is the total number of floats to process. + // n1 is the number of floats we can process in parallel using SSE. + // If x is not aligned on a 4 bytes boundary, we eschew all asm. + int result = 0; + int n = (int)(x_end - x_beg); + int n1 = 0; + if (((long)x_beg) % 16 == 0) + n1 = 32 * (n / 32); // we are going to process 32 bytes at a time - // if n1 is not on a 32 byte boundary then use slower code - for (; x_beg != x_end; ++x_beg) - if (*x_beg > 0) - return false; - return true; -#endif - } else { // SSE 4.1 + if (n1 > 0) { - for (; x_beg != x_end; ++x_beg) - if (*x_beg > 0) - return false; - return true; + __asm__ __volatile__( + // fill xmm4 with all 1's, + // our mask to detect if there are on bits + // in the vector or not + "subq $16, %%rsp\n\t" // allocate 4 floats on the stack + "movq $0xffffffff, (%%rsp)\n\t" // copy mask 4 times, + "movq $0xffffffff, 4(%%rsp)\n\t" // then move 16 bytes at once + "movq $0xffffffff, 8(%%rsp)\n\t" // using movaps + "movq $0xffffffff, 12(%%rsp)\n\t" + "movaps (%%rsp), %%xmm4\n\t" + "addq $16, %%rsp\n\t" // deallocate 4 floats on the stack + + "0:\n\t" + // rsi and rdi point to the same x, but staggered, so + // that we can load 2x4 bytes into xmm0 and xmm1 + "movaps (%%rdi), %%xmm0\n\t" // move 4 floats from x + "movaps (%%rsi), %%xmm1\n\t" // move another 4 floats from same x + "ptest %%xmm4, %%xmm0\n\t" // ptest first 4 floats, in xmm0 + "jne 1f\n\t" // jump if ZF = 0, some bit is not zero + "ptest %%xmm4, %%xmm1\n\t" // ptest second 4 floats, in xmm1 + "jne 1f\n\t" // jump if ZF = 0, some bit is not zero + + "addq $32, %%rdi\n\t" // jump 32 bytes (16 in xmm0 + 16 in xmm1) + "addq $32, %%rsi\n\t" // and another 32 bytes + "subq $32, %%rcx\n\t" // processed 32 bytes + "ja 0b\n\t" + + "movl $0, %0\n\t" // didn't find anything, result = 0 (int) + "jmp 2f\n\t" // exit + + "1:\n\t" // found something + "movl $0x1, %0\n\t" // result = 1 (int) + + "2:\n\t" // exit + "popa\n\t" // restore all registers + + : "=m"(result), "=D"(x_beg) + : "D"(x_beg), "S"(x_beg + 16), "c"(n1) + :); + + if (result == 1) + return false; + else + return true; } + // if n1 is not on a 32 byte boundary then use slower code for (; x_beg != x_end; ++x_beg) if (*x_beg > 0) return false; return true; +#endif + } else { // SSE 4.1 - } - - //-------------------------------------------------------------------------------- - template - inline bool - positive_less_than(InIter begin, InIter end, - const typename std::iterator_traits::value_type threshold) - { - { - NTA_ASSERT(begin <= end) - << "positive_less_than: Invalid input range"; - } - - for (; begin != end; ++begin) - if (*begin > threshold) + for (; x_beg != x_end; ++x_beg) + if (*x_beg > 0) return false; return true; } - //-------------------------------------------------------------------------------- - template - inline void print_bits(const T& x) - { - for (int i = sizeof(T) - 1; 0 <= i; --i) { - unsigned char* b = (unsigned char*)(&x) + i; - for (int j = 7; 0 <= j; --j) - std::cout << ((*b & (1 << j)) / (1 << j)); - std::cout << ' '; - } - } - - //-------------------------------------------------------------------------------- - // N BYTES - //-------------------------------------------------------------------------------- - /** - * For primitive types. - */ - template - inline size_t n_bytes(const T&) - { - return sizeof(T); + for (; x_beg != x_end; ++x_beg) + if (*x_beg > 0) + return false; + return true; +} + +//-------------------------------------------------------------------------------- +template +inline bool positive_less_than( + InIter begin, InIter end, + const typename std::iterator_traits::value_type threshold) { + { NTA_ASSERT(begin <= end) << "positive_less_than: Invalid input range"; } + + for (; begin != end; ++begin) + if (*begin > threshold) + return false; + return true; +} + +//-------------------------------------------------------------------------------- +template inline void print_bits(const T &x) { + for (int i = sizeof(T) - 1; 0 <= i; --i) { + unsigned char *b = (unsigned char *)(&x) + i; + for (int j = 7; 0 <= j; --j) + std::cout << ((*b & (1 << j)) / (1 << j)); + std::cout << ' '; + } +} + +//-------------------------------------------------------------------------------- +// N BYTES +//-------------------------------------------------------------------------------- +/** + * For primitive types. + */ +template inline size_t n_bytes(const T &) { return sizeof(T); } + +//-------------------------------------------------------------------------------- +template +inline size_t n_bytes(const std::pair &p) { + size_t n = n_bytes(p.first) + n_bytes(p.second); + return n; +} + +//-------------------------------------------------------------------------------- +/** + * For more bytes for alignment on x86 with darwin: darwin32 always allocates on + * 16 bytes boundaries, so the three pointers in the STL vectors (of 32 bits + * each in -m32), become: 3 * 4 + 4 = 16 bytes. The capacity similarly needs to + * be adjusted for aligment. On other platforms, the alignment might be + * different. + * + * NOTE/WARNING: this is really "accurate" only on darwin32. And even, it's + * probably only approximate. + */ +template +inline size_t n_bytes(const std::vector &a, size_t alignment = 16) { + size_t n1 = a.capacity() * sizeof(T); + if (n1 % alignment != 0) + n1 = alignment * (n1 / alignment + 1); + + size_t n2 = sizeof(std::vector); + if (n2 % alignment != 0) + n2 = alignment * (n2 / alignment + 1); + + return n1 + n2; +} + +//-------------------------------------------------------------------------------- +template +inline size_t n_bytes(const std::vector> &a, + size_t alignment = 16) { + size_t n = sizeof(std::vector>); + if (n % alignment != 0) + n = alignment * (n / alignment + 1); + + for (size_t i = 0; i != a.size(); ++i) + n += n_bytes(a[i]); + + return n; +} + +//-------------------------------------------------------------------------------- +template inline float load_factor(const std::vector &x) { + return (float)x.size() / (float)x.capacity(); +} + +//-------------------------------------------------------------------------------- +template +inline void adjust_load_factor(std::vector &x, float target) { + NTA_ASSERT(0.0 <= target && target <= 1.0); + + size_t new_capacity = (size_t)((float)x.size() / target); + + std::vector y; + y.reserve(new_capacity); + y.resize(x.size()); + std::copy(x.begin(), x.end(), y.begin()); + x.swap(y); +} + +//-------------------------------------------------------------------------------- +// VARIOUS +//-------------------------------------------------------------------------------- +inline std::string operator+(const std::string &str, size_t idx) { + std::stringstream buff; + buff << str << idx; + return buff.str(); +} + +//-------------------------------------------------------------------------------- +template +inline void append(const std::vector &a, std::vector &b) { + b.insert(b.end(), a.begin(), a.end()); +} + +//-------------------------------------------------------------------------------- +template +inline std::vector &operator+=(std::vector &b, const std::vector &a) { + append(a, b); + return b; +} + +//-------------------------------------------------------------------------------- +template inline void append(const std::set &a, std::set &b) { + b.insert(a.begin(), a.end()); +} + +//-------------------------------------------------------------------------------- +template +inline std::set &operator+=(std::set &b, const std::set &a) { + append(a, b); + return b; +} + +//-------------------------------------------------------------------------------- +// map insert or increment +template +inline void increment(std::map &m, const T1 &key, const T2 &init = 1) { + typename std::map::iterator it = m.find(key); + if (it != m.end()) + ++it->second; + else + m[key] = init; +} + +//-------------------------------------------------------------------------------- +template +inline bool is_in(const K &key, const std::map &m) { + return m.find(key) != m.end(); +} + +//-------------------------------------------------------------------------------- +// Deriving from std::map to add frequently used functionality +template , + typename A = std::allocator>> +struct dict : public std::map { + inline bool has_key(const K &key) const { return is_in(key, *this); } + + // Often useful for histograms, where V is an integral type + inline void increment(const K &key, const V &init = 1) { + nupic::increment(*this, key, init); + } + + // Inserts once in the map, or return false if already inserted + // (saves having to write find(...) == this->end()) + inline bool insert_once(const K &key, const V &v) { + if (has_key(key)) + return false; + else + this->insert(std::make_pair(key, v)); + return true; } - //-------------------------------------------------------------------------------- - template - inline size_t n_bytes(const std::pair& p) + /* + // Returns an existing value for the key, if it is in the dict already, + // or creates one and returns it. (operator[] on std::map does that?) + inline V& operator(const K& key) { - size_t n = n_bytes(p.first) + n_bytes(p.second); - return n; + iterator it = this->find(key); + if (key == end) { + (*this)[key] = V(); + return (*this)[key]; + } else + return *it; } + */ +}; - //-------------------------------------------------------------------------------- - /** - * For more bytes for alignment on x86 with darwin: darwin32 always allocates on - * 16 bytes boundaries, so the three pointers in the STL vectors (of 32 bits each - * in -m32), become: 3 * 4 + 4 = 16 bytes. The capacity similarly needs to be - * adjusted for aligment. On other platforms, the alignment might be different. - * - * NOTE/WARNING: this is really "accurate" only on darwin32. And even, it's probably - * only approximate. - */ - template - inline size_t n_bytes(const std::vector& a, size_t alignment =16) - { - size_t n1 = a.capacity() * sizeof(T); - if (n1 % alignment != 0) - n1 = alignment * (n1 / alignment + 1); +//-------------------------------------------------------------------------------- +// INIT LIST +//-------------------------------------------------------------------------------- +template struct vector_init_list { + std::vector &v; - size_t n2 = sizeof(std::vector); - if (n2 % alignment != 0) - n2 = alignment * (n2 / alignment + 1); + inline vector_init_list(std::vector &v_ref) : v(v_ref) {} + inline vector_init_list(const vector_init_list &o) : v(o.v) {} - return n1 + n2; + inline vector_init_list &operator=(const vector_init_list &o) { + v(o.v); + return *this; } - //-------------------------------------------------------------------------------- - template - inline size_t n_bytes(const std::vector >& a, size_t alignment =16) - { - size_t n = sizeof(std::vector >); - if (n % alignment != 0) - n = alignment * (n / alignment + 1); - - for (size_t i = 0; i != a.size(); ++i) - n += n_bytes(a[i]); - - return n; + template inline vector_init_list &operator,(const T2 &x) { + v.push_back((T)x); + return *this; } +}; - //-------------------------------------------------------------------------------- - template - inline float load_factor(const std::vector& x) - { - return (float) x.size() / (float) x.capacity(); - } +//-------------------------------------------------------------------------------- +template +inline vector_init_list operator+=(std::vector &v, const T2 &x) { + v.push_back((T)x); + return vector_init_list(v); +} - //-------------------------------------------------------------------------------- - template - inline void adjust_load_factor(std::vector& x, float target) - { - NTA_ASSERT(0.0 <= target && target <= 1.0); +//-------------------------------------------------------------------------------- +// TODO: merge with preceding by changing parametrization? +//-------------------------------------------------------------------------------- +template struct set_init_list { + std::set &v; - size_t new_capacity = (size_t)((float)x.size() / target); + inline set_init_list(std::set &v_ref) : v(v_ref) {} + inline set_init_list(const set_init_list &o) : v(o.v) {} - std::vector y; - y.reserve(new_capacity); - y.resize(x.size()); - std::copy(x.begin(), x.end(), y.begin()); - x.swap(y); + inline set_init_list &operator=(const set_init_list &o) { + v(o.v); + return *this; } - //-------------------------------------------------------------------------------- - // VARIOUS - //-------------------------------------------------------------------------------- - inline std::string operator+(const std::string& str, size_t idx) - { - std::stringstream buff; - buff << str << idx; - return buff.str(); + template inline set_init_list &operator,(const T2 &x) { + v.insert((T)x); + return *this; + } +}; + +//-------------------------------------------------------------------------------- +template +inline set_init_list operator+=(std::set &v, const T2 &x) { + v.insert((T)x); + return set_init_list(v); +} + +//-------------------------------------------------------------------------------- +// FIND IN VECTOR +//-------------------------------------------------------------------------------- +// T1 and T2 to get around constness with pointers +template +inline int find_index(const T1 &x, const std::vector &v) { + for (size_t i = 0; i != v.size(); ++i) + if (v[i] == x) + return (int)i; + return -1; +} + +//-------------------------------------------------------------------------------- +template +inline int find_index(const T1 &x, const std::vector> &v) { + for (size_t i = 0; i != v.size(); ++i) + if (v[i].first == x) + return (int)i; + return -1; +} + +//-------------------------------------------------------------------------------- +template inline bool not_in(const T &x, const std::vector &v) { + return std::find(v.begin(), v.end(), x) == v.end(); +} + +//-------------------------------------------------------------------------------- +template +inline bool not_in(const T1 &x, const std::vector> &v) { + typename std::vector>::const_iterator it; + for (it = v.begin(); it != v.end(); ++it) + if (it->first == x) + return false; + return true; +} + +//-------------------------------------------------------------------------------- +template inline bool not_in(const T &x, const std::set &s) { + return s.find(x) == s.end(); +} + +//-------------------------------------------------------------------------------- +template inline bool is_in(const T &x, const std::vector &v) { + return !not_in(x, v); +} + +//-------------------------------------------------------------------------------- +template +inline bool is_in(const T1 &x, const std::vector> &v) { + return !not_in(x, v); +} + +//-------------------------------------------------------------------------------- +template inline bool is_in(const T &x, const std::set &s) { + return !not_in(x, s); +} + +//-------------------------------------------------------------------------------- +template +inline bool is_sorted(It begin, It end, bool ascending = true, + bool unique = true) { + if (begin < end) { + for (It prev = begin, it = ++begin; it < end; ++it, ++prev) + + if (ascending) { + if (unique) { + if (*prev >= *it) + return false; + } else { + if (*prev > *it) + return false; + } + } else { + if (unique) { + if (*prev <= *it) + return false; + } else { + if (*prev < *it) + return false; + } + } } - //-------------------------------------------------------------------------------- - template - inline void append(const std::vector& a, std::vector& b) - { - b.insert(b.end(), a.begin(), a.end()); + return true; +} + +//-------------------------------------------------------------------------------- +template +inline bool is_sorted(const std::vector &x, bool ascending = true, + bool unique = true) { + return is_sorted(x.begin(), x.end(), ascending, unique); +} + +//-------------------------------------------------------------------------------- +template +inline bool operator==(const std::vector &a, const std::vector &b) { + NTA_ASSERT(a.size() == b.size()); + if (a.size() != b.size()) + return false; + for (size_t i = 0; i != a.size(); ++i) + if (a[i] != b[i]) + return false; + return true; +} + +//-------------------------------------------------------------------------------- +template +inline bool operator!=(const std::vector &a, const std::vector &b) { + return !(a == b); +} + +//-------------------------------------------------------------------------------- +template +inline bool operator==(const std::map &a, const std::map &b) { + typename std::map::const_iterator ita = a.begin(), itb = b.begin(); + for (; ita != a.end(); ++ita, ++itb) + if (ita->first != itb->first || ita->second != itb->second) + return false; + return true; +} + +//-------------------------------------------------------------------------------- +template +inline bool operator!=(const std::map &a, const std::map &b) { + return !(a == b); +} + +//-------------------------------------------------------------------------------- +/** + * Proxy for an insert iterator that allows inserting at the second element + * when iterating over a container of pairs. + */ +template struct inserter_second { + typedef typename std::iterator_traits::value_type pair_type; + typedef typename pair_type::second_type second_type; + typedef second_type value_type; + + Iterator it; + + inline inserter_second(Iterator _it) : it(_it) {} + inline second_type &operator*() { return it->second; } + inline void operator++() { ++it; } +}; + +template inserter_second insert_2nd(Iterator it) { + return inserter_second(it); +} + +//-------------------------------------------------------------------------------- +/** + * Proxy for an insert iterator that allows inserting at the second element when + * iterating over a container of pairs, while setting the first element to the + * current index value (watch out if iterator passed to constructor is not + * pointing to the beginning of the container!) + */ +template struct inserter_second_incrementer_first { + typedef typename std::iterator_traits::value_type pair_type; + typedef typename pair_type::second_type second_type; + typedef second_type value_type; + + Iterator it; + size_t i; + + inline inserter_second_incrementer_first(Iterator _it) : it(_it), i(0) {} + inline second_type &operator*() { return it->second; } + inline void operator++() { + it->first = i++; + ++it; + } +}; + +template +inserter_second_incrementer_first insert_2nd_inc(Iterator it) { + return inserter_second_incrementer_first(it); +} + +//-------------------------------------------------------------------------------- +template +inline T2 dot(const std::vector &x, const Buffer &y) { + size_t n1 = x.size(), n2 = y.nnz, i1 = 0, i2 = 0; + T2 s = 0; + + while (i1 != n1 && i2 != n2) + if (x[i1] < y[i2]) { + ++i1; + } else if (y[i2] < x[i1]) { + ++i2; + } else { + ++s; + ++i1; + ++i2; + } + + return s; +} + +//-------------------------------------------------------------------------------- +inline float dot(const float *x, const float *x_end, const float *y) { + float result = 0; + for (; x != x_end; ++x, ++y) + result += *x * *y; + return result; +} + +//-------------------------------------------------------------------------------- +// copy +//-------------------------------------------------------------------------------- +template +inline void copy(It1 begin, It1 end, It2 out_begin, It2 out_end) { + std::copy(begin, end, out_begin); +} + +//-------------------------------------------------------------------------------- +/** + * Copies a whole container into another. + * + * Does not allocate memory for b: b needs to have enough room for + * a.size() elements. + * + * @param a the source container + * @param b the destination container + */ +template inline void copy(const T1 &a, T2 &b) { + b.resize(a.size()); + copy(a.begin(), a.end(), b.begin(), b.end()); +} + +//-------------------------------------------------------------------------------- +template +inline void copy(const std::vector &a, size_t n, std::vector &b, + size_t o = 0) { + NTA_ASSERT(o + n <= b.size()); + std::copy(a.begin(), a.begin() + n, b.begin() + o); +} + +//-------------------------------------------------------------------------------- +template +inline void copy(const std::vector &a, size_t i, size_t j, + std::vector &b) { + std::copy(a.begin() + i, a.begin() + j, b.begin() + i); +} + +//-------------------------------------------------------------------------------- +template +inline void copy(const std::vector &a, std::vector &b, size_t offset) { + NTA_ASSERT(offset + a.size() <= b.size()); + std::copy(a.begin(), a.end(), b.begin() + offset); +} + +//-------------------------------------------------------------------------------- +template +inline void copy_indices(const SparseVector &x, Buffer &y) { + NTA_ASSERT(x.nnz <= y.size()); + + for (size_t i = 0; i != x.nnz; ++i) + y[i] = x[i].first; + y.nnz = x.nnz; +} + +//-------------------------------------------------------------------------------- +// TO DENSE +//-------------------------------------------------------------------------------- +template +inline void to_dense_01(It1 ind, It1 ind_end, It2 dense, It2 dense_end) { + { + NTA_ASSERT(ind <= ind_end) << "to_dense: Mismatched iterators"; + NTA_ASSERT(dense <= dense_end) << "to_dense: Mismatched iterators"; + NTA_ASSERT(ind_end - ind <= dense_end - dense) + << "to_dense: Not enough memory"; } - //-------------------------------------------------------------------------------- - template - inline std::vector& operator+=(std::vector& b, const std::vector& a) - { - append(a, b); - return b; - } + typedef typename std::iterator_traits::value_type value_type; - //-------------------------------------------------------------------------------- - template - inline void append(const std::set& a, std::set& b) - { - b.insert(a.begin(), a.end()); - } + // TODO: make faster with single pass? + // (but if's for all the elements might be slower) + std::fill(dense, dense_end, (value_type)0); - //-------------------------------------------------------------------------------- - template - inline std::set& operator+=(std::set& b, const std::set& a) - { - append(a, b); - return b; - } + for (; ind != ind_end; ++ind) + *(dense + *ind) = (value_type)1; +} - //-------------------------------------------------------------------------------- - // map insert or increment - template - inline void increment(std::map& m, const T1& key, const T2& init =1) - { - typename std::map::iterator it = m.find(key); - if (it != m.end()) - ++ it->second; - else - m[key] = init; - } +//-------------------------------------------------------------------------------- +template +inline void to_dense_01(const std::vector &sparse, std::vector &dense) { + to_dense_01(sparse.begin(), sparse.end(), dense.begin(), dense.end()); +} - //-------------------------------------------------------------------------------- - template - inline bool is_in(const K& key, const std::map& m) - { - return m.find(key) != m.end(); - } +//-------------------------------------------------------------------------------- +template +inline void to_dense_01(const Buffer &buffer, OutIt y, OutIt y_end) { + typedef typename std::iterator_traits::value_type value_type; - //-------------------------------------------------------------------------------- - // Deriving from std::map to add frequently used functionality - template , - typename A =std::allocator > > - struct dict : public std::map - { - inline bool has_key(const K& key) const - { - return is_in(key, *this); - } + std::fill(y, y_end, (value_type)0); - // Often useful for histograms, where V is an integral type - inline void increment(const K& key, const V& init =1) - { - nupic::increment(*this, key, init); - } + for (size_t i = 0; i != buffer.nnz; ++i) + y[buffer[i]] = (value_type)1; +} - // Inserts once in the map, or return false if already inserted - // (saves having to write find(...) == this->end()) - inline bool insert_once(const K& key, const V& v) - { - if (has_key(key)) - return false; - else - this->insert(std::make_pair(key, v)); - return true; - } +//-------------------------------------------------------------------------------- +template +inline void to_dense_01(It begin, It end, std::vector &dense) { + to_dense_01(begin, end, dense.begin(), dense.end()); +} - /* - // Returns an existing value for the key, if it is in the dict already, - // or creates one and returns it. (operator[] on std::map does that?) - inline V& operator(const K& key) - { - iterator it = this->find(key); - if (key == end) { - (*this)[key] = V(); - return (*this)[key]; - } else - return *it; - } - */ - }; - - //-------------------------------------------------------------------------------- - // INIT LIST - //-------------------------------------------------------------------------------- - template - struct vector_init_list - { - std::vector& v; +//-------------------------------------------------------------------------------- +template +inline void to_dense_1st_01(const SparseVector &x, OutIt y, OutIt y_end) { + typedef typename std::iterator_traits::value_type value_type; - inline vector_init_list(std::vector& v_ref) : v(v_ref) {} - inline vector_init_list(const vector_init_list& o) : v(o.v) {} + std::fill(y, y_end, (value_type)0); - inline vector_init_list& operator=(const vector_init_list& o) - { v(o.v); return *this; } + for (size_t i = 0; i != x.nnz; ++i) + y[x[i].first] = (value_type)1; +} - template - inline vector_init_list& operator,(const T2& x) - { - v.push_back((T)x); - return *this; - } - }; +//-------------------------------------------------------------------------------- +template +inline void to_dense_01(size_t n, const std::vector &buffer, OutIt y, + OutIt y_end) { + NTA_ASSERT(n <= buffer.size()); - //-------------------------------------------------------------------------------- - template - inline vector_init_list operator+=(std::vector& v, const T2& x) - { - v.push_back((T)x); - return vector_init_list(v); + typedef typename std::iterator_traits::value_type value_type; + + std::fill(y, y_end, (value_type)0); + + const T *b = &buffer[0], *b_end = b + n; + for (; b != b_end; ++b) { + NTA_ASSERT(*b < (size_t)(y_end - y)); + y[*b] = (value_type)1; } +} - //-------------------------------------------------------------------------------- - // TODO: merge with preceding by changing parametrization? - //-------------------------------------------------------------------------------- - template - struct set_init_list +//-------------------------------------------------------------------------------- +/** + * Converts a sparse range described with indices and values to a dense + * range. + */ +template +inline void to_dense(It1 ind, It1 ind_end, It2 nz, It2 nz_end, It3 dense, + It3 dense_end) { { - std::set& v; - - inline set_init_list(std::set& v_ref) : v(v_ref) {} - inline set_init_list(const set_init_list& o) : v(o.v) {} + NTA_ASSERT(ind <= ind_end) << "to_dense: Mismatched ind iterators"; + NTA_ASSERT(dense <= dense_end) << "to_dense: Mismatched dense iterators"; + NTA_ASSERT(ind_end - ind <= dense_end - dense) + << "to_dense: Not enough memory"; + NTA_ASSERT(nz_end - nz == ind_end - ind) + << "to_dense: Mismatched ind and nz ranges"; + } - inline set_init_list& operator=(const set_init_list& o) - { v(o.v); return *this; } + typedef typename std::iterator_traits::value_type value_type; - template - inline set_init_list& operator,(const T2& x) - { - v.insert((T)x); - return *this; - } - }; + std::fill(dense, dense + (ind_end - ind), (value_type)0); - //-------------------------------------------------------------------------------- - template - inline set_init_list operator+=(std::set& v, const T2& x) - { - v.insert((T)x); - return set_init_list(v); - } + for (; ind != ind_end; ++ind, ++nz) + *(dense + *ind) = *nz; +} - //-------------------------------------------------------------------------------- - // FIND IN VECTOR - //-------------------------------------------------------------------------------- - // T1 and T2 to get around constness with pointers - template - inline int find_index(const T1& x, const std::vector& v) - { - for (size_t i = 0; i != v.size(); ++i) - if (v[i] == x) - return (int) i; - return -1; +//-------------------------------------------------------------------------------- +/** + * Needs non-zero indices to be sorted! + */ +template +inline void in_place_sparse_to_dense_01(int n, It begin, It end) { + for (int i = n - 1; i >= 0; --i) { + int p = (int)*(begin + i); + std::fill(begin + p, end, 0); + *(begin + p) = 1; + end = begin + p; } - //-------------------------------------------------------------------------------- - template - inline int find_index(const T1& x, const std::vector >& v) - { - for (size_t i = 0; i != v.size(); ++i) - if (v[i].first == x) - return (int) i; - return -1; - } + std::fill(begin, end, 0); +} - //-------------------------------------------------------------------------------- - template - inline bool not_in(const T& x, const std::vector& v) - { - return std::find(v.begin(), v.end(), x) == v.end(); - } +//-------------------------------------------------------------------------------- +/** + * Pb with size of the vectors? + */ +template +inline void in_place_sparse_to_dense_01(int n, std::vector &x) { + in_place_sparse_to_dense_01(n, x.begin(), x.end()); +} + +//-------------------------------------------------------------------------------- +/** + * Converts a sparse range stored in a dense vector into an (index,value) + * representation. + * + * @param begin + * @param end + * @param ind + * @param nz + * @param eps + */ +template +inline void from_dense( + It1 begin, It1 end, It2 ind, It3 nz, + typename std::iterator_traits::value_type eps = nupic::Epsilon) { + { NTA_ASSERT(begin <= end) << "from_dense: Mismatched dense iterators"; } - //-------------------------------------------------------------------------------- - template - inline bool not_in(const T1& x, const std::vector >& v) - { - typename std::vector >::const_iterator it; - for (it = v.begin(); it != v.end(); ++it) - if (it->first == x) - return false; - return true; - } + typedef size_t size_type; + typedef typename std::iterator_traits::value_type value_type; - //-------------------------------------------------------------------------------- - template - inline bool not_in(const T& x, const std::set& s) - { - return s.find(x) == s.end(); - } + Abs abs_f; - //-------------------------------------------------------------------------------- - template - inline bool is_in(const T& x, const std::vector& v) - { - return ! not_in(x, v); + for (It1 it = begin; it != end; ++it) { + value_type val = *it; + if (abs_f(val) > eps) { + *ind = (size_type)(it - begin); + *nz = val; + ++ind; + ++nz; + } } +} - //-------------------------------------------------------------------------------- - template - inline bool is_in(const T1& x, const std::vector >& v) - { - return ! not_in(x, v); - } +//-------------------------------------------------------------------------------- +template +inline void from_dense(It begin, It end, Buffer &buffer) { + NTA_ASSERT((size_t)(end - begin) <= buffer.size()); - //-------------------------------------------------------------------------------- - template - inline bool is_in(const T& x, const std::set& s) - { - return ! not_in(x, s); - } + typename Buffer::iterator it2 = buffer.begin(); - //-------------------------------------------------------------------------------- - template - inline bool is_sorted(It begin, It end, bool ascending =true, bool unique =true) - { - if (begin < end) { - for (It prev = begin, it = ++begin; it < end; ++it, ++prev) - - if (ascending) { - if (unique) { - if (*prev >= *it) - return false; - } else { - if (*prev > *it) - return false; - } - } else { - if (unique) { - if (*prev <= *it) - return false; - } else { - if (*prev < *it) - return false; - } - } + for (It it = begin; it != end; ++it) + if (*it != 0) { + *it2++ = (T)(it - begin); } - return true; - } - - //-------------------------------------------------------------------------------- - template - inline bool is_sorted(const std::vector& x, bool ascending =true, bool unique =true) - { - return is_sorted(x.begin(), x.end(), ascending, unique); - } + buffer.nnz = it2 - buffer.begin(); +} - //-------------------------------------------------------------------------------- - template - inline bool operator==(const std::vector& a, const std::vector& b) - { - NTA_ASSERT(a.size() == b.size()); - if (a.size() != b.size()) - return false; - for (size_t i = 0; i != a.size(); ++i) - if (a[i] != b[i]) - return false; - return true; +//-------------------------------------------------------------------------------- +// erase from vector +//-------------------------------------------------------------------------------- +/** + * Erases a value from a vector. + * + * The STL process to really remove a value from a vector is tricky. + * + * @param v the vector + * @param val the value to remove + */ +template inline void remove(const T &del, std::vector &v) { + v.erase(std::remove(v.begin(), v.end(), del), v.end()); +} + +//-------------------------------------------------------------------------------- +template +inline void remove(const std::vector &del, std::vector &b) { + for (size_t i = 0; i != del.size(); ++i) + remove(del[i], b); +} + +//-------------------------------------------------------------------------------- +template +inline void remove_for_pairs(const T1 &key, std::vector> &v) { + typename std::vector>::const_iterator it; + for (it = v.begin(); it != v.end() && it->first != key; ++it) + ; + remove(*it, v); +} + +//-------------------------------------------------------------------------------- +template +inline void remove_from_end(const T &elt, std::vector &a) { + for (int i = a.size() - 1; i >= 0; --i) { + if (a[i] == elt) { + for (size_t j = i; j < a.size() - 1; ++j) + a[j] = a[j + 1]; + a.resize(a.size() - 1); + return; + } } +} - //-------------------------------------------------------------------------------- - template - inline bool operator!=(const std::vector& a, const std::vector& b) - { - return !(a == b); +//-------------------------------------------------------------------------------- +/** + * Given a vector of indices, removes the elements of a at those indices + * (indices before any removal is carried out), where a is a vector of pairs. + * + * Need to pass in non-empty vector of sorted, unique indices to delete. + */ +// TODO: remove this? Should be covered just below?? +template +inline void remove_for_pairs(const std::vector &del, + std::vector> &a) { + NTA_ASSERT(std::set(del.begin(), del.end()).size() == del.size()); + + if (del.empty()) + return; + + size_t old = del[0] + 1, cur = del[0], d = 1; + + while (old < a.size() && d < del.size()) { + if (old == (size_t)del[d]) { + ++d; + ++old; + } else if ((size_t)del[d] < old) { + ++d; + } else { + a[cur++] = a[old++]; + } } - //-------------------------------------------------------------------------------- - template - inline bool operator==(const std::map& a, const std::map& b) - { - typename std::map::const_iterator ita = a.begin(), itb = b.begin(); - for (; ita != a.end(); ++ita, ++itb) - if (ita->first != itb->first || ita->second != itb->second) - return false; - return true; - } + while (old < a.size()) + a[cur++] = a[old++]; - //-------------------------------------------------------------------------------- - template - inline bool operator!=(const std::map& a, const std::map& b) - { - return !(a == b); - } + a.resize(a.size() - del.size()); +} - //-------------------------------------------------------------------------------- - /** - * Proxy for an insert iterator that allows inserting at the second element - * when iterating over a container of pairs. - */ - template - struct inserter_second - { - typedef typename std::iterator_traits::value_type pair_type; - typedef typename pair_type::second_type second_type; - typedef second_type value_type; +//-------------------------------------------------------------------------------- +/** + * Remove several elements from a vector, the elements to remove being specified + * by their index (in del). After this call, a's size is reduced. Requires + * default constructor on T to be defined (for resize). O(n). + */ +template +inline void remove_at(const std::vector &del, std::vector &a) { + NTA_ASSERT(std::set(del.begin(), del.end()).size() == del.size()); - Iterator it; + if (del.empty()) + return; - inline inserter_second(Iterator _it) : it(_it) {} - inline second_type& operator*() { return it->second; } - inline void operator++() { ++it; } - }; + size_t old = del[0] + 1, cur = del[0], d = 1; - template - inserter_second insert_2nd(Iterator it) - { - return inserter_second(it); + while (old < a.size() && d < del.size()) { + if (old == (size_t)del[d]) { + ++d; + ++old; + } else if ((size_t)del[d] < old) { + ++d; + } else { + a[cur++] = a[old++]; + } } - //-------------------------------------------------------------------------------- - /** - * Proxy for an insert iterator that allows inserting at the second element when - * iterating over a container of pairs, while setting the first element to the - * current index value (watch out if iterator passed to constructor is not - * pointing to the beginning of the container!) - */ - template - struct inserter_second_incrementer_first - { - typedef typename std::iterator_traits::value_type pair_type; - typedef typename pair_type::second_type second_type; - typedef second_type value_type; - - Iterator it; - size_t i; + while (old < a.size()) + a[cur++] = a[old++]; - inline inserter_second_incrementer_first(Iterator _it) - : it(_it), i(0) {} - inline second_type& operator*() { return it->second; } - inline void operator++() { it->first = i++; ++it; } - }; + a.resize(a.size() - del.size()); +} - template - inserter_second_incrementer_first insert_2nd_inc(Iterator it) - { - return inserter_second_incrementer_first(it); - } +//-------------------------------------------------------------------------------- +/** + * Finds index of elt in ref, and removes corresponding element of a. + */ +template +inline void remove(const T2 &elt, std::vector &a, + const std::vector &ref) { + a.erase(a.begin() + find_index(elt, ref)); +} + +//-------------------------------------------------------------------------------- +template +inline void remove(const std::vector &del, std::set &a) { + for (size_t i = 0; i != del.size(); ++i) + a.erase(del[i]); +} + +//-------------------------------------------------------------------------------- +template +inline void remove(const std::set &y, std::vector &x) { + std::vector del; + + for (size_t i = 0; i != x.size(); ++i) + if (y.find(x[i]) != y.end()) { + NTA_ASSERT(not_in(x[i], del)); + del.push_back(x[i]); + } + + nupic::remove(del, x); +} + +//-------------------------------------------------------------------------------- +template +inline std::set &operator-=(std::set &a, const std::vector &b) { + remove(b, a); + return a; +} + +//-------------------------------------------------------------------------------- +template +inline std::vector &operator-=(std::vector &a, const std::vector &b) { + remove(b, a); + return a; +} + +//-------------------------------------------------------------------------------- +// DIFFERENCES +//-------------------------------------------------------------------------------- +/** + * Returns a vector that contains the indices of the positions where x and y + * have different values. + */ +template +inline void find_all_differences(const std::vector &x, + const std::vector &y, + std::vector &diffs) { + NTA_ASSERT(x.size() == y.size()); + diffs.clear(); + for (size_t i = 0; i != x.size(); ++i) + if (x[i] != y[i]) + diffs.push_back(i); +} + +//-------------------------------------------------------------------------------- +// fill +//-------------------------------------------------------------------------------- +/** + * Fills a container with the given value. + * + * @param a + * @param val + */ +template +inline void fill(T &a, const typename T::value_type &val) { + typename T::iterator i = a.begin(), e = a.end(); - //-------------------------------------------------------------------------------- - template - inline T2 dot(const std::vector& x, const Buffer& y) - { - size_t n1 = x.size(), n2 = y.nnz, i1 = 0, i2 = 0; - T2 s = 0; - - while (i1 != n1 && i2 != n2) - if (x[i1] < y[i2]) { - ++i1; - } else if (y[i2] < x[i1]) { - ++i2; - } else { - ++s; - ++i1; ++i2; - } + for (; i != e; ++i) + *i = val; +} - return s; - } +//-------------------------------------------------------------------------------- +/** + * Zeroes out a range. + * + * @param begin + * @param end + */ +template inline void zero(It begin, It end) { + { NTA_ASSERT(begin <= end) << "zero: Invalid input range"; } - //-------------------------------------------------------------------------------- - inline float dot(const float* x, const float* x_end, const float* y) - { - float result = 0; - for (; x != x_end; ++x, ++y) - result += *x * *y; - return result; - } + typedef typename std::iterator_traits::value_type T; - //-------------------------------------------------------------------------------- - // copy - //-------------------------------------------------------------------------------- - template - inline void copy(It1 begin, It1 end, It2 out_begin, It2 out_end) - { - std::copy(begin, end, out_begin); - } + for (; begin != end; ++begin) + *begin = T(0); +} - //-------------------------------------------------------------------------------- - /** - * Copies a whole container into another. - * - * Does not allocate memory for b: b needs to have enough room for - * a.size() elements. - * - * @param a the source container - * @param b the destination container - */ - template - inline void copy(const T1& a, T2& b) - { - b.resize(a.size()); - copy(a.begin(), a.end(), b.begin(), b.end()); - } +//-------------------------------------------------------------------------------- +/** + * Zeroes out a whole container. + * + * @param a the container + */ +template inline void zero(T &a) { zero(a.begin(), a.end()); } - //-------------------------------------------------------------------------------- - template - inline void copy(const std::vector& a, size_t n, std::vector& b, size_t o =0) - { - NTA_ASSERT(o + n <= b.size()); - std::copy(a.begin(), a.begin() + n, b.begin() + o); - } +//-------------------------------------------------------------------------------- +template inline void set_to_zero(T &a) { zero(a); } - //-------------------------------------------------------------------------------- - template - inline void copy(const std::vector& a, size_t i, size_t j, std::vector& b) - { - std::copy(a.begin() + i, a.begin() + j, b.begin() + i); - } +//-------------------------------------------------------------------------------- +template +inline void set_to_zero(std::vector &a, size_t begin, size_t end) { + zero(a.begin() + begin, a.begin() + end); +} - //-------------------------------------------------------------------------------- - template - inline void copy(const std::vector& a, std::vector& b, size_t offset) - { - NTA_ASSERT(offset + a.size() <= b.size()); - std::copy(a.begin(), a.end(), b.begin() + offset); - } +//-------------------------------------------------------------------------------- +/** + * Fills a range with ones. + * + * @param begin + * @param end + */ +template inline void ones(It begin, It end) { + { NTA_ASSERT(begin <= end) << "ones: Invalid input range"; } - //-------------------------------------------------------------------------------- - template - inline void copy_indices(const SparseVector& x, Buffer& y) - { - NTA_ASSERT(x.nnz <= y.size()); + typedef typename std::iterator_traits::value_type T; - for (size_t i = 0; i != x.nnz; ++i) - y[i] = x[i].first; - y.nnz = x.nnz; - } + for (; begin != end; ++begin) + *begin = T(1); +} - //-------------------------------------------------------------------------------- - // TO DENSE - //-------------------------------------------------------------------------------- - template - inline void to_dense_01(It1 ind, It1 ind_end, It2 dense, It2 dense_end) - { - { - NTA_ASSERT(ind <= ind_end) - << "to_dense: Mismatched iterators"; - NTA_ASSERT(dense <= dense_end) - << "to_dense: Mismatched iterators"; - NTA_ASSERT(ind_end - ind <= dense_end - dense) - << "to_dense: Not enough memory"; - } - - typedef typename std::iterator_traits::value_type value_type; +//-------------------------------------------------------------------------------- +/** + * Fills a container with ones. + * + * @param a the container + */ +template inline void ones(T &a) { ones(a.begin(), a.end()); } - // TODO: make faster with single pass? - // (but if's for all the elements might be slower) - std::fill(dense, dense_end, (value_type) 0); +//-------------------------------------------------------------------------------- +template inline void set_to_one(std::vector &a) { ones(a); } - for (; ind != ind_end; ++ind) - *(dense + *ind) = (value_type) 1; - } +//-------------------------------------------------------------------------------- +template +inline void set_to_one(std::vector &a, size_t begin, size_t end) { + ones(a.begin() + begin, a.begin() + end); +} - //-------------------------------------------------------------------------------- - template - inline void to_dense_01(const std::vector& sparse, std::vector& dense) +//-------------------------------------------------------------------------------- +/** + * Sets a range to 0, except for a single value at pos, which will be equal to + * val. + * + * @param pos the position of the single non-zero value + * @param begin + * @param end + * @param val the value of the non-zero value in the range + */ +template +inline void dirac(size_t pos, It begin, It end, + typename std::iterator_traits::value_type val = 1) { { - to_dense_01(sparse.begin(), sparse.end(), dense.begin(), dense.end()); + NTA_ASSERT(begin <= end) << "dirac: Invalid input range"; + + NTA_ASSERT(0 <= pos && pos < (size_t)(end - begin)) + << "dirac: Invalid position: " << pos + << " - Should be between 0 and: " << (size_t)(end - begin); } - //-------------------------------------------------------------------------------- - template - inline void to_dense_01(const Buffer& buffer, OutIt y, OutIt y_end) - { - typedef typename std::iterator_traits::value_type value_type; + typedef typename std::iterator_traits::value_type value_type; - std::fill(y, y_end, (value_type) 0); + std::fill(begin, end, (value_type)0); + *(begin + pos) = val; +} - for (size_t i = 0; i != buffer.nnz; ++i) - y[buffer[i]] = (value_type) 1; +//-------------------------------------------------------------------------------- +/** + * Sets a range to 0, except for a single value at pos, which will be equal to + * val. + * + * @param pos the position of the single non-zero value + * @param c the container + * @param val the value of the Dirac + */ +template +inline void dirac(size_t pos, C &c, typename C::value_type val = 1) { + { + NTA_ASSERT(pos >= 0 && pos < c.size()) + << "dirac: Can't set Dirac at pos: " << pos + << " when container has size: " << c.size(); } - //-------------------------------------------------------------------------------- - template - inline void to_dense_01(It begin, It end, std::vector& dense) + dirac(pos, c.begin(), c.end(), val); +} + +//-------------------------------------------------------------------------------- +/** + * Computes the CDF of the given range seen as a discrete PMF. + * + * @param begin1 the beginning of the discrete PMF range + * @param end1 one past the end of the discrete PMF range + * @param begin2 the beginning of the CDF range + */ +template +inline void cumulative(It1 begin1, It1 end1, It2 begin2, It2 end2) { { - to_dense_01(begin, end, dense.begin(), dense.end()); + NTA_ASSERT(begin1 <= end1) << "cumulative: Invalid input range"; + NTA_ASSERT(begin2 <= end2) << "cumulative: Invalid output range"; + NTA_ASSERT(end1 - begin1 == end2 - begin2) + << "cumulative: Incompatible sizes"; } - //-------------------------------------------------------------------------------- - template - inline void to_dense_1st_01(const SparseVector& x, OutIt y, OutIt y_end) - { - typedef typename std::iterator_traits::value_type value_type; + typedef typename std::iterator_traits::value_type value_type; - std::fill(y, y_end, (value_type) 0); + It2 prev = begin2; + *begin2++ = (value_type)*begin1++; + for (; begin1 < end1; ++begin1, ++begin2, ++prev) + *begin2 = *prev + (value_type)*begin1; +} - for (size_t i = 0; i != x.nnz; ++i) - y[x[i].first] = (value_type) 1; +//-------------------------------------------------------------------------------- +/** + * Computes the CDF of a discrete PMF. + * + * @param pmf the PMF + * @param cdf the CDF + */ +template +inline void cumulative(const C1 &pmf, C2 &cdf) { + cumulative(pmf.begin(), pmf.end(), cdf.begin(), cdf.end()); +} + +//-------------------------------------------------------------------------------- +/** + * Finds percentiles. + */ +template +inline void percentiles(size_t n_percentiles, It1 begin1, It1 end1, It2 begin2, + It2 end2, bool alreadyNormalized = false) { + { + NTA_ASSERT(begin1 <= end1) << "percentiles: Invalid input range"; + NTA_ASSERT(begin2 <= end2) << "percentiles: Invalid output range"; + NTA_ASSERT(end1 - begin1 == end2 - begin2) + << "percentiles: Mismatched ranges"; } - //-------------------------------------------------------------------------------- - template - inline void - to_dense_01(size_t n, const std::vector& buffer, OutIt y, OutIt y_end) - { - NTA_ASSERT(n <= buffer.size()); + typedef typename std::iterator_traits::value_type value_type; + typedef typename std::iterator_traits::value_type size_type; - typedef typename std::iterator_traits::value_type value_type; + value_type n = (value_type)(alreadyNormalized ? 1.0f : 0.0f); - std::fill(y, y_end, (value_type) 0); + if (!alreadyNormalized) + for (It1 it = begin1; it != end1; ++it) + n += *it; - const T* b = &buffer[0], *b_end = b + n; - for (; b != b_end; ++b) { - NTA_ASSERT(*b < (size_t)(y_end - y)); - y[*b] = (value_type) 1; - } + value_type increment = n / value_type(n_percentiles); + value_type sum = (value_type)0.0f; + size_type p = (size_type)0; + + for (value_type v = increment; v < n; v += increment) { + for (; sum < v; ++p) + sum += *begin1++; + *begin2++ = p; } +} + +//-------------------------------------------------------------------------------- +template +inline void percentiles(size_t n_percentiles, const C1 &pmf, C2 &pcts) { + percentiles(n_percentiles, pmf.begin(), pmf.end(), pcts.begin()); +} - //-------------------------------------------------------------------------------- - /** - * Converts a sparse range described with indices and values to a dense - * range. - */ - template - inline void to_dense(It1 ind, It1 ind_end, It2 nz, It2 nz_end, - It3 dense, It3 dense_end) +//-------------------------------------------------------------------------------- +template +inline void rand_range( + It begin, It end, const typename std::iterator_traits::value_type &min_, + const typename std::iterator_traits::value_type &max_, RNG &rng) { { - { - NTA_ASSERT(ind <= ind_end) - << "to_dense: Mismatched ind iterators"; - NTA_ASSERT(dense <= dense_end) - << "to_dense: Mismatched dense iterators"; - NTA_ASSERT(ind_end - ind <= dense_end - dense) - << "to_dense: Not enough memory"; - NTA_ASSERT(nz_end - nz == ind_end - ind) - << "to_dense: Mismatched ind and nz ranges"; - } + NTA_ASSERT(begin <= end) << "rand_range: Invalid input range"; + NTA_ASSERT(min_ < max_) + << "rand_range: Invalid min/max: " << min_ << " " << max_; + } - typedef typename std::iterator_traits::value_type value_type; + typedef typename std::iterator_traits::value_type value_type; - std::fill(dense, dense + (ind_end - ind), (value_type) 0); + double range = double(max_ - min_) / double(rng.max() - rng.min()); + for (; begin != end; ++begin) + *begin = value_type(double(rng()) * range + min_); +} - for (; ind != ind_end; ++ind, ++nz) - *(dense + *ind) = *nz; +//-------------------------------------------------------------------------------- +/** + * Initializes a range with random values. + * + * @param begin + * @param end + * @param min_ + * @param max_ + */ +template +inline void +rand_range(It begin, It end, + const typename std::iterator_traits::value_type &min_, + const typename std::iterator_traits::value_type &max_) { + nupic::Random rng; + rand_range(begin, end, min_, max_, rng); +} + +//-------------------------------------------------------------------------------- +template +inline void rand_range(T &a, const typename T::value_type &min, + const typename T::value_type &max, RNG &rng) { + rand_range(a.begin(), a.end(), min, max, rng); +} + +//-------------------------------------------------------------------------------- +/** + * Initializes a container with random values. + * + * @param a the container + * @param min + * @param max + */ +template +inline void rand_range(T &a, const typename T::value_type &min, + const typename T::value_type &max) { + rand_range(a.begin(), a.end(), min, max); +} + +//-------------------------------------------------------------------------------- +template +inline void rand_float_range(std::vector &x, size_t start, size_t end, + RNG &rng) { + for (size_t i = start; i != end; ++i) + x[i] = (float)rng.getReal64(); +} + +//-------------------------------------------------------------------------------- +/** + * Initializes a range with the normal distribution. + * + * @param begin + * @param end + * @param mean + * @param stddev + */ +template +inline void +normal_range(It begin, It end, + const typename std::iterator_traits::value_type &mean, + const typename std::iterator_traits::value_type &stddev) { + { NTA_ASSERT(begin <= end) << "normal_range: Invalid input range"; } + + // TODO implement numerical recipes' method +} + +//-------------------------------------------------------------------------------- +template +inline void rand_range_01(It begin, It end, double pct, RNG &rng) { + { + NTA_ASSERT(begin <= end) << "rand_range_01: Invalid input range"; + NTA_ASSERT(0 <= pct && pct < 1) + << "rand_range_01: Invalid threshold: " << pct + << " - Should be between 0 and 1"; } - //-------------------------------------------------------------------------------- - /** - * Needs non-zero indices to be sorted! - */ - template - inline void in_place_sparse_to_dense_01(int n, It begin, It end) - { - for (int i = n - 1; i >= 0; --i) { - int p = (int) *(begin + i); - std::fill(begin + p, end, 0); - *(begin + p) = 1; - end = begin + p; - } + typedef typename std::iterator_traits::value_type value_type; - std::fill(begin, end, 0); - } + for (; begin != end; ++begin) + *begin = (value_type)(double(rng()) / double(rng.max() - rng.min()) > pct); +} - //-------------------------------------------------------------------------------- - /** - * Pb with size of the vectors? - */ - template - inline void in_place_sparse_to_dense_01(int n, std::vector& x) - { - in_place_sparse_to_dense_01(n, x.begin(), x.end()); - } +//-------------------------------------------------------------------------------- +/** + * Initializes a range with a random binary vector. + * + * @param begin + * @param end + * @param pct the percentage of ones + */ +template +inline void rand_range_01(It begin, It end, double pct = .5) { + nupic::Random rng; + rand_range_01(begin, end, pct, rng); +} + +//-------------------------------------------------------------------------------- +template +inline void rand_range_01(T &a, double pct, RNG &rng) { + rand_range_01(a.begin(), a.end(), pct, rng); +} + +//-------------------------------------------------------------------------------- +/** + * Initializes a container with a random binary vector. + * + * @param a the container + * @param pct the percentage of ones + */ +template inline void rand_range_01(T &a, double pct = .5) { + nupic::Random rng; + rand_range_01(a.begin(), a.end(), pct, rng); +} + +//-------------------------------------------------------------------------------- +/** + * Initializes a range with a ramp function. + * + * @param begin + * @param end + * @param start the start value of the ramp + * @param step the step of the ramp + */ +template +inline void ramp_range(It begin, It end, T start = 0, T step = 1) { + { NTA_ASSERT(begin <= end) << "ramp_range: Invalid input range"; } - //-------------------------------------------------------------------------------- - /** - * Converts a sparse range stored in a dense vector into an (index,value) - * representation. - * - * @param begin - * @param end - * @param ind - * @param nz - * @param eps - */ - template - inline void - from_dense(It1 begin, It1 end, It2 ind, It3 nz, - typename std::iterator_traits::value_type eps = nupic::Epsilon) + for (; begin != end; ++begin, start += step) + *begin = start; +} + +//-------------------------------------------------------------------------------- +/** + * Initializes a range with a ramp function. + * + * @param a the container + * @param start the start value of the ramp + * @param step the step of the ramp + */ +template +inline void ramp_range(T &a, typename T::value_type start = 0, + typename T::value_type step = 1) { + ramp_range(a.begin(), a.end(), start); +} + +//-------------------------------------------------------------------------------- +/** + * Fills a range with values taken randomly from another range. + * + * @param begin + * @param end + * @param enum_begin the values to draw from + * @param enum_end the values to draw from + * @param replace whether to draw with or without replacements + */ +template +inline void rand_enum_range(It1 begin, It1 end, It2 enum_begin, It2 enum_end, + bool replace, RNG &rng) { { - { - NTA_ASSERT(begin <= end) - << "from_dense: Mismatched dense iterators"; - } + NTA_ASSERT(begin <= end) << "rand_enum_range: Invalid input range"; + + NTA_ASSERT(enum_begin <= enum_end) + << "rand_enum_range: Invalid values range"; + } - typedef size_t size_type; - typedef typename std::iterator_traits::value_type value_type; + typedef typename std::iterator_traits::value_type value_type; - Abs abs_f; + size_t n = (size_t)(enum_end - enum_begin); - for (It1 it = begin; it != end; ++it) { - value_type val = *it; - if (abs_f(val) > eps) { - *ind = (size_type) (it - begin); - *nz = val; - ++ind; ++nz; - } - } - } + if (replace) { - //-------------------------------------------------------------------------------- - template - inline void from_dense(It begin, It end, Buffer& buffer) - { - NTA_ASSERT((size_t)(end - begin) <= buffer.size()); + for (; begin != end; ++begin) + *begin = (value_type) * (enum_begin + rng() % n); - typename Buffer::iterator it2 = buffer.begin(); + } else { - for (It it = begin; it != end; ++it) - if (*it != 0) { - *it2++ = (T)(it - begin); - } + std::vector ind(n); + ramp_range(ind); - buffer.nnz = it2 - buffer.begin(); + for (; begin != end; ++begin) { + size_t p = rng() % ind.size(); + *begin = (value_type) * (enum_begin + p); + remove(p, ind); + } } +} - //-------------------------------------------------------------------------------- - // erase from vector - //-------------------------------------------------------------------------------- - /** - * Erases a value from a vector. - * - * The STL process to really remove a value from a vector is tricky. - * - * @param v the vector - * @param val the value to remove - */ - template - inline void remove(const T& del, std::vector& v) +//-------------------------------------------------------------------------------- +/** + * Fills a range with values taken randomly from another range. + * + * @param begin + * @param end + * @param enum_begin the values to draw from + * @param enum_end the values to draw from + * @param replace whether to draw with or without replacements + */ +template +inline void rand_enum_range(It1 begin, It1 end, It2 enum_begin, It2 enum_end, + bool replace = false) { { - v.erase(std::remove(v.begin(), v.end(), del), v.end()); + NTA_ASSERT(begin <= end) << "rand_enum_range: Invalid input range"; + NTA_ASSERT(enum_begin <= enum_end) << "rand_enum_range: Invalid enum range"; } - //-------------------------------------------------------------------------------- - template - inline void remove(const std::vector& del, std::vector& b) + nupic::Random rng; + rand_enum_range(begin, end, enum_begin, enum_end, replace, rng); +} + +//-------------------------------------------------------------------------------- +/** + * Fills a container with values taken randomly from another container. + * + * @param a the container to fill + * @param b the container of the values to use + * @param replace whether the draw with replacements or not + */ +template +inline void rand_enum_range(C1 &a, const C2 &b, bool replace = false) { + rand_enum_range(a.begin(), a.end(), b.begin(), b.end(), replace); +} + +//-------------------------------------------------------------------------------- +template +inline void rand_enum_range(C1 &a, const C2 &b, bool replace, RNG &rng) { + rand_enum_range(a.begin(), a.end(), b.begin(), b.end(), replace, rng); +} + +//-------------------------------------------------------------------------------- +/** + * Fills a range with a random permutation of a ramp. + * + * @param begin + * @param end + * @param start the start value of the ramp + * @param step the step value of the ramp + */ +template +inline void random_perm_interval( + It begin, It end, typename std::iterator_traits::value_type start, + typename std::iterator_traits::value_type step, RNG &rng) { + { NTA_ASSERT(begin <= end) << "random_perm_interval 1: Invalid input range"; } + + ramp_range(begin, end, start, step); + std::random_shuffle(begin, end, rng); +} + +//-------------------------------------------------------------------------------- +/** + * Fills a range with a random permutation of a ramp. + * + * @param begin + * @param end + * @param start the start value of the ramp + * @param step the step value of the ramp + */ +template +inline void +random_perm_interval(It begin, It end, + typename std::iterator_traits::value_type start = 0, + typename std::iterator_traits::value_type step = 1) { + { NTA_ASSERT(begin <= end) << "random_perm_interval 2: Invalid input range"; } + + nupic::Random rng; + random_perm_interval(begin, end, start, step, rng); +} + +//-------------------------------------------------------------------------------- +/** + * Fills a container with a random permutation of a ramp. + * + * @param c the container to fill + * @param start the start value of the ramp + * @param step the step value of the ramp + */ +template +inline void random_perm_interval(C &c, typename C::value_type start = 0, + typename C::value_type step = 1) { + random_perm_interval(c.begin(), c.end(), start, step); +} + +//-------------------------------------------------------------------------------- +template +inline void random_perm_interval(C &c, typename C::value_type start, + typename C::value_type step, RNG &rng) { + random_perm_interval(c.begin(), c.end(), start, step, rng); +} + +//-------------------------------------------------------------------------------- +/** + * Draws a random sample from a range, without replacements. + * + * Currently assumes that the first range is larger than the second. + * + * @param begin1 + * @param end1 + * @param begin2 + * @param end2 + */ +template +inline void random_sample(It1 begin1, It1 end1, It2 begin2, It2 end2, + RNG &rng) { { - for (size_t i = 0; i != del.size(); ++i) - remove(del[i], b); + NTA_ASSERT(begin1 <= end1) << "random_sample 1: Invalid value range"; + NTA_ASSERT(begin2 <= end2) << "random_sample 1: Invalid output range"; } - //-------------------------------------------------------------------------------- - template - inline void remove_for_pairs(const T1& key, std::vector >& v) + size_t n1 = (size_t)(end1 - begin1); + std::vector perm(n1); + random_perm_interval(perm, 0, n1, rng); + for (size_t p = 0; begin2 != end2; ++begin2, ++p) + *begin2 = *(begin1 + p); +} + +//-------------------------------------------------------------------------------- +/** + * Draws a random sample from a range, without replacements. + * + * Currently assumes that the first range is larger than the second. + * + * @param begin1 + * @param end1 + * @param begin2 + * @param end2 + */ +template +inline void random_sample(It1 begin1, It1 end1, It2 begin2, It2 end2) { { - typename std::vector >::const_iterator it; - for (it = v.begin(); it != v.end() && it->first != key; ++it); - remove(*it, v); + NTA_ASSERT(begin1 <= end1) << "random_sample 2: Invalid value range"; + NTA_ASSERT(begin2 <= end2) << "random_sample 2: Invalid output range"; } - //-------------------------------------------------------------------------------- - template - inline void remove_from_end(const T& elt, std::vector& a) + nupic::Random rng; + random_sample(begin1, end1, begin2, end2, rng); +} + +//-------------------------------------------------------------------------------- +/** + * Draws a random sample from a container, and uses the values to initialize + * another container. + * + * @param c1 the container to initialize + * @param c2 the container from which to take the values + */ +template +inline void random_sample(const std::vector &c1, std::vector &c2) { + random_sample(c1.begin(), c1.end(), c2.begin(), c2.end()); +} + +//-------------------------------------------------------------------------------- +template +inline void random_sample(const std::vector &c1, std::vector &c2, + RNG &rng) { + random_sample(c1.begin(), c1.end(), c2.begin(), c2.end(), rng); +} + +//-------------------------------------------------------------------------------- +/** + * Initializes a container with elements taken from the specified ramp range, + * randomly. + * + * @param c the container to fill + * @param size the size of the ramp + * @param start the first value of the ramp + * @param step the step of the ramp + */ +template +inline void random_sample(std::vector &c, size_t size, size_t start, + size_t step, RNG &rng) { + ramp_range(c, start, step); + std::random_shuffle(c.begin(), c.end(), rng); +} + +//-------------------------------------------------------------------------------- +template +inline void random_sample(std::vector &c, size_t size, size_t start = 0, + size_t step = 1) { + nupic::Random rng; + random_sample(c, size, start, step, rng); +} + +//-------------------------------------------------------------------------------- +template inline void random_sample(size_t n, std::vector &a) { + NTA_ASSERT(0 < a.size()); + + std::vector x(n); + for (size_t i = 0; i != n; ++i) + x[i] = i; + std::random_shuffle(x.begin(), x.end()); + std::copy(x.begin(), x.begin() + a.size(), a.begin()); +} + +//-------------------------------------------------------------------------------- +template inline void random_sample(std::vector &a) { + random_sample(a, a.size()); +} + +//-------------------------------------------------------------------------------- +template +inline void random_sample(const std::set &a, std::vector &b) { + NTA_ASSERT(0 < b.size()); + + std::vector aa(a.begin(), a.end()); + std::random_shuffle(aa.begin(), aa.end()); + std::copy(aa.begin(), aa.begin() + b.size(), b.begin()); +} + +//-------------------------------------------------------------------------------- +template +inline void random_binary(float proba, std::vector &x) { + size_t threshold = (size_t)(proba * 65535); + std::fill(x.begin(), x.end(), 0); + for (size_t i = 0; i != x.size(); ++i) + x[i] = ((size_t)rand() % 65535 < threshold) ? 1 : 0; +} + +//-------------------------------------------------------------------------------- +/** + * Generates a matrix of random (index,value) pairs from [0..ncols], with + * nnzpc numbers per row, and n columns [generates a constant sparse matrix + * with constant number of non-zeros per row]. This uses a uniform distribution + * of the non-zero bits. + */ +template +inline void random_pair_sample(size_t nrows, size_t ncols, size_t nnzpr, + std::vector> &a, + const T2 &init_nz_val, int seed = -1, + bool sorted = true) { { - for (int i = a.size() - 1; i >= 0; --i) { - if (a[i] == elt) { - for (size_t j = i; j < a.size() - 1; ++j) - a[j] = a[j+1]; - a.resize(a.size() - 1); - return; - } - } + NTA_ASSERT(0 < a.size()); + NTA_ASSERT(nnzpr <= ncols); } - //-------------------------------------------------------------------------------- - /** - * Given a vector of indices, removes the elements of a at those indices (indices - * before any removal is carried out), where a is a vector of pairs. - * - * Need to pass in non-empty vector of sorted, unique indices to delete. - */ - // TODO: remove this? Should be covered just below?? - template - inline void - remove_for_pairs(const std::vector& del, std::vector >& a) - { - NTA_ASSERT(std::set(del.begin(),del.end()).size() == del.size()); + a.resize(nrows * nnzpr); - if (del.empty()) - return; +#if defined(NTA_ARCH_32) && defined(NTA_OS_DARWIN) + nupic::Random rng(seed == -1 ? arc4random() : seed); +#else + nupic::Random rng(seed == -1 ? rand() : seed); +#endif - size_t old = del[0] + 1, cur = del[0], d = 1; + std::vector x(ncols); + for (size_t i = 0; i != ncols; ++i) + x[i] = i; + for (size_t i = 0; i != nrows; ++i) { + std::random_shuffle(x.begin(), x.end(), rng); + if (sorted) + std::sort(x.begin(), x.begin() + nnzpr); + size_t offset = i * nnzpr; + for (size_t j = 0; j != nnzpr; ++j) + a[offset + j] = std::pair(x[j], init_nz_val); + } +} + +//-------------------------------------------------------------------------------- +/** + * Generates a matrix of random (index,value) pairs from [0..ncols], with + * nnzpr numbers per row, and n columns [generates a constant sparse matrix + * with constant number of non-zeros per row]. This uses a 2D Gaussian + * distribution for the on bits of each coincidence. That is, each coincidence + * is seen as a folded 2D array, and a 2D Gaussian is used to distribute the on + * bits of each coincidence. + * + * Each row is seen as an image of size (ncols / rf_x) by rf_x. + * 'sigma' is the parameter of the Gaussian, which is centered at the center of + * the image. Uses a symmetric Gaussian, specified only by the location of its + * max and a single sigma parameter (no Sigma matrix). We use the symmetry of + * the 2d gaussian to simplify computations. The conditional distribution + * obtained from the 2d gaussian by fixing y is again a gaussian, with + * parameters than can be easily deduced from the original 2d gaussian. + */ +template +inline void gaussian_2d_pair_sample(size_t nrows, size_t ncols, size_t nnzpr, + size_t rf_x, T2 sigma, + std::vector> &a, + const T2 &init_nz_val, int seed = -1, + bool sorted = true) { + { + NTA_ASSERT(ncols % rf_x == 0); + NTA_ASSERT(nnzpr <= ncols); + NTA_ASSERT(0 < sigma); + } - while (old < a.size() && d < del.size()) { - if (old == (size_t) del[d]) { - ++d; ++old; - } else if ((size_t) del[d] < old) { - ++d; - } else { - a[cur++] = a[old++]; - } - } + a.resize(nrows * nnzpr); - while (old < a.size()) - a[cur++] = a[old++]; +#if defined(NTA_ARCH_32) && defined(NTA_OS_DARWIN) + nupic::Random rng(seed == -1 ? arc4random() : seed); +#else + nupic::Random rng(seed == -1 ? rand() : seed); +#endif - a.resize(a.size() - del.size()); - } + size_t rf_y = ncols / rf_x; + T2 c_x = float(rf_x - 1.0) / 2.0, c_y = float(rf_y - 1.0) / 2.0; + Gaussian2D sg2d(c_x, c_y, sigma * sigma, 0, 0, sigma * sigma); + std::vector z(ncols); - //-------------------------------------------------------------------------------- - /** - * Remove several elements from a vector, the elements to remove being specified - * by their index (in del). After this call, a's size is reduced. Requires - * default constructor on T to be defined (for resize). O(n). - */ - template - inline void - remove_at(const std::vector& del, std::vector& a) - { - NTA_ASSERT(std::set(del.begin(),del.end()).size() == del.size()); + // Renormalize because we've lost some mass + // with a compact domain of definition. + float s = 0; + for (size_t j = 0; j != ncols; ++j) + s += z[j] = sg2d(j / rf_y, j % rf_y); - if (del.empty()) - return; + for (size_t j = 0; j != ncols; ++j) + z[j] /= s; + // z[j] = 1.0f / (float)(rf_x * rf_y); - size_t old = del[0] + 1, cur = del[0], d = 1; + // std::vector counts(ncols, 0); - while (old < a.size() && d < del.size()) { - if (old == (size_t) del[d]) { - ++d; ++old; - } else if ((size_t) del[d] < old) { - ++d; - } else { - a[cur++] = a[old++]; - } - } + // TODO: argsort z so that the bigger bins come first, and it's faster + // to draw samples in the area where the pdf is higher + for (size_t i = 0; i != nrows; ++i) { - while (old < a.size()) - a[cur++] = a[old++]; + std::set b; - a.resize(a.size() - del.size()); - } + while (b.size() < nnzpr) { + T2 s = z[0], p = T2(rng.getReal64()); + size_t k = 0; + while (s < p && k < ncols - 1) + s += z[++k]; + //++counts[k]; + b.insert(k); + } - //-------------------------------------------------------------------------------- - /** - * Finds index of elt in ref, and removes corresponding element of a. - */ - template - inline void remove(const T2& elt, std::vector& a, const std::vector& ref) - { - a.erase(a.begin() + find_index(elt, ref)); + size_t offset = i * nnzpr; + auto it = b.begin(); + for (size_t j = 0; j != nnzpr; ++j, ++it) + a[offset + j] = std::pair(*it, init_nz_val); } - //-------------------------------------------------------------------------------- - template - inline void remove(const std::vector& del, std::set& a) - { - for (size_t i = 0; i != del.size(); ++i) - a.erase(del[i]); - } + /* + for (size_t i = 0; i != counts.size(); ++i) + std::cout << counts[i] << " "; + std::cout << std::endl; + */ +} + +//-------------------------------------------------------------------------------- +template inline void random_shuffle(std::vector &x) { + std::random_shuffle(x.begin(), x.end()); +} + +//-------------------------------------------------------------------------------- +// generate +//-------------------------------------------------------------------------------- +/** + * Initializes a range by calling gen() repetitively. + * + * @param c the container to initialize + * @param gen the generator functor + */ +template +inline void generate(Container &c, Generator gen) { + typename Container::iterator i = c.begin(), e = c.end(); + + for (; i != e; ++i) + *i = gen(); +} + +//-------------------------------------------------------------------------------- +// concatenate +//-------------------------------------------------------------------------------- +/** + * Concatenates multiple sub-ranges of a source range into a single range. + * + * @param x_begin the beginning of the source range + * @param seg_begin the beginning of the ranges that describe the + * sub-ranges (start, size) + * @param seg_end one past the end of the ranges that describe the + * sub-ranges + * @param y_begin the beginning of the concatenated range + */ +template +inline void concatenate(InIt1 x_begin, InIt2 seg_begin, InIt2 seg_end, + OutIt y_begin) { + { NTA_ASSERT(seg_begin <= seg_end) << "concatenate: Invalid segment range"; } + + for (; seg_begin != seg_end; ++seg_begin) { + InIt1 begin = x_begin + seg_begin->first; + InIt1 end = begin + seg_begin->second; + std::copy(begin, end, y_begin); + y_begin += seg_begin->second; + } +} + +//-------------------------------------------------------------------------------- +// Clip, threshold, binarize +//-------------------------------------------------------------------------------- +/** + * Clip the values in a range to be between min (included) and max (included): + * any value less than min becomes min and any value greater than max becomes + * max. + * + * @param begin + * @param end + * @param _min the minimum value + * @param _max the maximum value + */ +template +inline void clip(It begin, It end, + const typename std::iterator_traits::value_type &_min, + const typename std::iterator_traits::value_type &_max) { + { NTA_ASSERT(begin <= end) << "clip: Invalid range"; } + + while (begin != end) { + typename std::iterator_traits::value_type val = *begin; + if (val > _max) + *begin = _max; + else if (val < _min) + *begin = _min; + ++begin; + } +} + +//-------------------------------------------------------------------------------- +/** + * Clip the values in a container to be between min (included) and max + * (included): any value less than min becomes min and any value greater than + * max becomes max. + * + * @param a the container + * @param _min + * @param _max + */ +template +inline void clip(T &a, const typename T::value_type &_min, + const typename T::value_type &_max) { + clip(a.begin(), a.end(), _min, _max); +} + +//-------------------------------------------------------------------------------- +/** + * Threshold a range and puts the values that were not eliminated into + * another (sparse) range (index, value). + * + * @param begin + * @param end + * @param ind the beginning of the sparse indices + * @param nz the beginning of the sparse values + * @param th the threshold to use + */ +template +inline size_t +threshold(InIter begin, InIter end, OutIter1 ind, OutIter2 nz, + const typename std::iterator_traits::value_type &th, + bool above = true) { + { NTA_ASSERT(begin <= end) << "threshold: Invalid range"; } - //-------------------------------------------------------------------------------- - template - inline void remove(const std::set& y, std::vector& x) - { - std::vector del; + typedef typename std::iterator_traits::value_type value_type; + typedef size_t size_type; - for (size_t i = 0; i != x.size(); ++i) - if (y.find(x[i]) != y.end()) { - NTA_ASSERT(not_in(x[i], del)); - del.push_back(x[i]); - } + size_type n = 0; - nupic::remove(del, x); - } + if (above) { - //-------------------------------------------------------------------------------- - template - inline std::set& operator-=(std::set& a, const std::vector& b) - { - remove(b, a); - return a; - } + for (InIter it = begin; it != end; ++it) { + value_type val = (value_type)*it; + if (val >= th) { + *ind = (size_type)(it - begin); + *nz = val; + ++ind; + ++nz; + ++n; + } + } - //-------------------------------------------------------------------------------- - template - inline std::vector& operator-=(std::vector& a, const std::vector& b) - { - remove(b, a); - return a; - } + } else { - //-------------------------------------------------------------------------------- - // DIFFERENCES - //-------------------------------------------------------------------------------- - /** - * Returns a vector that contains the indices of the positions where x and y - * have different values. - */ - template - inline void - find_all_differences(const std::vector& x, const std::vector& y, - std::vector& diffs) - { - NTA_ASSERT(x.size() == y.size()); - diffs.clear(); - for (size_t i = 0; i != x.size(); ++i) - if (x[i] != y[i]) - diffs.push_back(i); + for (InIter it = begin; it != end; ++it) { + value_type val = (value_type)*it; + if (val < th) { + *ind = (size_type)(it - begin); + *nz = val; + ++ind; + ++nz; + ++n; + } + } } - //-------------------------------------------------------------------------------- - // fill - //-------------------------------------------------------------------------------- - /** - * Fills a container with the given value. - * - * @param a - * @param val - */ - template - inline void fill(T& a, const typename T::value_type& val) - { - typename T::iterator i = a.begin(), e = a.end(); + return n; +} - for (; i != e; ++i) - *i = val; +//-------------------------------------------------------------------------------- +/** + * Given a threshold and a dense vector x, returns another dense vectors with + * 1's where the value of x is > threshold, and 0 elsewhere. Also returns the + * count of 1's. + */ +template +inline nupic::UInt32 +binarize_with_threshold(nupic::Real32 threshold, InputIterator x, + InputIterator x_end, OutputIterator y, + OutputIterator y_end) { + { NTA_ASSERT(x_end - x == y_end - y); } + + nupic::UInt32 count = 0; + + for (; x != x_end; ++x, ++y) + if (*x > threshold) { + *y = 1; + ++count; + } else + *y = 0; + + return count; +} + +//-------------------------------------------------------------------------------- +// INDICATORS +//-------------------------------------------------------------------------------- + +//-------------------------------------------------------------------------------- +/** + * Given a dense 2D array of 0 and 1, return a vector that has as many rows as x + * a 1 wherever x as a non-zero row, and a 0 elsewhere. I.e. the result is the + * indicator of non-zero rows. Gets fast by not scanning a row more than is + * necessary, i.e. stops as soon as a 1 is found on the row. + */ +template +inline void nonZeroRowsIndicator_01(nupic::UInt32 nrows, nupic::UInt32 ncols, + InputIterator x, InputIterator x_end, + OutputIterator y, OutputIterator y_end) { + { + NTA_ASSERT(0 < nrows); + NTA_ASSERT(0 < ncols); + NTA_ASSERT((nupic::UInt32)(x_end - x) == nrows * ncols); + NTA_ASSERT((nupic::UInt32)(y_end - y) == nrows); +#ifdef NTA_ASSERTION_ON + for (nupic::UInt32 i = 0; i != nrows * ncols; ++i) + NTA_ASSERT(x[i] == 0 || x[i] == 1); +#endif } - //-------------------------------------------------------------------------------- - /** - * Zeroes out a range. - * - * @param begin - * @param end - */ - template - inline void zero(It begin, It end) - { - { - NTA_ASSERT(begin <= end) - << "zero: Invalid input range"; - } + for (nupic::UInt32 r = 0; r != nrows; ++r, ++y) { - typedef typename std::iterator_traits::value_type T; + InputIterator it = x + r * ncols, it_end = it + ncols; + nupic::UInt32 found = 0; - for (; begin != end; ++begin) - *begin = T(0); - } + while (it != it_end && found == 0) + found = nupic::UInt32(*it++); - //-------------------------------------------------------------------------------- - /** - * Zeroes out a whole container. - * - * @param a the container - */ - template - inline void zero(T& a) - { - zero(a.begin(), a.end()); + *y = found; } +} - //-------------------------------------------------------------------------------- - template - inline void set_to_zero(T& a) +//-------------------------------------------------------------------------------- +/** + * Given a dense 2D array of 0 and 1, return the number of rows that have + * at least one non-zero. Gets fast by not scanning a row more than is + * necessary, i.e. stops as soon as a 1 is found on the row. + */ +template +inline nupic::UInt32 nNonZeroRows_01(nupic::UInt32 nrows, nupic::UInt32 ncols, + InputIterator x, InputIterator x_end) { { - zero(a); + NTA_ASSERT(0 < nrows); + NTA_ASSERT(0 < ncols); + NTA_ASSERT((nupic::UInt32)(x_end - x) == nrows * ncols); +#ifdef NTA_ASSERTION_ON + for (nupic::UInt32 i = 0; i != nrows * ncols; ++i) + NTA_ASSERT(x[i] == 0 || x[i] == 1); +#endif } - //-------------------------------------------------------------------------------- - template - inline void set_to_zero(std::vector& a, size_t begin, size_t end) - { - zero(a.begin() + begin, a.begin() + end); - } + nupic::UInt32 count = 0; - //-------------------------------------------------------------------------------- - /** - * Fills a range with ones. - * - * @param begin - * @param end - */ - template - inline void ones(It begin, It end) - { - { - NTA_ASSERT(begin <= end) - << "ones: Invalid input range"; - } + for (nupic::UInt32 r = 0; r != nrows; ++r) { - typedef typename std::iterator_traits::value_type T; + InputIterator it = x + r * ncols, it_end = it + ncols; + nupic::UInt32 found = 0; - for (; begin != end; ++begin) - *begin = T(1); - } + while (it != it_end && found == 0) + found = nupic::UInt32(*it++); - //-------------------------------------------------------------------------------- - /** - * Fills a container with ones. - * - * @param a the container - */ - template - inline void ones(T& a) - { - ones(a.begin(), a.end()); + count += found; } - //-------------------------------------------------------------------------------- - template - inline void set_to_one(std::vector& a) - { - ones(a); - } + return count; +} - //-------------------------------------------------------------------------------- - template - inline void set_to_one(std::vector& a, size_t begin, size_t end) - { - ones(a.begin() + begin, a.begin() + end); +//-------------------------------------------------------------------------------- +/** + * Given a dense 2D array of 0 and 1 x, return a vector that has as many cols as + * x a 1 wherever x as a non-zero col, and a 0 elsewhere. I.e. the result is the + * indicator of non-zero cols. Gets fast by not scanning a row more than is + * necessary, i.e. stops as soon as a 1 is found on the col. + */ +template +inline void nonZeroColsIndicator_01(nupic::UInt32 nrows, nupic::UInt32 ncols, + InputIterator x, InputIterator x_end, + OutputIterator y, OutputIterator y_end) { + { + NTA_ASSERT(0 < nrows); + NTA_ASSERT(0 < ncols); + NTA_ASSERT((nupic::UInt32)(x_end - x) == nrows * ncols); + NTA_ASSERT((nupic::UInt32)(y_end - y) == ncols); +#ifdef NTA_ASSERTION_ON + for (nupic::UInt32 i = 0; i != nrows * ncols; ++i) + NTA_ASSERT(x[i] == 0 || x[i] == 1); +#endif } - //-------------------------------------------------------------------------------- - /** - * Sets a range to 0, except for a single value at pos, which will be equal to val. - * - * @param pos the position of the single non-zero value - * @param begin - * @param end - * @param val the value of the non-zero value in the range - */ - template - inline void dirac(size_t pos, It begin, It end, - typename std::iterator_traits::value_type val =1) - { - { - NTA_ASSERT(begin <= end) - << "dirac: Invalid input range"; - - NTA_ASSERT(0 <= pos && pos < (size_t)(end - begin)) - << "dirac: Invalid position: " << pos - << " - Should be between 0 and: " << (size_t)(end - begin); - } + nupic::UInt32 N = nrows * ncols; - typedef typename std::iterator_traits::value_type value_type; + for (nupic::UInt32 c = 0; c != ncols; ++c, ++y) { - std::fill(begin, end, (value_type) 0); - *(begin + pos) = val; - } + InputIterator it = x + c, it_end = it + N; + nupic::UInt32 found = 0; - //-------------------------------------------------------------------------------- - /** - * Sets a range to 0, except for a single value at pos, which will be equal to val. - * - * @param pos the position of the single non-zero value - * @param c the container - * @param val the value of the Dirac - */ - template - inline void dirac(size_t pos, C& c, typename C::value_type val =1) - { - { - NTA_ASSERT(pos >= 0 && pos < c.size()) - << "dirac: Can't set Dirac at pos: " << pos - << " when container has size: " << c.size(); + while (it != it_end && found == 0) { + found = nupic::UInt32(*it); + it += ncols; } - dirac(pos, c.begin(), c.end(), val); + *y = found; } +} - //-------------------------------------------------------------------------------- - /** - * Computes the CDF of the given range seen as a discrete PMF. - * - * @param begin1 the beginning of the discrete PMF range - * @param end1 one past the end of the discrete PMF range - * @param begin2 the beginning of the CDF range - */ - template - inline void cumulative(It1 begin1, It1 end1, It2 begin2, It2 end2) +//-------------------------------------------------------------------------------- +/** + * Given a dense 2D array of 0 and 1, return the number of columns that have + * at least one non-zero. + * Gets fast by not scanning a col more than is necessary, i.e. stops as soon as + * a 1 is found on the col. + */ +template +inline nupic::UInt32 nNonZeroCols_01(nupic::UInt32 nrows, nupic::UInt32 ncols, + InputIterator x, InputIterator x_end) { { - { - NTA_ASSERT(begin1 <= end1) - << "cumulative: Invalid input range"; - NTA_ASSERT(begin2 <= end2) - << "cumulative: Invalid output range"; - NTA_ASSERT(end1 - begin1 == end2 - begin2) - << "cumulative: Incompatible sizes"; - } + NTA_ASSERT(0 < nrows); + NTA_ASSERT(0 < ncols); + NTA_ASSERT((nupic::UInt32)(x_end - x) == nrows * ncols); +#ifdef NTA_ASSERTION_ON + for (nupic::UInt32 i = 0; i != nrows * ncols; ++i) + NTA_ASSERT(x[i] == 0 || x[i] == 1); +#endif + } - typedef typename std::iterator_traits::value_type value_type; + nupic::UInt32 count = 0; + nupic::UInt32 N = nrows * ncols; - It2 prev = begin2; - *begin2++ = (value_type) *begin1++; - for (; begin1 < end1; ++begin1, ++begin2, ++prev) - *begin2 = *prev + (value_type) *begin1; - } + for (nupic::UInt32 c = 0; c != ncols; ++c) { - //-------------------------------------------------------------------------------- - /** - * Computes the CDF of a discrete PMF. - * - * @param pmf the PMF - * @param cdf the CDF - */ - template - inline void cumulative(const C1& pmf, C2& cdf) - { - cumulative(pmf.begin(), pmf.end(), cdf.begin(), cdf.end()); - } + InputIterator it = x + c, it_end = it + N; + nupic::UInt32 found = 0; - //-------------------------------------------------------------------------------- - /** - * Finds percentiles. - */ - template - inline void percentiles(size_t n_percentiles, - It1 begin1, It1 end1, - It2 begin2, It2 end2, - bool alreadyNormalized =false) - { - { - NTA_ASSERT(begin1 <= end1) - << "percentiles: Invalid input range"; - NTA_ASSERT(begin2 <= end2) - << "percentiles: Invalid output range"; - NTA_ASSERT(end1 - begin1 == end2 - begin2) - << "percentiles: Mismatched ranges"; + while (it != it_end && found == 0) { + found = nupic::UInt32(*it); + it += ncols; } - typedef typename std::iterator_traits::value_type value_type; - typedef typename std::iterator_traits::value_type size_type; + count += found; + } + + return count; +} - value_type n = (value_type) (alreadyNormalized ? 1.0f : 0.0f); +//-------------------------------------------------------------------------------- +// MASK +//-------------------------------------------------------------------------------- +/** + * Mask an array. + */ +template +inline void mask(InIter begin, InIter end, InIter zone_begin, InIter zone_end, + const typename std::iterator_traits::value_type &v = 0, + bool maskOutside = true) { + { // Pre-conditions + NTA_ASSERT(begin <= end) << "mask 1: Invalid range for vector"; + NTA_ASSERT(zone_begin <= zone_end) << "mask 1: Invalid range for mask"; + NTA_ASSERT(begin <= zone_begin && zone_end <= end) + << "mask 1: Mask incompatible with vector"; + } // End pre-conditions + + if (maskOutside) { + std::fill(begin, zone_begin, v); + std::fill(zone_end, end, v); + } else { + std::fill(zone_begin, zone_end, v); + } +} + +//-------------------------------------------------------------------------------- +template +inline void mask(std::vector &x, + typename std::vector::size_type zone_begin, + typename std::vector::size_type zone_end, + const value_type &v = 0, bool maskOutside = true) { + { // Pre-conditions + NTA_ASSERT(0 <= zone_begin && zone_begin <= zone_end && + zone_end <= x.size()) + << "mask 2: Mask incompatible with vector"; + } // End pre-conditions + + mask(x.begin(), x.end(), x.begin() + zone_begin, x.begin() + zone_end, v, + maskOutside); +} + +//-------------------------------------------------------------------------------- +template +inline void mask(std::vector &x, + const std::vector &mask, + bool multiplyYesNo = false, + value_type2 eps = (value_type2)nupic::Epsilon) { + { // Pre-conditions + NTA_ASSERT(x.size() == mask.size()) + << "mask 3: Need mask and vector to have same size"; + } // End pre-conditions - if (!alreadyNormalized) - for (It1 it = begin1; it != end1; ++it) - n += *it; + typedef typename std::vector::size_type size_type; - value_type increment = n/value_type(n_percentiles); - value_type sum = (value_type) 0.0f; - size_type p = (size_type) 0; + if (multiplyYesNo) { + for (size_type i = 0; i != x.size(); ++i) + if (!nearlyZero(mask[i]), eps) + x[i] *= (value_type1)mask[i]; + else + x[i] = (value_type1)0; - for (value_type v = increment; v < n; v += increment) { - for (; sum < v; ++p) - sum += *begin1++; - *begin2++ = p; - } + } else { + for (size_type i = 0; i != x.size(); ++i) + if (nearlyZero(mask[i], eps)) + x[i] = (value_type1)0; } +} - //-------------------------------------------------------------------------------- - template - inline void percentiles(size_t n_percentiles, const C1& pmf, C2& pcts) - { - percentiles(n_percentiles, pmf.begin(), pmf.end(), pcts.begin()); +//-------------------------------------------------------------------------------- +// NORMS +//-------------------------------------------------------------------------------- +/** + * A class that provides init and operator(), to be used in distance + * computations when using the Hamming (L0) norm. + */ +template struct Lp0 { + typedef T value_type; + + inline value_type operator()(value_type &a, value_type b) const { + value_type inc = value_type(b < -nupic::Epsilon || b > nupic::Epsilon); + a += inc; + return inc; } - //-------------------------------------------------------------------------------- - template - inline void rand_range(It begin, It end, - const typename std::iterator_traits::value_type& min_, - const typename std::iterator_traits::value_type& max_, - RNG& rng) - { - { - NTA_ASSERT(begin <= end) - << "rand_range: Invalid input range"; - NTA_ASSERT(min_ < max_) - << "rand_range: Invalid min/max: " << min_ << " " << max_; - } + inline value_type root(value_type x) const { return x; } +}; - typedef typename std::iterator_traits::value_type value_type; +//-------------------------------------------------------------------------------- +/** + * A class that provides init and operator(), to be used in distance + * computations when using the Manhattan (L1) norm. + */ +template struct Lp1 { + typedef T value_type; - double range = double(max_ - min_) / double(rng.max() - rng.min()); - for (; begin != end; ++begin) - *begin = value_type(double(rng()) * range + min_); + inline value_type operator()(value_type &a, value_type b) const { + value_type inc = fabs(b); // b > 0.0 ? b : -b; + a += inc; + return inc; } - //-------------------------------------------------------------------------------- - /** - * Initializes a range with random values. - * - * @param begin - * @param end - * @param min_ - * @param max_ - */ - template - inline void rand_range(It begin, It end, - const typename std::iterator_traits::value_type& min_, - const typename std::iterator_traits::value_type& max_) - { - nupic::Random rng; - rand_range(begin, end, min_, max_, rng); - } + inline value_type root(value_type x) const { return x; } +}; - //-------------------------------------------------------------------------------- - template - inline void rand_range(T& a, - const typename T::value_type& min, - const typename T::value_type& max, - RNG& rng) - { - rand_range(a.begin(), a.end(), min, max, rng); - } +//-------------------------------------------------------------------------------- +/** + * A class that provides square and square root methods, to be + * used in distance computations when using L2 norm. + */ +template struct Lp2 { + typedef T value_type; - //-------------------------------------------------------------------------------- - /** - * Initializes a container with random values. - * - * @param a the container - * @param min - * @param max - */ - template - inline void rand_range(T& a, - const typename T::value_type& min, - const typename T::value_type& max) - { - rand_range(a.begin(), a.end(), min, max); - } + nupic::Sqrt s; - //-------------------------------------------------------------------------------- - template - inline void rand_float_range(std::vector& x, size_t start, size_t end, RNG& rng) - { - for (size_t i = start; i != end; ++i) - x[i] = (float) rng.getReal64(); + inline value_type operator()(value_type &a, value_type b) const { + value_type inc = b * b; + a += inc; + return inc; } - //-------------------------------------------------------------------------------- - /** - * Initializes a range with the normal distribution. - * - * @param begin - * @param end - * @param mean - * @param stddev - */ - template - inline void normal_range(It begin, It end, - const typename std::iterator_traits::value_type& mean, - const typename std::iterator_traits::value_type& stddev) - { - { - NTA_ASSERT(begin <= end) - << "normal_range: Invalid input range"; - } + inline value_type root(value_type x) const { return s(x); } +}; - //TODO implement numerical recipes' method - } +//-------------------------------------------------------------------------------- +/** + * A class that provides power p and root p methods, to be + * used in distance computations using Lp norm. + */ +template struct Lp { + typedef T value_type; - //-------------------------------------------------------------------------------- - template - inline void rand_range_01(It begin, It end, double pct, RNG& rng) - { - { - NTA_ASSERT(begin <= end) - << "rand_range_01: Invalid input range"; - NTA_ASSERT(0 <= pct && pct < 1) - << "rand_range_01: Invalid threshold: " << pct - << " - Should be between 0 and 1"; - } + nupic::Pow pf; - typedef typename std::iterator_traits::value_type value_type; + Lp(value_type p_) : p(p_), inv_p((value_type)1.0) { + // We allow only positive values of p for now, as this + // keeps the root function monotonically increasing, which + // results in further speed-ups. + NTA_ASSERT(p_ > (value_type)0.0) + << "NearestNeighbor::PP(): " + << "Invalid value for p: " << p_ << " - p needs to be > 0"; - for (; begin != end; ++begin) - *begin = (value_type)(double(rng()) / double(rng.max() - rng.min()) > pct); + inv_p = (value_type)1.0 / p; } - //-------------------------------------------------------------------------------- - /** - * Initializes a range with a random binary vector. - * - * @param begin - * @param end - * @param pct the percentage of ones - */ - template - inline void rand_range_01(It begin, It end, double pct =.5) - { - nupic::Random rng; - rand_range_01(begin, end, pct, rng); - } + value_type p, inv_p; - //-------------------------------------------------------------------------------- - template - inline void rand_range_01(T& a, double pct, RNG& rng) - { - rand_range_01(a.begin(), a.end(), pct, rng); + inline value_type operator()(value_type &a, value_type b) const { + value_type inc = pf(b > 0.0 ? b : -b, p); + a += inc; + return inc; } - //-------------------------------------------------------------------------------- - /** - * Initializes a container with a random binary vector. - * - * @param a the container - * @param pct the percentage of ones - */ - template - inline void rand_range_01(T& a, double pct =.5) - { - nupic::Random rng; - rand_range_01(a.begin(), a.end(), pct, rng); + inline value_type root(value_type x) const { + // skipping abs, because we know we've been adding positive + // numbers when root is called when computing a norm + return pf(x, inv_p); } +}; - //-------------------------------------------------------------------------------- - /** - * Initializes a range with a ramp function. - * - * @param begin - * @param end - * @param start the start value of the ramp - * @param step the step of the ramp - */ - template - inline void ramp_range(It begin, It end, T start =0, T step =1) - { - { - NTA_ASSERT(begin <= end) - << "ramp_range: Invalid input range"; - } +//-------------------------------------------------------------------------------- +/** + * A class that provides power p and root p methods, to be + * used in distance computations using LpMax norm. + */ +template struct LpMax { + typedef T value_type; - for (; begin != end; ++begin, start += step) - *begin = start; - } + nupic::Max m; - //-------------------------------------------------------------------------------- - /** - * Initializes a range with a ramp function. - * - * @param a the container - * @param start the start value of the ramp - * @param step the step of the ramp - */ - template - inline void ramp_range(T& a, - typename T::value_type start =0, - typename T::value_type step =1) - { - ramp_range(a.begin(), a.end(), start); + inline value_type operator()(value_type &a, value_type b) const { + value_type inc = m(a, b > 0 ? b : -b); + a = inc; + return inc; } - //-------------------------------------------------------------------------------- - /** - * Fills a range with values taken randomly from another range. - * - * @param begin - * @param end - * @param enum_begin the values to draw from - * @param enum_end the values to draw from - * @param replace whether to draw with or without replacements - */ - template - inline void rand_enum_range(It1 begin, It1 end, It2 enum_begin, It2 enum_end, - bool replace, RNG& rng) - { - { - NTA_ASSERT(begin <= end) - << "rand_enum_range: Invalid input range"; + inline value_type root(value_type x) const { return x; } +}; - NTA_ASSERT(enum_begin <= enum_end) - << "rand_enum_range: Invalid values range"; - } +//-------------------------------------------------------------------------------- +/** + * Hamming norm. + */ +template +inline typename std::iterator_traits::value_type l0_norm(It begin, It end, + bool = true) { + { NTA_ASSERT(begin <= end) << "l0_norm: Invalid range"; } - typedef typename std::iterator_traits::value_type value_type; + typedef typename std::iterator_traits::value_type value_type; - size_t n = (size_t)(enum_end - enum_begin); + value_type n = (value_type)0; + Lp0 lp0; - if (replace) { + for (; begin != end; ++begin) + lp0(n, *begin); - for (; begin != end; ++begin) - *begin = (value_type) *(enum_begin + rng() % n); + return n; +} - } else { +//-------------------------------------------------------------------------------- +/** + * Hamming norm on a container + */ +template +inline typename T::value_type l0_norm(const T &a, bool = true) { + return l0_norm(a.begin(), a.end()); +} + +//-------------------------------------------------------------------------------- +/** + * Manhattan norm. + */ +template +inline typename std::iterator_traits::value_type l1_norm(It begin, It end, + bool = true) { + { NTA_ASSERT(begin <= end) << "l1_norm: Invalid range"; } - std::vector ind(n); - ramp_range(ind); + typedef typename std::iterator_traits::value_type value_type; - for (; begin != end; ++begin) { - size_t p = rng() % ind.size(); - *begin = (value_type) *(enum_begin + p); - remove(p, ind); - } - } - } + value_type n = (value_type)0; + Lp1 lp1; - //-------------------------------------------------------------------------------- - /** - * Fills a range with values taken randomly from another range. - * - * @param begin - * @param end - * @param enum_begin the values to draw from - * @param enum_end the values to draw from - * @param replace whether to draw with or without replacements - */ - template - inline void rand_enum_range(It1 begin, It1 end, It2 enum_begin, It2 enum_end, - bool replace =false) - { - { - NTA_ASSERT(begin <= end) - << "rand_enum_range: Invalid input range"; - NTA_ASSERT(enum_begin <= enum_end) - << "rand_enum_range: Invalid enum range"; - } + for (; begin != end; ++begin) + lp1(n, *begin); - nupic::Random rng; - rand_enum_range(begin, end, enum_begin, enum_end, replace, rng); - } + return n; +} - //-------------------------------------------------------------------------------- - /** - * Fills a container with values taken randomly from another container. - * - * @param a the container to fill - * @param b the container of the values to use - * @param replace whether the draw with replacements or not - */ - template - inline void rand_enum_range(C1& a, const C2& b, bool replace =false) - { - rand_enum_range(a.begin(), a.end(), b.begin(), b.end(), replace); - } - - //-------------------------------------------------------------------------------- - template - inline void rand_enum_range(C1& a, const C2& b, bool replace, RNG& rng) - { - rand_enum_range(a.begin(), a.end(), b.begin(), b.end(), replace, rng); - } - - //-------------------------------------------------------------------------------- - /** - * Fills a range with a random permutation of a ramp. - * - * @param begin - * @param end - * @param start the start value of the ramp - * @param step the step value of the ramp - */ - template - inline void - random_perm_interval(It begin, It end, - typename std::iterator_traits::value_type start, - typename std::iterator_traits::value_type step, - RNG& rng) - { - { - NTA_ASSERT(begin <= end) - << "random_perm_interval 1: Invalid input range"; - } - - ramp_range(begin, end, start, step); - std::random_shuffle(begin, end, rng); - } - - //-------------------------------------------------------------------------------- - /** - * Fills a range with a random permutation of a ramp. - * - * @param begin - * @param end - * @param start the start value of the ramp - * @param step the step value of the ramp - */ - template - inline void - random_perm_interval(It begin, It end, - typename std::iterator_traits::value_type start =0, - typename std::iterator_traits::value_type step =1) - { - { - NTA_ASSERT(begin <= end) - << "random_perm_interval 2: Invalid input range"; - } - - nupic::Random rng; - random_perm_interval(begin, end, start, step, rng); - } - - //-------------------------------------------------------------------------------- - /** - * Fills a container with a random permutation of a ramp. - * - * @param c the container to fill - * @param start the start value of the ramp - * @param step the step value of the ramp - */ - template - inline void random_perm_interval(C& c, - typename C::value_type start =0, - typename C::value_type step =1) - { - random_perm_interval(c.begin(), c.end(), start, step); - } - - //-------------------------------------------------------------------------------- - template - inline void random_perm_interval(C& c, - typename C::value_type start, - typename C::value_type step, - RNG& rng) - { - random_perm_interval(c.begin(), c.end(), start, step, rng); - } - - //-------------------------------------------------------------------------------- - /** - * Draws a random sample from a range, without replacements. - * - * Currently assumes that the first range is larger than the second. - * - * @param begin1 - * @param end1 - * @param begin2 - * @param end2 - */ - template - inline void random_sample(It1 begin1, It1 end1, It2 begin2, It2 end2, RNG& rng) - { - { - NTA_ASSERT(begin1 <= end1) - << "random_sample 1: Invalid value range"; - NTA_ASSERT(begin2 <= end2) - << "random_sample 1: Invalid output range"; - } - - size_t n1 = (size_t) (end1 - begin1); - std::vector perm(n1); - random_perm_interval(perm, 0, n1, rng); - for (size_t p = 0; begin2 != end2; ++begin2, ++p) - *begin2 = *(begin1 + p); - } - - //-------------------------------------------------------------------------------- - /** - * Draws a random sample from a range, without replacements. - * - * Currently assumes that the first range is larger than the second. - * - * @param begin1 - * @param end1 - * @param begin2 - * @param end2 - */ - template - inline void random_sample(It1 begin1, It1 end1, It2 begin2, It2 end2) - { - { - NTA_ASSERT(begin1 <= end1) - << "random_sample 2: Invalid value range"; - NTA_ASSERT(begin2 <= end2) - << "random_sample 2: Invalid output range"; - } - - nupic::Random rng; - random_sample(begin1, end1, begin2, end2, rng); - } - - //-------------------------------------------------------------------------------- - /** - * Draws a random sample from a container, and uses the values to initialize - * another container. - * - * @param c1 the container to initialize - * @param c2 the container from which to take the values - */ - template - inline void random_sample(const std::vector& c1, std::vector& c2) - { - random_sample(c1.begin(), c1.end(), c2.begin(), c2.end()); - } - - //-------------------------------------------------------------------------------- - template - inline void random_sample(const std::vector& c1, std::vector& c2, RNG& rng) - { - random_sample(c1.begin(), c1.end(), c2.begin(), c2.end(), rng); - } - - //-------------------------------------------------------------------------------- - /** - * Initializes a container with elements taken from the specified ramp range, - * randomly. - * - * @param c the container to fill - * @param size the size of the ramp - * @param start the first value of the ramp - * @param step the step of the ramp - */ - template - inline void random_sample(std::vector& c, - size_t size, - size_t start, - size_t step, - RNG& rng) - { - ramp_range(c, start, step); - std::random_shuffle(c.begin(), c.end(), rng); - } - - //-------------------------------------------------------------------------------- - template - inline void random_sample(std::vector& c, - size_t size, - size_t start =0, - size_t step =1) - { - nupic::Random rng; - random_sample(c, size, start, step, rng); - } - - //-------------------------------------------------------------------------------- - template - inline void random_sample(size_t n, std::vector& a) - { - NTA_ASSERT(0 < a.size()); - - std::vector x(n); - for (size_t i = 0; i != n; ++i) - x[i] = i; - std::random_shuffle(x.begin(), x.end()); - std::copy(x.begin(), x.begin() + a.size(), a.begin()); - } - - //-------------------------------------------------------------------------------- - template - inline void random_sample(std::vector& a) - { - random_sample(a, a.size()); - } - - //-------------------------------------------------------------------------------- - template - inline void random_sample(const std::set& a, std::vector& b) - { - NTA_ASSERT(0 < b.size()); - - std::vector aa(a.begin(), a.end()); - std::random_shuffle(aa.begin(), aa.end()); - std::copy(aa.begin(), aa.begin() + b.size(), b.begin()); - } - - //-------------------------------------------------------------------------------- - template - inline void random_binary(float proba, std::vector& x) - { - size_t threshold = (size_t) (proba * 65535); - std::fill(x.begin(), x.end(), 0); - for (size_t i = 0; i != x.size(); ++i) - x[i] = ((size_t) rand() % 65535 < threshold) ? 1 : 0; - } - - //-------------------------------------------------------------------------------- - /** - * Generates a matrix of random (index,value) pairs from [0..ncols], with - * nnzpc numbers per row, and n columns [generates a constant sparse matrix - * with constant number of non-zeros per row]. This uses a uniform distribution - * of the non-zero bits. - */ - template - inline void - random_pair_sample(size_t nrows, size_t ncols, size_t nnzpr, - std::vector >& a, - const T2& init_nz_val, - int seed =-1, - bool sorted =true) - { - { - NTA_ASSERT(0 < a.size()); - NTA_ASSERT(nnzpr <= ncols); - } - - a.resize(nrows * nnzpr); - -#if defined(NTA_ARCH_32) && defined(NTA_OS_DARWIN) - nupic::Random rng(seed == -1 ? arc4random() : seed); -#else - nupic::Random rng(seed == -1 ? rand() : seed); -#endif - - std::vector x(ncols); - for (size_t i = 0; i != ncols; ++i) - x[i] = i; - for (size_t i = 0; i != nrows; ++i) { - std::random_shuffle(x.begin(), x.end(), rng); - if (sorted) - std::sort(x.begin(), x.begin() + nnzpr); - size_t offset = i*nnzpr; - for (size_t j = 0; j != nnzpr; ++j) - a[offset + j] = std::pair(x[j], init_nz_val); - } - } - - //-------------------------------------------------------------------------------- - /** - * Generates a matrix of random (index,value) pairs from [0..ncols], with - * nnzpr numbers per row, and n columns [generates a constant sparse matrix - * with constant number of non-zeros per row]. This uses a 2D Gaussian distribution - * for the on bits of each coincidence. That is, each coincidence is seen as a - * folded 2D array, and a 2D Gaussian is used to distribute the on bits of each - * coincidence. - * - * Each row is seen as an image of size (ncols / rf_x) by rf_x. - * 'sigma' is the parameter of the Gaussian, which is centered at the center of - * the image. Uses a symmetric Gaussian, specified only by the location of its - * max and a single sigma parameter (no Sigma matrix). We use the symmetry of the - * 2d gaussian to simplify computations. The conditional distribution obtained - * from the 2d gaussian by fixing y is again a gaussian, with parameters than can - * be easily deduced from the original 2d gaussian. - */ - template - inline void - gaussian_2d_pair_sample(size_t nrows, size_t ncols, size_t nnzpr, size_t rf_x, - T2 sigma, - std::vector >& a, - const T2& init_nz_val, - int seed =-1, - bool sorted =true) - { - { - NTA_ASSERT(ncols % rf_x == 0); - NTA_ASSERT(nnzpr <= ncols); - NTA_ASSERT(0 < sigma); - } - - a.resize(nrows * nnzpr); - -#if defined(NTA_ARCH_32) && defined(NTA_OS_DARWIN) - nupic::Random rng(seed == -1 ? arc4random() : seed); -#else - nupic::Random rng(seed == -1 ? rand() : seed); -#endif - - size_t rf_y = ncols / rf_x; - T2 c_x = float(rf_x - 1.0) / 2.0, c_y = float(rf_y - 1.0) / 2.0; - Gaussian2D sg2d(c_x, c_y, sigma*sigma, 0, 0, sigma*sigma); - std::vector z(ncols); - - // Renormalize because we've lost some mass - // with a compact domain of definition. - float s = 0; - for (size_t j = 0; j != ncols; ++j) - s += z[j] = sg2d(j / rf_y, j % rf_y); - - for (size_t j = 0; j != ncols; ++j) - z[j] /= s; - //z[j] = 1.0f / (float)(rf_x * rf_y); - - //std::vector counts(ncols, 0); - - // TODO: argsort z so that the bigger bins come first, and it's faster - // to draw samples in the area where the pdf is higher - for (size_t i = 0; i != nrows; ++i) { - - std::set b; - - while (b.size() < nnzpr) { - T2 s = z[0], p = T2(rng.getReal64()); - size_t k = 0; - while (s < p && k < ncols-1) - s += z[++k]; - //++counts[k]; - b.insert(k); - } - - size_t offset = i*nnzpr; - auto it = b.begin(); - for (size_t j = 0; j != nnzpr; ++j, ++it) - a[offset + j] = std::pair(*it, init_nz_val); - } - - /* - for (size_t i = 0; i != counts.size(); ++i) - std::cout << counts[i] << " "; - std::cout << std::endl; - */ - } - - //-------------------------------------------------------------------------------- - template - inline void random_shuffle(std::vector& x) - { - std::random_shuffle(x.begin(), x.end()); - } - - //-------------------------------------------------------------------------------- - // generate - //-------------------------------------------------------------------------------- - /** - * Initializes a range by calling gen() repetitively. - * - * @param c the container to initialize - * @param gen the generator functor - */ - template - inline void generate(Container& c, Generator gen) - { - typename Container::iterator i = c.begin(), e = c.end(); - - for (; i != e; ++i) - *i = gen(); - } - - //-------------------------------------------------------------------------------- - // concatenate - //-------------------------------------------------------------------------------- - /** - * Concatenates multiple sub-ranges of a source range into a single range. - * - * @param x_begin the beginning of the source range - * @param seg_begin the beginning of the ranges that describe the - * sub-ranges (start, size) - * @param seg_end one past the end of the ranges that describe the - * sub-ranges - * @param y_begin the beginning of the concatenated range - */ - template - inline void - concatenate(InIt1 x_begin, InIt2 seg_begin, InIt2 seg_end, OutIt y_begin) - { - { - NTA_ASSERT(seg_begin <= seg_end) - << "concatenate: Invalid segment range"; - } - - for (; seg_begin != seg_end; ++seg_begin) { - InIt1 begin = x_begin + seg_begin->first; - InIt1 end = begin + seg_begin->second; - std::copy(begin, end, y_begin); - y_begin += seg_begin->second; - } - } - - //-------------------------------------------------------------------------------- - // Clip, threshold, binarize - //-------------------------------------------------------------------------------- - /** - * Clip the values in a range to be between min (included) and max (included): - * any value less than min becomes min and any value greater than max becomes - * max. - * - * @param begin - * @param end - * @param _min the minimum value - * @param _max the maximum value - */ - template - inline void clip(It begin, It end, - const typename std::iterator_traits::value_type& _min, - const typename std::iterator_traits::value_type& _max) - { - { - NTA_ASSERT(begin <= end) - << "clip: Invalid range"; - } - - while (begin != end) { - typename std::iterator_traits::value_type val = *begin; - if (val > _max) - *begin = _max; - else if (val < _min) - *begin = _min; - ++begin; - } - } - - //-------------------------------------------------------------------------------- - /** - * Clip the values in a container to be between min (included) and max (included): - * any value less than min becomes min and any value greater than max becomes - * max. - * - * @param a the container - * @param _min - * @param _max - */ - template - inline void clip(T& a, - const typename T::value_type& _min, - const typename T::value_type& _max) - { - clip(a.begin(), a.end(), _min, _max); - } - - //-------------------------------------------------------------------------------- - /** - * Threshold a range and puts the values that were not eliminated into - * another (sparse) range (index, value). - * - * @param begin - * @param end - * @param ind the beginning of the sparse indices - * @param nz the beginning of the sparse values - * @param th the threshold to use - */ - template - inline size_t threshold(InIter begin, InIter end, - OutIter1 ind, OutIter2 nz, - const typename std::iterator_traits::value_type& th, - bool above =true) - { - { - NTA_ASSERT(begin <= end) - << "threshold: Invalid range"; - } - - typedef typename std::iterator_traits::value_type value_type; - typedef size_t size_type; - - size_type n = 0; - - if (above) { - - for (InIter it = begin; it != end; ++it) { - value_type val = (value_type) *it; - if (val >= th) { - *ind = (size_type) (it - begin); - *nz = val; - ++ind; ++nz; ++n; - } - } - - } else { - - for (InIter it = begin; it != end; ++it) { - value_type val = (value_type) *it; - if (val < th) { - *ind = (size_type) (it - begin); - *nz = val; - ++ind; ++nz; ++n; - } - } - } - - return n; - } - - //-------------------------------------------------------------------------------- - /** - * Given a threshold and a dense vector x, returns another dense vectors with 1's - * where the value of x is > threshold, and 0 elsewhere. Also returns the count - * of 1's. - */ - template - inline nupic::UInt32 - binarize_with_threshold(nupic::Real32 threshold, - InputIterator x, InputIterator x_end, - OutputIterator y, OutputIterator y_end) - { - { - NTA_ASSERT(x_end - x == y_end - y); - } - - nupic::UInt32 count = 0; - - for (; x != x_end; ++x, ++y) - if (*x > threshold) { - *y = 1; - ++count; - } else - *y = 0; - - return count; - } - - //-------------------------------------------------------------------------------- - // INDICATORS - //-------------------------------------------------------------------------------- - - //-------------------------------------------------------------------------------- - /** - * Given a dense 2D array of 0 and 1, return a vector that has as many rows as x - * a 1 wherever x as a non-zero row, and a 0 elsewhere. I.e. the result is the - * indicator of non-zero rows. Gets fast by not scanning a row more than is - * necessary, i.e. stops as soon as a 1 is found on the row. - */ - template - inline void - nonZeroRowsIndicator_01(nupic::UInt32 nrows, nupic::UInt32 ncols, - InputIterator x, InputIterator x_end, - OutputIterator y, OutputIterator y_end) - { - { - NTA_ASSERT(0 < nrows); - NTA_ASSERT(0 < ncols); - NTA_ASSERT((nupic::UInt32)(x_end - x) == nrows * ncols); - NTA_ASSERT((nupic::UInt32)(y_end - y) == nrows); -#ifdef NTA_ASSERTION_ON - for (nupic::UInt32 i = 0; i != nrows * ncols; ++i) - NTA_ASSERT(x[i] == 0 || x[i] == 1); -#endif - } - - for (nupic::UInt32 r = 0; r != nrows; ++r, ++y) { - - InputIterator it = x + r * ncols, it_end = it + ncols; - nupic::UInt32 found = 0; - - while (it != it_end && found == 0) - found = nupic::UInt32(*it++); - - *y = found; - } - } - - //-------------------------------------------------------------------------------- - /** - * Given a dense 2D array of 0 and 1, return the number of rows that have - * at least one non-zero. Gets fast by not scanning a row more than is - * necessary, i.e. stops as soon as a 1 is found on the row. - */ - template - inline nupic::UInt32 - nNonZeroRows_01(nupic::UInt32 nrows, nupic::UInt32 ncols, - InputIterator x, InputIterator x_end) - { - { - NTA_ASSERT(0 < nrows); - NTA_ASSERT(0 < ncols); - NTA_ASSERT((nupic::UInt32)(x_end - x) == nrows * ncols); -#ifdef NTA_ASSERTION_ON - for (nupic::UInt32 i = 0; i != nrows * ncols; ++i) - NTA_ASSERT(x[i] == 0 || x[i] == 1); -#endif - } - - nupic::UInt32 count = 0; - - for (nupic::UInt32 r = 0; r != nrows; ++r) { - - InputIterator it = x + r * ncols, it_end = it + ncols; - nupic::UInt32 found = 0; - - while (it != it_end && found == 0) - found = nupic::UInt32(*it++); - - count += found; - } - - return count; - } - - //-------------------------------------------------------------------------------- - /** - * Given a dense 2D array of 0 and 1 x, return a vector that has as many cols as x - * a 1 wherever x as a non-zero col, and a 0 elsewhere. I.e. the result is the - * indicator of non-zero cols. Gets fast by not scanning a row more than is - * necessary, i.e. stops as soon as a 1 is found on the col. - */ - template - inline void - nonZeroColsIndicator_01(nupic::UInt32 nrows, nupic::UInt32 ncols, - InputIterator x, InputIterator x_end, - OutputIterator y, OutputIterator y_end) - { - { - NTA_ASSERT(0 < nrows); - NTA_ASSERT(0 < ncols); - NTA_ASSERT((nupic::UInt32)(x_end - x) == nrows * ncols); - NTA_ASSERT((nupic::UInt32)(y_end - y) == ncols); -#ifdef NTA_ASSERTION_ON - for (nupic::UInt32 i = 0; i != nrows * ncols; ++i) - NTA_ASSERT(x[i] == 0 || x[i] == 1); -#endif - } - - nupic::UInt32 N = nrows*ncols; - - for (nupic::UInt32 c = 0; c != ncols; ++c, ++y) { - - InputIterator it = x + c, it_end = it + N; - nupic::UInt32 found = 0; - - while (it != it_end && found == 0) { - found = nupic::UInt32(*it); - it += ncols; - } - - *y = found; - } - } - - //-------------------------------------------------------------------------------- - /** - * Given a dense 2D array of 0 and 1, return the number of columns that have - * at least one non-zero. - * Gets fast by not scanning a col more than is necessary, i.e. stops as soon as - * a 1 is found on the col. - */ - template - inline nupic::UInt32 - nNonZeroCols_01(nupic::UInt32 nrows, nupic::UInt32 ncols, - InputIterator x, InputIterator x_end) - { - { - NTA_ASSERT(0 < nrows); - NTA_ASSERT(0 < ncols); - NTA_ASSERT((nupic::UInt32)(x_end - x) == nrows * ncols); -#ifdef NTA_ASSERTION_ON - for (nupic::UInt32 i = 0; i != nrows * ncols; ++i) - NTA_ASSERT(x[i] == 0 || x[i] == 1); -#endif - } - - nupic::UInt32 count = 0; - nupic::UInt32 N = nrows*ncols; - - for (nupic::UInt32 c = 0; c != ncols; ++c) { - - InputIterator it = x + c, it_end = it + N; - nupic::UInt32 found = 0; - - while (it != it_end && found == 0) { - found = nupic::UInt32(*it); - it += ncols; - } - - count += found; - } - - return count; - } - - //-------------------------------------------------------------------------------- - // MASK - //-------------------------------------------------------------------------------- - /** - * Mask an array. - */ - template - inline void mask(InIter begin, InIter end, InIter zone_begin, InIter zone_end, - const typename std::iterator_traits::value_type& v =0, - bool maskOutside =true) - { - { // Pre-conditions - NTA_ASSERT(begin <= end) - << "mask 1: Invalid range for vector"; - NTA_ASSERT(zone_begin <= zone_end) - << "mask 1: Invalid range for mask"; - NTA_ASSERT(begin <= zone_begin && zone_end <= end) - << "mask 1: Mask incompatible with vector"; - } // End pre-conditions - - if (maskOutside) { - std::fill(begin, zone_begin, v); - std::fill(zone_end, end, v); - } else { - std::fill(zone_begin, zone_end, v); - } - } - - //-------------------------------------------------------------------------------- - template - inline void mask(std::vector& x, - typename std::vector::size_type zone_begin, - typename std::vector::size_type zone_end, - const value_type& v =0, - bool maskOutside =true) - { - { // Pre-conditions - NTA_ASSERT(0 <= zone_begin && zone_begin <= zone_end && zone_end <= x.size()) - << "mask 2: Mask incompatible with vector"; - } // End pre-conditions - - mask(x.begin(), x.end(), x.begin() + zone_begin, x.begin() + zone_end, - v, maskOutside); - } - - //-------------------------------------------------------------------------------- - template - inline void mask(std::vector& x, const std::vector& mask, - bool multiplyYesNo =false, value_type2 eps =(value_type2)nupic::Epsilon) - { - { // Pre-conditions - NTA_ASSERT(x.size() == mask.size()) - << "mask 3: Need mask and vector to have same size"; - } // End pre-conditions - - typedef typename std::vector::size_type size_type; - - if (multiplyYesNo) { - for (size_type i = 0; i != x.size(); ++i) - if (!nearlyZero(mask[i]), eps) - x[i] *= (value_type1) mask[i]; - else - x[i] = (value_type1) 0; - - } else { - for (size_type i = 0; i != x.size(); ++i) - if (nearlyZero(mask[i], eps)) - x[i] = (value_type1) 0; - } - } - - //-------------------------------------------------------------------------------- - // NORMS - //-------------------------------------------------------------------------------- - /** - * A class that provides init and operator(), to be used in distance - * computations when using the Hamming (L0) norm. - */ - template - struct Lp0 - { - typedef T value_type; - - inline value_type operator()(value_type& a, value_type b) const - { - value_type inc = value_type(b < -nupic::Epsilon || b > nupic::Epsilon); - a += inc; - return inc; - } - - inline value_type root(value_type x) const { return x; } - }; - - //-------------------------------------------------------------------------------- - /** - * A class that provides init and operator(), to be used in distance - * computations when using the Manhattan (L1) norm. - */ - template - struct Lp1 - { - typedef T value_type; - - inline value_type operator()(value_type& a, value_type b) const - { - value_type inc = fabs(b); //b > 0.0 ? b : -b; - a += inc; - return inc; - } - - inline value_type root(value_type x) const { return x; } - }; - - //-------------------------------------------------------------------------------- - /** - * A class that provides square and square root methods, to be - * used in distance computations when using L2 norm. - */ - template - struct Lp2 - { - typedef T value_type; - - nupic::Sqrt s; - - inline value_type operator()(value_type& a, value_type b) const - { - value_type inc = b * b; - a += inc; - return inc; - } - - inline value_type root(value_type x) const - { - return s(x); - } - }; - - //-------------------------------------------------------------------------------- - /** - * A class that provides power p and root p methods, to be - * used in distance computations using Lp norm. - */ - template - struct Lp - { - typedef T value_type; - - nupic::Pow pf; - - Lp(value_type p_) - : p(p_), inv_p((value_type)1.0) - { - // We allow only positive values of p for now, as this - // keeps the root function monotonically increasing, which - // results in further speed-ups. - NTA_ASSERT(p_ > (value_type)0.0) - << "NearestNeighbor::PP(): " - << "Invalid value for p: " << p_ - << " - p needs to be > 0"; - - inv_p = (value_type)1.0/p; - } - - value_type p, inv_p; - - inline value_type operator()(value_type& a, value_type b) const - { - value_type inc = pf(b > 0.0 ? b : -b, p); - a += inc; - return inc; - } - - inline value_type root(value_type x) const - { - // skipping abs, because we know we've been adding positive - // numbers when root is called when computing a norm - return pf(x, inv_p); - } - }; - - //-------------------------------------------------------------------------------- - /** - * A class that provides power p and root p methods, to be - * used in distance computations using LpMax norm. - */ - template - struct LpMax - { - typedef T value_type; - - nupic::Max m; - - inline value_type operator()(value_type& a, value_type b) const - { - value_type inc = m(a, b > 0 ? b : -b); - a = inc; - return inc; - } - - inline value_type root(value_type x) const { return x; } - }; - - //-------------------------------------------------------------------------------- - /** - * Hamming norm. - */ - template - inline typename std::iterator_traits::value_type - l0_norm(It begin, It end, bool =true) - { - { - NTA_ASSERT(begin <= end) - << "l0_norm: Invalid range"; - } - - typedef typename std::iterator_traits::value_type value_type; - - value_type n = (value_type) 0; - Lp0 lp0; - - for (; begin != end; ++begin) - lp0(n, *begin); - - return n; - } - - //-------------------------------------------------------------------------------- - /** - * Hamming norm on a container - */ - template - inline typename T::value_type l0_norm(const T& a, bool =true) - { - return l0_norm(a.begin(), a.end()); - } - - //-------------------------------------------------------------------------------- - /** - * Manhattan norm. - */ - template - inline typename std::iterator_traits::value_type - l1_norm(It begin, It end, bool =true) - { - { - NTA_ASSERT(begin <= end) - << "l1_norm: Invalid range"; - } - - typedef typename std::iterator_traits::value_type value_type; - - value_type n = (value_type) 0; - Lp1 lp1; - - for (; begin != end; ++begin) - lp1(n, *begin); - - return n; - } - - //-------------------------------------------------------------------------------- - /** - * Manhattan norm on a container. - */ - template - inline typename T::value_type l1_norm(const T& a, bool =true) - { - return l1_norm(a.begin(), a.end()); - } - - //-------------------------------------------------------------------------------- - /** - * Euclidean norm. - */ - - //-------------------------------------------------------------------------------- - template - inline typename std::iterator_traits::value_type - l2_norm(It begin, It end, bool take_root =true) - { - { - NTA_ASSERT(begin <= end) - << "l2_norm: Invalid range"; - } - - typedef typename std::iterator_traits::value_type value_type; - value_type n = (value_type) 0; - - Lp2 lp2; - - for (; begin != end; ++begin) - lp2(n, *begin); - - if (take_root) - n = lp2.root(n); - - return n; - } - - //-------------------------------------------------------------------------------- - /** - * Euclidean norm on a container. - */ - template - inline typename T::value_type l2_norm(const T& a, bool take_root =true) - { - return l2_norm(a.begin(), a.end(), take_root); - } - - //-------------------------------------------------------------------------------- - /** - * p-norm. - */ - template - inline typename std::iterator_traits::value_type - lp_norm(typename std::iterator_traits::value_type p, - It begin, It end, bool take_root =true) - { - { - NTA_ASSERT(begin <= end) - << "lp_norm: Invalid range"; - } - - typedef typename std::iterator_traits::value_type value_type; - - value_type n = (value_type) 0; - Lp lp(p); - - for (; begin != end; ++begin) - lp(n, *begin); - - if (take_root) - n = lp.root(n); - - return n; - } - - //-------------------------------------------------------------------------------- - /** - * p-norm on a container. - */ - template - inline typename T::value_type - lp_norm(typename T::value_type p, const T& a, bool take_root =true) - { - return lp_norm(p, a.begin(), a.end(), take_root); - } - - //-------------------------------------------------------------------------------- - /** - * L inf / L max norm. - */ - template - inline typename std::iterator_traits::value_type - lmax_norm(It begin, It end, bool =true) - { - { - NTA_ASSERT(begin <= end) - << "lmax_norm: Invalid range"; - } - - typedef typename std::iterator_traits::value_type value_type; - - value_type n = (value_type) 0; - LpMax lmax; - - for (; begin != end; ++begin) - lmax(n, *begin); - - return n; - } - - //-------------------------------------------------------------------------------- - /** - * L inf / L max norm on a container. - */ - template - inline typename T::value_type lmax_norm(const T& a, bool =true) - { - return lmax_norm(a.begin(), a.end()); - } - - //-------------------------------------------------------------------------------- - /** - * Norm function. - * - * @param p the norm - * @param begin - * @param end - * @param take_root whether to take the p-th root or not - */ - template - inline typename std::iterator_traits::value_type - norm(typename std::iterator_traits::value_type p, - It begin, It end, bool take_root =true) - { - { - NTA_ASSERT(begin <= end) - << "norm: Invalid range"; - } - - typedef typename std::iterator_traits::value_type value_type; - - if (p == (value_type) 0) - return l0_norm(begin, end); - else if (p == (value_type) 1) - return l1_norm(begin, end); - else if (p == (value_type) 2) - return l2_norm(begin, end, take_root); - else if (p == std::numeric_limits::max()) - return lmax_norm(begin, end); - else - return lp_norm(p, begin, end, take_root); - } - - //-------------------------------------------------------------------------------- - /** - */ - template - inline void - multiply_val(It begin, It end, - const typename std::iterator_traits::value_type& val) - { - { - NTA_ASSERT(begin <= end) - << "multiply_val: Invalid range"; - } - - if (val == 1.0f) - return; - - for (; begin != end; ++begin) - *begin *= val; - } - - //-------------------------------------------------------------------------------- - /** - */ - template - inline void multiply_val(T& x, const typename T::value_type& val) - { - multiply_val(x.begin(), x.end(), val); - } - - //-------------------------------------------------------------------------------- - /** - * Norm on a whole container. - */ - template - inline typename T::value_type - norm(typename T::value_type p, const T& a, bool take_root =true) - { - return norm(p, a.begin(), a.end(), take_root); - } - - //-------------------------------------------------------------------------------- - /** - * Normalize a range, according to the p norm, so that the values sum up to n. - * - * @param begin - * @param end - * @param p the norm - * @param n the value of the sum of the elements after normalization - */ - template - inline void - normalize(It begin, It end, - const typename std::iterator_traits::value_type& p =1.0, - const typename std::iterator_traits::value_type& n =1.0) - { - { - NTA_ASSERT(begin <= end) - << "normalize: Invalid input range"; - } - - typedef typename std::iterator_traits::value_type value_type; - - value_type s = (value_type) 0; - - if (p == (value_type) 0) - s = l0_norm(begin, end); - else if (p == (value_type) 1) - s = l1_norm(begin, end); - else if (p == (value_type) 2) - s = l2_norm(begin, end); - else if (p == std::numeric_limits::max()) - s = lmax_norm(begin, end); - - if (s != (value_type) 0) - multiply_val(begin, end, n/s); - } - - //-------------------------------------------------------------------------------- - /** - * Normalize a container, with p-th norm and so that values add up to n. - */ - template - inline void normalize(T& a, - const typename T::value_type& p =1.0, - const typename T::value_type& n =1.0) - { - normalize(a.begin(), a.end(), p, n); - } - - //-------------------------------------------------------------------------------- - /** - * Normalization according to LpMax: finds the max of the range, - * and then divides all the values so that the max is n. - * Makes it nicer to call normalize when using LpMax. - */ - template - inline void - normalize_max(It begin, It end, - const typename std::iterator_traits::value_type& n = 1.0) - { - { - NTA_ASSERT(begin <= end) - << "normalize_max: Invalid range"; - } - - typedef typename std::iterator_traits::value_type value_type; - - normalize(begin, end, std::numeric_limits::max(), n); - } - - //-------------------------------------------------------------------------------- - /** - * Normalization according to LpMax. - */ - template - inline void normalize_max(std::vector& x, const value_type& n = 1.0) - { - normalize_max(x.begin(), x.end(), n); - } - - //-------------------------------------------------------------------------------- - /** - * Fills the container with a range of values. - */ - template - inline void generate_range(T& t, - typename T::value_type start, - typename T::value_type end, - typename T::value_type increment =1) - { - std::insert_iterator it(t, t.begin()); - - for (typename T::value_type i = start; i < end; i += increment, ++it) - *it = i; - } - - //-------------------------------------------------------------------------------- - /** - * Initializes a range with the uniform distribution. - * - * @param begin beginning of the range - * @param end one past the end of the range - * @param val the value to which the sum of the range will be equal to - */ - template - inline void - uniform_range(It begin, It end, - typename std::iterator_traits::value_type val =1) - { - { - NTA_ASSERT(begin <= end) - << "uniform_range: Invalid input range"; - } - - typedef typename std::iterator_traits::value_type value_type; - - std::fill(begin, end, (value_type) 1); - normalize(begin, end, val); - } - - //-------------------------------------------------------------------------------- - /** - * Initializes a container with the uniform distribution. - * - * @param a the container - * @param val the value for normalization - */ - template - inline void uniform_range(C& a, typename C::value_type val =1) - { - uniform_range(a.begin(), a.end(), val); - } - - //-------------------------------------------------------------------------------- - // DISTANCES - //-------------------------------------------------------------------------------- - /** - * Returns the max of the absolute values of the differences. - * - * @param begin1 - * @param end1 - * @param begin2 - */ - template - inline typename std::iterator_traits::value_type - max_abs_diff(It1 begin1, It1 end1, It2 begin2, It2 end2) - { - { - NTA_ASSERT(begin1 <= end1) - << "max_abs_diff: Invalid range 1"; - NTA_ASSERT(begin2 <= end2) - << "max_abs_diff: Invalid range 2"; - NTA_ASSERT(end1 - begin1 == end2 - begin2) - << "max_abs_diff: Ranges of different sizes"; - } - - typename std::iterator_traits::value_type d(0), val(0); - - while (begin1 != end1) { - val = *begin1 - *begin2; - val = val > 0 ? val : -val; - if (val > d) - d = val; - ++begin1; ++begin2; - } - - return d; - } - - //-------------------------------------------------------------------------------- - /** - * Returns the max of the absolute values of the differences. - * - * @param a first container - * @param b second container - */ - template - inline typename T1::value_type max_abs_diff(const T1& a, const T2& b) - { - return max_abs_diff(a.begin(), a.end(), b.begin(), b.end()); - } - - //-------------------------------------------------------------------------------- - /** - * Returns the Hamming distance of the two ranges. - * - * @param begin1 - * @param end1 - * @param begin2 - */ - template - inline typename std::iterator_traits::value_type - hamming_distance(It1 begin1, It1 end1, It2 begin2, It2 end2) - { - { - NTA_ASSERT(begin1 <= end1) - << "hamming_distance: Invalid range 1"; - NTA_ASSERT(begin2 <= end2) - << "hamming_distance: Invalid range 2"; - NTA_ASSERT(end1 - begin1 == end2 - begin2) - << "hamming_distance: Ranges of different sizes"; - } - - typename std::iterator_traits::value_type d(0); - - while (begin1 != end1) { - d += *begin1 != *begin2; - ++begin1; ++begin2; - } - - return d; - } - - //-------------------------------------------------------------------------------- - /** - * Returns the Hamming distance of the two containers. - * - * @param a first container - * @param b second container - */ - template - inline typename T1::value_type - hamming_distance(const T1& a, const T2& b) - { - return hamming_distance(a.begin(), a.end(), b.begin(), b.end()); - } - - //-------------------------------------------------------------------------------- - /** - * [begin1, end1) and [begin2, end2) are index encodings of binary 0/1 ranges. - */ - template - inline size_t - sparse_hamming_distance(It1 begin1, It1 end1, It2 begin2, It2 end2) - { - { - // todo: check that ranges are valid sparse indices ranges - // (increasing, no duplicate...) - NTA_ASSERT(begin1 <= end1) - << "sparse_hamming_distance: Invalid range 1"; - NTA_ASSERT(begin2 <= end2) - << "sparse_hamming_distance: Invalid range 2"; - } - - typedef size_t size_type; - - size_type d = 0; - - while (begin1 != end1 && begin2 != end2) { - if (*begin1 < *begin2) { - ++d; - ++begin1; - } else if (*begin2 < *begin1) { - ++d; - ++begin2; - } else { - ++begin1; - ++begin2; - } - } - - d += (size_type)(end1 - begin1); - d += (size_type)(end2 - begin2); - - return d; - } - - //-------------------------------------------------------------------------------- - template - inline typename T1::size_type - sparse_hamming_distance(const T1& a, const T2& b) - { - return sparse_hamming_distance(a.begin(), a.end(), b.begin(), b.end()); - } - - //-------------------------------------------------------------------------------- - /** - * Returns the Manhattan distance of the two ranges. - * - * @param begin1 - * @param end1 - * @param begin2 - */ - template - inline typename std::iterator_traits::value_type - manhattan_distance(It1 begin1, It1 end1, It2 begin2, It2 end2) - { - { - NTA_ASSERT(begin1 <= end1) - << "manhattan_distance: Invalid range 1"; - NTA_ASSERT(begin2 <= end2) - << "manhattan_distance: Invalid range 2"; - NTA_ASSERT(end1 - begin1 == end2 - begin2) - << "manhattan_distance: Ranges of different sizes"; - } - - typedef typename std::iterator_traits::value_type value_type; - - value_type d = (value_type) 0; - Lp1 lp1; - - for (; begin1 != end1; ++begin1, ++begin2) - lp1(d, *begin1 - *begin2); - - return d; - } - - //-------------------------------------------------------------------------------- - /** - * Returns the Manhattan distance of the two containers. - * - * @param a first container - * @param b second container - */ - template - inline typename T1::value_type - manhattan_distance(const T1& a, const T2& b) - { - return manhattan_distance(a.begin(), a.end(), b.begin(), b.end()); - } - - //-------------------------------------------------------------------------------- - /** - * Returns the Euclidean distance of the two ranges. - * - * @param begin1 - * @param end1 - * @param begin2 - */ - template - inline typename std::iterator_traits::value_type - euclidean_distance(It1 begin1, It1 end1, It2 begin2, It2 end2, bool take_root =true) - { - { - NTA_ASSERT(begin1 <= end1) - << "euclidean_distance: Invalid range 1"; - NTA_ASSERT(begin2 <= end2) - << "euclidean_distance: Invalid range 2"; - NTA_ASSERT(end1 - begin1 == end2 - begin2) - << "euclidean_distance: Ranges of different sizes"; - } - - typedef typename std::iterator_traits::value_type value_type; - - value_type d = (value_type) 0; - Lp2 lp2; - - for (; begin1 != end1; ++begin1, ++begin2) - lp2(d, *begin1 - *begin2); - - if (take_root) - d = lp2.root(d); - - return d; - } - - //-------------------------------------------------------------------------------- - /** - * Returns the Euclidean distance of the two containers. - * - * @param a first container - * @param b second container - */ - template - inline typename T1::value_type - euclidean_distance(const T1& a, const T2& b, bool take_root =true) - { - return euclidean_distance(a.begin(), a.end(), b.begin(), b.end(), take_root); - } - - //-------------------------------------------------------------------------------- - /** - * Returns the Lp distance of the two ranges. - * - * @param begin1 - * @param end1 - * @param begin2 - */ - template - inline typename std::iterator_traits::value_type - lp_distance(typename std::iterator_traits::value_type p, - It1 begin1, It1 end1, It2 begin2, It2 end2, bool take_root =true) - { - { - NTA_ASSERT(begin1 <= end1) - << "lp_distance: Invalid range 1"; - NTA_ASSERT(begin2 <= end2) - << "lp_distance: Invalid range 2"; - NTA_ASSERT(end1 - begin1 == end2 - begin2) - << "lp_distance: Ranges of different sizes"; - } - - typedef typename std::iterator_traits::value_type value_type; - - value_type d = (value_type) 0; - Lp lp(p); - - for (; begin1 != end1; ++begin1, ++begin2) - lp(d, *begin1 - *begin2); - - if (take_root) - d = lp.root(d); - - return d; - } - - //-------------------------------------------------------------------------------- - /** - * Returns the Lp distance of the two containers. - * - * @param a first container - * @param b second container - */ - template - inline typename T1::value_type - lp_distance(typename T1::value_type p, - const T1& a, const T2& b, bool take_root =true) - { - return lp_distance(p, a.begin(), a.end(), b.begin(), b.end(), take_root); - } - - //-------------------------------------------------------------------------------- - /** - * Returns the Lmax distance of the two ranges. - * - * @param begin1 - * @param end1 - * @param begin2 - */ - template - inline typename std::iterator_traits::value_type - lmax_distance(It1 begin1, It1 end1, It2 begin2, It2 end2, bool =true) - { - { - NTA_ASSERT(begin1 <= end1) - << "lmax_distance: Invalid range 1"; - NTA_ASSERT(begin2 <= end2) - << "lmax_distance: Invalid range 2"; - NTA_ASSERT(end1 - begin1 == end2 - begin2) - << "lmax_distance: Ranges of different sizes"; - } +//-------------------------------------------------------------------------------- +/** + * Manhattan norm on a container. + */ +template +inline typename T::value_type l1_norm(const T &a, bool = true) { + return l1_norm(a.begin(), a.end()); +} + +//-------------------------------------------------------------------------------- +/** + * Euclidean norm. + */ - typedef typename std::iterator_traits::value_type value_type; +//-------------------------------------------------------------------------------- +template +inline typename std::iterator_traits::value_type +l2_norm(It begin, It end, bool take_root = true) { + { NTA_ASSERT(begin <= end) << "l2_norm: Invalid range"; } - value_type d = (value_type) 0; - LpMax lmax; + typedef typename std::iterator_traits::value_type value_type; + value_type n = (value_type)0; - for (; begin1 != end1; ++begin1, ++begin2) - lmax(d, *begin1 - *begin2); + Lp2 lp2; - return d; - } + for (; begin != end; ++begin) + lp2(n, *begin); - //-------------------------------------------------------------------------------- - /** - * Returns the Lmax distance of the two containers. - * - * @param a first container - * @param b second container - */ - template - inline typename T1::value_type - lmax_distance(const T1& a, const T2& b, bool =true) - { - return lmax_distance(a.begin(), a.end(), b.begin(), b.end()); - } + if (take_root) + n = lp2.root(n); - //-------------------------------------------------------------------------------- - /** - * Returns the distance of the two ranges. - * - * @param begin1 - * @param end1 - * @param begin2 - */ - template - inline typename std::iterator_traits::value_type - distance(typename std::iterator_traits::value_type p, - It1 begin1, It1 end1, It2 begin2, It2 end2, bool take_root =true) - { - { - NTA_ASSERT(begin1 <= end1) - << "distance: Invalid range 1"; - NTA_ASSERT(begin2 <= end2) - << "distance: Invalid range 2"; - NTA_ASSERT(end1 - begin1 == end2 - begin2) - << "distance: Ranges of different sizes"; - } + return n; +} - typedef typename std::iterator_traits::value_type value_type; +//-------------------------------------------------------------------------------- +/** + * Euclidean norm on a container. + */ +template +inline typename T::value_type l2_norm(const T &a, bool take_root = true) { + return l2_norm(a.begin(), a.end(), take_root); +} + +//-------------------------------------------------------------------------------- +/** + * p-norm. + */ +template +inline typename std::iterator_traits::value_type +lp_norm(typename std::iterator_traits::value_type p, It begin, It end, + bool take_root = true) { + { NTA_ASSERT(begin <= end) << "lp_norm: Invalid range"; } - if (p == (value_type) 0) - return hamming_distance(begin1, end1, begin2); - else if (p == (value_type) 1) - return manhattan_distance(begin1, end1, begin2); - else if (p == (value_type) 2) - return euclidean_distance(begin1, end1, begin2, take_root); - else if (p == std::numeric_limits::max()) - return lmax_distance(begin1, end1, begin2); - else - return lp_distance(p, begin1, end1, begin2, take_root); - } + typedef typename std::iterator_traits::value_type value_type; - //-------------------------------------------------------------------------------- - /** - * Returns the distance of the two containers. - * - * @param a first container - * @param b second container - */ - template - inline typename T1::value_type - distance(typename T1::value_type p, const T1& a, const T2& b, bool take_root =true) - { - return distance(p, a.begin(), a.end(), b.begin(), b.end(), take_root); - } + value_type n = (value_type)0; + Lp lp(p); - //-------------------------------------------------------------------------------- - // Counting - //-------------------------------------------------------------------------------- - /** - * Counts the elements which satisfy the passed predicate in the given range. - */ - template - inline size_t count_if(const C& c, Predicate pred) - { - return std::count_if(c.begin(), c.end(), pred); - } + for (; begin != end; ++begin) + lp(n, *begin); - //-------------------------------------------------------------------------------- - /** - * Counts the number of zeros in the given range. - */ - template - inline size_t - count_zeros(It begin, It end, - const typename std::iterator_traits::value_type& eps =nupic::Epsilon) - { - { - NTA_ASSERT(begin <= end) - << "count_zeros: Invalid range"; - } + if (take_root) + n = lp.root(n); - typedef typename std::iterator_traits::value_type value_type; - return std::count_if(begin, end, IsNearlyZero >(eps)); - } + return n; +} - //-------------------------------------------------------------------------------- - /** - * Counts the number of zeros in the container passed in. - */ - template - inline size_t count_zeros(const C& c, const typename C::value_type& eps =nupic::Epsilon) - { - return count_zeros(c.begin, c.end(), eps); - } +//-------------------------------------------------------------------------------- +/** + * p-norm on a container. + */ +template +inline typename T::value_type lp_norm(typename T::value_type p, const T &a, + bool take_root = true) { + return lp_norm(p, a.begin(), a.end(), take_root); +} + +//-------------------------------------------------------------------------------- +/** + * L inf / L max norm. + */ +template +inline typename std::iterator_traits::value_type lmax_norm(It begin, It end, + bool = true) { + { NTA_ASSERT(begin <= end) << "lmax_norm: Invalid range"; } - //-------------------------------------------------------------------------------- - /** - * Count the number of ones in the given range. - */ - template - inline size_t - count_ones(It begin, It end, - const typename std::iterator_traits::value_type& eps =nupic::Epsilon) - { - { - NTA_ASSERT(begin <= end) - << "count_ones: Invalid range"; - } + typedef typename std::iterator_traits::value_type value_type; - typedef typename std::iterator_traits::value_type value_type; - return std::count_if(begin, end, IsNearlyZero >(eps)); - } + value_type n = (value_type)0; + LpMax lmax; - //-------------------------------------------------------------------------------- - /** - * Count the number of ones in the container passed in. - */ - template - inline size_t count_ones(const C& c, const typename C::value_type& eps =nupic::Epsilon) - { - return count_ones(c.begin(), c.end(), eps); - } + for (; begin != end; ++begin) + lmax(n, *begin); - //-------------------------------------------------------------------------------- - /** - * Counts the number of values greater than a given threshold in a given range. - * - * Asm SSE is many times faster than C++ (almost 10X), and C++ is 10X faster than - * numpy (some_array > threshold).sum(). The asm code doesn't have branchs, which - * is probably very good for the CPU front-end. - * - * This is not as general as a count_gt that would be parameterized on the type - * of the elements in the range, and it requires passing in a Python arrays - * that are .astype(float32). - * - * Doesn't work on win32. - */ - inline nupic::UInt32 - count_gt(nupic::Real32* begin, nupic::Real32* end, nupic::Real32 threshold) - { - NTA_ASSERT(begin <= end); + return n; +} - // Need this, because the asm syntax is not correct for win32, - // we simply can't compile the code as is on win32. +//-------------------------------------------------------------------------------- +/** + * L inf / L max norm on a container. + */ +template +inline typename T::value_type lmax_norm(const T &a, bool = true) { + return lmax_norm(a.begin(), a.end()); +} + +//-------------------------------------------------------------------------------- +/** + * Norm function. + * + * @param p the norm + * @param begin + * @param end + * @param take_root whether to take the p-th root or not + */ +template +inline typename std::iterator_traits::value_type +norm(typename std::iterator_traits::value_type p, It begin, It end, + bool take_root = true) { + { NTA_ASSERT(begin <= end) << "norm: Invalid range"; } + + typedef typename std::iterator_traits::value_type value_type; + + if (p == (value_type)0) + return l0_norm(begin, end); + else if (p == (value_type)1) + return l1_norm(begin, end); + else if (p == (value_type)2) + return l2_norm(begin, end, take_root); + else if (p == std::numeric_limits::max()) + return lmax_norm(begin, end); + else + return lp_norm(p, begin, end, take_root); +} + +//-------------------------------------------------------------------------------- +/** + */ +template +inline void +multiply_val(It begin, It end, + const typename std::iterator_traits::value_type &val) { + { NTA_ASSERT(begin <= end) << "multiply_val: Invalid range"; } - // Need this, because even on darwin32 (which is darwin86), some older machines might - // not have the right SSE instructions. - if (SSE_LEVEL >= 3) { + if (val == 1.0f) + return; - // Compute offsets into array [begin..end): - // start is the first 4 bytes aligned address (to start movaps) - // n0 is the number of floats before we reach start and can use parallel - // xmm operations - // n1 is the number floats we can process in parallel with xmm - // n2 is the number of "stragglers" what we will have to do one by one ( < 4) - nupic::Real32 count = 0; - NTA_UIntPtr x_addr = (NTA_UIntPtr) begin; // 8 bytes on 64 bits platforms - nupic::Real32* start = (x_addr % 16 == 0) ? begin : (nupic::Real32*) (16*(x_addr/16+1)); - int n0 = (int)(start - begin); - int n1 = 4 * ((end - start) / 4); - int n2 = (int)(end - start - n1); + for (; begin != end; ++begin) + *begin *= val; +} +//-------------------------------------------------------------------------------- +/** + */ +template +inline void multiply_val(T &x, const typename T::value_type &val) { + multiply_val(x.begin(), x.end(), val); +} + +//-------------------------------------------------------------------------------- +/** + * Norm on a whole container. + */ +template +inline typename T::value_type norm(typename T::value_type p, const T &a, + bool take_root = true) { + return norm(p, a.begin(), a.end(), take_root); +} + +//-------------------------------------------------------------------------------- +/** + * Normalize a range, according to the p norm, so that the values sum up to n. + * + * @param begin + * @param end + * @param p the norm + * @param n the value of the sum of the elements after normalization + */ +template +inline void +normalize(It begin, It end, + const typename std::iterator_traits::value_type &p = 1.0, + const typename std::iterator_traits::value_type &n = 1.0) { + { NTA_ASSERT(begin <= end) << "normalize: Invalid input range"; } + + typedef typename std::iterator_traits::value_type value_type; + + value_type s = (value_type)0; + + if (p == (value_type)0) + s = l0_norm(begin, end); + else if (p == (value_type)1) + s = l1_norm(begin, end); + else if (p == (value_type)2) + s = l2_norm(begin, end); + else if (p == std::numeric_limits::max()) + s = lmax_norm(begin, end); + + if (s != (value_type)0) + multiply_val(begin, end, n / s); +} + +//-------------------------------------------------------------------------------- +/** + * Normalize a container, with p-th norm and so that values add up to n. + */ +template +inline void normalize(T &a, const typename T::value_type &p = 1.0, + const typename T::value_type &n = 1.0) { + normalize(a.begin(), a.end(), p, n); +} + +//-------------------------------------------------------------------------------- +/** + * Normalization according to LpMax: finds the max of the range, + * and then divides all the values so that the max is n. + * Makes it nicer to call normalize when using LpMax. + */ +template +inline void +normalize_max(It begin, It end, + const typename std::iterator_traits::value_type &n = 1.0) { + { NTA_ASSERT(begin <= end) << "normalize_max: Invalid range"; } -#if defined(NTA_ARCH_64) && !defined(NTA_OS_WINDOWS) + typedef typename std::iterator_traits::value_type value_type; - #if defined(NTA_OS_DARWIN) + normalize(begin, end, std::numeric_limits::max(), n); +} - // DO NOT CHANGE THESE NEXT TWO LINES, OTHERWISE THE ASM CODE BELOW WILL BREAK. - // 'localThreshold' MUST BE STATIC!! Must always assign threshold to localThreshold also! - static float localThreshold; - localThreshold = threshold; - #endif +//-------------------------------------------------------------------------------- +/** + * Normalization according to LpMax. + */ +template +inline void normalize_max(std::vector &x, + const value_type &n = 1.0) { + normalize_max(x.begin(), x.end(), n); +} + +//-------------------------------------------------------------------------------- +/** + * Fills the container with a range of values. + */ +template +inline void generate_range(T &t, typename T::value_type start, + typename T::value_type end, + typename T::value_type increment = 1) { + std::insert_iterator it(t, t.begin()); + + for (typename T::value_type i = start; i < end; i += increment, ++it) + *it = i; +} + +//-------------------------------------------------------------------------------- +/** + * Initializes a range with the uniform distribution. + * + * @param begin beginning of the range + * @param end one past the end of the range + * @param val the value to which the sum of the range will be equal to + */ +template +inline void +uniform_range(It begin, It end, + typename std::iterator_traits::value_type val = 1) { + { NTA_ASSERT(begin <= end) << "uniform_range: Invalid input range"; } - __asm__ __volatile__( + typedef typename std::iterator_traits::value_type value_type; - #if defined(NTA_OS_DARWIN) - // We need to access localThreshold by it's mangled name here because g++ and - // clang++ do things differently on OS X. They clobber eax by the time they - // get here and 'threshold' is not properly loaded into rax by the constraint - // at the end of this asm snippet. This means we need to use 'rip relative' - // addressing to access a 'static' variable (a global would also work) here - // and then load it into eax manually. Then things work fine. - "movq __ZZN5nupic8count_gtEPfS0_fE14localThreshold@GOTPCREL(%%rip), %%r11\n\t" - "movl (%%r11), %%eax\n\t" - #endif - - "subq $16, %%rsp\n\t" // allocate 4 floats on stack - "movl %%eax, (%%rsp)\n\t" // copy threshold to 4 locations - "movl %%eax, 4(%%rsp)\n\t" // on stack: we want threshold - "movl %%eax, 8(%%rsp)\n\t" // to be filling xmm1 and xmm - "movl %%eax, 12(%%rsp)\n\t" // (operate on 4 floats at a time) - "movaps (%%rsp), %%xmm1\n\t" // move 4 thresholds into xmm1 - "movaps %%xmm1, %%xmm2\n\t" // copy 4 thresholds to xmm2 - - "movl $0x3f800000, (%%rsp)\n\t" // $0x3f800000 = (float) 1.0 - "movl $0x3f800000, 4(%%rsp)\n\t" // we want to have that constant - "movl $0x3f800000, 8(%%rsp)\n\t" // 8 times, in xmm3 and xmm4, - "movl $0x3f800000, 12(%%rsp)\n\t"// since the xmm4 registers allow - "movaps (%%rsp), %%xmm3\n\t" // us to operate on 4 floats at - "movaps (%%rsp), %%xmm4\n\t" // a time - - "addq $16, %%rsp\n\t" // deallocate 4 floats on stack - - "xorps %%xmm5, %%xmm5\n\t" // set xmm5 to 0 - - // Loop over individual floats till we reach the right alignment - // that was computed in n0. If we don't start handling 4 floats - // at a time with SSE on a 4 bytes boundary, we get a crash - // in movaps (here, we use only movss, moving only 1 float at a - // time). - "0:\n\t" - "test %%rcx, %%rcx\n\t" // if n0 == 0, jump to next loop - "jz 1f\n\t" - - "movss (%%rsi), %%xmm0\n\t" // move a single float to xmm0 - "cmpss $1, %%xmm0, %%xmm1\n\t" // compare to threshold - "andps %%xmm1, %%xmm3\n\t" // and with all 1s - "addss %%xmm3, %%xmm5\n\t" // add result to xmm5 (=count!) - "movaps %%xmm2, %%xmm1\n\t" // restore threshold in xmm1 - "movaps %%xmm4, %%xmm3\n\t" // restore all 1s in xmm3 - "addq $4, %%rsi\n\t" // move to next float (4 bytes) - "decq %%rcx\n\t" // decrement rcx, which started at n0 - "ja 0b\n\t" // jump if not done yet - - // Loop over 4 floats at a time: this time, we have reached - // the proper alignment for movaps, so we can operate in parallel - // on 4 floats at a time. The code is the same as the previous loop - // except that the "ss" instructions are now "ps" instructions. - "1:\n\t" - "test %%rdx, %%rdx\n\t" - "jz 2f\n\t" - - "movaps (%%rsi), %%xmm0\n\t" // note movaps, not movss - "cmpps $1, %%xmm0, %%xmm1\n\t" - "andps %%xmm1, %%xmm3\n\t" - "addps %%xmm3, %%xmm5\n\t" // addps, not addss - "movaps %%xmm2, %%xmm1\n\t" - "movaps %%xmm4, %%xmm3\n\t" - "addq $16, %%rsi\n\t" // jump over 4 floats - "subq $4, %%rdx\n\t" // decrement rdx (n1) by 4 - "ja 1b\n\t" - - // Tally up count so far into last float of xmm5: we were - // doing operations in parallels on the 4 floats in the xmm - // registers, resulting in 4 partial counts in xmm5. - "xorps %%xmm0, %%xmm0\n\t" - "haddps %%xmm0, %%xmm5\n\t" - "haddps %%xmm0, %%xmm5\n\t" - - // Last loop, for stragglers in case the array is not evenly - // divisible by 4. We are back to operating on a single float - // at a time, using movss and addss. - "2:\n\t" - "test %%rdi, %%rdi\n\t" - "jz 3f\n\t" - - "movss (%%rsi), %%xmm0\n\t" - "cmpss $1, %%xmm0, %%xmm1\n\t" - "andps %%xmm1, %%xmm3\n\t" - "addss %%xmm3, %%xmm5\n\t" - "movaps %%xmm2, %%xmm1\n\t" - "movaps %%xmm4, %%xmm3\n\t" - "addq $4, %%rsi\n\t" - "decq %%rdi\n\t" - "ja 0b\n\t" - - // Push result from xmm5 to variable count in memory. - "3:\n\t" - "movss %%xmm5, %0\n\t" - - : "=m" (count) - : "S" (begin), "a" (threshold), "c" (n0), "d" (n1), "D" (n2) - : - ); - - return (int) count; + std::fill(begin, end, (value_type)1); + normalize(begin, end, val); +} -#else - return std::count_if(begin, end, std::bind2nd(std::greater(), threshold)); -#endif - } else { - return std::count_if(begin, end, std::bind2nd(std::greater(), threshold)); - } +//-------------------------------------------------------------------------------- +/** + * Initializes a container with the uniform distribution. + * + * @param a the container + * @param val the value for normalization + */ +template +inline void uniform_range(C &a, typename C::value_type val = 1) { + uniform_range(a.begin(), a.end(), val); +} + +//-------------------------------------------------------------------------------- +// DISTANCES +//-------------------------------------------------------------------------------- +/** + * Returns the max of the absolute values of the differences. + * + * @param begin1 + * @param end1 + * @param begin2 + */ +template +inline typename std::iterator_traits::value_type +max_abs_diff(It1 begin1, It1 end1, It2 begin2, It2 end2) { + { + NTA_ASSERT(begin1 <= end1) << "max_abs_diff: Invalid range 1"; + NTA_ASSERT(begin2 <= end2) << "max_abs_diff: Invalid range 2"; + NTA_ASSERT(end1 - begin1 == end2 - begin2) + << "max_abs_diff: Ranges of different sizes"; } - //-------------------------------------------------------------------------------- - /** - * Counts the number of values greater than or equal to a given threshold in a - * given range. - * - * This is not as general as a count_gt that would be parameterized on the type - * of the elements in the range, and it requires passing in a Python arrays - * that are .astype(float32). - * - */ - inline nupic::UInt32 - count_gte(nupic::Real32* begin, nupic::Real32* end, nupic::Real32 threshold) - { - NTA_ASSERT(begin <= end); + typename std::iterator_traits::value_type d(0), val(0); - return std::count_if(begin, end, - std::bind2nd(std::greater_equal(), - threshold)); + while (begin1 != end1) { + val = *begin1 - *begin2; + val = val > 0 ? val : -val; + if (val > d) + d = val; + ++begin1; + ++begin2; } + return d; +} - //-------------------------------------------------------------------------------- - /** - * Counts the number of non-zeros in a vector. - */ - inline size_t count_non_zeros(nupic::Real32* begin, nupic::Real32* end) +//-------------------------------------------------------------------------------- +/** + * Returns the max of the absolute values of the differences. + * + * @param a first container + * @param b second container + */ +template +inline typename T1::value_type max_abs_diff(const T1 &a, const T2 &b) { + return max_abs_diff(a.begin(), a.end(), b.begin(), b.end()); +} + +//-------------------------------------------------------------------------------- +/** + * Returns the Hamming distance of the two ranges. + * + * @param begin1 + * @param end1 + * @param begin2 + */ +template +inline typename std::iterator_traits::value_type +hamming_distance(It1 begin1, It1 end1, It2 begin2, It2 end2) { { - NTA_ASSERT(begin <= end); - return count_gt(begin, end, 0); + NTA_ASSERT(begin1 <= end1) << "hamming_distance: Invalid range 1"; + NTA_ASSERT(begin2 <= end2) << "hamming_distance: Invalid range 2"; + NTA_ASSERT(end1 - begin1 == end2 - begin2) + << "hamming_distance: Ranges of different sizes"; } - //-------------------------------------------------------------------------------- - /** - * Counts the number of non-zeros in a vector. - * Doesn't work with vector. - */ - template - inline size_t count_non_zeros(const std::vector& x) - { - NTA_ASSERT(sizeof(T) == 4); - nupic::Real32* begin = (nupic::Real32*) &x[0]; - nupic::Real32* end = begin + x.size(); - return count_gt(begin, end, 0); - } + typename std::iterator_traits::value_type d(0); - //-------------------------------------------------------------------------------- - /** - * TODO: Use SSE. Maybe requires having our own vector so that we can avoid - * the shenanigans with the bit references and iterators. - */ - template <> - inline size_t count_non_zeros(const std::vector& x) - { - size_t count = 0; - for (size_t i = 0; i != x.size(); ++i) - count += x[i]; - return count; + while (begin1 != end1) { + d += *begin1 != *begin2; + ++begin1; + ++begin2; } - //-------------------------------------------------------------------------------- - template - inline size_t count_non_zeros(const std::vector >& x) - { - size_t count = 0; - for (size_t i = 0; i != x.size(); ++i) - if (! is_zero(x[i])) - ++count; - return count; - } + return d; +} - //-------------------------------------------------------------------------------- - /** - * Counts the number of values less than a given threshold in a given range. - */ - template - inline size_t - count_lt(It begin, It end, const typename std::iterator_traits::value_type& thres) +//-------------------------------------------------------------------------------- +/** + * Returns the Hamming distance of the two containers. + * + * @param a first container + * @param b second container + */ +template +inline typename T1::value_type hamming_distance(const T1 &a, const T2 &b) { + return hamming_distance(a.begin(), a.end(), b.begin(), b.end()); +} + +//-------------------------------------------------------------------------------- +/** + * [begin1, end1) and [begin2, end2) are index encodings of binary 0/1 ranges. + */ +template +inline size_t sparse_hamming_distance(It1 begin1, It1 end1, It2 begin2, + It2 end2) { { - typedef typename std::iterator_traits::value_type value_type; - return std::count_if(begin, end, std::bind2nd(std::less(), thres)); + // todo: check that ranges are valid sparse indices ranges + // (increasing, no duplicate...) + NTA_ASSERT(begin1 <= end1) << "sparse_hamming_distance: Invalid range 1"; + NTA_ASSERT(begin2 <= end2) << "sparse_hamming_distance: Invalid range 2"; } - //-------------------------------------------------------------------------------- - // Rounding - //-------------------------------------------------------------------------------- - /** - */ - template - inline void - round_01(It begin, It end, - const typename std::iterator_traits::value_type& threshold =.5) - { - { - NTA_ASSERT(begin <= end) - << "round_01: Invalid range"; - } + typedef size_t size_type; - typename std::iterator_traits::value_type val; + size_type d = 0; - while (begin != end) { - val = *begin; - if (val >= threshold) - val = 1; - else - val = 0; - *begin = val; - ++begin; + while (begin1 != end1 && begin2 != end2) { + if (*begin1 < *begin2) { + ++d; + ++begin1; + } else if (*begin2 < *begin1) { + ++d; + ++begin2; + } else { + ++begin1; + ++begin2; } } - //-------------------------------------------------------------------------------- - /** - */ - template - inline void round_01(T& a, const typename T::value_type& threshold =.5) + d += (size_type)(end1 - begin1); + d += (size_type)(end2 - begin2); + + return d; +} + +//-------------------------------------------------------------------------------- +template +inline typename T1::size_type sparse_hamming_distance(const T1 &a, + const T2 &b) { + return sparse_hamming_distance(a.begin(), a.end(), b.begin(), b.end()); +} + +//-------------------------------------------------------------------------------- +/** + * Returns the Manhattan distance of the two ranges. + * + * @param begin1 + * @param end1 + * @param begin2 + */ +template +inline typename std::iterator_traits::value_type +manhattan_distance(It1 begin1, It1 end1, It2 begin2, It2 end2) { { - round_01(a.begin(), a.end(), threshold); + NTA_ASSERT(begin1 <= end1) << "manhattan_distance: Invalid range 1"; + NTA_ASSERT(begin2 <= end2) << "manhattan_distance: Invalid range 2"; + NTA_ASSERT(end1 - begin1 == end2 - begin2) + << "manhattan_distance: Ranges of different sizes"; } - //-------------------------------------------------------------------------------- - // Addition... - //-------------------------------------------------------------------------------- - /** - * Computes the sum of the elements in a range. - * - * Note: a previous version used veclib on Mac's and vDSP. vDSP is much faster - * than C++, even optimized by gcc, but for now this works - * only with float (rather than double), and only on darwin86. With these - * restrictions the speed-up is usually better than 5X over optimized C++. - * vDSP also handles unaligned vectors correctly, and has good performance - * also when the vectors are small, not just when they are big. - */ - inline nupic::Real32 sum(nupic::Real32* begin, nupic::Real32* end) - { - { - NTA_ASSERT(begin <= end) - << "sum: Invalid range"; - } + typedef typename std::iterator_traits::value_type value_type; - nupic::Real32 result = 0; - for (; begin != end; ++begin) - result += *begin; - return result; + value_type d = (value_type)0; + Lp1 lp1; - } + for (; begin1 != end1; ++begin1, ++begin2) + lp1(d, *begin1 - *begin2); - //-------------------------------------------------------------------------------- - /** - * Compute the sum of a whole container. - * Here we revert to C++, which is going to be slower than the preceding function, - * but it will work for a container of anything, that container not necessarily - * being a contiguous vector of numbers. - */ - template - inline typename T::value_type sum(const T& x) - { - typename T::value_type result = 0; - typename T::const_iterator it; - for (it = x.begin(); it != x.end(); ++it) - result += *it; - return result; - } + return d; +} - //-------------------------------------------------------------------------------- - template - inline void sum(const std::vector& a, const std::vector& b, - size_t begin, size_t end, std::vector& c) - { - for (size_t i = begin; i != end; ++i) - c[i] = a[i] + b[i]; +//-------------------------------------------------------------------------------- +/** + * Returns the Manhattan distance of the two containers. + * + * @param a first container + * @param b second container + */ +template +inline typename T1::value_type manhattan_distance(const T1 &a, const T2 &b) { + return manhattan_distance(a.begin(), a.end(), b.begin(), b.end()); +} + +//-------------------------------------------------------------------------------- +/** + * Returns the Euclidean distance of the two ranges. + * + * @param begin1 + * @param end1 + * @param begin2 + */ +template +inline typename std::iterator_traits::value_type +euclidean_distance(It1 begin1, It1 end1, It2 begin2, It2 end2, + bool take_root = true) { + { + NTA_ASSERT(begin1 <= end1) << "euclidean_distance: Invalid range 1"; + NTA_ASSERT(begin2 <= end2) << "euclidean_distance: Invalid range 2"; + NTA_ASSERT(end1 - begin1 == end2 - begin2) + << "euclidean_distance: Ranges of different sizes"; } - //-------------------------------------------------------------------------------- - /** - * Computes the product of the elements in a range. - */ - template - inline typename std::iterator_traits::value_type product(It begin, It end) - { - { - NTA_ASSERT(begin <= end) - << "product: Invalid range"; - } + typedef typename std::iterator_traits::value_type value_type; - typename std::iterator_traits::value_type p(1); + value_type d = (value_type)0; + Lp2 lp2; - for (; begin != end; ++begin) - p *= *begin; + for (; begin1 != end1; ++begin1, ++begin2) + lp2(d, *begin1 - *begin2); - return p; - } + if (take_root) + d = lp2.root(d); - //-------------------------------------------------------------------------------- - /** - * Computes the product of all the elements in a container. - */ - template - inline typename T::value_type product(const T& x) - { - return product(x.begin(), x.end()); + return d; +} + +//-------------------------------------------------------------------------------- +/** + * Returns the Euclidean distance of the two containers. + * + * @param a first container + * @param b second container + */ +template +inline typename T1::value_type euclidean_distance(const T1 &a, const T2 &b, + bool take_root = true) { + return euclidean_distance(a.begin(), a.end(), b.begin(), b.end(), take_root); +} + +//-------------------------------------------------------------------------------- +/** + * Returns the Lp distance of the two ranges. + * + * @param begin1 + * @param end1 + * @param begin2 + */ +template +inline typename std::iterator_traits::value_type +lp_distance(typename std::iterator_traits::value_type p, It1 begin1, + It1 end1, It2 begin2, It2 end2, bool take_root = true) { + { + NTA_ASSERT(begin1 <= end1) << "lp_distance: Invalid range 1"; + NTA_ASSERT(begin2 <= end2) << "lp_distance: Invalid range 2"; + NTA_ASSERT(end1 - begin1 == end2 - begin2) + << "lp_distance: Ranges of different sizes"; } - //-------------------------------------------------------------------------------- - /** - */ - template - inline void add_val(It begin, It end, - const typename std::iterator_traits::value_type& val) - { - { - NTA_ASSERT(begin <= end) - << "add_val: Invalid range"; - } + typedef typename std::iterator_traits::value_type value_type; - if (val == 0.0f) - return; + value_type d = (value_type)0; + Lp lp(p); - for (; begin != end; ++begin) - *begin += val; - } + for (; begin1 != end1; ++begin1, ++begin2) + lp(d, *begin1 - *begin2); - //-------------------------------------------------------------------------------- - /** - */ - template - inline void add_val(T& x, const typename T::value_type& val) - { - add_val(x.begin(), x.end(), val); - } + if (take_root) + d = lp.root(d); - //-------------------------------------------------------------------------------- - /** - */ - template - inline void subtract_val(It begin, It end, - const typename std::iterator_traits::value_type& val) - { - add_val(begin, end, -val); - } + return d; +} - //-------------------------------------------------------------------------------- - /** - */ - template - inline void subtract_val(T& x, const typename T::value_type& val) +//-------------------------------------------------------------------------------- +/** + * Returns the Lp distance of the two containers. + * + * @param a first container + * @param b second container + */ +template +inline typename T1::value_type lp_distance(typename T1::value_type p, + const T1 &a, const T2 &b, + bool take_root = true) { + return lp_distance(p, a.begin(), a.end(), b.begin(), b.end(), take_root); +} + +//-------------------------------------------------------------------------------- +/** + * Returns the Lmax distance of the two ranges. + * + * @param begin1 + * @param end1 + * @param begin2 + */ +template +inline typename std::iterator_traits::value_type +lmax_distance(It1 begin1, It1 end1, It2 begin2, It2 end2, bool = true) { { - subtract_val(x.begin(), x.end(), val); + NTA_ASSERT(begin1 <= end1) << "lmax_distance: Invalid range 1"; + NTA_ASSERT(begin2 <= end2) << "lmax_distance: Invalid range 2"; + NTA_ASSERT(end1 - begin1 == end2 - begin2) + << "lmax_distance: Ranges of different sizes"; } - //-------------------------------------------------------------------------------- - /** - */ - template - inline void negate(It begin, It end) - { - { - NTA_ASSERT(begin <= end) - << "negate: Invalid range"; - } + typedef typename std::iterator_traits::value_type value_type; - for (; begin != end; ++begin) - *begin = -*begin; - } + value_type d = (value_type)0; + LpMax lmax; - //-------------------------------------------------------------------------------- - /** - */ - template - inline void negate(T& x) - { - negate(x.begin(), x.end()); - } + for (; begin1 != end1; ++begin1, ++begin2) + lmax(d, *begin1 - *begin2); - //-------------------------------------------------------------------------------- - /** - */ - template - inline void - divide_val(It begin, It end, - const typename std::iterator_traits::value_type& val) - { - { - NTA_ASSERT(begin <= end) - << "divide_val: Invalid range"; - NTA_ASSERT(val != 0) - << "divide_val: Division by zero"; - } + return d; +} - multiply_val(begin, end, 1.0f/val); +//-------------------------------------------------------------------------------- +/** + * Returns the Lmax distance of the two containers. + * + * @param a first container + * @param b second container + */ +template +inline typename T1::value_type lmax_distance(const T1 &a, const T2 &b, + bool = true) { + return lmax_distance(a.begin(), a.end(), b.begin(), b.end()); +} + +//-------------------------------------------------------------------------------- +/** + * Returns the distance of the two ranges. + * + * @param begin1 + * @param end1 + * @param begin2 + */ +template +inline typename std::iterator_traits::value_type +distance(typename std::iterator_traits::value_type p, It1 begin1, It1 end1, + It2 begin2, It2 end2, bool take_root = true) { + { + NTA_ASSERT(begin1 <= end1) << "distance: Invalid range 1"; + NTA_ASSERT(begin2 <= end2) << "distance: Invalid range 2"; + NTA_ASSERT(end1 - begin1 == end2 - begin2) + << "distance: Ranges of different sizes"; } - //-------------------------------------------------------------------------------- - // TODO: what if val == 0? - /** - */ - template - inline void divide_val(T& x, const typename T::value_type& val) - { - divide_val(x.begin(), x.end(), val); - } + typedef typename std::iterator_traits::value_type value_type; - //-------------------------------------------------------------------------------- - /** - */ - template - inline void add(It1 begin1, It1 end1, It2 begin2, It2 end2) - { - { - NTA_ASSERT(begin1 <= end1) - << "add: Invalid range"; - NTA_ASSERT(end1 - begin1 <= end2 - begin2) - << "add: Incompatible ranges"; - } + if (p == (value_type)0) + return hamming_distance(begin1, end1, begin2); + else if (p == (value_type)1) + return manhattan_distance(begin1, end1, begin2); + else if (p == (value_type)2) + return euclidean_distance(begin1, end1, begin2, take_root); + else if (p == std::numeric_limits::max()) + return lmax_distance(begin1, end1, begin2); + else + return lp_distance(p, begin1, end1, begin2, take_root); +} - for (; begin1 != end1; ++begin1, ++begin2) - *begin1 += *begin2; - } +//-------------------------------------------------------------------------------- +/** + * Returns the distance of the two containers. + * + * @param a first container + * @param b second container + */ +template +inline typename T1::value_type distance(typename T1::value_type p, const T1 &a, + const T2 &b, bool take_root = true) { + return distance(p, a.begin(), a.end(), b.begin(), b.end(), take_root); +} + +//-------------------------------------------------------------------------------- +// Counting +//-------------------------------------------------------------------------------- +/** + * Counts the elements which satisfy the passed predicate in the given range. + */ +template +inline size_t count_if(const C &c, Predicate pred) { + return std::count_if(c.begin(), c.end(), pred); +} + +//-------------------------------------------------------------------------------- +/** + * Counts the number of zeros in the given range. + */ +template +inline size_t count_zeros( + It begin, It end, + const typename std::iterator_traits::value_type &eps = nupic::Epsilon) { + { NTA_ASSERT(begin <= end) << "count_zeros: Invalid range"; } + + typedef typename std::iterator_traits::value_type value_type; + return std::count_if(begin, end, + IsNearlyZero>(eps)); +} + +//-------------------------------------------------------------------------------- +/** + * Counts the number of zeros in the container passed in. + */ +template +inline size_t count_zeros(const C &c, + const typename C::value_type &eps = nupic::Epsilon) { + return count_zeros(c.begin, c.end(), eps); +} + +//-------------------------------------------------------------------------------- +/** + * Count the number of ones in the given range. + */ +template +inline size_t count_ones( + It begin, It end, + const typename std::iterator_traits::value_type &eps = nupic::Epsilon) { + { NTA_ASSERT(begin <= end) << "count_ones: Invalid range"; } + + typedef typename std::iterator_traits::value_type value_type; + return std::count_if(begin, end, + IsNearlyZero>(eps)); +} + +//-------------------------------------------------------------------------------- +/** + * Count the number of ones in the container passed in. + */ +template +inline size_t count_ones(const C &c, + const typename C::value_type &eps = nupic::Epsilon) { + return count_ones(c.begin(), c.end(), eps); +} + +//-------------------------------------------------------------------------------- +/** + * Counts the number of values greater than a given threshold in a given range. + * + * Asm SSE is many times faster than C++ (almost 10X), and C++ is 10X faster + * than numpy (some_array > threshold).sum(). The asm code doesn't have branchs, + * which is probably very good for the CPU front-end. + * + * This is not as general as a count_gt that would be parameterized on the type + * of the elements in the range, and it requires passing in a Python arrays + * that are .astype(float32). + * + * Doesn't work on win32. + */ +inline nupic::UInt32 count_gt(nupic::Real32 *begin, nupic::Real32 *end, + nupic::Real32 threshold) { + NTA_ASSERT(begin <= end); + + // Need this, because the asm syntax is not correct for win32, + // we simply can't compile the code as is on win32. + + // Need this, because even on darwin32 (which is darwin86), some older + // machines might not have the right SSE instructions. + if (SSE_LEVEL >= 3) { + + // Compute offsets into array [begin..end): + // start is the first 4 bytes aligned address (to start movaps) + // n0 is the number of floats before we reach start and can use parallel + // xmm operations + // n1 is the number floats we can process in parallel with xmm + // n2 is the number of "stragglers" what we will have to do one by one ( < + // 4) + nupic::Real32 count = 0; + NTA_UIntPtr x_addr = (NTA_UIntPtr)begin; // 8 bytes on 64 bits platforms + nupic::Real32 *start = + (x_addr % 16 == 0) ? begin : (nupic::Real32 *)(16 * (x_addr / 16 + 1)); + int n0 = (int)(start - begin); + int n1 = 4 * ((end - start) / 4); + int n2 = (int)(end - start - n1); - //-------------------------------------------------------------------------------- - /** - */ - template - inline void add(T1& x, const T2& y) - { - add(x.begin(), x.end(), y.begin(), y.end()); - } +#if defined(NTA_ARCH_64) && !defined(NTA_OS_WINDOWS) - //-------------------------------------------------------------------------------- - /** - */ - template - inline void subtract(It1 begin1, It1 end1, It2 begin2, It2 end2) - { - { - NTA_ASSERT(begin1 <= end1) - << "subtract: Invalid range"; - NTA_ASSERT(end1 - begin1 <= end2 - begin2) - << "subtract: Incompatible ranges"; - } +#if defined(NTA_OS_DARWIN) - for (; begin1 != end1; ++begin1, ++begin2) - *begin1 -= *begin2; - } + // DO NOT CHANGE THESE NEXT TWO LINES, OTHERWISE THE ASM CODE BELOW WILL + // BREAK. 'localThreshold' MUST BE STATIC!! Must always assign threshold to + // localThreshold also! + static float localThreshold; + localThreshold = threshold; +#endif - //-------------------------------------------------------------------------------- - // TODO: should we have the same argument ordering as copy?? - /** - */ - template - inline void subtract(T1& x, const T2& y) - { - subtract(x.begin(), x.end(), y.begin(), y.end()); - } + __asm__ __volatile__( + +#if defined(NTA_OS_DARWIN) + // We need to access localThreshold by it's mangled name here because + // g++ and clang++ do things differently on OS X. They clobber eax by + // the time they get here and 'threshold' is not properly loaded into + // rax by the constraint at the end of this asm snippet. This means we + // need to use 'rip relative' addressing to access a 'static' variable + // (a global would also work) here and then load it into eax manually. + // Then things work fine. + "movq __ZZN5nupic8count_gtEPfS0_fE14localThreshold@GOTPCREL(%%rip), " + "%%r11\n\t" + "movl (%%r11), %%eax\n\t" +#endif - //-------------------------------------------------------------------------------- - /** - */ - template - inline void multiply(It1 begin1, It1 end1, It2 begin2, It2 end2) - { - { - NTA_ASSERT(begin1 <= end1) - << "Binary multiply: Invalid range"; - NTA_ASSERT(end1 - begin1 <= end2 - begin2) - << "Binary multiply: Incompatible ranges"; - } + "subq $16, %%rsp\n\t" // allocate 4 floats on stack + "movl %%eax, (%%rsp)\n\t" // copy threshold to 4 locations + "movl %%eax, 4(%%rsp)\n\t" // on stack: we want threshold + "movl %%eax, 8(%%rsp)\n\t" // to be filling xmm1 and xmm + "movl %%eax, 12(%%rsp)\n\t" // (operate on 4 floats at a time) + "movaps (%%rsp), %%xmm1\n\t" // move 4 thresholds into xmm1 + "movaps %%xmm1, %%xmm2\n\t" // copy 4 thresholds to xmm2 + + "movl $0x3f800000, (%%rsp)\n\t" // $0x3f800000 = (float) 1.0 + "movl $0x3f800000, 4(%%rsp)\n\t" // we want to have that constant + "movl $0x3f800000, 8(%%rsp)\n\t" // 8 times, in xmm3 and xmm4, + "movl $0x3f800000, 12(%%rsp)\n\t" // since the xmm4 registers allow + "movaps (%%rsp), %%xmm3\n\t" // us to operate on 4 floats at + "movaps (%%rsp), %%xmm4\n\t" // a time + + "addq $16, %%rsp\n\t" // deallocate 4 floats on stack + + "xorps %%xmm5, %%xmm5\n\t" // set xmm5 to 0 + + // Loop over individual floats till we reach the right alignment + // that was computed in n0. If we don't start handling 4 floats + // at a time with SSE on a 4 bytes boundary, we get a crash + // in movaps (here, we use only movss, moving only 1 float at a + // time). + "0:\n\t" + "test %%rcx, %%rcx\n\t" // if n0 == 0, jump to next loop + "jz 1f\n\t" + + "movss (%%rsi), %%xmm0\n\t" // move a single float to xmm0 + "cmpss $1, %%xmm0, %%xmm1\n\t" // compare to threshold + "andps %%xmm1, %%xmm3\n\t" // and with all 1s + "addss %%xmm3, %%xmm5\n\t" // add result to xmm5 (=count!) + "movaps %%xmm2, %%xmm1\n\t" // restore threshold in xmm1 + "movaps %%xmm4, %%xmm3\n\t" // restore all 1s in xmm3 + "addq $4, %%rsi\n\t" // move to next float (4 bytes) + "decq %%rcx\n\t" // decrement rcx, which started at n0 + "ja 0b\n\t" // jump if not done yet + + // Loop over 4 floats at a time: this time, we have reached + // the proper alignment for movaps, so we can operate in parallel + // on 4 floats at a time. The code is the same as the previous loop + // except that the "ss" instructions are now "ps" instructions. + "1:\n\t" + "test %%rdx, %%rdx\n\t" + "jz 2f\n\t" + + "movaps (%%rsi), %%xmm0\n\t" // note movaps, not movss + "cmpps $1, %%xmm0, %%xmm1\n\t" + "andps %%xmm1, %%xmm3\n\t" + "addps %%xmm3, %%xmm5\n\t" // addps, not addss + "movaps %%xmm2, %%xmm1\n\t" + "movaps %%xmm4, %%xmm3\n\t" + "addq $16, %%rsi\n\t" // jump over 4 floats + "subq $4, %%rdx\n\t" // decrement rdx (n1) by 4 + "ja 1b\n\t" + + // Tally up count so far into last float of xmm5: we were + // doing operations in parallels on the 4 floats in the xmm + // registers, resulting in 4 partial counts in xmm5. + "xorps %%xmm0, %%xmm0\n\t" + "haddps %%xmm0, %%xmm5\n\t" + "haddps %%xmm0, %%xmm5\n\t" + + // Last loop, for stragglers in case the array is not evenly + // divisible by 4. We are back to operating on a single float + // at a time, using movss and addss. + "2:\n\t" + "test %%rdi, %%rdi\n\t" + "jz 3f\n\t" + + "movss (%%rsi), %%xmm0\n\t" + "cmpss $1, %%xmm0, %%xmm1\n\t" + "andps %%xmm1, %%xmm3\n\t" + "addss %%xmm3, %%xmm5\n\t" + "movaps %%xmm2, %%xmm1\n\t" + "movaps %%xmm4, %%xmm3\n\t" + "addq $4, %%rsi\n\t" + "decq %%rdi\n\t" + "ja 0b\n\t" + + // Push result from xmm5 to variable count in memory. + "3:\n\t" + "movss %%xmm5, %0\n\t" + + : "=m"(count) + : "S"(begin), "a"(threshold), "c"(n0), "d"(n1), "D"(n2) + :); + + return (int)count; - for (; begin1 != end1; ++begin1, ++begin2) - *begin1 *= *begin2; +#else + return std::count_if( + begin, end, std::bind2nd(std::greater(), threshold)); +#endif + } else { + return std::count_if( + begin, end, std::bind2nd(std::greater(), threshold)); } +} - //-------------------------------------------------------------------------------- - /** - */ - template - inline void multiply(T1& x, const T2& y) - { - multiply(x.begin(), x.end(), y.begin(), y.end()); +//-------------------------------------------------------------------------------- +/** + * Counts the number of values greater than or equal to a given threshold in a + * given range. + * + * This is not as general as a count_gt that would be parameterized on the type + * of the elements in the range, and it requires passing in a Python arrays + * that are .astype(float32). + * + */ +inline nupic::UInt32 count_gte(nupic::Real32 *begin, nupic::Real32 *end, + nupic::Real32 threshold) { + NTA_ASSERT(begin <= end); + + return std::count_if( + begin, end, std::bind2nd(std::greater_equal(), threshold)); +} + +//-------------------------------------------------------------------------------- +/** + * Counts the number of non-zeros in a vector. + */ +inline size_t count_non_zeros(nupic::Real32 *begin, nupic::Real32 *end) { + NTA_ASSERT(begin <= end); + return count_gt(begin, end, 0); +} + +//-------------------------------------------------------------------------------- +/** + * Counts the number of non-zeros in a vector. + * Doesn't work with vector. + */ +template inline size_t count_non_zeros(const std::vector &x) { + NTA_ASSERT(sizeof(T) == 4); + nupic::Real32 *begin = (nupic::Real32 *)&x[0]; + nupic::Real32 *end = begin + x.size(); + return count_gt(begin, end, 0); +} + +//-------------------------------------------------------------------------------- +/** + * TODO: Use SSE. Maybe requires having our own vector so that we can + * avoid the shenanigans with the bit references and iterators. + */ +template <> inline size_t count_non_zeros(const std::vector &x) { + size_t count = 0; + for (size_t i = 0; i != x.size(); ++i) + count += x[i]; + return count; +} + +//-------------------------------------------------------------------------------- +template +inline size_t count_non_zeros(const std::vector> &x) { + size_t count = 0; + for (size_t i = 0; i != x.size(); ++i) + if (!is_zero(x[i])) + ++count; + return count; +} + +//-------------------------------------------------------------------------------- +/** + * Counts the number of values less than a given threshold in a given range. + */ +template +inline size_t +count_lt(It begin, It end, + const typename std::iterator_traits::value_type &thres) { + typedef typename std::iterator_traits::value_type value_type; + return std::count_if(begin, end, + std::bind2nd(std::less(), thres)); +} + +//-------------------------------------------------------------------------------- +// Rounding +//-------------------------------------------------------------------------------- +/** + */ +template +inline void +round_01(It begin, It end, + const typename std::iterator_traits::value_type &threshold = .5) { + { NTA_ASSERT(begin <= end) << "round_01: Invalid range"; } + + typename std::iterator_traits::value_type val; + + while (begin != end) { + val = *begin; + if (val >= threshold) + val = 1; + else + val = 0; + *begin = val; + ++begin; } +} - //-------------------------------------------------------------------------------- - template - inline void multiply(It1 begin1, It1 end1, It2 begin2, It2 end2, - It3 begin3, It3 end3) - { - { - NTA_ASSERT(begin1 <= end1) - << "Ternary multiply: Invalid range"; - NTA_ASSERT(end1 - begin1 <= end2 - begin2) - << "Ternary multiply: Incompatible input ranges"; - NTA_ASSERT(end1 - begin1 <= end3 - begin3) - << "Ternary multiply: Not enough memory for result"; - } +//-------------------------------------------------------------------------------- +/** + */ +template +inline void round_01(T &a, const typename T::value_type &threshold = .5) { + round_01(a.begin(), a.end(), threshold); +} + +//-------------------------------------------------------------------------------- +// Addition... +//-------------------------------------------------------------------------------- +/** + * Computes the sum of the elements in a range. + * + * Note: a previous version used veclib on Mac's and vDSP. vDSP is much faster + * than C++, even optimized by gcc, but for now this works + * only with float (rather than double), and only on darwin86. With these + * restrictions the speed-up is usually better than 5X over optimized C++. + * vDSP also handles unaligned vectors correctly, and has good performance + * also when the vectors are small, not just when they are big. + */ +inline nupic::Real32 sum(nupic::Real32 *begin, nupic::Real32 *end) { + { NTA_ASSERT(begin <= end) << "sum: Invalid range"; } + + nupic::Real32 result = 0; + for (; begin != end; ++begin) + result += *begin; + return result; +} + +//-------------------------------------------------------------------------------- +/** + * Compute the sum of a whole container. + * Here we revert to C++, which is going to be slower than the preceding + * function, but it will work for a container of anything, that container not + * necessarily being a contiguous vector of numbers. + */ +template inline typename T::value_type sum(const T &x) { + typename T::value_type result = 0; + typename T::const_iterator it; + for (it = x.begin(); it != x.end(); ++it) + result += *it; + return result; +} + +//-------------------------------------------------------------------------------- +template +inline void sum(const std::vector &a, const std::vector &b, + size_t begin, size_t end, std::vector &c) { + for (size_t i = begin; i != end; ++i) + c[i] = a[i] + b[i]; +} + +//-------------------------------------------------------------------------------- +/** + * Computes the product of the elements in a range. + */ +template +inline typename std::iterator_traits::value_type product(It begin, It end) { + { NTA_ASSERT(begin <= end) << "product: Invalid range"; } - typedef typename std::iterator_traits::value_type value_type; + typename std::iterator_traits::value_type p(1); - for (; begin1 != end1; ++begin1, ++begin2, ++begin3) - *begin3 = (value_type) *begin1 * *begin2; - } + for (; begin != end; ++begin) + p *= *begin; - //-------------------------------------------------------------------------------- - template - inline void multiply(const T1& x, const T2& y, T3& z) - { - multiply(x.begin(), x.end(), y.begin(), y.end(), z.begin(), z.end()); - } + return p; +} - //-------------------------------------------------------------------------------- - /** - * Given a vector of pairs and a value val, multiplies the values - * by val, but only if index is in indices. Needs x and indices to be sorted - * in order of increasing indices. - */ - template - inline void - multiply_val(T val, const Buffer& indices, SparseVector& x) - { - I n1 = indices.nnz, n2 = x.nnz, i1 = 0, i2 =0; +//-------------------------------------------------------------------------------- +/** + * Computes the product of all the elements in a container. + */ +template inline typename T::value_type product(const T &x) { + return product(x.begin(), x.end()); +} - while (i1 != n1 && i2 != n2) - if (x[i2].first < indices[i1]) { - ++i2; - } else if (indices[i1] < x[i2].first) { - ++i1; - } else { - x[i2].second *= val; - ++i1; ++i2; - } - } +//-------------------------------------------------------------------------------- +/** + */ +template +inline void add_val(It begin, It end, + const typename std::iterator_traits::value_type &val) { + { NTA_ASSERT(begin <= end) << "add_val: Invalid range"; } - //-------------------------------------------------------------------------------- - /** - */ - template - inline void divide(It1 begin1, It1 end1, It2 begin2, It2 end2, - typename std::iterator_traits::value_type fuzz =0) - { - { - NTA_ASSERT(begin1 <= end1) - << "divide: Invalid range"; - NTA_ASSERT(end1 - begin1 <= end2 - begin2) - << "divide: Incompatible ranges"; - } + if (val == 0.0f) + return; - if (fuzz == 0) - for (; begin1 != end1; ++begin1, ++begin2) - *begin1 /= *begin2; - else - for (; begin1 != end1; ++begin1, ++begin2) - *begin1 /= (*begin2 + fuzz); - } + for (; begin != end; ++begin) + *begin += val; +} - //-------------------------------------------------------------------------------- - // What if y contains one or more zeros? - /** - */ - template - inline void divide(T1& x, const T2& y, typename T1::value_type fuzz =0) - { - divide(x.begin(), x.end(), y.begin(), y.end(), fuzz); - } +//-------------------------------------------------------------------------------- +/** + */ +template +inline void add_val(T &x, const typename T::value_type &val) { + add_val(x.begin(), x.end(), val); +} - //-------------------------------------------------------------------------------- - template - inline void divide_by_max(It1 begin, It1 end) - { - { - NTA_ASSERT(begin <= end) - << "divide_by_max: Invalid range"; - } +//-------------------------------------------------------------------------------- +/** + */ +template +inline void +subtract_val(It begin, It end, + const typename std::iterator_traits::value_type &val) { + add_val(begin, end, -val); +} + +//-------------------------------------------------------------------------------- +/** + */ +template +inline void subtract_val(T &x, const typename T::value_type &val) { + subtract_val(x.begin(), x.end(), val); +} - typename std::iterator_traits::value_type max_val = - *(std::max_element(begin, end)); +//-------------------------------------------------------------------------------- +/** + */ +template inline void negate(It begin, It end) { + { NTA_ASSERT(begin <= end) << "negate: Invalid range"; } - if (!nupic::nearlyZero(max_val)) - for (It1 it = begin; it != end; ++it) - *it /= max_val; - } + for (; begin != end; ++begin) + *begin = -*begin; +} - //-------------------------------------------------------------------------------- - template - inline void divide_by_max(T1& v) +//-------------------------------------------------------------------------------- +/** + */ +template inline void negate(T &x) { negate(x.begin(), x.end()); } + +//-------------------------------------------------------------------------------- +/** + */ +template +inline void +divide_val(It begin, It end, + const typename std::iterator_traits::value_type &val) { { - divide_by_max(v.begin(), v.end()); + NTA_ASSERT(begin <= end) << "divide_val: Invalid range"; + NTA_ASSERT(val != 0) << "divide_val: Division by zero"; } - //-------------------------------------------------------------------------------- - /** - */ - template - inline void inverseNZ(It1 begin1, It1 end1, It2 out, It2 out_end, - TFuncIsNearlyZero fIsZero, TFuncHandleZero fHandleZero) - { - { - NTA_ASSERT(begin1 <= end1) - << "inverseNZ: Invalid input range"; - NTA_ASSERT(out <= out_end) - << "inverseNZ: Invalid output range"; - NTA_ASSERT(end1 - begin1 == out_end - out) - << "inverseNZ: Incompatible ranges"; - } + multiply_val(begin, end, 1.0f / val); +} - const typename std::iterator_traits::value_type one(1.0); +//-------------------------------------------------------------------------------- +// TODO: what if val == 0? +/** + */ +template +inline void divide_val(T &x, const typename T::value_type &val) { + divide_val(x.begin(), x.end(), val); +} - for (; begin1 != end1; ++begin1, ++out) { - if(fIsZero(*begin1)) - *out = fHandleZero(*begin1); // Can't pass one? - else - *out = one / *begin1; - } +//-------------------------------------------------------------------------------- +/** + */ +template +inline void add(It1 begin1, It1 end1, It2 begin2, It2 end2) { + { + NTA_ASSERT(begin1 <= end1) << "add: Invalid range"; + NTA_ASSERT(end1 - begin1 <= end2 - begin2) << "add: Incompatible ranges"; } - //-------------------------------------------------------------------------------- - /** - * Computes the reciprocal of each element of vector 'x' and writes - * result into vector 'out'. - * 'out' must be of at least the size of 'x'. - * Does not resize 'out'; behavior is undefined if 'out' is not of - * the correct size. - * Uses only 'value_type', 'begin()' and 'end()' of 'x' and - * 'value_type' and 'begin()' of 'out'. - * Checks the value of each element of 'x' with 'fIsNearlyZero', and - * if 'fIsNearlyZero' returns false, computes the reciprocal as - * T2::value_type(1.0) / element value. - * If 'fIsNearlyZero' returns true, computes uses the output of - * 'fHandleZero(element value)' as the result for that element. - * - * Usage: nupic::inverseNZ(input, output, - * nupic::IsNearlyZero< DistanceToZero >(), - * nupic::Identity()); - */ - template - inline void inverseNZ(const T1& x, T2 &out, - TFuncIsNearlyZero fIsNearlyZero, TFuncHandleZero fHandleZero) + for (; begin1 != end1; ++begin1, ++begin2) + *begin1 += *begin2; +} + +//-------------------------------------------------------------------------------- +/** + */ +template inline void add(T1 &x, const T2 &y) { + add(x.begin(), x.end(), y.begin(), y.end()); +} + +//-------------------------------------------------------------------------------- +/** + */ +template +inline void subtract(It1 begin1, It1 end1, It2 begin2, It2 end2) { { - inverseNZ(x.begin(), x.end(), out.begin(), out.end(), fIsNearlyZero, fHandleZero); + NTA_ASSERT(begin1 <= end1) << "subtract: Invalid range"; + NTA_ASSERT(end1 - begin1 <= end2 - begin2) + << "subtract: Incompatible ranges"; } - //-------------------------------------------------------------------------------- - template - inline void inverse(It1 begin1, It1 end1, It2 out, It2 out_end, - const typename std::iterator_traits::value_type one =1.0) - { - { - NTA_ASSERT(begin1 <= end1) - << "inverse: Invalid input range"; - NTA_ASSERT(out <= out_end) - << "inverse: Invalid output range"; - NTA_ASSERT(end1 - begin1 == out_end - out) - << "inverse: Incompatible ranges"; - } + for (; begin1 != end1; ++begin1, ++begin2) + *begin1 -= *begin2; +} - for (; begin1 != end1; ++begin1, ++out) - *out = one / *begin1; - } +//-------------------------------------------------------------------------------- +// TODO: should we have the same argument ordering as copy?? +/** + */ +template inline void subtract(T1 &x, const T2 &y) { + subtract(x.begin(), x.end(), y.begin(), y.end()); +} - //-------------------------------------------------------------------------------- - template - inline void inverse(const T1& x, T2 &out, const typename T2::value_type one=1.0) +//-------------------------------------------------------------------------------- +/** + */ +template +inline void multiply(It1 begin1, It1 end1, It2 begin2, It2 end2) { { - inverse(x.begin(), x.end(), out.begin(), out.end(), one); + NTA_ASSERT(begin1 <= end1) << "Binary multiply: Invalid range"; + NTA_ASSERT(end1 - begin1 <= end2 - begin2) + << "Binary multiply: Incompatible ranges"; } - //-------------------------------------------------------------------------------- - /** - * x += k y - */ - template - inline void add_ky(const typename std::iterator_traits::value_type& k, - It1 y, It1 y_end, It2 x, It2 x_end) - { - { - NTA_ASSERT(y <= y_end) - << "add_ky: Invalid y range"; - NTA_ASSERT(x <= x_end) - << "add_ky: Invalid x range"; - NTA_ASSERT(y_end - y <= x - x_end) - << "add_ky: Result range too small"; - } + for (; begin1 != end1; ++begin1, ++begin2) + *begin1 *= *begin2; +} - while (y != y_end) { - *x += k * *y; - ++x; ++y; - } - } +//-------------------------------------------------------------------------------- +/** + */ +template inline void multiply(T1 &x, const T2 &y) { + multiply(x.begin(), x.end(), y.begin(), y.end()); +} - //-------------------------------------------------------------------------------- - /** - */ - template - inline void add_ky(const typename T1::value_type& k, const T2& y, T1& x) +//-------------------------------------------------------------------------------- +template +inline void multiply(It1 begin1, It1 end1, It2 begin2, It2 end2, It3 begin3, + It3 end3) { { - add_ky(k, y.begin(), y.end(), x.begin(), x.end()); + NTA_ASSERT(begin1 <= end1) << "Ternary multiply: Invalid range"; + NTA_ASSERT(end1 - begin1 <= end2 - begin2) + << "Ternary multiply: Incompatible input ranges"; + NTA_ASSERT(end1 - begin1 <= end3 - begin3) + << "Ternary multiply: Not enough memory for result"; } - //-------------------------------------------------------------------------------- - /** - * x2 = x1 + k y - */ - template - inline void add_ky(It1 x1, It1 x1_end, - const typename std::iterator_traits::value_type& k, - It2 y, It3 x2) - { - while (x1 != x1_end) { - *x2 = *x1 + k * *y; - ++x2; ++x1; ++y; + typedef typename std::iterator_traits::value_type value_type; + + for (; begin1 != end1; ++begin1, ++begin2, ++begin3) + *begin3 = (value_type)*begin1 * *begin2; +} + +//-------------------------------------------------------------------------------- +template +inline void multiply(const T1 &x, const T2 &y, T3 &z) { + multiply(x.begin(), x.end(), y.begin(), y.end(), z.begin(), z.end()); +} + +//-------------------------------------------------------------------------------- +/** + * Given a vector of pairs and a value val, multiplies the values + * by val, but only if index is in indices. Needs x and indices to be sorted + * in order of increasing indices. + */ +template +inline void multiply_val(T val, const Buffer &indices, + SparseVector &x) { + I n1 = indices.nnz, n2 = x.nnz, i1 = 0, i2 = 0; + + while (i1 != n1 && i2 != n2) + if (x[i2].first < indices[i1]) { + ++i2; + } else if (indices[i1] < x[i2].first) { + ++i1; + } else { + x[i2].second *= val; + ++i1; + ++i2; } - } +} - //-------------------------------------------------------------------------------- - /** - */ - template - inline void add_ky(const T1& x1, - const typename T1::value_type& k, const T2& y, - T3& x2) +//-------------------------------------------------------------------------------- +/** + */ +template +inline void divide(It1 begin1, It1 end1, It2 begin2, It2 end2, + typename std::iterator_traits::value_type fuzz = 0) { { - ////assert(y.size() >= x.size()); - - add_ky(x1.begin(), x1.end(), k, y.begin(), x2.begin()); + NTA_ASSERT(begin1 <= end1) << "divide: Invalid range"; + NTA_ASSERT(end1 - begin1 <= end2 - begin2) << "divide: Incompatible ranges"; } - // TODO: write binary operations x = y + z ... + if (fuzz == 0) + for (; begin1 != end1; ++begin1, ++begin2) + *begin1 /= *begin2; + else + for (; begin1 != end1; ++begin1, ++begin2) + *begin1 /= (*begin2 + fuzz); +} - //-------------------------------------------------------------------------------- - /** - * x = a * x + y - * - * TODO: write the rest of BLAS level 1 - */ - template - inline void axpy(T1& x, const typename T1::value_type& a, const T2& y) - { - ////assert(y.size() >= x.size()); +//-------------------------------------------------------------------------------- +// What if y contains one or more zeros? +/** + */ +template +inline void divide(T1 &x, const T2 &y, typename T1::value_type fuzz = 0) { + divide(x.begin(), x.end(), y.begin(), y.end(), fuzz); +} - typename T1::iterator it_x = x.begin(), it_x_end = x.end(); - typename T2::const_iterator it_y = y.begin(); +//-------------------------------------------------------------------------------- +template inline void divide_by_max(It1 begin, It1 end) { + { NTA_ASSERT(begin <= end) << "divide_by_max: Invalid range"; } - while (it_x != it_x_end) { - *it_x = a * *it_x + *it_y; - ++it_x; ++it_y; - } - } + typename std::iterator_traits::value_type max_val = + *(std::max_element(begin, end)); - //-------------------------------------------------------------------------------- - /** - * x = a * x + b * y - */ - template - inline void axby(const typename std::iterator_traits::value_type& a, - X x, X x_end, - const typename std::iterator_traits::value_type& b, - Y y) - { - while (x != x_end) { - *x = a * *x + b * *y; - ++x; ++y; - } - } + if (!nupic::nearlyZero(max_val)) + for (It1 it = begin; it != end; ++it) + *it /= max_val; +} - //-------------------------------------------------------------------------------- - /** - */ - template - inline void axby(const typename T1::value_type& a, T1& x, - const typename T1::value_type& b, const T2& y) - { - ////assert(y.size() >= x.size()); +//-------------------------------------------------------------------------------- +template inline void divide_by_max(T1 &v) { + divide_by_max(v.begin(), v.end()); +} - axby(a, x.begin(), x.end(), b, y.begin()); +//-------------------------------------------------------------------------------- +/** + */ +template +inline void inverseNZ(It1 begin1, It1 end1, It2 out, It2 out_end, + TFuncIsNearlyZero fIsZero, TFuncHandleZero fHandleZero) { + { + NTA_ASSERT(begin1 <= end1) << "inverseNZ: Invalid input range"; + NTA_ASSERT(out <= out_end) << "inverseNZ: Invalid output range"; + NTA_ASSERT(end1 - begin1 == out_end - out) + << "inverseNZ: Incompatible ranges"; } - //-------------------------------------------------------------------------------- - /** - * exp(k * x) for all the elements of a range. - */ - template - inline void range_exp(typename std::iterator_traits::value_type k, - It begin, It end) - { - typedef typename std::iterator_traits::value_type value_type; - - Exp e_f; + const typename std::iterator_traits::value_type one(1.0); - for (; begin != end; ++begin) - *begin = e_f(k * *begin); + for (; begin1 != end1; ++begin1, ++out) { + if (fIsZero(*begin1)) + *out = fHandleZero(*begin1); // Can't pass one? + else + *out = one / *begin1; } - - //-------------------------------------------------------------------------------- - /** - */ - template - inline void range_exp(typename C::value_type k, C& c) - { - range_exp(k, c.begin(), c.end()); +} + +//-------------------------------------------------------------------------------- +/** + * Computes the reciprocal of each element of vector 'x' and writes + * result into vector 'out'. + * 'out' must be of at least the size of 'x'. + * Does not resize 'out'; behavior is undefined if 'out' is not of + * the correct size. + * Uses only 'value_type', 'begin()' and 'end()' of 'x' and + * 'value_type' and 'begin()' of 'out'. + * Checks the value of each element of 'x' with 'fIsNearlyZero', and + * if 'fIsNearlyZero' returns false, computes the reciprocal as + * T2::value_type(1.0) / element value. + * If 'fIsNearlyZero' returns true, computes uses the output of + * 'fHandleZero(element value)' as the result for that element. + * + * Usage: nupic::inverseNZ(input, output, + * nupic::IsNearlyZero< DistanceToZero >(), + * nupic::Identity()); + */ +template +inline void inverseNZ(const T1 &x, T2 &out, TFuncIsNearlyZero fIsNearlyZero, + TFuncHandleZero fHandleZero) { + inverseNZ(x.begin(), x.end(), out.begin(), out.end(), fIsNearlyZero, + fHandleZero); +} + +//-------------------------------------------------------------------------------- +template +inline void +inverse(It1 begin1, It1 end1, It2 out, It2 out_end, + const typename std::iterator_traits::value_type one = 1.0) { + { + NTA_ASSERT(begin1 <= end1) << "inverse: Invalid input range"; + NTA_ASSERT(out <= out_end) << "inverse: Invalid output range"; + NTA_ASSERT(end1 - begin1 == out_end - out) + << "inverse: Incompatible ranges"; } - //-------------------------------------------------------------------------------- - /** - * k1 * exp(k2 * x) for all the elements of a range. - */ - template - inline void range_exp(typename std::iterator_traits::value_type k1, - typename std::iterator_traits::value_type k2, - It begin, It end) - { - typedef typename std::iterator_traits::value_type value_type; + for (; begin1 != end1; ++begin1, ++out) + *out = one / *begin1; +} - Exp e_f; +//-------------------------------------------------------------------------------- +template +inline void inverse(const T1 &x, T2 &out, + const typename T2::value_type one = 1.0) { + inverse(x.begin(), x.end(), out.begin(), out.end(), one); +} - for (; begin != end; ++begin) - *begin = k1 * e_f(k2 * *begin); +//-------------------------------------------------------------------------------- +/** + * x += k y + */ +template +inline void add_ky(const typename std::iterator_traits::value_type &k, + It1 y, It1 y_end, It2 x, It2 x_end) { + { + NTA_ASSERT(y <= y_end) << "add_ky: Invalid y range"; + NTA_ASSERT(x <= x_end) << "add_ky: Invalid x range"; + NTA_ASSERT(y_end - y <= x - x_end) << "add_ky: Result range too small"; } - //-------------------------------------------------------------------------------- - /** - */ - template - inline void range_exp(typename C::value_type k1, typename C::value_type k2, C& c) - { - range_exp(k1, k2, c.begin(), c.end()); + while (y != y_end) { + *x += k * *y; + ++x; + ++y; } +} - //-------------------------------------------------------------------------------- - // Inner product - //-------------------------------------------------------------------------------- - /** - * Bypasses the STL API and its init value. - * TODO: when range is empty?? - */ - template - inline typename std::iterator_traits::value_type - inner_product(It1 it_x, It1 it_x_end, It2 it_y) - { - typename std::iterator_traits::value_type n(0); +//-------------------------------------------------------------------------------- +/** + */ +template +inline void add_ky(const typename T1::value_type &k, const T2 &y, T1 &x) { + add_ky(k, y.begin(), y.end(), x.begin(), x.end()); +} + +//-------------------------------------------------------------------------------- +/** + * x2 = x1 + k y + */ +template +inline void add_ky(It1 x1, It1 x1_end, + const typename std::iterator_traits::value_type &k, + It2 y, It3 x2) { + while (x1 != x1_end) { + *x2 = *x1 + k * *y; + ++x2; + ++x1; + ++y; + } +} + +//-------------------------------------------------------------------------------- +/** + */ +template +inline void add_ky(const T1 &x1, const typename T1::value_type &k, const T2 &y, + T3 &x2) { + ////assert(y.size() >= x.size()); - while (it_x != it_x_end) { - n += *it_x * *it_y; - ++it_x; ++it_y; - } + add_ky(x1.begin(), x1.end(), k, y.begin(), x2.begin()); +} - return n; - } +// TODO: write binary operations x = y + z ... - //-------------------------------------------------------------------------------- - /** - * In place transform of a range. - */ - template - inline void transform(It begin, It end, F1 f) - { - { - NTA_ASSERT(begin <= end) - << "transform: Invalid range"; - } +//-------------------------------------------------------------------------------- +/** + * x = a * x + y + * + * TODO: write the rest of BLAS level 1 + */ +template +inline void axpy(T1 &x, const typename T1::value_type &a, const T2 &y) { + ////assert(y.size() >= x.size()); - for (; begin != end; ++begin) - *begin = f(*begin); + typename T1::iterator it_x = x.begin(), it_x_end = x.end(); + typename T2::const_iterator it_y = y.begin(); + + while (it_x != it_x_end) { + *it_x = a * *it_x + *it_y; + ++it_x; + ++it_y; } +} - //-------------------------------------------------------------------------------- - /** - */ - template - inline void transform(T1& a, F1 f) - { - typename T1::iterator ia = a.begin(), iae = a.end(); +//-------------------------------------------------------------------------------- +/** + * x = a * x + b * y + */ +template +inline void axby(const typename std::iterator_traits::value_type &a, X x, + X x_end, const typename std::iterator_traits::value_type &b, + Y y) { + while (x != x_end) { + *x = a * *x + b * *y; + ++x; + ++y; + } +} + +//-------------------------------------------------------------------------------- +/** + */ +template +inline void axby(const typename T1::value_type &a, T1 &x, + const typename T1::value_type &b, const T2 &y) { + ////assert(y.size() >= x.size()); - for (; ia != iae; ++ia) - *ia = f(*ia); - } + axby(a, x.begin(), x.end(), b, y.begin()); +} - //-------------------------------------------------------------------------------- - /** - */ - template - inline void transform(const T1& a, T2& b, F1 f) - { - ////assert(b.size() >= a.size()); +//-------------------------------------------------------------------------------- +/** + * exp(k * x) for all the elements of a range. + */ +template +inline void range_exp(typename std::iterator_traits::value_type k, It begin, + It end) { + typedef typename std::iterator_traits::value_type value_type; - typename T1::const_iterator ia = a.begin(), iae = a.end(); - typename T2::iterator ib = b.begin(); + Exp e_f; - for (; ia != iae; ++ia, ++ib) - *ib = f(*ia); - } + for (; begin != end; ++begin) + *begin = e_f(k * *begin); +} - //-------------------------------------------------------------------------------- - /** - */ - template - inline void transform(const T1& a, const T2& b, T3& c, F2 f) - { - ////assert(c.size() >= a.size()); - ////assert(b.size() >= a.size()); - ////assert(c.size() >= b.size()); +//-------------------------------------------------------------------------------- +/** + */ +template inline void range_exp(typename C::value_type k, C &c) { + range_exp(k, c.begin(), c.end()); +} - typename T1::const_iterator ia = a.begin(), iae = a.end(); - typename T2::const_iterator ib = b.begin(); - typename T3::iterator ic = c.begin(); +//-------------------------------------------------------------------------------- +/** + * k1 * exp(k2 * x) for all the elements of a range. + */ +template +inline void range_exp(typename std::iterator_traits::value_type k1, + typename std::iterator_traits::value_type k2, + It begin, It end) { + typedef typename std::iterator_traits::value_type value_type; - for (; ia != iae; ++ia, ++ib, ++ic) - *ic = f(*ia, *ib); - } + Exp e_f; - //-------------------------------------------------------------------------------- - /** - */ - template - inline void transform(const T1& a, const T2& b, const T3& c, T4& d, F3 f) - { - ////assert(d.size() >= a.size()); - ////assert(d.size() >= b.size()); - ////assert(d.size() >= c.size()); - ////assert(b.size() >= a.size()); - ////assert(c.size() >= a.size()); - - typename T1::const_iterator ia = a.begin(), iae = a.end(); - typename T2::const_iterator ib = b.begin(); - typename T3::const_iterator ic = c.begin(); - typename T4::iterator id = d.begin(); - - for (; ia != iae; ++ia, ++ib, ++ic, ++id) - *id = f(*ia, *ib, *ic); - } + for (; begin != end; ++begin) + *begin = k1 * e_f(k2 * *begin); +} - //-------------------------------------------------------------------------------- - // min_element / max_element - //-------------------------------------------------------------------------------- - /** - * Returns the position at which f takes its minimum between first and last. - */ - template - inline ForwardIterator - min_element(ForwardIterator first, ForwardIterator last, F f) - { - { - NTA_ASSERT(first <= last) - << "min_element: Invalid range"; - } +//-------------------------------------------------------------------------------- +/** + */ +template +inline void range_exp(typename C::value_type k1, typename C::value_type k2, + C &c) { + range_exp(k1, k2, c.begin(), c.end()); +} + +//-------------------------------------------------------------------------------- +// Inner product +//-------------------------------------------------------------------------------- +/** + * Bypasses the STL API and its init value. + * TODO: when range is empty?? + */ +template +inline typename std::iterator_traits::value_type +inner_product(It1 it_x, It1 it_x_end, It2 it_y) { + typename std::iterator_traits::value_type n(0); - typedef typename ForwardIterator::value_type value_type; + while (it_x != it_x_end) { + n += *it_x * *it_y; + ++it_x; + ++it_y; + } - ForwardIterator min_it = first; - value_type min_val = f(*first); + return n; +} - while (first != last) { - value_type val = f(*first); - if (val < min_val) { - min_it = first; - min_val = val; - } - ++first; - } +//-------------------------------------------------------------------------------- +/** + * In place transform of a range. + */ +template +inline void transform(It begin, It end, F1 f) { + { NTA_ASSERT(begin <= end) << "transform: Invalid range"; } - return min_it; - } + for (; begin != end; ++begin) + *begin = f(*begin); +} - //-------------------------------------------------------------------------------- - /** - * Returns the position at which f takes its maximum between first and last. - */ - template - inline ForwardIterator - max_element(ForwardIterator first, ForwardIterator last, F f) - { - { - NTA_ASSERT(first <= last) - << "max_element: Invalid range"; - } +//-------------------------------------------------------------------------------- +/** + */ +template inline void transform(T1 &a, F1 f) { + typename T1::iterator ia = a.begin(), iae = a.end(); - typedef typename ForwardIterator::value_type value_type; + for (; ia != iae; ++ia) + *ia = f(*ia); +} - ForwardIterator max_it = first; - value_type max_val = f(*first); +//-------------------------------------------------------------------------------- +/** + */ +template +inline void transform(const T1 &a, T2 &b, F1 f) { + ////assert(b.size() >= a.size()); - while (first != last) { - value_type val = f(*first); - if (val > max_val) { - max_it = first; - max_val = val; - } - ++first; - } + typename T1::const_iterator ia = a.begin(), iae = a.end(); + typename T2::iterator ib = b.begin(); - return max_it; - } + for (; ia != iae; ++ia, ++ib) + *ib = f(*ia); +} - //-------------------------------------------------------------------------------- - /** - * Finds the min element in a container. - */ - template - inline size_t min_element(const C& c) - { - if (c.empty()) - return (size_t) 0; - else - return (size_t) (std::min_element(c.begin(), c.end()) - c.begin()); - } +//-------------------------------------------------------------------------------- +/** + */ +template +inline void transform(const T1 &a, const T2 &b, T3 &c, F2 f) { + ////assert(c.size() >= a.size()); + ////assert(b.size() >= a.size()); + ////assert(c.size() >= b.size()); + + typename T1::const_iterator ia = a.begin(), iae = a.end(); + typename T2::const_iterator ib = b.begin(); + typename T3::iterator ic = c.begin(); + + for (; ia != iae; ++ia, ++ib, ++ic) + *ic = f(*ia, *ib); +} + +//-------------------------------------------------------------------------------- +/** + */ +template +inline void transform(const T1 &a, const T2 &b, const T3 &c, T4 &d, F3 f) { + ////assert(d.size() >= a.size()); + ////assert(d.size() >= b.size()); + ////assert(d.size() >= c.size()); + ////assert(b.size() >= a.size()); + ////assert(c.size() >= a.size()); + + typename T1::const_iterator ia = a.begin(), iae = a.end(); + typename T2::const_iterator ib = b.begin(); + typename T3::const_iterator ic = c.begin(); + typename T4::iterator id = d.begin(); + + for (; ia != iae; ++ia, ++ib, ++ic, ++id) + *id = f(*ia, *ib, *ic); +} + +//-------------------------------------------------------------------------------- +// min_element / max_element +//-------------------------------------------------------------------------------- +/** + * Returns the position at which f takes its minimum between first and last. + */ +template +inline ForwardIterator min_element(ForwardIterator first, ForwardIterator last, + F f) { + { NTA_ASSERT(first <= last) << "min_element: Invalid range"; } - //-------------------------------------------------------------------------------- - /** - * Finds the maximum element in a container. - */ - template - inline size_t max_element(const C& c) - { - if (c.empty()) - return (size_t) 0; - else - return (size_t) (std::max_element(c.begin(), c.end()) - c.begin()); - } + typedef typename ForwardIterator::value_type value_type; - //-------------------------------------------------------------------------------- - /** - * Writes the component-wise minimum to the output vector. - */ - template - inline void minimum(It1 begin1, It1 end1, It2 begin2, It3 out) - { - { - NTA_ASSERT(begin1 <= end1) - << "minimum: Invalid range"; - } + ForwardIterator min_it = first; + value_type min_val = f(*first); - typedef typename std::iterator_traits::value_type T; - for(; begin1!=end1; ++begin1, ++begin2, ++out) { - *out = std::min(*begin1, *begin2); + while (first != last) { + value_type val = f(*first); + if (val < min_val) { + min_it = first; + min_val = val; } + ++first; } - //-------------------------------------------------------------------------------- - /** - * Writes the component-wise minimum to the output vector. - */ - template - inline void minimum(const T1 &x, const T2 &y, T3 &out) - { - minimum(x.begin(), x.end(), y.begin(), out.begin()); - } + return min_it; +} - //-------------------------------------------------------------------------------- - // contains - //-------------------------------------------------------------------------------- - template - inline bool contains(const C& c, typename C::value_type& v) - { - return std::find(c.begin(), c.end(), v) != c.end(); - } +//-------------------------------------------------------------------------------- +/** + * Returns the position at which f takes its maximum between first and last. + */ +template +inline ForwardIterator max_element(ForwardIterator first, ForwardIterator last, + F f) { + { NTA_ASSERT(first <= last) << "max_element: Invalid range"; } - //-------------------------------------------------------------------------------- - template - inline bool is_subsequence(const C1& seq, const C2& sub) - { - return std::search(seq.begin(), seq.end(), sub.begin(), sub.end()) != seq.end(); - } + typedef typename ForwardIterator::value_type value_type; - //-------------------------------------------------------------------------------- - template - inline bool is_subsequence_of(const C1& c, const C2& sub) - { - bool found = false; - typename C1::const_iterator it, end = c.end(); - for (it = c.begin(); it != end; ++it) - if (is_subsequence(*it, sub)) - found = true; - return found; - } + ForwardIterator max_it = first; + value_type max_val = f(*first); - //-------------------------------------------------------------------------------- - // sample - //-------------------------------------------------------------------------------- - /** - * Sample n times from a given pdf. - */ - template - inline void sample(size_t n, It1 pdf_begin, It1 pdf_end, It2 output, RNG& rng) - { - { - NTA_ASSERT(pdf_begin <= pdf_end) - << "sample: Invalid range for pdf"; + while (first != last) { + value_type val = f(*first); + if (val > max_val) { + max_it = first; + max_val = val; } + ++first; + } - typedef typename std::iterator_traits::value_type size_type2; + return max_it; +} - size_t size = (size_t) (pdf_end - pdf_begin); - std::vector cdf(size, 0); - std::vector::const_iterator it = cdf.begin(); - cumulative(pdf_begin, pdf_end, cdf.begin(), cdf.end()); - double m = cdf[size-1]; +//-------------------------------------------------------------------------------- +/** + * Finds the min element in a container. + */ +template inline size_t min_element(const C &c) { + if (c.empty()) + return (size_t)0; + else + return (size_t)(std::min_element(c.begin(), c.end()) - c.begin()); +} + +//-------------------------------------------------------------------------------- +/** + * Finds the maximum element in a container. + */ +template inline size_t max_element(const C &c) { + if (c.empty()) + return (size_t)0; + else + return (size_t)(std::max_element(c.begin(), c.end()) - c.begin()); +} + +//-------------------------------------------------------------------------------- +/** + * Writes the component-wise minimum to the output vector. + */ +template +inline void minimum(It1 begin1, It1 end1, It2 begin2, It3 out) { + { NTA_ASSERT(begin1 <= end1) << "minimum: Invalid range"; } - for (size_t i = 0; i != n; ++i, ++output) { - double p = m * double(rng()) / double(rng.max() - rng.min()); - it = std::lower_bound(cdf.begin(), cdf.end(), p); - *output = (size_type2) (it - cdf.begin()); - } + typedef typename std::iterator_traits::value_type T; + for (; begin1 != end1; ++begin1, ++begin2, ++out) { + *out = std::min(*begin1, *begin2); } +} - //-------------------------------------------------------------------------------- - /** - * Sample n times from a given pdf. - */ - template - inline void sample(size_t n, It1 pdf_begin, It1 pdf_end, It2 output) - { - { - NTA_ASSERT(pdf_begin <= pdf_end) - << "sample: Invalid range for pdf"; - } +//-------------------------------------------------------------------------------- +/** + * Writes the component-wise minimum to the output vector. + */ +template +inline void minimum(const T1 &x, const T2 &y, T3 &out) { + minimum(x.begin(), x.end(), y.begin(), out.begin()); +} + +//-------------------------------------------------------------------------------- +// contains +//-------------------------------------------------------------------------------- +template +inline bool contains(const C &c, typename C::value_type &v) { + return std::find(c.begin(), c.end(), v) != c.end(); +} + +//-------------------------------------------------------------------------------- +template +inline bool is_subsequence(const C1 &seq, const C2 &sub) { + return std::search(seq.begin(), seq.end(), sub.begin(), sub.end()) != + seq.end(); +} + +//-------------------------------------------------------------------------------- +template +inline bool is_subsequence_of(const C1 &c, const C2 &sub) { + bool found = false; + typename C1::const_iterator it, end = c.end(); + for (it = c.begin(); it != end; ++it) + if (is_subsequence(*it, sub)) + found = true; + return found; +} + +//-------------------------------------------------------------------------------- +// sample +//-------------------------------------------------------------------------------- +/** + * Sample n times from a given pdf. + */ +template +inline void sample(size_t n, It1 pdf_begin, It1 pdf_end, It2 output, RNG &rng) { + { NTA_ASSERT(pdf_begin <= pdf_end) << "sample: Invalid range for pdf"; } - nupic::Random rng; - sample(n, pdf_begin, pdf_end, output, rng); - } + typedef typename std::iterator_traits::value_type size_type2; - //-------------------------------------------------------------------------------- - /** - * Sample one time from a given pdf. - */ - template - inline size_t sample_one(const C1& pdf) - { - size_t c = 0; - sample(1, pdf.begin(), pdf.end(), &c); - return c; - } + size_t size = (size_t)(pdf_end - pdf_begin); + std::vector cdf(size, 0); + std::vector::const_iterator it = cdf.begin(); + cumulative(pdf_begin, pdf_end, cdf.begin(), cdf.end()); + double m = cdf[size - 1]; - //-------------------------------------------------------------------------------- - template - inline size_t sample_one(const C1& pdf, RNG& rng) - { - size_t c = 0; - sample(1, pdf.begin(), pdf.end(), &c, rng); - return c; + for (size_t i = 0; i != n; ++i, ++output) { + double p = m * double(rng()) / double(rng.max() - rng.min()); + it = std::lower_bound(cdf.begin(), cdf.end(), p); + *output = (size_type2)(it - cdf.begin()); } +} + +//-------------------------------------------------------------------------------- +/** + * Sample n times from a given pdf. + */ +template +inline void sample(size_t n, It1 pdf_begin, It1 pdf_end, It2 output) { + { NTA_ASSERT(pdf_begin <= pdf_end) << "sample: Invalid range for pdf"; } + + nupic::Random rng; + sample(n, pdf_begin, pdf_end, output, rng); +} - //-------------------------------------------------------------------------------- - // DENSE LOGICAL AND/OR - //-------------------------------------------------------------------------------- - /** - * For each corresponding elements of x and y, put the logical and of those two - * elements at the corresponding position in z. This is faster than the numpy - * logical_and, which doesn't seem to be using SSE. - * - * x, y and z are arrays of floats, but with 0/1 values. - * - * If any of the vectors is not aligned on a 16 bytes boundary, the function - * reverts to slow C++. This can happen when using it with slices of numpy - * arrays. - * - * Doesn't work on win32/win64. - * - * TODO: find 16 bytes aligned block that can be sent to SSE. - * TODO: support win32/win64 for the fast path. - */ - template - inline void logical_and(InputIterator x, InputIterator x_end, - InputIterator y, InputIterator y_end, - OutputIterator z, OutputIterator z_end) +//-------------------------------------------------------------------------------- +/** + * Sample one time from a given pdf. + */ +template inline size_t sample_one(const C1 &pdf) { + size_t c = 0; + sample(1, pdf.begin(), pdf.end(), &c); + return c; +} + +//-------------------------------------------------------------------------------- +template +inline size_t sample_one(const C1 &pdf, RNG &rng) { + size_t c = 0; + sample(1, pdf.begin(), pdf.end(), &c, rng); + return c; +} + +//-------------------------------------------------------------------------------- +// DENSE LOGICAL AND/OR +//-------------------------------------------------------------------------------- +/** + * For each corresponding elements of x and y, put the logical and of those two + * elements at the corresponding position in z. This is faster than the numpy + * logical_and, which doesn't seem to be using SSE. + * + * x, y and z are arrays of floats, but with 0/1 values. + * + * If any of the vectors is not aligned on a 16 bytes boundary, the function + * reverts to slow C++. This can happen when using it with slices of numpy + * arrays. + * + * Doesn't work on win32/win64. + * + * TODO: find 16 bytes aligned block that can be sent to SSE. + * TODO: support win32/win64 for the fast path. + */ +template +inline void logical_and(InputIterator x, InputIterator x_end, InputIterator y, + InputIterator y_end, OutputIterator z, + OutputIterator z_end) { { - { - NTA_ASSERT(x_end - x == y_end - y); - NTA_ASSERT(x_end - x == z_end - z); - } + NTA_ASSERT(x_end - x == y_end - y); + NTA_ASSERT(x_end - x == z_end - z); + } - // See comments in count_gt. We need both conditional compilation and - // SSE_LEVEL check. + // See comments in count_gt. We need both conditional compilation and + // SSE_LEVEL check. #if !defined(NTA_OS_WINDOWS) && defined(NTA_ASM) - if (SSE_LEVEL >= 3) { - - // n is the total number of floats to process - // n1 is 0 if any of the arrays x,y,z is not aligned on a 4 bytes - // boundary, or the number of floats we'll be able to process in parallel - // using the xmm. - int n = (int)(x_end - x); - int n1 = 0; - if (((long)x) % 16 == 0 && ((long)y) % 16 == 0 && ((long)z) % 16 == 0) - n1 = 16 * (n / 16); - - // If we are not aligned on 4 bytes, n1 == 0, and we simply - // skip the asm. - if (n1 > 0) { - - #if defined(NTA_ARCH_32) - __asm__ __volatile__( - "pusha\n\t" // save all registers - - "0:\n\t" - "movaps (%%esi), %%xmm0\n\t" // move 4 floats of x to xmm0 - "andps (%%edi), %%xmm0\n\t" // parallel and with 4 floats of y - "movaps 16(%%esi), %%xmm1\n\t"// play again with next 4 floats - "andps 16(%%edi), %%xmm1\n\t" - "movaps 32(%%esi), %%xmm2\n\t"// and next 4 floats - "andps 32(%%edi), %%xmm2\n\t" - "movaps 48(%%esi), %%xmm3\n\t"// and next 4 floats: we've and'ed - "andps 48(%%edi), %%xmm3\n\t" // 16 floats of x and y at this point - - "movaps %%xmm0, (%%ecx)\n\t" // simply move 4 floats at a time to z - "movaps %%xmm1, 16(%%ecx)\n\t"// and next 4 floats - "movaps %%xmm2, 32(%%ecx)\n\t"// and next 4 floats - "movaps %%xmm3, 48(%%ecx)\n\t"// and next 4: moved 16 floats to z - - "addl $64, %%esi\n\t" // increment pointer into x by 16 floats - "addl $64, %%edi\n\t" // increment pointer into y - "addl $64, %%ecx\n\t" // increment pointer into z - "subl $16, %%edx\n\t" // we've processed 16 floats - "ja 0b\n\t" // loop - - "popa\n\t" // restore registers - - : - : "S" (x), "D" (y), "c" (z), "d" (n1) - : - ); - - #else - __asm__ __volatile__( - "pushq %%rsi\n\t" // save affected registers - "pushq %%rdi\n\t" // this 'shouldn't' be necessary - "pushq %%rcx\n\t" // but I was seeing some random - "pushq %%rdx\n\t" // crashes on OS X. Remove? - - "0:\n\t" - "movaps (%%rsi), %%xmm0\n\t" // move 4 floats of x to xmm0 - "andps (%%rdi), %%xmm0\n\t" // parallel and with 4 floats of y - "movaps 16(%%rsi), %%xmm1\n\t"// play again with next 4 floats - "andps 16(%%rdi), %%xmm1\n\t" - "movaps 32(%%rsi), %%xmm2\n\t"// and next 4 floats - "andps 32(%%rdi), %%xmm2\n\t" - "movaps 48(%%rsi), %%xmm3\n\t"// and next 4 floats: we've and'ed - "andps 48(%%rdi), %%xmm3\n\t" // 16 floats of x and y at this point - - "movaps %%xmm0, (%%rcx)\n\t" // simply move 4 floats at a time to z - "movaps %%xmm1, 16(%%rcx)\n\t"// and next 4 floats - "movaps %%xmm2, 32(%%rcx)\n\t"// and next 4 floats - "movaps %%xmm3, 48(%%rcx)\n\t"// and next 4: moved 16 floats to z - - "addq $64, %%rsi\n\t" // increment pointer into x by 16 floats - "addq $64, %%rdi\n\t" // increment pointer into y - "addq $64, %%rcx\n\t" // increment pointer into z - "subq $16, %%rdx\n\t" // we've processed 16 floats - - "ja 0b\n\t" // loop - - "popq %%rdx\n\t" // restore saved registers - "popq %%rcx\n\t" - "popq %%rdi\n\t" - "popq %%rsi\n\t" - - : - : "S" (x), "D" (y), "c" (z), "d" (n1) - : - ); - - #endif - } + if (SSE_LEVEL >= 3) { - // Finish up for stragglers in case the array length was not - // evenly divisible by 4 - for (int i = n1; i != n; ++i) - *(z+i) = *(x+i) && *(y+i); + // n is the total number of floats to process + // n1 is 0 if any of the arrays x,y,z is not aligned on a 4 bytes + // boundary, or the number of floats we'll be able to process in parallel + // using the xmm. + int n = (int)(x_end - x); + int n1 = 0; + if (((long)x) % 16 == 0 && ((long)y) % 16 == 0 && ((long)z) % 16 == 0) + n1 = 16 * (n / 16); - } else { + // If we are not aligned on 4 bytes, n1 == 0, and we simply + // skip the asm. + if (n1 > 0) { - for (; x != x_end; ++x, ++y, ++z) - *z = (*x) && (*y); +#if defined(NTA_ARCH_32) + __asm__ __volatile__( + "pusha\n\t" // save all registers + + "0:\n\t" + "movaps (%%esi), %%xmm0\n\t" // move 4 floats of x to xmm0 + "andps (%%edi), %%xmm0\n\t" // parallel and with 4 floats of y + "movaps 16(%%esi), %%xmm1\n\t" // play again with next 4 floats + "andps 16(%%edi), %%xmm1\n\t" + "movaps 32(%%esi), %%xmm2\n\t" // and next 4 floats + "andps 32(%%edi), %%xmm2\n\t" + "movaps 48(%%esi), %%xmm3\n\t" // and next 4 floats: we've and'ed + "andps 48(%%edi), %%xmm3\n\t" // 16 floats of x and y at this point + + "movaps %%xmm0, (%%ecx)\n\t" // simply move 4 floats at a time to z + "movaps %%xmm1, 16(%%ecx)\n\t" // and next 4 floats + "movaps %%xmm2, 32(%%ecx)\n\t" // and next 4 floats + "movaps %%xmm3, 48(%%ecx)\n\t" // and next 4: moved 16 floats to z + + "addl $64, %%esi\n\t" // increment pointer into x by 16 floats + "addl $64, %%edi\n\t" // increment pointer into y + "addl $64, %%ecx\n\t" // increment pointer into z + "subl $16, %%edx\n\t" // we've processed 16 floats + "ja 0b\n\t" // loop + + "popa\n\t" // restore registers + + : + : "S"(x), "D"(y), "c"(z), "d"(n1) + :); - } #else - for (; x != x_end; ++x, ++y, ++z) - *z = (*x) && (*y); -#endif - } + __asm__ __volatile__( + "pushq %%rsi\n\t" // save affected registers + "pushq %%rdi\n\t" // this 'shouldn't' be necessary + "pushq %%rcx\n\t" // but I was seeing some random + "pushq %%rdx\n\t" // crashes on OS X. Remove? + + "0:\n\t" + "movaps (%%rsi), %%xmm0\n\t" // move 4 floats of x to xmm0 + "andps (%%rdi), %%xmm0\n\t" // parallel and with 4 floats of y + "movaps 16(%%rsi), %%xmm1\n\t" // play again with next 4 floats + "andps 16(%%rdi), %%xmm1\n\t" + "movaps 32(%%rsi), %%xmm2\n\t" // and next 4 floats + "andps 32(%%rdi), %%xmm2\n\t" + "movaps 48(%%rsi), %%xmm3\n\t" // and next 4 floats: we've and'ed + "andps 48(%%rdi), %%xmm3\n\t" // 16 floats of x and y at this point + + "movaps %%xmm0, (%%rcx)\n\t" // simply move 4 floats at a time to z + "movaps %%xmm1, 16(%%rcx)\n\t" // and next 4 floats + "movaps %%xmm2, 32(%%rcx)\n\t" // and next 4 floats + "movaps %%xmm3, 48(%%rcx)\n\t" // and next 4: moved 16 floats to z + + "addq $64, %%rsi\n\t" // increment pointer into x by 16 floats + "addq $64, %%rdi\n\t" // increment pointer into y + "addq $64, %%rcx\n\t" // increment pointer into z + "subq $16, %%rdx\n\t" // we've processed 16 floats + + "ja 0b\n\t" // loop + + "popq %%rdx\n\t" // restore saved registers + "popq %%rcx\n\t" + "popq %%rdi\n\t" + "popq %%rsi\n\t" + + : + : "S"(x), "D"(y), "c"(z), "d"(n1) + :); - //-------------------------------------------------------------------------------- - /** - * Same as previous logical_and, but puts the result back into y. - * Same comments. - */ - template - inline void in_place_logical_and(Iterator x, Iterator x_end, - Iterator y, Iterator y_end) - { - { - NTA_ASSERT(x_end - x == y_end - y); +#endif } - // See comments in count_gt. We need conditional compilation - // _AND_ SSE_LEVEL check. -#if (defined(NTA_OS_LINUX) || defined(NTA_OS_DARWIN)) && defined(NTA_ASM) - - if (SSE_LEVEL >= 3) { - - // See comments in logical_and. - int n = (int)(x_end - x); - int n1 = 0; - if (((long)x) % 16 == 0 && ((long)y) % 16 == 0) - n1 = 16 * (n / 16); - - if (n1 > 0) { - - #if defined(NTA_ARCH_32) - __asm__ __volatile__( - "pusha\n\t" - - "0:\n\t" - "movaps (%%esi), %%xmm0\n\t" - "movaps 16(%%esi), %%xmm1\n\t" - "movaps 32(%%esi), %%xmm2\n\t" - "movaps 48(%%esi), %%xmm3\n\t" - "andps (%%edi), %%xmm0\n\t" - "andps 16(%%edi), %%xmm1\n\t" - "andps 32(%%edi), %%xmm2\n\t" - "andps 48(%%edi), %%xmm3\n\t" - "movaps %%xmm0, (%%edi)\n\t" - "movaps %%xmm1, 16(%%edi)\n\t" - "movaps %%xmm2, 32(%%edi)\n\t" - "movaps %%xmm3, 48(%%edi)\n\t" - - "addl $64, %%esi\n\t" - "addl $64, %%edi\n\t" - "subl $16, %%edx\n\t" - "prefetch (%%esi)\n\t" - "ja 0b\n\t" - - "popa\n\t" - - : - : "S" (x), "D" (y), "d" (n1) - : - ); - - #else //64bit - __asm__ __volatile__( - "pushq %%rsi\n\t" // save affected registers - "pushq %%rdi\n\t" - "pushq %%rdx\n\t" - - "0:\n\t" - "movaps (%%rsi), %%xmm0\n\t" - "movaps 16(%%rsi), %%xmm1\n\t" - "movaps 32(%%rsi), %%xmm2\n\t" - "movaps 48(%%rsi), %%xmm3\n\t" - - "andps (%%rdi), %%xmm0\n\t" - "andps 16(%%rdi), %%xmm1\n\t" // in place and - "andps 32(%%rdi), %%xmm2\n\t" - "andps 48(%%rdi), %%xmm3\n\t" - - "movaps %%xmm0, (%%rdi)\n\t" - "movaps %%xmm1, 16(%%rdi)\n\t" - "movaps %%xmm2, 32(%%rdi)\n\t" - "movaps %%xmm3, 48(%%rdi)\n\t" - - "addq $64, %%rsi\n\t" // increment pointer into x by 16 floats - "addq $64, %%rdi\n\t" // increment pointer into y - "subq $16, %%rdx\n\t" // we've processed 16 floats - - "ja 0b\n\t" // loop - - "popq %%rdx\n\t" // restore saved registers - "popq %%rdi\n\t" - "popq %%rsi\n\t" - - : - : "S" (x), "D" (y), "d" (n1) - : - ); - #endif - } - - for (int i = n1; i != n; ++i) - *(y+i) = *(x+i) && *(y+i); + // Finish up for stragglers in case the array length was not + // evenly divisible by 4 + for (int i = n1; i != n; ++i) + *(z + i) = *(x + i) && *(y + i); - } else { + } else { - for (; x != x_end; ++x, ++y) - *y = (*x) && *(y); - } + for (; x != x_end; ++x, ++y, ++z) + *z = (*x) && (*y); + } #else - for (; x != x_end; ++x, ++y) - *y = (*x) && *(y); + for (; x != x_end; ++x, ++y, ++z) + *z = (*x) && (*y); #endif - } +} - //-------------------------------------------------------------------------------- - /** - * A specialization tuned for unsigned char. - * TODO: keep only one code that computes the right offsets based on - * the iterator value type? - * TODO: vectorize, but watch out for alignments - */ - inline void in_place_logical_and(const ByteVector& x, ByteVector& y, - int begin =-1, int end =-1) - { - if (begin == -1) - begin = 0; +//-------------------------------------------------------------------------------- +/** + * Same as previous logical_and, but puts the result back into y. + * Same comments. + */ +template +inline void in_place_logical_and(Iterator x, Iterator x_end, Iterator y, + Iterator y_end) { + { NTA_ASSERT(x_end - x == y_end - y); } - if (end == -1) - end = (int) x.size(); + // See comments in count_gt. We need conditional compilation + // _AND_ SSE_LEVEL check. +#if (defined(NTA_OS_LINUX) || defined(NTA_OS_DARWIN)) && defined(NTA_ASM) - for (int i = begin; i != end; ++i) - y[i] &= x[i]; - } + if (SSE_LEVEL >= 3) { + + // See comments in logical_and. + int n = (int)(x_end - x); + int n1 = 0; + if (((long)x) % 16 == 0 && ((long)y) % 16 == 0) + n1 = 16 * (n / 16); + + if (n1 > 0) { + +#if defined(NTA_ARCH_32) + __asm__ __volatile__("pusha\n\t" + + "0:\n\t" + "movaps (%%esi), %%xmm0\n\t" + "movaps 16(%%esi), %%xmm1\n\t" + "movaps 32(%%esi), %%xmm2\n\t" + "movaps 48(%%esi), %%xmm3\n\t" + "andps (%%edi), %%xmm0\n\t" + "andps 16(%%edi), %%xmm1\n\t" + "andps 32(%%edi), %%xmm2\n\t" + "andps 48(%%edi), %%xmm3\n\t" + "movaps %%xmm0, (%%edi)\n\t" + "movaps %%xmm1, 16(%%edi)\n\t" + "movaps %%xmm2, 32(%%edi)\n\t" + "movaps %%xmm3, 48(%%edi)\n\t" + + "addl $64, %%esi\n\t" + "addl $64, %%edi\n\t" + "subl $16, %%edx\n\t" + "prefetch (%%esi)\n\t" + "ja 0b\n\t" + + "popa\n\t" + + : + : "S"(x), "D"(y), "d"(n1) + :); + +#else // 64bit + __asm__ __volatile__( + "pushq %%rsi\n\t" // save affected registers + "pushq %%rdi\n\t" + "pushq %%rdx\n\t" + + "0:\n\t" + "movaps (%%rsi), %%xmm0\n\t" + "movaps 16(%%rsi), %%xmm1\n\t" + "movaps 32(%%rsi), %%xmm2\n\t" + "movaps 48(%%rsi), %%xmm3\n\t" + + "andps (%%rdi), %%xmm0\n\t" + "andps 16(%%rdi), %%xmm1\n\t" // in place and + "andps 32(%%rdi), %%xmm2\n\t" + "andps 48(%%rdi), %%xmm3\n\t" + + "movaps %%xmm0, (%%rdi)\n\t" + "movaps %%xmm1, 16(%%rdi)\n\t" + "movaps %%xmm2, 32(%%rdi)\n\t" + "movaps %%xmm3, 48(%%rdi)\n\t" + + "addq $64, %%rsi\n\t" // increment pointer into x by 16 floats + "addq $64, %%rdi\n\t" // increment pointer into y + "subq $16, %%rdx\n\t" // we've processed 16 floats + + "ja 0b\n\t" // loop + + "popq %%rdx\n\t" // restore saved registers + "popq %%rdi\n\t" + "popq %%rsi\n\t" + + : + : "S"(x), "D"(y), "d"(n1) + :); +#endif + } - //-------------------------------------------------------------------------------- - /** - * TODO: write with SSE for big enough vectors. - */ - inline void in_place_logical_or(const ByteVector& x, ByteVector& y, - int begin =-1, int end =-1) - { - if (begin == -1) - begin = 0; + for (int i = n1; i != n; ++i) + *(y + i) = *(x + i) && *(y + i); - if (end == -1) - end = (int) x.size(); + } else { - for (int i = begin; i != end; ++i) - y[i] |= x[i]; + for (; x != x_end; ++x, ++y) + *y = (*x) && *(y); } +#else + for (; x != x_end; ++x, ++y) + *y = (*x) && *(y); +#endif +} + +//-------------------------------------------------------------------------------- +/** + * A specialization tuned for unsigned char. + * TODO: keep only one code that computes the right offsets based on + * the iterator value type? + * TODO: vectorize, but watch out for alignments + */ +inline void in_place_logical_and(const ByteVector &x, ByteVector &y, + int begin = -1, int end = -1) { + if (begin == -1) + begin = 0; - //-------------------------------------------------------------------------------- - inline void - logical_or(size_t n, const ByteVector& x, const ByteVector& y, ByteVector& z) - { - for (size_t i = 0; i != n; ++i) - z[i] = x[i] || y[i]; - } + if (end == -1) + end = (int)x.size(); - //-------------------------------------------------------------------------------- - inline void in_place_logical_or(size_t n, const ByteVector& x, ByteVector& y) - { - for (size_t i = 0; i != n; ++i) - y[i] |= x[i]; - } + for (int i = begin; i != end; ++i) + y[i] &= x[i]; +} - //-------------------------------------------------------------------------------- - // SPARSE OR/AND - //-------------------------------------------------------------------------------- - template - inline size_t sparseOr(size_t n, - InputIterator1 begin1, InputIterator1 end1, - InputIterator2 begin2, InputIterator2 end2, - OutputIterator out, OutputIterator out_end) - { - { // Pre-conditions - NTA_ASSERT(0 <= end1 - begin1) +//-------------------------------------------------------------------------------- +/** + * TODO: write with SSE for big enough vectors. + */ +inline void in_place_logical_or(const ByteVector &x, ByteVector &y, + int begin = -1, int end = -1) { + if (begin == -1) + begin = 0; + + if (end == -1) + end = (int)x.size(); + + for (int i = begin; i != end; ++i) + y[i] |= x[i]; +} + +//-------------------------------------------------------------------------------- +inline void logical_or(size_t n, const ByteVector &x, const ByteVector &y, + ByteVector &z) { + for (size_t i = 0; i != n; ++i) + z[i] = x[i] || y[i]; +} + +//-------------------------------------------------------------------------------- +inline void in_place_logical_or(size_t n, const ByteVector &x, ByteVector &y) { + for (size_t i = 0; i != n; ++i) + y[i] |= x[i]; +} + +//-------------------------------------------------------------------------------- +// SPARSE OR/AND +//-------------------------------------------------------------------------------- +template +inline size_t sparseOr(size_t n, InputIterator1 begin1, InputIterator1 end1, + InputIterator2 begin2, InputIterator2 end2, + OutputIterator out, OutputIterator out_end) { + { // Pre-conditions + NTA_ASSERT(0 <= end1 - begin1) << "sparseOr: Mismatched iterators for first vector"; - NTA_ASSERT(0 <= end2 - begin2) + NTA_ASSERT(0 <= end2 - begin2) << "sparseOr: Mismatched iterators for second vector"; - NTA_ASSERT(0 <= out_end - out) + NTA_ASSERT(0 <= out_end - out) << "sparseOr: Mismatched iterators for output vector"; - NTA_ASSERT(0 <= n) - << "sparseOr: Invalid max size: " << n; - NTA_ASSERT((size_t)(end1 - begin1) <= n) + NTA_ASSERT(0 <= n) << "sparseOr: Invalid max size: " << n; + NTA_ASSERT((size_t)(end1 - begin1) <= n) << "sparseOr: Invalid first vector size"; - NTA_ASSERT((size_t)(end2 - begin2) <= n) + NTA_ASSERT((size_t)(end2 - begin2) <= n) << "sparseOr: Invalid second vector size"; - NTA_ASSERT(n <= (size_t)(out_end - out)) + NTA_ASSERT(n <= (size_t)(out_end - out)) << "sparseOr: Insufficient memory for result"; - for (int i = 0; i < (int)(end1 - begin1); ++i) - NTA_ASSERT(/*0 <= *(begin1 + i) &&*/ *(begin1 + i) < n) + for (int i = 0; i < (int)(end1 - begin1); ++i) + NTA_ASSERT(/*0 <= *(begin1 + i) &&*/ *(begin1 + i) < n) << "sparseOr: Invalid index in first vector: " << *(begin1 + i); - for (int i = 1; i < (int)(end1 - begin1); ++i) - NTA_ASSERT(*(begin1 + i - 1) < *(begin1 + i)) + for (int i = 1; i < (int)(end1 - begin1); ++i) + NTA_ASSERT(*(begin1 + i - 1) < *(begin1 + i)) << "sparseOr: Indices need to be in strictly increasing order" << " (first vector)"; - for (int i = 0; i < (int)(end2 - begin2); ++i) - NTA_ASSERT(/*0 <= *(begin2 + i) &&*/ *(begin2 + i) < n) + for (int i = 0; i < (int)(end2 - begin2); ++i) + NTA_ASSERT(/*0 <= *(begin2 + i) &&*/ *(begin2 + i) < n) << "sparseOr: Invalid index in second vector: " << *(begin2 + i); - for (int i = 1; i < (int)(end2 - begin2); ++i) - NTA_ASSERT(*(begin2 + i - 1) < *(begin2 + i)) + for (int i = 1; i < (int)(end2 - begin2); ++i) + NTA_ASSERT(*(begin2 + i - 1) < *(begin2 + i)) << "sparseOr: Indices need to be in strictly increasing order" << " (second vector)"; - } // End pre-conditions + } // End pre-conditions - typedef typename std::iterator_traits::value_type value_type; + typedef typename std::iterator_traits::value_type value_type; - OutputIterator out_begin = out; + OutputIterator out_begin = out; - while (begin1 != end1 && begin2 != end2) { + while (begin1 != end1 && begin2 != end2) { - if (*begin1 < *begin2) { - *out++ = (value_type) *begin1++; - } else if (*begin2 < *begin1) { - *out++ = (value_type) *begin2++; - } else { - *out++ = (value_type) *begin1++; - ++begin2; - } + if (*begin1 < *begin2) { + *out++ = (value_type)*begin1++; + } else if (*begin2 < *begin1) { + *out++ = (value_type)*begin2++; + } else { + *out++ = (value_type)*begin1++; + ++begin2; } + } - for (; begin1 != end1; ++begin1) - *out++ = (value_type) *begin1; + for (; begin1 != end1; ++begin1) + *out++ = (value_type)*begin1; - for (; begin2 != end2; ++begin2) - *out++ = (value_type) *begin2; + for (; begin2 != end2; ++begin2) + *out++ = (value_type)*begin2; - return (size_t)(out - out_begin); - } + return (size_t)(out - out_begin); +} - //-------------------------------------------------------------------------------- - template - inline size_t sparseOr(size_t n, - const std::vector& x1, const std::vector& x2, - std::vector& out) - { - return sparseOr(n, x1.begin(), x1.end(), x2.begin(), x2.end(), - out.begin(), out.end()); - } +//-------------------------------------------------------------------------------- +template +inline size_t sparseOr(size_t n, const std::vector &x1, + const std::vector &x2, std::vector &out) { + return sparseOr(n, x1.begin(), x1.end(), x2.begin(), x2.end(), out.begin(), + out.end()); +} - //-------------------------------------------------------------------------------- - template - inline size_t sparseAnd(size_t n, - InputIterator1 begin1, InputIterator1 end1, - InputIterator2 begin2, InputIterator2 end2, - OutputIterator out, OutputIterator out_end) - { - { // Pre-conditions - NTA_ASSERT(0 <= end1 - begin1) +//-------------------------------------------------------------------------------- +template +inline size_t sparseAnd(size_t n, InputIterator1 begin1, InputIterator1 end1, + InputIterator2 begin2, InputIterator2 end2, + OutputIterator out, OutputIterator out_end) { + { // Pre-conditions + NTA_ASSERT(0 <= end1 - begin1) << "sparseAnd: Mismatched iterators for first vector"; - NTA_ASSERT(0 <= end2 - begin2) + NTA_ASSERT(0 <= end2 - begin2) << "sparseAnd: Mismatched iterators for second vector"; - NTA_ASSERT(0 <= out_end - out) + NTA_ASSERT(0 <= out_end - out) << "sparseAnd: Mismatched iterators for output vector"; - NTA_ASSERT(0 <= n) - << "sparseAnd: Invalid max size: " << n; - NTA_ASSERT((size_t)(end1 - begin1) <= n) + NTA_ASSERT(0 <= n) << "sparseAnd: Invalid max size: " << n; + NTA_ASSERT((size_t)(end1 - begin1) <= n) << "sparseAnd: Invalid first vector size"; - NTA_ASSERT((size_t)(end2 - begin2) <= n) + NTA_ASSERT((size_t)(end2 - begin2) <= n) << "sparseAnd: Invalid second vector size"; - //NTA_ASSERT(n <= (size_t)(out_end - out)) - //<< "sparseAnd: Insufficient memory for result"; - for (int i = 0; i < (int)(end1 - begin1); ++i) - NTA_ASSERT(/*0 <= *(begin1 + i) &&*/ *(begin1 + i) < n) + // NTA_ASSERT(n <= (size_t)(out_end - out)) + //<< "sparseAnd: Insufficient memory for result"; + for (int i = 0; i < (int)(end1 - begin1); ++i) + NTA_ASSERT(/*0 <= *(begin1 + i) &&*/ *(begin1 + i) < n) << "sparseAnd: Invalid index in first vector: " << *(begin1 + i); - for (int i = 1; i < (int)(end1 - begin1); ++i) - NTA_ASSERT(*(begin1 + i - 1) < *(begin1 + i)) + for (int i = 1; i < (int)(end1 - begin1); ++i) + NTA_ASSERT(*(begin1 + i - 1) < *(begin1 + i)) << "sparseAnd: Indices need to be in strictly increasing order" << " (first vector)"; - for (int i = 0; i < (int)(end2 - begin2); ++i) - NTA_ASSERT(/*0 <= *(begin2 + i) &&*/ *(begin2 + i) < n) + for (int i = 0; i < (int)(end2 - begin2); ++i) + NTA_ASSERT(/*0 <= *(begin2 + i) &&*/ *(begin2 + i) < n) << "sparseAnd: Invalid index in second vector: " << *(begin2 + i); - for (int i = 1; i < (int)(end2 - begin2); ++i) - NTA_ASSERT(*(begin2 + i - 1) < *(begin2 + i)) + for (int i = 1; i < (int)(end2 - begin2); ++i) + NTA_ASSERT(*(begin2 + i - 1) < *(begin2 + i)) << "sparseAnd: Indices need to be in strictly increasing order" << " (second vector)"; - } // End pre-conditions + } // End pre-conditions - typedef typename std::iterator_traits::value_type value_type; + typedef typename std::iterator_traits::value_type value_type; - OutputIterator out_begin = out; + OutputIterator out_begin = out; - while (begin1 != end1 && begin2 != end2) { + while (begin1 != end1 && begin2 != end2) { - if (*begin1 < *begin2) { - ++begin1; - } else if (*begin2 < *begin1) { - ++begin2; - } else { - *out++ = (value_type) *begin1++; - ++begin2; - } + if (*begin1 < *begin2) { + ++begin1; + } else if (*begin2 < *begin1) { + ++begin2; + } else { + *out++ = (value_type)*begin1++; + ++begin2; } - - return (size_t)(out - out_begin); } - //-------------------------------------------------------------------------------- - template - inline size_t sparseAnd(size_t n, - const std::vector& x1, const std::vector& x2, - std::vector& out) - { - return sparseAnd(n, x1.begin(), x1.end(), x2.begin(), x2.end(), - out.begin(), out.end()); - } + return (size_t)(out - out_begin); +} + +//-------------------------------------------------------------------------------- +template +inline size_t sparseAnd(size_t n, const std::vector &x1, + const std::vector &x2, std::vector &out) { + return sparseAnd(n, x1.begin(), x1.end(), x2.begin(), x2.end(), out.begin(), + out.end()); +} + +//-------------------------------------------------------------------------------- +// SORTING +//-------------------------------------------------------------------------------- + +//-------------------------------------------------------------------------------- +template inline void sort(C &c) { std::sort(c.begin(), c.end()); } + +//-------------------------------------------------------------------------------- +template inline void sort(C &c, F f) { + std::sort(c.begin(), c.end(), f); +} + +//-------------------------------------------------------------------------------- +template +inline void sort_on_first(It x_begin, It x_end, int direction = 1) { + typedef typename std::iterator_traits::value_type P; + typedef typename P::first_type I; + typedef typename P::second_type F; + + typedef select1st> sel1st; + + if (direction == -1) { + std::sort(x_begin, x_end, predicate_compose, sel1st>()); + } else { + std::sort(x_begin, x_end, predicate_compose, sel1st>()); + } +} + +//-------------------------------------------------------------------------------- +template +inline void sort_on_first(size_t n, std::vector> &x, + int direction = 1) { + sort_on_first(x.begin(), x.begin() + n, direction); +} + +//-------------------------------------------------------------------------------- +template +inline void sort_on_first(std::vector> &x, int direction = 1) { + sort_on_first(x.begin(), x.end(), direction); +} + +//-------------------------------------------------------------------------------- +template +inline void sort_on_first(SparseVector &x, int direction = 1) { + sort_on_first(x.begin(), x.begin() + x.nnz, direction); +} + +//-------------------------------------------------------------------------------- +/** + * Partial sort of a container given a functor. + */ +template +inline void partial_sort(I k, C &elts, F f) { + std::partial_sort(elts.begin(), elts.begin() + k, elts.end(), f); +} + +//-------------------------------------------------------------------------------- +/** + * Partial sort of a range, that returns the values and the indices. + */ +template +inline void partial_sort_2nd(size_type k, InputIterator in_begin, + InputIterator in_end, OutputIterator out_begin, + Order) { + typedef typename std::iterator_traits::value_type value_type; + typedef select2nd> sel2nd; - //-------------------------------------------------------------------------------- - // SORTING - //-------------------------------------------------------------------------------- + std::vector> v(in_end - in_begin); - //-------------------------------------------------------------------------------- - template - inline void sort(C& c) - { - std::sort(c.begin(), c.end()); - } + for (size_type i = 0; in_begin != in_end; ++in_begin, ++i) + v[i] = std::make_pair(i, *in_begin); - //-------------------------------------------------------------------------------- - template - inline void sort(C& c, F f) - { - std::sort(c.begin(), c.end(), f); - } + std::partial_sort(v.begin(), v.begin() + k, v.end(), + predicate_compose()); - //-------------------------------------------------------------------------------- - template - inline void sort_on_first(It x_begin, It x_end, int direction =1) - { - typedef typename std::iterator_traits::value_type P; - typedef typename P::first_type I; - typedef typename P::second_type F; + for (size_type i = 0; i != k; ++i, ++out_begin) + *out_begin = v[i]; +} - typedef select1st > sel1st; +//-------------------------------------------------------------------------------- +/** + * Partial sort of a container. + */ +template +inline void partial_sort_2nd(size_t k, const C1 &c1, OutputIterator out_begin, + Order order) { + partial_sort_2nd(k, c1.begin(), c1.end(), out_begin, order); +} + +//-------------------------------------------------------------------------------- +/** + * Partial sort of a range of vectors, based on a given order predicate for + * the vectors, putting the result into two iterators, + * one for the indices and one for the element values. + * Order needs to work for pairs (i.e., is a binary predicate). + * start_offset specifies an optional for the indices that will be generated + * for the pairs. This is useful when calling partial_sort_2nd repetitively + * for different ranges inside a larger range. + * If resort_on_first is true, the indices of the pairs are resorted, + * otherwise, the indices might come out in any order. + */ +template +inline void partial_sort(size_type k, InIter in_begin, InIter in_end, + OutputIterator1 ind, OutputIterator2 nz, Order order, + size_type start_offset = 0, + bool resort_on_first = false) { + typedef typename std::iterator_traits::value_type value_type; + typedef select1st> sel1st; - if (direction == -1) { - std::sort(x_begin, x_end, predicate_compose, sel1st>()); - } else { - std::sort(x_begin, x_end, predicate_compose, sel1st>()); - } - } + std::vector> v(in_end - in_begin); - //-------------------------------------------------------------------------------- - template - inline void sort_on_first(size_t n, std::vector >& x, - int direction =1) - { - sort_on_first(x.begin(), x.begin() + n, direction); - } + for (size_type i = start_offset; in_begin != in_end; ++in_begin, ++i) + v[i - start_offset] = std::make_pair(i, *in_begin); - //-------------------------------------------------------------------------------- - template - inline void sort_on_first(std::vector >& x, int direction =1) - { - sort_on_first(x.begin(), x.end(), direction); + std::partial_sort(v.begin(), v.begin() + k, v.end(), order); + + if (resort_on_first) { + std::sort(v.begin(), v.begin() + k, + predicate_compose, sel1st>()); } - //-------------------------------------------------------------------------------- - template - inline void sort_on_first(SparseVector& x, int direction =1) - { - sort_on_first(x.begin(), x.begin() + x.nnz, direction); + for (size_type i = 0; i != k; ++i, ++ind, ++nz) { + *ind = v[i].first; + *nz = v[i].second; } +} - //-------------------------------------------------------------------------------- - /** - * Partial sort of a container given a functor. - */ - template - inline void partial_sort(I k, C& elts, F f) +//-------------------------------------------------------------------------------- +/** + * In place. + */ +template +inline void partial_argsort(I0 k, SparseVector &x, int direction = -1) { { - std::partial_sort(elts.begin(), elts.begin() + k, elts.end(), f); + NTA_ASSERT(0 < k); + NTA_ASSERT(k <= x.size()); + NTA_ASSERT(direction == -1 || direction == 1); } - //-------------------------------------------------------------------------------- - /** - * Partial sort of a range, that returns the values and the indices. - */ - template - inline void - partial_sort_2nd(size_type k, - InputIterator in_begin, InputIterator in_end, - OutputIterator out_begin, Order) - { - typedef typename std::iterator_traits::value_type value_type; - typedef select2nd > sel2nd; + typedef I size_type; + typedef T value_type; - std::vector > v(in_end - in_begin); + if (direction == -1) { - for (size_type i = 0; in_begin != in_end; ++in_begin, ++i) - v[i] = std::make_pair(i, *in_begin); + greater_2nd_no_ties order; + std::partial_sort(x.begin(), x.begin() + k, x.begin() + x.nnz, order); - std::partial_sort(v.begin(), v.begin() + k, v.end(), - predicate_compose()); + } else if (direction == 1) { - for (size_type i = 0; i != k; ++i, ++out_begin) - *out_begin = v[i]; + less_2nd order; + std::partial_sort(x.begin(), x.begin() + k, x.begin() + x.nnz, order); } +} - //-------------------------------------------------------------------------------- - /** - * Partial sort of a container. - */ - template - inline void - partial_sort_2nd(size_t k, const C1& c1, OutputIterator out_begin, Order order) - { - partial_sort_2nd(k, c1.begin(), c1.end(), out_begin, order); - } +//-------------------------------------------------------------------------------- +// Static buffer for partial_argsort, so that we don't have to allocate +// memory each time (faster). +//-------------------------------------------------------------------------------- +static SparseVector partial_argsort_buffer; - //-------------------------------------------------------------------------------- - /** - * Partial sort of a range of vectors, based on a given order predicate for - * the vectors, putting the result into two iterators, - * one for the indices and one for the element values. - * Order needs to work for pairs (i.e., is a binary predicate). - * start_offset specifies an optional for the indices that will be generated - * for the pairs. This is useful when calling partial_sort_2nd repetitively - * for different ranges inside a larger range. - * If resort_on_first is true, the indices of the pairs are resorted, - * otherwise, the indices might come out in any order. - */ - template - inline void - partial_sort(size_type k, InIter in_begin, InIter in_end, - OutputIterator1 ind, OutputIterator2 nz, - Order order, size_type start_offset =0, - bool resort_on_first =false) +//-------------------------------------------------------------------------------- +// A partial argsort that can use an already allocated buffer to avoid creating +// a data structure each time it's called. Assumes that the elements to be +// sorted are nupic::Real32, or at least that they have the same size. +// +// A partial sort is much faster than a full sort. The elements after the k +// first in the result are not sorted, except that they are greater (or lesser) +// than all the k first elements. If direction is -1, the sort is in decreasing +// order. If direction is 1, the sort is in increasing order. +// +// The result is returned in the first k positions of the buffer for speed. +// +// Uses a pre-allocated buffer to avoid allocating memory each time a sort +// is needed. +//-------------------------------------------------------------------------------- +template +inline void partial_argsort(size_t k, InIter begin, InIter end, OutIter sorted, + OutIter sorted_end, int direction = -1) { { - typedef typename std::iterator_traits::value_type value_type; - typedef select1st > sel1st; - - std::vector > v(in_end - in_begin); - - for (size_type i = start_offset; in_begin != in_end; ++in_begin, ++i) - v[i - start_offset] = std::make_pair(i, *in_begin); - - std::partial_sort(v.begin(), v.begin() + k, v.end(), order); - - if (resort_on_first) { - std::sort(v.begin(), v.begin() + k, - predicate_compose, sel1st>()); - } - - for (size_type i = 0; i != k; ++i, ++ind, ++nz) { - *ind = v[i].first; - *nz = v[i].second; - } + NTA_ASSERT(0 < k); + NTA_ASSERT(0 < end - begin); + NTA_ASSERT(k <= (size_t)(end - begin)); + NTA_ASSERT(k <= (size_t)(sorted_end - sorted)); + NTA_ASSERT(direction == -1 || direction == 1); } - //-------------------------------------------------------------------------------- - /** - * In place. - */ - template - inline void - partial_argsort(I0 k, SparseVector& x, int direction =-1) - { - { - NTA_ASSERT(0 < k); - NTA_ASSERT(k <= x.size()); - NTA_ASSERT(direction == -1 || direction == 1); - } + typedef size_t size_type; + typedef float value_type; - typedef I size_type; - typedef T value_type; + SparseVector &buff = partial_argsort_buffer; - if (direction == -1) { + size_type n = (size_type)(end - begin); - greater_2nd_no_ties order; - std::partial_sort(x.begin(), x.begin() + k, x.begin() + x.nnz, order); + // Need to clean up, lest the next sort, with a possibly smaller range, + // picks up values that are not in the current [begin,end). + buff.resize(n); + buff.nnz = n; - } else if (direction == 1) { + InIter it = begin; - less_2nd order; - std::partial_sort(x.begin(), x.begin() + k, x.begin() + x.nnz, order); - } + for (size_type i = 0; i != n; ++i, ++it) { + buff[i].first = i; + buff[i].second = *it; } - //-------------------------------------------------------------------------------- - // Static buffer for partial_argsort, so that we don't have to allocate - // memory each time (faster). - //-------------------------------------------------------------------------------- - static SparseVector partial_argsort_buffer; - - //-------------------------------------------------------------------------------- - // A partial argsort that can use an already allocated buffer to avoid creating - // a data structure each time it's called. Assumes that the elements to be sorted - // are nupic::Real32, or at least that they have the same size. - // - // A partial sort is much faster than a full sort. The elements after the k first - // in the result are not sorted, except that they are greater (or lesser) than - // all the k first elements. If direction is -1, the sort is in decreasing order. - // If direction is 1, the sort is in increasing order. - // - // The result is returned in the first k positions of the buffer for speed. - // - // Uses a pre-allocated buffer to avoid allocating memory each time a sort - // is needed. - //-------------------------------------------------------------------------------- - template - inline void partial_argsort(size_t k, InIter begin, InIter end, - OutIter sorted, OutIter sorted_end, - int direction =-1) - { - { - NTA_ASSERT(0 < k); - NTA_ASSERT(0 < end - begin); - NTA_ASSERT(k <= (size_t)(end - begin)); - NTA_ASSERT(k <= (size_t)(sorted_end - sorted)); - NTA_ASSERT(direction == -1 || direction == 1); - } - - typedef size_t size_type; - typedef float value_type; - - SparseVector& buff = partial_argsort_buffer; + partial_argsort(k, buff, direction); - size_type n = (size_type)(end - begin); + for (size_type i = 0; i != k; ++i) + sorted[i] = buff[i].first; +} - // Need to clean up, lest the next sort, with a possibly smaller range, - // picks up values that are not in the current [begin,end). - buff.resize(n); - buff.nnz = n; - - InIter it = begin; - - for (size_type i = 0; i != n; ++i, ++it) { - buff[i].first = i; - buff[i].second = *it; - } - - partial_argsort(k, buff, direction); - - for (size_type i = 0; i != k; ++i) - sorted[i] = buff[i].first; - } - - //-------------------------------------------------------------------------------- - /** - * Specialized partial argsort with selective random noise for breaking ties, to - * speed-up FDR C SP. - */ - template - inline void - partial_argsort_rnd_tie_break(size_t k, - InIter begin, InIter end, - OutIter sorted, OutIter sorted_end, - Random& rng, - bool real_random =false) +//-------------------------------------------------------------------------------- +/** + * Specialized partial argsort with selective random noise for breaking ties, to + * speed-up FDR C SP. + */ +template +inline void partial_argsort_rnd_tie_break(size_t k, InIter begin, InIter end, + OutIter sorted, OutIter sorted_end, + Random &rng, + bool real_random = false) { { - { - NTA_ASSERT(0 < k); - NTA_ASSERT(0 < end - begin); - NTA_ASSERT(k <= (size_t)(end - begin)); - NTA_ASSERT(k <= (size_t)(sorted_end - sorted)); - } - - typedef size_t size_type; - typedef float value_type; + NTA_ASSERT(0 < k); + NTA_ASSERT(0 < end - begin); + NTA_ASSERT(k <= (size_t)(end - begin)); + NTA_ASSERT(k <= (size_t)(sorted_end - sorted)); + } - SparseVector& buff = partial_argsort_buffer; + typedef size_t size_type; + typedef float value_type; - size_type n = (size_type)(end - begin); + SparseVector &buff = partial_argsort_buffer; - // Need to clean up, lest the next sort, with a possibly smaller range, - // picks up values that are not in the current [begin,end). - buff.resize(n); - buff.nnz = n; + size_type n = (size_type)(end - begin); - InIter it = begin; + // Need to clean up, lest the next sort, with a possibly smaller range, + // picks up values that are not in the current [begin,end). + buff.resize(n); + buff.nnz = n; - for (size_type i = 0; i != n; ++i, ++it) { - buff[i].first = i; - buff[i].second = *it; - } + InIter it = begin; - if (!real_random) { - greater_2nd order; - std::partial_sort(buff.begin(), buff.begin() + k, buff.begin() + buff.nnz, order); - } else { - greater_2nd_rnd_ties order(rng); - std::partial_sort(buff.begin(), buff.begin() + k, buff.begin() + buff.nnz, order); - } + for (size_type i = 0; i != n; ++i, ++it) { + buff[i].first = i; + buff[i].second = *it; + } - for (size_type i = 0; i != k; ++i) - sorted[i] = buff[i].first; + if (!real_random) { + greater_2nd order; + std::partial_sort(buff.begin(), buff.begin() + k, buff.begin() + buff.nnz, + order); + } else { + greater_2nd_rnd_ties order(rng); + std::partial_sort(buff.begin(), buff.begin() + k, buff.begin() + buff.nnz, + order); } - //-------------------------------------------------------------------------------- - // QUANTIZE - //-------------------------------------------------------------------------------- - template - inline void - update_with_indices_of_non_zeros(nupic::UInt32 segment_size, - It1 input_begin, It1 input_end, - It1 prev_begin, It1 prev_end, - It1 curr_begin, It1 curr_end) - { - typedef nupic::UInt32 size_type; + for (size_type i = 0; i != k; ++i) + sorted[i] = buff[i].first; +} - size_type input_size = (size_type)(input_end - input_begin); +//-------------------------------------------------------------------------------- +// QUANTIZE +//-------------------------------------------------------------------------------- +template +inline void update_with_indices_of_non_zeros(nupic::UInt32 segment_size, + It1 input_begin, It1 input_end, + It1 prev_begin, It1 prev_end, + It1 curr_begin, It1 curr_end) { + typedef nupic::UInt32 size_type; - std::fill(curr_begin, curr_end, 0); + size_type input_size = (size_type)(input_end - input_begin); - for (size_type i = 0; i != input_size; ++i) { + std::fill(curr_begin, curr_end, 0); - if (*(input_begin + i) == 0) - continue; + for (size_type i = 0; i != input_size; ++i) { - size_type begin = i*segment_size; - size_type end = begin + segment_size; - bool all_zero = true; + if (*(input_begin + i) == 0) + continue; - for (size_type j = begin; j != end; ++j) { + size_type begin = i * segment_size; + size_type end = begin + segment_size; + bool all_zero = true; - if (*(prev_begin + j) > 0) { - all_zero = false; - *(curr_begin + j) = 1; - } - } + for (size_type j = begin; j != end; ++j) { - if (all_zero) - std::fill(curr_begin + begin, curr_begin + end, 1); + if (*(prev_begin + j) > 0) { + all_zero = false; + *(curr_begin + j) = 1; + } } + + if (all_zero) + std::fill(curr_begin + begin, curr_begin + end, 1); } +} - //-------------------------------------------------------------------------------- - // Winner takes all - //-------------------------------------------------------------------------------- - /** - * Finds the maximum in each interval defined by the boundaries, replaces that - * maximum by a 1, and sets all the other values to 0. Returns the max value - * over all intervals, and its position. - */ - template - inline void - winnerTakesAll(const std::vector& boundaries, InIter begin1, OutIter begin2) - { - typedef typename std::iterator_traits::value_type value_type; - - I max_i = 0, size = (I) boundaries.size(); - value_type max_v = 0; - - for (I i = 0, k = 0; i < size; ++i) { - max_v = 0; - max_i = i == 0 ? 0 : boundaries[i-1]; - while (k < boundaries[i]) { - if (*begin1 > max_v) { - max_i = k; - max_v = *begin1; - } - ++k; - ++begin1; +//-------------------------------------------------------------------------------- +// Winner takes all +//-------------------------------------------------------------------------------- +/** + * Finds the maximum in each interval defined by the boundaries, replaces that + * maximum by a 1, and sets all the other values to 0. Returns the max value + * over all intervals, and its position. + */ +template +inline void winnerTakesAll(const std::vector &boundaries, InIter begin1, + OutIter begin2) { + typedef typename std::iterator_traits::value_type value_type; + + I max_i = 0, size = (I)boundaries.size(); + value_type max_v = 0; + + for (I i = 0, k = 0; i < size; ++i) { + max_v = 0; + max_i = i == 0 ? 0 : boundaries[i - 1]; + while (k < boundaries[i]) { + if (*begin1 > max_v) { + max_i = k; + max_v = *begin1; } - *begin2 = (value_type) max_i; - ++begin2; + ++k; + ++begin1; } + *begin2 = (value_type)max_i; + ++begin2; } +} - //-------------------------------------------------------------------------------- - /** - * Winner takes all 2. - */ - template - std::pair::value_type> - winnerTakesAll2(const std::vector& boundaries, InIter begin1, OutIter begin2) - { - I max_i = 0; - typedef typename std::iterator_traits::value_type value_type; - value_type max_v = 0; - - for (I i = 0, k = 0; i < boundaries.size(); ++i) { - max_v = 0; - max_i = i == 0 ? 0 : boundaries[i-1]; - while (k < boundaries[i]) { - if (begin1[k] > max_v) { - begin2[max_i] = 0; - max_i = k; - max_v = (value_type) (begin1[k]); - } else { - begin2[k] = 0; - } - ++k; +//-------------------------------------------------------------------------------- +/** + * Winner takes all 2. + */ +template +std::pair::value_type> +winnerTakesAll2(const std::vector &boundaries, InIter begin1, + OutIter begin2) { + I max_i = 0; + typedef typename std::iterator_traits::value_type value_type; + value_type max_v = 0; + + for (I i = 0, k = 0; i < boundaries.size(); ++i) { + max_v = 0; + max_i = i == 0 ? 0 : boundaries[i - 1]; + while (k < boundaries[i]) { + if (begin1[k] > max_v) { + begin2[max_i] = 0; + max_i = k; + max_v = (value_type)(begin1[k]); + } else { + begin2[k] = 0; } - begin2[max_i] = 1; + ++k; } - return std::make_pair(max_i, max_v); + begin2[max_i] = 1; } + return std::make_pair(max_i, max_v); +} - //-------------------------------------------------------------------------------- - /** - * Keeps the values of k winners per segment, where each segment in [begin..end) - * has length seg_size, and zeroes-out all the other elements. - * Returns the indices and the values of the winners. - * For zero segments, we randomly pick a winner, and output its index, with the - * value zero. - * If a segment has only zeros, randomly picks a winner. - */ - template - inline void - winnerTakesAll3(I k, I seg_size, InIter begin, InIter end, - OutIter1 ind, OutIter2 nz, RNG& rng) - { - typedef I size_type; - typedef typename std::iterator_traits::value_type value_type; - - { // Pre-conditions - NTA_ASSERT(k > 0) - << "winnerTakesAll3: Invalid k: " << k - << " - Needs to be > 0"; - - NTA_ASSERT(seg_size > 0) +//-------------------------------------------------------------------------------- +/** + * Keeps the values of k winners per segment, where each segment in [begin..end) + * has length seg_size, and zeroes-out all the other elements. + * Returns the indices and the values of the winners. + * For zero segments, we randomly pick a winner, and output its index, with the + * value zero. + * If a segment has only zeros, randomly picks a winner. + */ +template +inline void winnerTakesAll3(I k, I seg_size, InIter begin, InIter end, + OutIter1 ind, OutIter2 nz, RNG &rng) { + typedef I size_type; + typedef typename std::iterator_traits::value_type value_type; + + { // Pre-conditions + NTA_ASSERT(k > 0) << "winnerTakesAll3: Invalid k: " << k + << " - Needs to be > 0"; + + NTA_ASSERT(seg_size > 0) << "winnerTakesAll3: Invalid segment size: " << seg_size << " - Needs to be > 0"; - NTA_ASSERT(k <= seg_size) - << "winnerTakesAll3: Invalid k (" << k << ") or " - << "segment size (" << seg_size << ")" - << " - k needs to be <= seg_size"; + NTA_ASSERT(k <= seg_size) << "winnerTakesAll3: Invalid k (" << k << ") or " + << "segment size (" << seg_size << ")" + << " - k needs to be <= seg_size"; - NTA_ASSERT((size_type) (end - begin) % seg_size == 0) + NTA_ASSERT((size_type)(end - begin) % seg_size == 0) << "winnerTakesAll3: Invalid input range of size: " - << (size_type) (end - begin) - << " - Needs to be integer multiple of segment size: " - << seg_size; - } // End pre-conditions - - typedef select2nd > sel2nd; + << (size_type)(end - begin) + << " - Needs to be integer multiple of segment size: " << seg_size; + } // End pre-conditions - InIter seg_begin = begin; - size_type offset = (size_type) 0; + typedef select2nd> sel2nd; - for (; seg_begin != end; seg_begin += seg_size, offset += seg_size) { + InIter seg_begin = begin; + size_type offset = (size_type)0; - InIter seg_end = seg_begin + seg_size; - size_type offset = (size_type) (seg_begin - begin); + for (; seg_begin != end; seg_begin += seg_size, offset += seg_size) { - if (nearlyZeroRange(seg_begin, seg_end)) { + InIter seg_end = seg_begin + seg_size; + size_type offset = (size_type)(seg_begin - begin); - std::vector indices(seg_size); - random_perm_interval(indices, offset, 1, rng); + if (nearlyZeroRange(seg_begin, seg_end)) { - sort(indices.begin(), indices.begin() + k, std::less()); - - for (size_type i = 0; i != k ; ++i, ++ind, ++nz) { - *ind = indices[i]; - *nz = (value_type) 0; - } + std::vector indices(seg_size); + random_perm_interval(indices, offset, 1, rng); - } else { + sort(indices.begin(), indices.begin() + k, std::less()); - partial_sort(k, seg_begin, seg_end, ind, nz, - predicate_compose, sel2nd>(), - offset, true); + for (size_type i = 0; i != k; ++i, ++ind, ++nz) { + *ind = indices[i]; + *nz = (value_type)0; } - } - } - //-------------------------------------------------------------------------------- - template - inline void - winnerTakesAll3(I k, I seg_size, InIter begin, InIter end, - OutIter1 ind, OutIter2 nz) - { - nupic::Random rng; - winnerTakesAll3(k, seg_size, begin, end, ind, nz, rng); - } + } else { - //-------------------------------------------------------------------------------- - // Dendritic tree activation - // - // Given a window size, a threshold, an array of indices and a vector of values, - // scans the vector of values with a sliding window on each row of the array of - // indices, and as soon as the activation in one window is above the threshold, - // declare that the corresponding line of the array of indices has "fired". Real - // dendrites branch, but we are not modelling that here. Learning of the synapses, - // i.e. populating the list of indices for each neuron, is not done here. Here, - // we just compute which neurons fire in a collection of neurons, given the - // synaspes on the dendrites for each neuron. - // - // The array of indices represents multiple neurons, one per row, and each row - // represents multiple segments of the dendritic tree of each neuron. However, - // the indices are not contiguous (a dendritic segment looks at random positions - // in its input vector). As soon as the activation in any window in any segment - // of the dendritic tree reaches the threshold, the neuron fires. Indices are added - // to the list of indices for a given neuron in a specific order, tying position to - // to time of activation of the synapses: the farther away the synapses, the earlier - // the signal was. - // - // ncells and max_dendrite_size are the number of rows and columns, respectively, - // of the indices matrix. If ncells is 10,000, max_dendrite_size would be, - // typically, 100, meaning that a given neuron has synapses in its dendritic - // tree with at most 100 other neurons. Those synapses are learnt, so during - // learning, there are actually less than 100 synapses in the dendritic tree. - // - // window_size is the size of the sliding window, i.e. the number of indices - // we use to sum up activation in dendritic neighborhood. In biology, activation - // might be "superlinear" for synapses further down the dendrite. - // - // threshold is the value which, if reached in any given dendritic neighborhood, - // triggers activation of the neuron. - // - // indices and indices_end are pointers to the start of the matrix of indices - // and one past the last value in that matrix. This matrix represents synapses - // that have been learnt between neurons. n_indices and n_indices_end describe - // a vector that contains the number of indices in the dendritic tree of each - // neuron. If ncells is 10,000, the values of the indices range from 0 to 9,999. - // - // input and input_end are a pointer to the vector of input, and one past the end - // of that vector. That vector is of size ncells. The inputs are real valued. - // - // output and output_end are a pointer to the vector of output, and one past the - // end of that vector. That vector is of size ncells. That vector contains either - // 0 if the corresponding neuron doesn't fire, or the real value of the activation. - // - // 'mode' controls which can of operation is performed. For now, mode can only be - // 0, which performs a sum in the sliding window. - //-------------------------------------------------------------------------------- - /* - template - inline void - dendritic_activation(S nsegs, S max_dendrite_size, - S window_size, T threshold, - S* indices, S* indices_end, - S* n_indices, S* n_indices_end, - T* input, T* input_end, - I* output, I* output_end, - S mode =0) - { - typedef S size_type; - typedef T value_type; - - { // Pre-conditions - NTA_ASSERT(0 < nsegs); - NTA_ASSERT(0 < max_dendrite_size); - NTA_ASSERT(max_dendrite_size <= nsegs); - NTA_ASSERT(0 < window_size); - NTA_ASSERT(window_size <= max_dendrite_size); - NTA_ASSERT(0 <= threshold); - NTA_ASSERT((S)(indices_end - indices) == nsegs * max_dendrite_size); - NTA_ASSERT((S)(n_indices_end - n_indices) == nsegs); - NTA_ASSERT((S)(input_end - input) == nsegs); - NTA_ASSERT((S)(output_end - output) == nsegs); + partial_sort(k, seg_begin, seg_end, ind, nz, + predicate_compose, sel2nd>(), + offset, true); + } + } +} + +//-------------------------------------------------------------------------------- +template +inline void winnerTakesAll3(I k, I seg_size, InIter begin, InIter end, + OutIter1 ind, OutIter2 nz) { + nupic::Random rng; + winnerTakesAll3(k, seg_size, begin, end, ind, nz, rng); +} + +//-------------------------------------------------------------------------------- +// Dendritic tree activation +// +// Given a window size, a threshold, an array of indices and a vector of values, +// scans the vector of values with a sliding window on each row of the array of +// indices, and as soon as the activation in one window is above the threshold, +// declare that the corresponding line of the array of indices has "fired". Real +// dendrites branch, but we are not modelling that here. Learning of the +// synapses, i.e. populating the list of indices for each neuron, is not done +// here. Here, we just compute which neurons fire in a collection of neurons, +// given the synaspes on the dendrites for each neuron. +// +// The array of indices represents multiple neurons, one per row, and each row +// represents multiple segments of the dendritic tree of each neuron. However, +// the indices are not contiguous (a dendritic segment looks at random positions +// in its input vector). As soon as the activation in any window in any segment +// of the dendritic tree reaches the threshold, the neuron fires. Indices are +// added to the list of indices for a given neuron in a specific order, tying +// position to to time of activation of the synapses: the farther away the +// synapses, the earlier the signal was. +// +// ncells and max_dendrite_size are the number of rows and columns, +// respectively, of the indices matrix. If ncells is 10,000, max_dendrite_size +// would be, typically, 100, meaning that a given neuron has synapses in its +// dendritic tree with at most 100 other neurons. Those synapses are learnt, so +// during learning, there are actually less than 100 synapses in the dendritic +// tree. +// +// window_size is the size of the sliding window, i.e. the number of indices +// we use to sum up activation in dendritic neighborhood. In biology, activation +// might be "superlinear" for synapses further down the dendrite. +// +// threshold is the value which, if reached in any given dendritic neighborhood, +// triggers activation of the neuron. +// +// indices and indices_end are pointers to the start of the matrix of indices +// and one past the last value in that matrix. This matrix represents synapses +// that have been learnt between neurons. n_indices and n_indices_end describe +// a vector that contains the number of indices in the dendritic tree of each +// neuron. If ncells is 10,000, the values of the indices range from 0 to 9,999. +// +// input and input_end are a pointer to the vector of input, and one past the +// end of that vector. That vector is of size ncells. The inputs are real +// valued. +// +// output and output_end are a pointer to the vector of output, and one past the +// end of that vector. That vector is of size ncells. That vector contains +// either 0 if the corresponding neuron doesn't fire, or the real value of the +// activation. +// +// 'mode' controls which can of operation is performed. For now, mode can only +// be 0, which performs a sum in the sliding window. +//-------------------------------------------------------------------------------- +/* +template +inline void +dendritic_activation(S nsegs, S max_dendrite_size, + S window_size, T threshold, + S* indices, S* indices_end, + S* n_indices, S* n_indices_end, + T* input, T* input_end, + I* output, I* output_end, + S mode =0) +{ + typedef S size_type; + typedef T value_type; + + { // Pre-conditions + NTA_ASSERT(0 < nsegs); + NTA_ASSERT(0 < max_dendrite_size); + NTA_ASSERT(max_dendrite_size <= nsegs); + NTA_ASSERT(0 < window_size); + NTA_ASSERT(window_size <= max_dendrite_size); + NTA_ASSERT(0 <= threshold); + NTA_ASSERT((S)(indices_end - indices) == nsegs * max_dendrite_size); + NTA_ASSERT((S)(n_indices_end - n_indices) == nsegs); + NTA_ASSERT((S)(input_end - input) == nsegs); + NTA_ASSERT((S)(output_end - output) == nsegs); #ifdef NTA_ASSERTION_ON - for (size_type c = 0; c != nsegs; ++c) - NTA_ASSERT(n_indices[c] == 0 || window_size <= n_indices[c]); + for (size_type c = 0; c != nsegs; ++c) + NTA_ASSERT(n_indices[c] == 0 || window_size <= n_indices[c]); #endif - } // End pre-conditions + } // End pre-conditions - for (size_type seg = 0; seg != nsegs; ++seg) { + for (size_type seg = 0; seg != nsegs; ++seg) { - output[seg] = (int) -1; + output[seg] = (int) -1; - if (n_indices[seg] == 0) - continue; + if (n_indices[seg] == 0) + continue; - // w_end is how far we can move the window down the dendritic segment - value_type activation = 0; - size_type seg_start = seg*max_dendrite_size; + // w_end is how far we can move the window down the dendritic segment + value_type activation = 0; + size_type seg_start = seg*max_dendrite_size; - for (size_type i = 0; i != window_size; ++i) - activation += input[indices[seg_start + i]]; + for (size_type i = 0; i != window_size; ++i) + activation += input[indices[seg_start + i]]; - if (activation >= threshold) { + if (activation >= threshold) { - output[seg] = (int) 0; + output[seg] = (int) 0; - } else { + } else { - int w_end = (int) n_indices[seg] - (int) window_size; + int w_end = (int) n_indices[seg] - (int) window_size; - for (int w_start = 0; w_start < w_end; ++w_start) { + for (int w_start = 0; w_start < w_end; ++w_start) { - size_type w_end1 = std::min(w_start + window_size, n_indices[seg]); - activation -= input[indices[seg_start + w_start]]; - activation += input[indices[seg_start + w_end1]]; + size_type w_end1 = std::min(w_start + window_size, n_indices[seg]); + activation -= input[indices[seg_start + w_start]]; + activation += input[indices[seg_start + w_end1]]; - if (activation >= threshold) { - output[seg] = (int) w_start + 1; - break; - } + if (activation >= threshold) { + output[seg] = (int) w_start + 1; + break; } } } + } - // for (size_type cell = 0; cell != ncells; ++cell, ++output) { + // for (size_type cell = 0; cell != ncells; ++cell, ++output) { - // if (n_indices[cell] == 0) { - // *output = (int) -1; - // continue; - // } + // if (n_indices[cell] == 0) { + // *output = (int) -1; + // continue; + // } - // size_type w_end = n_indices[cell] - window_size + 1; + // size_type w_end = n_indices[cell] - window_size + 1; - // for (size_type w_start = 0; w_start < w_end; ++w_start) { + // for (size_type w_start = 0; w_start < w_end; ++w_start) { - // size_type w_end1 = w_start + window_size; - // value_type activation = 0; + // size_type w_end1 = w_start + window_size; + // value_type activation = 0; - // // Maintain activation with +/- - // for (size_type i = w_start; i < w_end1; ++i) - // activation += input[indices[cell*max_dendrite_size+i]]; + // // Maintain activation with +/- + // for (size_type i = w_start; i < w_end1; ++i) + // activation += input[indices[cell*max_dendrite_size+i]]; - // if (activation >= threshold) { - // *output = (int) w_start; - // break; - // } + // if (activation >= threshold) { + // *output = (int) w_start; + // break; + // } - // *output = (int) -1; - // } - // } + // *output = (int) -1; + // } + // } - } +} */ - //-------------------------------------------------------------------------------- +//-------------------------------------------------------------------------------- } // end namespace nupic #endif // NTA_ARRAY_ALGO_HPP diff --git a/src/nupic/math/Convolution.hpp b/src/nupic/math/Convolution.hpp index 15f7640e3c..706c8043c9 100644 --- a/src/nupic/math/Convolution.hpp +++ b/src/nupic/math/Convolution.hpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Declarations for convolutions */ @@ -31,9 +31,7 @@ /** * Computes convolutions in 2D, for separable kernels. */ -template -struct SeparableConvolution2D -{ +template struct SeparableConvolution2D { typedef size_t size_type; typedef T value_type; @@ -46,20 +44,19 @@ struct SeparableConvolution2D size_type f1_middle_; size_type f2_middle_; - T* f1_; - T* f2_; + T *f1_; + T *f2_; T *f1_end_; T *f2_end_; - T* buffer_; + T *buffer_; /** * nrows is the number of rows in the original image, and ncols * is the number of columns. */ - inline void init(size_type nrows, size_type ncols, - size_type f1_size, size_type f2_size, - T* f1, T* f2) + inline void init(size_type nrows, size_type ncols, size_type f1_size, + size_type f2_size, T *f1, T *f2) /* : nrows_(nrows), ncols_(ncols), f1_size_(f1_size), f2_size_(f2_size), @@ -81,16 +78,13 @@ struct SeparableConvolution2D f2_ = f2; f1_end_ = f1 + f1_size; f2_end_ = f2 + f2_size; - buffer_ = new T[nrows*ncols]; + buffer_ = new T[nrows * ncols]; } - inline SeparableConvolution2D() : buffer_(NULL) - { - } + inline SeparableConvolution2D() : buffer_(NULL) {} - inline ~SeparableConvolution2D() - { - delete [] buffer_; + inline ~SeparableConvolution2D() { + delete[] buffer_; buffer_ = NULL; } @@ -100,32 +94,31 @@ struct SeparableConvolution2D * * Down-sampling? */ - inline void compute(T* data, T* convolved, bool rotated45 =false) - { + inline void compute(T *data, T *convolved, bool rotated45 = false) { for (size_type i = 0; i != nrows_; ++i) { - T* b = buffer_ + i*ncols_ + f1_middle_, *d_row = data + i*ncols_; + T *b = buffer_ + i * ncols_ + f1_middle_, *d_row = data + i * ncols_; for (size_type j = 0; j != f1_end_j_; ++j) { - register T dot = 0, *f = f1_, *d = d_row + j; - while (f != f1_end_) - dot += *f++ * *d++; - *b++ = dot; + register T dot = 0, *f = f1_, *d = d_row + j; + while (f != f1_end_) + dot += *f++ * *d++; + *b++ = dot; } } for (size_type i = 0; i != f2_end_i_; ++i) { - T* c = convolved + (i + f2_middle_)*ncols_, *b_row = buffer_ + i*ncols_; + T *c = convolved + (i + f2_middle_) * ncols_, + *b_row = buffer_ + i * ncols_; for (size_type j = 0; j != ncols_; ++j) { - register T dot = 0, *f = f2_, *b = b_row + j; - while (f != f2_end_) { - dot += *f++ * *b; - b += ncols_; - } - *c++ = dot; + register T dot = 0, *f = f2_, *b = b_row + j; + while (f != f2_end_) { + dot += *f++ * *b; + b += ncols_; + } + *c++ = dot; } } } }; //-------------------------------------------------------------------------------- -#endif //NTA_CONVOLUTION_HPP - +#endif // NTA_CONVOLUTION_HPP diff --git a/src/nupic/math/DenseMatrix.hpp b/src/nupic/math/DenseMatrix.hpp index 22e9f2cd6b..4359914807 100644 --- a/src/nupic/math/DenseMatrix.hpp +++ b/src/nupic/math/DenseMatrix.hpp @@ -31,878 +31,789 @@ #include -#include #include - +#include namespace nupic { - //-------------------------------------------------------------------------------- - /** - */ - template - struct Dense - { - typedef std::vector Memory; - typedef typename Memory::iterator iterator; - typedef typename Memory::const_iterator const_iterator; - - Int nrows, ncols; - Memory m; - - Dense(Int nr = 0, Int nc = 0) - : nrows(nr), ncols(nc), m(nr*nc, 0) - {} - - Dense(Int nr, Int nc, Int nzr, bool small =false, bool emptyRows =false, - TRandom* r = nullptr) - : nrows(nr), ncols(nc), - m(nr*nc, 0) - { - if (small) - { - NTA_CHECK(r != nullptr) - << "Random number generator required for Dense() constructor" - << " when small is true"; - } +//-------------------------------------------------------------------------------- +/** + */ +template struct Dense { + typedef std::vector Memory; + typedef typename Memory::iterator iterator; + typedef typename Memory::const_iterator const_iterator; + + Int nrows, ncols; + Memory m; + + Dense(Int nr = 0, Int nc = 0) : nrows(nr), ncols(nc), m(nr * nc, 0) {} + + Dense(Int nr, Int nc, Int nzr, bool small = false, bool emptyRows = false, + TRandom *r = nullptr) + : nrows(nr), ncols(nc), m(nr * nc, 0) { + if (small) { + NTA_CHECK(r != nullptr) + << "Random number generator required for Dense() constructor" + << " when small is true"; + } + ITER_2(nrows, ncols) + if (!small) { + at(i, j) = Real(10 * i + j + 1); // none zero, positive + } else { + at(i, j) = 5 * nupic::Epsilon * r->getReal64(); + if (nupic::nearlyZero(at(i, j))) + at(i, j) = 0.0; + } + + if (nzr > 0 && ncols / nzr > 0) { ITER_2(nrows, ncols) - if (!small) { - at(i,j) = Real(10*i+j+1); // none zero, positive - } else { - at(i,j) = 5 * nupic::Epsilon * r->getReal64(); - if (nupic::nearlyZero(at(i,j))) - at(i,j) = 0.0; - } - - if (nzr > 0 && ncols / nzr > 0) { - ITER_2(nrows, ncols) - if (j % (ncols / nzr) == 0) - at(i,j) = 0.0; - } - - if (emptyRows) - for (Int i = 0; i < nrows; i += 2) - for (Int j = 0; j < ncols; ++j) - at(i,j) = 0.0; - } - - ~Dense() {} - - - inline iterator begin() { return m.begin(); } - inline const_iterator begin() const { return m.begin(); } - inline iterator begin(const Int i) { return m.begin() + i*ncols; } - inline const_iterator begin(const Int i) const { return m.begin() + i*ncols; } - - // row i, column j - inline Float& at(const Int i, const Int j) - { - return *(m.begin() + i*ncols + j); - } - - inline const Float& at(const Int i, const Int j) const - { - return *(m.begin() + i*ncols + j); - } - - inline void copy(const Dense& other) - { - nrows = other.nrows; - ncols = other.ncols; - std::vector new_m(other.m); - std::swap(m, new_m); - } - - template - inline void addRow(InIter begin) - { - Int nrows_new = nrows + 1; - std::vector new_m(nrows_new * ncols); - for (Int i = 0; i < nrows; ++i) - for (Int j = 0; j < ncols; ++j) - new_m[i*ncols+j] = at(i,j); - Int k = nrows*ncols; - for (Int j = 0; j < ncols; ++j, ++begin) - new_m[k+j] = *begin; - std::swap(m, new_m); - nrows = nrows_new; - } - - template - inline void deleteRows(InIter del, InIter del_end) - { - Int nrows_new = nrows - (del_end - del); - std::set del_set(del, del_end); - std::vector new_m(nrows_new * ncols); - Int i1 = 0; - for (Int i = 0; i < nrows; ++i) { - if (del_set.find(i) == del_set.end()) { - for (Int j = 0; j < ncols; ++j) - new_m[i1*ncols + j] = at(i,j); - ++i1; - } - } - std::swap(m, new_m); - nrows = nrows_new; + if (j % (ncols / nzr) == 0) + at(i, j) = 0.0; } - template - inline void deleteCols(InIter del, InIter del_end) - { - Int ncols_new = ncols - (del_end - del); - std::set del_set(del, del_end); - std::vector new_m(nrows * ncols_new); - Int j1 = 0; - for (Int j = 0; j < ncols; ++j) { - if (del_set.find(j) == del_set.end()) { - for (Int i = 0; i < nrows; ++i) - new_m[i*ncols_new + j1] = at(i,j); - ++j1; - } + if (emptyRows) + for (Int i = 0; i < nrows; i += 2) + for (Int j = 0; j < ncols; ++j) + at(i, j) = 0.0; + } + + ~Dense() {} + + inline iterator begin() { return m.begin(); } + inline const_iterator begin() const { return m.begin(); } + inline iterator begin(const Int i) { return m.begin() + i * ncols; } + inline const_iterator begin(const Int i) const { + return m.begin() + i * ncols; + } + + // row i, column j + inline Float &at(const Int i, const Int j) { + return *(m.begin() + i * ncols + j); + } + + inline const Float &at(const Int i, const Int j) const { + return *(m.begin() + i * ncols + j); + } + + inline void copy(const Dense &other) { + nrows = other.nrows; + ncols = other.ncols; + std::vector new_m(other.m); + std::swap(m, new_m); + } + + template inline void addRow(InIter begin) { + Int nrows_new = nrows + 1; + std::vector new_m(nrows_new * ncols); + for (Int i = 0; i < nrows; ++i) + for (Int j = 0; j < ncols; ++j) + new_m[i * ncols + j] = at(i, j); + Int k = nrows * ncols; + for (Int j = 0; j < ncols; ++j, ++begin) + new_m[k + j] = *begin; + std::swap(m, new_m); + nrows = nrows_new; + } + + template + inline void deleteRows(InIter del, InIter del_end) { + Int nrows_new = nrows - (del_end - del); + std::set del_set(del, del_end); + std::vector new_m(nrows_new * ncols); + Int i1 = 0; + for (Int i = 0; i < nrows; ++i) { + if (del_set.find(i) == del_set.end()) { + for (Int j = 0; j < ncols; ++j) + new_m[i1 * ncols + j] = at(i, j); + ++i1; } - std::swap(m, new_m); - ncols = ncols_new; } + std::swap(m, new_m); + nrows = nrows_new; + } - inline void resize(const Int new_nrows, const Int new_ncols) - { - std::vector new_m(new_nrows * new_ncols); - Int row_m = std::min(new_nrows, nrows); - Int col_m = std::min(new_ncols, ncols); - ITER_2(row_m, col_m) - new_m[i*new_ncols + j] = at(i,j); - std::swap(m, new_m); - nrows = new_nrows; - ncols = new_ncols; - } + template + inline void deleteCols(InIter del, InIter del_end) { + Int ncols_new = ncols - (del_end - del); + std::set del_set(del, del_end); + std::vector new_m(nrows * ncols_new); + Int j1 = 0; + for (Int j = 0; j < ncols; ++j) { + if (del_set.find(j) == del_set.end()) { + for (Int i = 0; i < nrows; ++i) + new_m[i * ncols_new + j1] = at(i, j); + ++j1; + } + } + std::swap(m, new_m); + ncols = ncols_new; + } - inline void setRowToZero(Int row) - { - for (Int j = 0; j != ncols; ++j) - at(row,j) = 0; - } + inline void resize(const Int new_nrows, const Int new_ncols) { + std::vector new_m(new_nrows * new_ncols); + Int row_m = std::min(new_nrows, nrows); + Int col_m = std::min(new_ncols, ncols); + ITER_2(row_m, col_m) + new_m[i * new_ncols + j] = at(i, j); + std::swap(m, new_m); + nrows = new_nrows; + ncols = new_ncols; + } - inline void setColToZero(Int col) - { - for (Int i = 0; i != nrows; ++i) - at(i,col) = 0; - } - - void fromCSR(std::stringstream& stream) - { - std::string tag; - stream >> tag; - Int nnz, nnzr, j; - Float val; - stream >> nrows >> ncols >> nnz; - - m.resize(nrows*ncols); - std::fill(m.begin(), m.end(), Real(0)); - - for (Int i = 0; i < nrows; ++i) { - stream >> nnzr; - for (Int k = 0; k < nnzr; ++k) { - stream >> j >> val; - if (!nupic::nearlyZero(val)) - at(i,j) = val; - else - at(i,j) = 0; - } + inline void setRowToZero(Int row) { + for (Int j = 0; j != ncols; ++j) + at(row, j) = 0; + } + + inline void setColToZero(Int col) { + for (Int i = 0; i != nrows; ++i) + at(i, col) = 0; + } + + void fromCSR(std::stringstream &stream) { + std::string tag; + stream >> tag; + Int nnz, nnzr, j; + Float val; + stream >> nrows >> ncols >> nnz; + + m.resize(nrows * ncols); + std::fill(m.begin(), m.end(), Real(0)); + + for (Int i = 0; i < nrows; ++i) { + stream >> nnzr; + for (Int k = 0; k < nnzr; ++k) { + stream >> j >> val; + if (!nupic::nearlyZero(val)) + at(i, j) = val; + else + at(i, j) = 0; } } + } - void clear() - { - std::fill(m.begin(), m.end(), Real(0)); - } + void clear() { std::fill(m.begin(), m.end(), Real(0)); } - //-------------------------------------------------------------------------------- - // TESTS - //-------------------------------------------------------------------------------- + //-------------------------------------------------------------------------------- + // TESTS + //-------------------------------------------------------------------------------- - bool isZero() const - { - ITER_1(nrows*ncols) - if (!nupic::nearlyZero(m[i])) - return false; - return true; - } + bool isZero() const { + ITER_1(nrows * ncols) + if (!nupic::nearlyZero(m[i])) + return false; + return true; + } - inline Int nRows(){ return nrows;} - inline Int nCols(){ return ncols;} + inline Int nRows() { return nrows; } + inline Int nCols() { return ncols; } - Int nNonZerosOnRow(Int row) const - { - Int n = 0; - ITER_1(ncols) - if (!nupic::nearlyZero(at(row,i))) - ++n; - return n; - } + Int nNonZerosOnRow(Int row) const { + Int n = 0; + ITER_1(ncols) + if (!nupic::nearlyZero(at(row, i))) + ++n; + return n; + } - Int nNonZerosOnCol(Int col) const - { - Int n = 0; - ITER_2(nrows, ncols) - if (!nupic::nearlyZero(at(i,j)) && j == col) - ++n; - return n; - } + Int nNonZerosOnCol(Int col) const { + Int n = 0; + ITER_2(nrows, ncols) + if (!nupic::nearlyZero(at(i, j)) && j == col) + ++n; + return n; + } - bool isRowZero(Int row) const - { - return nNonZerosOnRow(row) == 0; - } + bool isRowZero(Int row) const { return nNonZerosOnRow(row) == 0; } - bool isColZero(Int col) const - { - return nNonZerosOnCol(col) == 0; - } - - Int nNonZeros() const - { - Int n = 0; - ITER_1(nrows) - n += nNonZerosOnRow(i); - return n; - } + bool isColZero(Int col) const { return nNonZerosOnCol(col) == 0; } - template - void nNonZerosPerRow(OutIter it) const - { - ITER_1(nrows) - *it++ = nNonZerosOnRow(i); - } + Int nNonZeros() const { + Int n = 0; + ITER_1(nrows) + n += nNonZerosOnRow(i); + return n; + } - template - void nNonZerosPerCol(OutIter it) const - { - std::fill(it, it + ncols, 0); + template void nNonZerosPerRow(OutIter it) const { + ITER_1(nrows) + *it++ = nNonZerosOnRow(i); + } - ITER_2(nrows, ncols) - if (!nupic::nearlyZero(at(i,j))) - *(it + j) += 1; - } + template void nNonZerosPerCol(OutIter it) const { + std::fill(it, it + ncols, 0); - //-------------------------------------------------------------------------------- + ITER_2(nrows, ncols) + if (!nupic::nearlyZero(at(i, j))) + *(it + j) += 1; + } - void transpose(Dense& tr) const - { - ITER_2(nrows, ncols) - tr.at(j,i) = at(i,j); - } - - template - void vecMaxProd(InIter x, OutIter y) const - { - for (Int i = 0; i < nrows; ++i) { - Float max = - std::numeric_limits::max(); - for (Int j = 0; j < ncols; ++j) { - Float val = at(i,j) * x[j]; - if (val > max) - max = val; - } - y[i] = max; + //-------------------------------------------------------------------------------- + + void transpose(Dense &tr) const { + ITER_2(nrows, ncols) + tr.at(j, i) = at(i, j); + } + + template + void vecMaxProd(InIter x, OutIter y) const { + for (Int i = 0; i < nrows; ++i) { + Float max = -std::numeric_limits::max(); + for (Int j = 0; j < ncols; ++j) { + Float val = at(i, j) * x[j]; + if (val > max) + max = val; } + y[i] = max; } + } - template - void rightVecProd(InIter x, OutIter y) const - { - for (Int i = 0; i < nrows; ++i) { - Float s = 0; - for (Int j = 0; j < ncols; ++j) - s += at(i,j) * x[j]; - y[i] = s; - } + template + void rightVecProd(InIter x, OutIter y) const { + for (Int i = 0; i < nrows; ++i) { + Float s = 0; + for (Int j = 0; j < ncols; ++j) + s += at(i, j) * x[j]; + y[i] = s; } + } + + template + Float rowLpDist(Float p, Int row, InIter x, bool take_root = false) const { + if (p == 0.0) + return rowL0Dist(row, x); + + Float val = 0; + for (Int j = 0; j != ncols; ++j) + val += ::pow(::fabs(x[j] - at(row, j)), p); + if (take_root) + val = ::pow(val, 1.0 / p); + return val; + } + + template Float rowL0Dist(Int row, InIter x) const { + Float val = 0; + for (Int j = 0; j != ncols; ++j) + val += ::fabs(x[j] - at(row, j)) > nupic::Epsilon; + return val; + } + + template Float rowLMaxDist(Int row, InIter x) const { + Float val = 0; + for (Int j = 0; j != ncols; ++j) + val = std::max(::fabs(x[j] - at(row, j)), val); + return val; + } - template - Float rowLpDist(Float p, Int row, InIter x, bool take_root =false) const - { - if (p == 0.0) - return rowL0Dist(row, x); + template + void LpDist(Float p, InIter x, OutIter y, bool take_root = false) const { + if (p == 0.0) { + L0Dist(x, y); + return; + } + for (Int i = 0; i != nrows; ++i) { Float val = 0; - for (Int j = 0; j != ncols; ++j) - val += ::pow(::fabs(x[j] - at(row,j)), p); - if (take_root) - val = ::pow(val, 1.0/p); - return val; - } - - template - Float rowL0Dist(Int row, InIter x) const - { + for (Int j = 0; j != ncols; ++j) + val += ::pow(::fabs(x[j] - at(i, j)), p); + y[i] = take_root ? ::pow(val, 1.0 / p) : val; + } + } + + template + void L0Dist(InIter x, OutIter y) const { + for (Int i = 0; i != nrows; ++i) { Float val = 0; - for (Int j = 0; j != ncols; ++j) - val += ::fabs(x[j] - at(row,j)) > nupic::Epsilon; - return val; + for (Int j = 0; j != ncols; ++j) + val += ::fabs(x[j] - at(i, j)) > nupic::Epsilon; + y[i] = val; } + } - template - Float rowLMaxDist(Int row, InIter x) const - { + template + void LMaxDist(InIter x, OutIter y) const { + for (Int i = 0; i != nrows; ++i) { Float val = 0; - for (Int j = 0; j != ncols; ++j) - val = std::max(::fabs(x[j] - at(row,j)), val); - return val; + for (Int j = 0; j != ncols; ++j) + val = std::max(val, ::fabs(x[j] - at(i, j))); + y[i] = val; } + } - template - void LpDist(Float p, InIter x, OutIter y, bool take_root =false) const - { - if (p == 0.0) { - L0Dist(x, y); - return; - } + template + inline void LpNearest(Float p, InIter x, OutIter nn, Int k = 1, + bool take_root = false) const { + if (p == 0.0) { + L0Nearest(x, nn, k); + return; + } - for (Int i = 0; i != nrows; ++i) { - Float val = 0; - for (Int j = 0; j != ncols; ++j) - val += ::pow(::fabs(x[j] - at(i,j)), p); - y[i] = take_root ? ::pow(val, 1.0/p) : val; - } + std::vector> dists(nrows); + + for (Int i = 0; i != nrows; ++i) { + dists[i].first = i; + dists[i].second = 0; } - template - void L0Dist(InIter x, OutIter y) const - { - for (Int i = 0; i != nrows; ++i) { - Float val = 0; - for (Int j = 0; j != ncols; ++j) - val += ::fabs(x[j] - at(i,j)) > nupic::Epsilon; - y[i] = val; - } + for (Int i = 0; i < nrows; ++i) { + Float val = 0; + for (Int j = 0; j < ncols; ++j) + val += ::pow(::fabs(x[j] - at(i, j)), p); + dists[i].second = take_root ? ::pow(val, 1.0 / p) : val; } - template - void LMaxDist(InIter x, OutIter y) const - { - for (Int i = 0; i != nrows; ++i) { - Float val = 0; - for (Int j = 0; j != ncols; ++j) - val = std::max(val, ::fabs(x[j] - at(i,j))); - y[i] = val; - } + std::partial_sort( + dists.begin(), dists.begin() + k, dists.end(), + predicate_compose, + nupic::select2nd>>()); + + for (Int i = 0; i != nrows; ++i, ++nn) { + nn->first = dists[i].first; + nn->second = dists[i].second; } + } - template - inline void - LpNearest(Float p, InIter x, OutIter nn, Int k =1, bool take_root =false) const - { - if (p == 0.0) { - L0Nearest(x, nn, k); - return; - } + template + inline void L0Nearest(InIter x, OutIter nn, Int k = 1, + bool take_root = false) const { + std::vector> dists(nrows); - std::vector > dists(nrows); + for (Int i = 0; i != nrows; ++i) { + dists[i].first = i; + dists[i].second = 0; + } - for (Int i = 0; i != nrows; ++i) { - dists[i].first = i; - dists[i].second = 0; - } + for (Int i = 0; i < nrows; ++i) { + Float val = 0; + for (Int j = 0; j < ncols; ++j) + val += ::fabs(x[j] - at(i, j)) > nupic::Epsilon; + dists[i].second = val; + } - for (Int i = 0; i < nrows; ++i) { - Float val = 0; - for (Int j = 0; j < ncols; ++j) - val += ::pow(::fabs(x[j] - at(i,j)), p); - dists[i].second = take_root ? ::pow(val, 1.0/p) : val; - } + std::partial_sort( + dists.begin(), dists.begin() + k, dists.end(), + predicate_compose, + nupic::select2nd>>()); - std::partial_sort(dists.begin(), dists.begin() + k, dists.end(), - predicate_compose, nupic::select2nd > >()); + for (Int i = 0; i != nrows; ++i, ++nn) { + nn->first = dists[i].first; + nn->second = dists[i].second; + } + } - for (Int i = 0; i != nrows; ++i, ++nn) { - nn->first = dists[i].first; - nn->second = dists[i].second; - } + template + inline void LMaxNearest(InIter x, OutIter nn, Int k = 1, + bool take_root = false) const { + std::vector> dists(nrows); + + for (Int i = 0; i != nrows; ++i) { + dists[i].first = i; + dists[i].second = 0; } - template - inline void - L0Nearest(InIter x, OutIter nn, Int k =1, bool take_root =false) const - { - std::vector > dists(nrows); + for (Int i = 0; i < nrows; ++i) { + Float val = 0; + for (Int j = 0; j < ncols; ++j) + val = std::max(val, ::fabs(x[j] - at(i, j))); + dists[i].second = val; + } - for (Int i = 0; i != nrows; ++i) { - dists[i].first = i; - dists[i].second = 0; - } + std::partial_sort( + dists.begin(), dists.begin() + k, dists.end(), + predicate_compose, + nupic::select2nd>>()); - for (Int i = 0; i < nrows; ++i) { - Float val = 0; - for (Int j = 0; j < ncols; ++j) - val += ::fabs(x[j] - at(i,j)) > nupic::Epsilon; - dists[i].second = val; - } + for (Int i = 0; i != nrows; ++i, ++nn) { + nn->first = dists[i].first; + nn->second = dists[i].second; + } + } - std::partial_sort(dists.begin(), dists.begin() + k, dists.end(), predicate_compose, nupic::select2nd > >()); + template std::pair dotNearest(InIter x) const { + Float val, max_val = -std::numeric_limits::max(); + Int arg_i = 0; - for (Int i = 0; i != nrows; ++i, ++nn) { - nn->first = dists[i].first; - nn->second = dists[i].second; + for (Int i = 0; i < nrows; ++i) { + val = 0; + for (Int j = 0; j < ncols; ++j) + val += at(i, j) * x[j]; + if (val > max_val) { + max_val = val; + arg_i = i; } } - template - inline void - LMaxNearest(InIter x, OutIter nn, Int k =1, bool take_root =false) const - { - std::vector > dists(nrows); + return make_pair(arg_i, max_val); + } - for (Int i = 0; i != nrows; ++i) { - dists[i].first = i; - dists[i].second = 0; - } + template void axby(Int r, Float a, Float b, InIter x) { + for (Int j = 0; j < ncols; ++j) + at(r, j) = a * at(r, j) + b * x[j]; - for (Int i = 0; i < nrows; ++i) { - Float val = 0; - for (Int j = 0; j < ncols; ++j) - val = std::max(val, ::fabs(x[j] - at(i,j))); - dists[i].second = val; - } + threshold(r, nupic::Epsilon); + } - std::partial_sort(dists.begin(), dists.begin() + k, dists.end(), - predicate_compose, nupic::select2nd > >()); + template void axby(Float a, Float b, InIter x) { + ITER_2(nrows, ncols) + at(i, j) = a * at(i, j) + b * x[j]; - for (Int i = 0; i != nrows; ++i, ++nn) { - nn->first = dists[i].first; - nn->second = dists[i].second; - } - } + threshold(nupic::Epsilon); + } - template - std::pair dotNearest(InIter x) const - { - Float val, max_val = - std::numeric_limits::max(); - Int arg_i = 0; - - for (Int i = 0; i < nrows; ++i) { - val = 0; - for (Int j = 0; j < ncols; ++j) - val += at(i,j) * x[j]; - if (val > max_val) { - max_val = val; - arg_i = i; + template + void xMaxAtNonZero(InIter x, OutIter y) const { + for (Int i = 0; i < nrows; ++i) { + Int arg_j = 0; + Float max_val = -std::numeric_limits::max(); + for (Int j = 0; j < ncols; ++j) + if (at(i, j) > 0 && x[j] > max_val) { + arg_j = j; + max_val = x[j]; } - } - - return make_pair(arg_i, max_val); + y[i] = Float(arg_j); } + } - template - void axby(Int r, Float a, Float b, InIter x) - { - for (Int j = 0; j < ncols; ++j) - at(r,j) = a * at(r,j) + b * x[j]; + void normalizeRows(bool exact = false) { + for (Int i = 0; i != nrows; ++i) { - threshold(r, nupic::Epsilon); - } + Float val = 0; + bool oneMore = false; - template - void axby(Float a, Float b, InIter x) - { - ITER_2(nrows, ncols) - at(i,j) = a * at(i,j) + b * x[j]; + for (Int j = 0; j != ncols; ++j) + val += at(i, j); - threshold(nupic::Epsilon); - } + if (!nupic::nearlyZero(val)) + for (Int j = 0; j != ncols; ++j) { + at(i, j) /= val; + if (nupic::nearlyZero(at(i, j))) + oneMore = true; + } - template - void xMaxAtNonZero(InIter x, OutIter y) const - { - for (Int i = 0; i < nrows; ++i) { - Int arg_j = 0; - Float max_val = - std::numeric_limits::max(); - for (Int j = 0; j < ncols; ++j) - if (at(i,j) > 0 && x[j] > max_val) { - arg_j = j; - max_val = x[j]; - } - y[i] = Float(arg_j); + if (oneMore && exact) { + + threshold(i, nupic::Epsilon); + + val = 0; + + for (Int j = 0; j != ncols; ++j) + val += at(i, j); + + if (!nupic::nearlyZero(val)) + for (Int j = 0; j != ncols; ++j) + at(i, j) /= val; } } + } - void normalizeRows(bool exact =false) - { - for (Int i = 0; i != nrows; ++i) { - - Float val = 0; - bool oneMore = false; + void normalizeCols(bool exact = false) { + for (Int j = 0; j != ncols; ++j) { - for (Int j = 0; j != ncols; ++j) - val += at(i,j); + Float val = 0; + bool oneMore = false; - if (!nupic::nearlyZero(val)) - for (Int j = 0; j != ncols; ++j) { - at(i,j) /= val; - if (nupic::nearlyZero(at(i,j))) - oneMore = true; + for (Int i = 0; i != nrows; ++i) + val += at(i, j); + + if (!nupic::nearlyZero(val)) + for (Int i = 0; i != nrows; ++i) { + at(i, j) /= val; + if (nupic::nearlyZero(at(i, j))) { + at(i, j) = 0; + oneMore = true; } + } - if (oneMore && exact) { - - threshold(i, nupic::Epsilon); + if (oneMore && exact) { - val = 0; + val = 0; - for (Int j = 0; j != ncols; ++j) - val += at(i,j); + for (Int i = 0; i != nrows; ++i) + val += at(i, j); - if (!nupic::nearlyZero(val)) - for (Int j = 0; j != ncols; ++j) - at(i,j) /= val; - } + if (!nupic::nearlyZero(val)) + for (Int i = 0; i != nrows; ++i) + at(i, j) /= val; } } + } - void normalizeCols(bool exact =false) - { - for (Int j = 0; j != ncols; ++j) { - - Float val = 0; - bool oneMore = false; + template + void rowProd(InIter x, OutIter y) const { + for (Int i = 0; i < nrows; ++i) { + Float val = 1.0; + for (Int j = 0; j < ncols; ++j) + if (at(i, j) > 0) + val *= x[j]; + y[i] = val; + } + } - for (Int i = 0; i != nrows; ++i) - val += at(i,j); + void threshold(Int row, Float thres) { + for (Int j = 0; j < ncols; ++j) + if (::fabs(at(row, j)) <= thres) + at(row, j) = 0; + } - if (!nupic::nearlyZero(val)) - for (Int i = 0; i != nrows; ++i) { - at(i,j) /= val; - if (nupic::nearlyZero(at(i,j))) { - at(i,j) = 0; - oneMore = true; - } - } + void threshold(Float thres) { + for (Int i = 0; i < nrows; ++i) + threshold(i, thres); + } - if (oneMore && exact) { + void lerp(Float a, Float b, const Dense &B) { + ITER_2(nrows, ncols) + at(i, j) = a * at(i, j) + b * B.at(i, j); - val = 0; + threshold(nupic::Epsilon); + } - for (Int i = 0; i != nrows; ++i) - val += at(i,j); + template + void apply(InIter x, binary_functor f) { + ITER_2(nrows, ncols) + at(i, j) = f(at(i, j), x[j]); + } - if (!nupic::nearlyZero(val)) - for (Int i = 0; i != nrows; ++i) - at(i,j) /= val; - } - } - } + template + void apply(const Dense &B, Dense &C, + binary_functor f) { + ITER_2(nrows, ncols) + C.at(i, j) = f(at(i, j), B.at(i, j)); + } - template - void rowProd(InIter x, OutIter y) const - { - for (Int i = 0; i < nrows; ++i) { - Float val = 1.0; - for (Int j = 0; j < ncols; ++j) - if (at(i,j) > 0) - val *= x[j]; - y[i] = val; + template + Float accumulate_nz(const Int row, binary_functor f, const Float &init = 0) { + Float r = init; + for (Int j = 0; j < ncols; ++j) + if (!nupic::nearlyZero(at(row, j))) + r = f(r, at(row, j)); + return r; + } + + template + Float accumulate(const Int row, binary_functor f, const Float &init = 0) { + Float r = init; + for (Int j = 0; j != ncols; ++j) + r = f(r, at(row, j)); + return r; + } + + void multiply(const Dense &B, Dense &C) { + ITER_2(C.nrows, C.ncols) { + C.at(i, j) = 0; + for (Int k = 0; k < ncols; k++) { + C.at(i, j) += at(i, k) * B.at(k, j); } } + } - void threshold(Int row, Float thres) - { - for (Int j = 0; j < ncols; ++j) - if (::fabs(at(row,j)) <= thres) - at(row,j) = 0; - } + inline void setZero(Int i, Int j) { at(i, j) = 0; } - void threshold(Float thres) - { - for (Int i = 0; i < nrows; ++i) - threshold(i, thres); - } + inline void setNonZero(Int i, Int j, const Float &val) { at(i, j) = val; } - void lerp(Float a, Float b, const Dense& B) - { - ITER_2(nrows, ncols) - at(i,j) = a * at(i,j) + b * B.at(i,j); - - threshold(nupic::Epsilon); - } - - template - void apply(InIter x, binary_functor f) - { - ITER_2(nrows, ncols) - at(i,j) = f(at(i,j), x[j]); - } + inline void set(Int i, Int j, const Float &val) { at(i, j) = val; } - template - void apply(const Dense& B, Dense& C, binary_functor f) - { - ITER_2(nrows, ncols) - C.at(i,j) = f(at(i,j), B.at(i,j)); - } - - template - Float accumulate_nz(const Int row, binary_functor f, const Float& init =0) - { - Float r = init; - for (Int j = 0; j < ncols; ++j) - if (!nupic::nearlyZero(at(row,j))) - r = f(r, at(row,j)); - return r; - } + template inline void add(Int row, InIter x) { + for (Int j = 0; j < ncols; ++j) + at(row, j) += *x++; + } - template - Float accumulate(const Int row, binary_functor f, const Float& init =0) - { - Float r = init; - for (Int j = 0; j != ncols; ++j) - r = f(r, at(row,j)); - return r; - } - - void multiply(const Dense& B, Dense& C) - { - ITER_2(C.nrows, C.ncols) { - C.at(i,j) = 0; - for(Int k=0; k + inline void vecMaxAtNZ(InIter x, OutIter y) const { + Float currentMax; - inline void setNonZero(Int i, Int j, const Float& val) - { - at(i,j) = val; + for (Int i = 0; i < nrows; i++) { + currentMax = 0; + for (Int j = 0; j < ncols; j++) { + if (at(i, j) > currentMax) + currentMax = at(i, j); + } + *y = currentMax; } + } - inline void set(Int i, Int j, const Float& val) - { - at(i,j) = val; - } - - template - inline void add(Int row, InIter x) - { - for (Int j = 0; j < ncols; ++j) - at(row, j) += *x++; - } - - inline void add(const Dense& B) - { - for(Int i=0; i - inline void vecMaxAtNZ(InIter x, OutIter y) const - { - Float currentMax; - - for(Int i=0; i currentMax) - currentMax = at(i,j); - } - *y = currentMax; + template + inline void rowProd(InIter x, OutIter y, Float lb) const { + Float curProduct; + InIter x_begin = x; + + for (Int i = 0; i < nrows; i++) { + curProduct = 1; + x = x_begin; + for (Int j = 0; j < ncols; j++) { + if (at(i, j) != 0) + curProduct *= *x; + x++; } - } - - template - inline void rowProd(InIter x, OutIter y, Float lb) const - { - Float curProduct; - InIter x_begin=x; - - for(Int i=0; i - inline void getRowToDense(Int r, OutIter dense) const - { - for (Int j = 0; j != ncols; ++j) - *dense++ = at(r,j); - } - - template - inline void getColToDense(Int c, OutIter dense) const - { - for (Int i = 0; i != nrows; ++i) - *dense++ = at(i,c); - } + template + inline void getRowToDense(Int r, OutIter dense) const { + for (Int j = 0; j != ncols; ++j) + *dense++ = at(r, j); + } + + template + inline void getColToDense(Int c, OutIter dense) const { + for (Int i = 0; i != nrows; ++i) + *dense++ = at(i, c); + } - template - inline void getRowToSparse(Int r, OutIter1 indIt, OutIter2 nzIt) const - { - for(Int j=0; j + inline void getRowToSparse(Int r, OutIter1 indIt, OutIter2 nzIt) const { + for (Int j = 0; j < ncols; j++) { + if (at(r, j) != 0) { + *indIt++ = j; + *nzIt++ = at(r, j); } } + } - template - inline void getColToSparse(Int c, OutIter1 indIt, OutIter2 nzIt) const - { - for(Int j=0; j + inline void getColToSparse(Int c, OutIter1 indIt, OutIter2 nzIt) const { + for (Int j = 0; j < nrows; j++) { + if (at(j, c) != 0) { + *indIt++ = j; + *nzIt++ = at(j, c); } } - - template - inline Int findRow(const Int nnzr, IndIt ind_it, NzIt nz_it) - { - IndIt ind_it_begin = ind_it, ind_it_end = ind_it+nnzr; - NzIt nz_it_begin = nz_it; - - for (Int i=0; i + inline Int findRow(const Int nnzr, IndIt ind_it, NzIt nz_it) { + IndIt ind_it_begin = ind_it, ind_it_end = ind_it + nnzr; + NzIt nz_it_begin = nz_it; + + for (Int i = 0; i < nrows; ++i) { + ind_it = ind_it_begin; + nz_it = nz_it_begin; + while (ind_it != ind_it_end) { + if (at(i, *ind_it) != *nz_it) + break; + ++ind_it; + ++nz_it; } - - return nrows; + if (ind_it == ind_it_end) + return i; } - inline void max(Int& max_i, Int& max_j, Float& max_val) const - { - max_i = 0, max_j = 0; - max_val = - std::numeric_limits::max(); + return nrows; + } - ITER_2(nrows, ncols) - if (!nupic::nearlyZero(at(i,j)) && at(i,j) > max_val) { - max_val = at(i,j); - max_i = i; - max_j = j; - } - - if (max_val == -std::numeric_limits::max()) - max_val = 0; + inline void max(Int &max_i, Int &max_j, Float &max_val) const { + max_i = 0, max_j = 0; + max_val = -std::numeric_limits::max(); + + ITER_2(nrows, ncols) + if (!nupic::nearlyZero(at(i, j)) && at(i, j) > max_val) { + max_val = at(i, j); + max_i = i; + max_j = j; } - inline void min(Int& min_i, Int& min_j, Float& min_val) const - { - min_i = 0, min_j = 0; - min_val = std::numeric_limits::max(); + if (max_val == -std::numeric_limits::max()) + max_val = 0; + } - ITER_2(nrows, ncols) - if (!nupic::nearlyZero(at(i,j)) && at(i,j) < min_val) { - min_val = at(i,j); - min_i = i; - min_j = j; - } + inline void min(Int &min_i, Int &min_j, Float &min_val) const { + min_i = 0, min_j = 0; + min_val = std::numeric_limits::max(); - if (min_val == std::numeric_limits::max()) - min_val = 0; + ITER_2(nrows, ncols) + if (!nupic::nearlyZero(at(i, j)) && at(i, j) < min_val) { + min_val = at(i, j); + min_i = i; + min_j = j; } - template - inline void rowMax(Maxima maxima) const - { - for (Int i = 0; i != nrows; ++i) { - maxima[i].first = 0; - maxima[i].second = - std::numeric_limits::max(); - for (Int j = 0; j != ncols; ++j) { - if (!nupic::nearlyZero(at(i,j)) && at(i,j) > maxima[i].second) { - maxima[i].first = j; - maxima[i].second = at(i,j); - } - } - } - } + if (min_val == std::numeric_limits::max()) + min_val = 0; + } - template - inline void rowMin(Minima minima) const - { - for (Int i = 0; i != nrows; ++i) { - minima[i].first = 0; - minima[i].second = std::numeric_limits::max(); - for (Int j = 0; j != ncols; ++j) { - if (!nupic::nearlyZero(at(i,j)) && at(i,j) < minima[i].second) { - minima[i].first = j; - minima[i].second = at(i,j); - } - } + template inline void rowMax(Maxima maxima) const { + for (Int i = 0; i != nrows; ++i) { + maxima[i].first = 0; + maxima[i].second = -std::numeric_limits::max(); + for (Int j = 0; j != ncols; ++j) { + if (!nupic::nearlyZero(at(i, j)) && at(i, j) > maxima[i].second) { + maxima[i].first = j; + maxima[i].second = at(i, j); + } } } + } - template - inline void colMax(Maxima maxima) const - { + template inline void rowMin(Minima minima) const { + for (Int i = 0; i != nrows; ++i) { + minima[i].first = 0; + minima[i].second = std::numeric_limits::max(); for (Int j = 0; j != ncols; ++j) { - maxima[j].first = 0; - maxima[j].second = - std::numeric_limits::max(); - for (Int i = 0; i != nrows; ++i) { - if (!nupic::nearlyZero(at(i,j)) && at(i,j) > maxima[j].second) { - maxima[j].first = i; - maxima[j].second = at(i,j); - } - } - if (maxima[j].second == - std::numeric_limits::max()) - maxima[j].second = 0; + if (!nupic::nearlyZero(at(i, j)) && at(i, j) < minima[i].second) { + minima[i].first = j; + minima[i].second = at(i, j); + } } } + } - template - inline void colMin(Minima minima) const - { - for (Int j = 0; j != ncols; ++j) { - minima[j].first = 0; - minima[j].second = std::numeric_limits::max(); - for (Int i = 0; i != nrows; ++i) { - if (!nupic::nearlyZero(at(i,j)) && at(i,j) < minima[j].second) { - minima[j].first = i; - minima[j].second = at(i,j); - } - } - if (minima[j].second == std::numeric_limits::max()) - minima[j].second = 0; + template inline void colMax(Maxima maxima) const { + for (Int j = 0; j != ncols; ++j) { + maxima[j].first = 0; + maxima[j].second = -std::numeric_limits::max(); + for (Int i = 0; i != nrows; ++i) { + if (!nupic::nearlyZero(at(i, j)) && at(i, j) > maxima[j].second) { + maxima[j].first = i; + maxima[j].second = at(i, j); + } } + if (maxima[j].second == -std::numeric_limits::max()) + maxima[j].second = 0; } - }; + } - //-------------------------------------------------------------------------------- - template - std::ostream& operator<<(std::ostream& out, const Dense& d) - { - for (Int i = 0; i < d.nrows; ++i) { - for (Int j = 0; j < d.ncols; ++j) - out << d.at(i,j) << " "; - out << std::endl; + template inline void colMin(Minima minima) const { + for (Int j = 0; j != ncols; ++j) { + minima[j].first = 0; + minima[j].second = std::numeric_limits::max(); + for (Int i = 0; i != nrows; ++i) { + if (!nupic::nearlyZero(at(i, j)) && at(i, j) < minima[j].second) { + minima[j].first = i; + minima[j].second = at(i, j); + } + } + if (minima[j].second == std::numeric_limits::max()) + minima[j].second = 0; } - - return out; + } +}; + +//-------------------------------------------------------------------------------- +template +std::ostream &operator<<(std::ostream &out, const Dense &d) { + for (Int i = 0; i < d.nrows; ++i) { + for (Int j = 0; j < d.ncols; ++j) + out << d.at(i, j) << " "; + out << std::endl; } - //-------------------------------------------------------------------------------- + return out; +} + +//-------------------------------------------------------------------------------- } // end namespace nupic #endif // NTA_DENSE_MATRIX_HPP diff --git a/src/nupic/math/Domain.hpp b/src/nupic/math/Domain.hpp index b1388e2ca3..d1a9259703 100644 --- a/src/nupic/math/Domain.hpp +++ b/src/nupic/math/Domain.hpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Definition and implementation for Domain class */ @@ -35,389 +35,330 @@ namespace nupic { - /** - * A class that models a range along a given dimension. - * dim is the dimension number, lb is the lower bound, - * and ub the upper bound. The range includes its lower bound - * but does not contain its upper bound: [lb..ub), open - * on the right. - */ - template - class DimRange - { - public: - inline DimRange() - : dim_(0), lb_(0), ub_(0) - {} - - inline DimRange(const UInt dim, const UInt& lb, const UInt& ub) - : dim_(dim), lb_(lb), ub_(ub) - { - NTA_ASSERT(lb >= 0); - NTA_ASSERT(lb <= ub) - << "DimRange::DimRange(dim, lb, ub): " - << "Lower bound (" << lb << ") should be <= upper bound " - << "(" << ub << ") for dim: " << dim; - } - - inline DimRange(const DimRange& o) - : dim_(o.dim_), lb_(o.lb_), ub_(o.ub_) - {} +/** + * A class that models a range along a given dimension. + * dim is the dimension number, lb is the lower bound, + * and ub the upper bound. The range includes its lower bound + * but does not contain its upper bound: [lb..ub), open + * on the right. + */ +template class DimRange { +public: + inline DimRange() : dim_(0), lb_(0), ub_(0) {} + + inline DimRange(const UInt dim, const UInt &lb, const UInt &ub) + : dim_(dim), lb_(lb), ub_(ub) { + NTA_ASSERT(lb >= 0); + NTA_ASSERT(lb <= ub) << "DimRange::DimRange(dim, lb, ub): " + << "Lower bound (" << lb + << ") should be <= upper bound " + << "(" << ub << ") for dim: " << dim; + } - inline DimRange& operator=(const DimRange& o) - { - if (&o != this) { - dim_ = o.dim_; - lb_ = o.lb_; - ub_ = o.ub_; - } - return *this; - } - - inline const UInt getDim() const { return dim_; } - inline const UInt getLB() const { return lb_; } - inline const UInt getUB() const { return ub_; } - inline const UInt size() const { return ub_ - lb_; } - inline bool empty() const { return lb_ == ub_; } - - inline bool includes(const UInt& i) const - { - bool ok = false; - if (lb_ == ub_) { - if (i == lb_) - ok = true; - } - else if (lb_ <= i && i < ub_) - ok = true; - return ok; - } + inline DimRange(const DimRange &o) : dim_(o.dim_), lb_(o.lb_), ub_(o.ub_) {} - inline void set(const UInt dim, const UInt& lb, const UInt& ub) - { - NTA_ASSERT(lb <= ub) - << "DimRange::set(dim, lb, ub): " - << "Lower bound (" << lb << ") should be <= upper bound " - << "(" << ub << ") for dim: " << dim; - - dim_ = dim; - lb_ = lb; - ub_ = ub; + inline DimRange &operator=(const DimRange &o) { + if (&o != this) { + dim_ = o.dim_; + lb_ = o.lb_; + ub_ = o.ub_; } + return *this; + } - template - NTA_HIDDEN friend std::ostream& operator<<(std::ostream&, const DimRange&); + inline const UInt getDim() const { return dim_; } + inline const UInt getLB() const { return lb_; } + inline const UInt getUB() const { return ub_; } + inline const UInt size() const { return ub_ - lb_; } + inline bool empty() const { return lb_ == ub_; } - template - NTA_HIDDEN friend bool operator==(const DimRange& r1, const DimRange& r2); + inline bool includes(const UInt &i) const { + bool ok = false; + if (lb_ == ub_) { + if (i == lb_) + ok = true; + } else if (lb_ <= i && i < ub_) + ok = true; + return ok; + } - template - NTA_HIDDEN friend bool operator!=(const DimRange& r1, const DimRange& r2); + inline void set(const UInt dim, const UInt &lb, const UInt &ub) { + NTA_ASSERT(lb <= ub) << "DimRange::set(dim, lb, ub): " + << "Lower bound (" << lb + << ") should be <= upper bound " + << "(" << ub << ") for dim: " << dim; - private: - UInt dim_; - UInt lb_, ub_; - }; + dim_ = dim; + lb_ = lb; + ub_ = ub; + } - //-------------------------------------------------------------------------------- template - inline std::ostream& operator<<(std::ostream& outStream, const DimRange& r) - { - return outStream << "[" << r.dim_ << ": " << r.lb_ << ".." << r.ub_ << ")"; - } - + NTA_HIDDEN friend std::ostream &operator<<(std::ostream &, + const DimRange &); + template - inline bool operator==(const DimRange& r1, const DimRange& r2) - { - return r1.dim_ == r2.dim_ && r1.lb_ == r2.lb_ && r1.ub_ == r2.ub_; - } + NTA_HIDDEN friend bool operator==(const DimRange &r1, + const DimRange &r2); template - inline bool operator!=(const DimRange& r1, const DimRange& r2) + NTA_HIDDEN friend bool operator!=(const DimRange &r1, + const DimRange &r2); + +private: + UInt dim_; + UInt lb_, ub_; +}; + +//-------------------------------------------------------------------------------- +template +inline std::ostream &operator<<(std::ostream &outStream, + const DimRange &r) { + return outStream << "[" << r.dim_ << ": " << r.lb_ << ".." << r.ub_ << ")"; +} + +template +inline bool operator==(const DimRange &r1, const DimRange &r2) { + return r1.dim_ == r2.dim_ && r1.lb_ == r2.lb_ && r1.ub_ == r2.ub_; +} + +template +inline bool operator!=(const DimRange &r1, const DimRange &r2) { + return !(r1 == r2); +} + +//-------------------------------------------------------------------------------- +/** + * A class that models the cartesian product of several ranges + * along several dimensions. + */ +template class Domain { +public: + // Doesn't work on shona + /* + explicit inline Domain(UInt NDims, UInt d0, UInt lb0, UInt ub0, ...) { - return ! (r1 == r2); - } + ranges_.push_back(DimRange(d0, lb0, ub0)); + va_list indices; + va_start(indices, ub0); + for (UInt k = 1; k < NDims; ++k) + ranges_.push_back(DimRange(va_arg(indices, UInt), + va_arg(indices, UInt), + va_arg(indices, UInt))); + va_end(indices); - //-------------------------------------------------------------------------------- - /** - * A class that models the cartesian product of several ranges - * along several dimensions. - */ - template - class Domain - { - public: - // Doesn't work on shona - /* - explicit inline Domain(UInt NDims, UInt d0, UInt lb0, UInt ub0, ...) - { - ranges_.push_back(DimRange(d0, lb0, ub0)); - va_list indices; - va_start(indices, ub0); - for (UInt k = 1; k < NDims; ++k) - ranges_.push_back(DimRange(va_arg(indices, UInt), - va_arg(indices, UInt), - va_arg(indices, UInt))); - va_end(indices); - - { - for (UInt i = 0; i < NDims-1; ++i) - NTA_ASSERT(ranges_[i].getDim() < ranges_[i+1].getDim()) - << "Domain::Domain(...): " - << "Dimensions need to be in strictly increasing order"; - } - } - */ - - // Half-space constructor - template - explicit inline Domain(const Index& ub) - : ranges_() { - for (UInt k = 0; k < ub.size(); ++k) - ranges_.push_back(DimRange(k, 0, ub[k])); + for (UInt i = 0; i < NDims-1; ++i) + NTA_ASSERT(ranges_[i].getDim() < ranges_[i+1].getDim()) + << "Domain::Domain(...): " + << "Dimensions need to be in strictly increasing order"; } + } + */ - template - explicit inline Domain(const Index& lb, const Index& ub) - : ranges_() - { - { - NTA_ASSERT(lb.size() == ub.size()); - } + // Half-space constructor + template + explicit inline Domain(const Index &ub) : ranges_() { + for (UInt k = 0; k < ub.size(); ++k) + ranges_.push_back(DimRange(k, 0, ub[k])); + } - for (UInt k = 0; k < ub.size(); ++k) - ranges_.push_back(DimRange(k, lb[k], ub[k])); - } + template + explicit inline Domain(const Index &lb, const Index &ub) : ranges_() { + { NTA_ASSERT(lb.size() == ub.size()); } - inline Domain(const Domain& o) - : ranges_(o.ranges_) - {} + for (UInt k = 0; k < ub.size(); ++k) + ranges_.push_back(DimRange(k, lb[k], ub[k])); + } - inline Domain& operator=(const Domain& o) - { - if (&o != this) - ranges_ = o.ranges_; - return *this; - } + inline Domain(const Domain &o) : ranges_(o.ranges_) {} - inline UInt rank() const { return (UInt)ranges_.size(); } - inline bool empty() const { return size_elts() == 0; } + inline Domain &operator=(const Domain &o) { + if (&o != this) + ranges_ = o.ranges_; + return *this; + } - inline UInt size_elts() const - { - UInt n = 1; - for (UInt k = 0; k < rank(); ++k) - n *= ranges_[k].size(); - return n; - } + inline UInt rank() const { return (UInt)ranges_.size(); } + inline bool empty() const { return size_elts() == 0; } - inline DimRange operator[](const UInt& idx) const - { - { - NTA_ASSERT(0 <= idx && idx < rank()); - } + inline UInt size_elts() const { + UInt n = 1; + for (UInt k = 0; k < rank(); ++k) + n *= ranges_[k].size(); + return n; + } - return ranges_[idx]; - } + inline DimRange operator[](const UInt &idx) const { + { NTA_ASSERT(0 <= idx && idx < rank()); } - template - inline void getLB(Index& lb) const - { - { - NTA_ASSERT(lb.size() == rank()); - } + return ranges_[idx]; + } - for (UInt k = 0; k < rank(); ++k) - lb[k] = ranges_[k].getLB(); - } + template inline void getLB(Index &lb) const { + { NTA_ASSERT(lb.size() == rank()); } - template - inline void getUB(Index& ub) const - { - { - NTA_ASSERT(ub.size() == rank()); - } + for (UInt k = 0; k < rank(); ++k) + lb[k] = ranges_[k].getLB(); + } - for (UInt k = 0; k < rank(); ++k) - ub[k] = ranges_[k].getUB(); - } + template inline void getUB(Index &ub) const { + { NTA_ASSERT(ub.size() == rank()); } - template - inline void getIterationLast(Index& last) const - { - { - NTA_ASSERT(last.size() == rank()); - NTA_ASSERT(!hasClosedDims()); - } + for (UInt k = 0; k < rank(); ++k) + ub[k] = ranges_[k].getUB(); + } - for (UInt k = 0; k < rank(); ++k) - last[k] = ranges_[k].getUB() - 1; + template inline void getIterationLast(Index &last) const { + { + NTA_ASSERT(last.size() == rank()); + NTA_ASSERT(!hasClosedDims()); } - template - inline void getDims(Index& dims) const - { - { - NTA_ASSERT(dims.size() == rank()); - } + for (UInt k = 0; k < rank(); ++k) + last[k] = ranges_[k].getUB() - 1; + } - for (UInt k = 0; k < rank(); ++k) - dims[k] = ranges_[k].getDim(); - } + template inline void getDims(Index &dims) const { + { NTA_ASSERT(dims.size() == rank()); } - inline UInt getNOpenDims() const - { - UInt k, n; - for (k = 0, n = 0; k < rank(); ++k) - if (! ranges_[k].empty()) - ++n; - return n; - } + for (UInt k = 0; k < rank(); ++k) + dims[k] = ranges_[k].getDim(); + } - template - inline void getOpenDims(Index2& dims) const + inline UInt getNOpenDims() const { + UInt k, n; + for (k = 0, n = 0; k < rank(); ++k) + if (!ranges_[k].empty()) + ++n; + return n; + } + + template inline void getOpenDims(Index2 &dims) const { { - { - NTA_ASSERT(dims.size() == getNOpenDims()) + NTA_ASSERT(dims.size() == getNOpenDims()) << "Domain::getOpenDims(): " << "Wrong number of dimensions, passed: " << dims.size() << " - Should be " << getNOpenDims(); - } - - UInt k, k1; - for (k = 0, k1 = 0; k < rank(); ++k) - if (! ranges_[k].empty()) - dims[k1++] = ranges_[k].getDim(); } - inline bool hasClosedDims() const - { - for (UInt k = 0; k < rank(); ++k) - if (ranges_[k].empty()) - return true; - return false; - } + UInt k, k1; + for (k = 0, k1 = 0; k < rank(); ++k) + if (!ranges_[k].empty()) + dims[k1++] = ranges_[k].getDim(); + } - inline UInt getNClosedDims() const - { - return rank() - getNOpenDims(); - } + inline bool hasClosedDims() const { + for (UInt k = 0; k < rank(); ++k) + if (ranges_[k].empty()) + return true; + return false; + } + + inline UInt getNClosedDims() const { return rank() - getNOpenDims(); } - template - inline void getClosedDims(Index2& dims) const + template inline void getClosedDims(Index2 &dims) const { { - { - NTA_ASSERT(dims.size() == getNClosedDims()) + NTA_ASSERT(dims.size() == getNClosedDims()) << "Domain::getClosedDims(): " << "Wrong number of dimensions, passed: " << dims.size() << " - Should be " << getNClosedDims(); - } - - UInt k, k1; - for (k = 0, k1 = 0; k < rank(); ++k) - if (ranges_[k].empty()) - dims[k1++] = ranges_[k].getDim(); } - template - inline bool includes(const Index& index) const - { - { - NTA_ASSERT(index.size() == rank()); - } - - bool ok = true; - for (UInt k = 0; k < rank() && ok; ++k) - ok = ranges_[k].includes(index[k]); - return ok; - } + UInt k, k1; + for (k = 0, k1 = 0; k < rank(); ++k) + if (ranges_[k].empty()) + dims[k1++] = ranges_[k].getDim(); + } - /** - * Not strict inclusion. - */ - inline bool includes(const Domain& d) const - { - { - NTA_ASSERT(d.rank() == rank()); - } - - for (UInt k = 0; k < rank(); ++k) - if (d.ranges_[k].getLB() < ranges_[k].getLB() - || d.ranges_[k].getUB() > ranges_[k].getUB()) - return false; - - return true; - } + template inline bool includes(const Index &index) const { + { NTA_ASSERT(index.size() == rank()); } - template - NTA_HIDDEN friend std::ostream& operator<<(std::ostream& outStream, const Domain& dom); - - template - NTA_HIDDEN friend bool operator==(const Domain& d1, const Domain& d2); + bool ok = true; + for (UInt k = 0; k < rank() && ok; ++k) + ok = ranges_[k].includes(index[k]); + return ok; + } - template - NTA_HIDDEN friend bool operator!=(const Domain& d1, const Domain& d2); + /** + * Not strict inclusion. + */ + inline bool includes(const Domain &d) const { + { NTA_ASSERT(d.rank() == rank()); } - protected: - // Could be a compile time dimension, as it used to be, - // and that would be faster - // sorted by construction, and unique dims - std::vector > ranges_; + for (UInt k = 0; k < rank(); ++k) + if (d.ranges_[k].getLB() < ranges_[k].getLB() || + d.ranges_[k].getUB() > ranges_[k].getUB()) + return false; - Domain() {} - }; + return true; + } - //-------------------------------------------------------------------------------- template - inline std::ostream& operator<<(std::ostream& outStream, const Domain& dom) - { - outStream << "["; - for (UInt k = 0; k < dom.rank(); ++k) - outStream << dom[k] << " "; - return outStream << "]" << std::endl; - } + NTA_HIDDEN friend std::ostream &operator<<(std::ostream &outStream, + const Domain &dom); template - inline bool operator==(const Domain& d1, const Domain& d2) - { - if (d1.rank() != d2.rank()) - return false; - for (UInt k = 0; k < d1.rank(); ++k) - if (d1.ranges_[k] != d2.ranges_[k]) - return false; - return true; - } + NTA_HIDDEN friend bool operator==(const Domain &d1, const Domain &d2); template - inline bool operator!=(const Domain& d1, const Domain& d2) - { - if (d1.rank() != d2.rank()) - return true; - for (UInt k = 0; k < d1.rank(); ++k) - if (d1.ranges_[k] != d2.ranges_[k]) - return true; + NTA_HIDDEN friend bool operator!=(const Domain &d1, const Domain &d2); + +protected: + // Could be a compile time dimension, as it used to be, + // and that would be faster + // sorted by construction, and unique dims + std::vector> ranges_; + + Domain() {} +}; + +//-------------------------------------------------------------------------------- +template +inline std::ostream &operator<<(std::ostream &outStream, + const Domain &dom) { + outStream << "["; + for (UInt k = 0; k < dom.rank(); ++k) + outStream << dom[k] << " "; + return outStream << "]" << std::endl; +} + +template +inline bool operator==(const Domain &d1, const Domain &d2) { + if (d1.rank() != d2.rank()) return false; - } + for (UInt k = 0; k < d1.rank(); ++k) + if (d1.ranges_[k] != d2.ranges_[k]) + return false; + return true; +} - //-------------------------------------------------------------------------------- - template - class Domain2D : public Domain - { - public: - inline Domain2D(T first_row, T row_end, T first_col, T col_end) - { - this->ranges_.resize(2); - this->ranges_[0].set(0, first_row, row_end); - this->ranges_[1].set(1, first_col, col_end); - } +template +inline bool operator!=(const Domain &d1, const Domain &d2) { + if (d1.rank() != d2.rank()) + return true; + for (UInt k = 0; k < d1.rank(); ++k) + if (d1.ranges_[k] != d2.ranges_[k]) + return true; + return false; +} + +//-------------------------------------------------------------------------------- +template class Domain2D : public Domain { +public: + inline Domain2D(T first_row, T row_end, T first_col, T col_end) { + this->ranges_.resize(2); + this->ranges_[0].set(0, first_row, row_end); + this->ranges_[1].set(1, first_col, col_end); + } - inline T getFirstRow() const { return this->ranges_[0].getLB(); } - inline T getRowEnd() const { return this->ranges_[0].getUB(); } - inline T getFirstCol() const { return this->ranges_[1].getLB(); } - inline T getColEnd() const { return this->ranges_[1].getUB(); } - }; + inline T getFirstRow() const { return this->ranges_[0].getLB(); } + inline T getRowEnd() const { return this->ranges_[0].getUB(); } + inline T getFirstCol() const { return this->ranges_[1].getLB(); } + inline T getColEnd() const { return this->ranges_[1].getUB(); } +}; - //-------------------------------------------------------------------------------- +//-------------------------------------------------------------------------------- } // end namespace nupic #endif // NTA_DOMAIN_HPP diff --git a/src/nupic/math/Erosion.hpp b/src/nupic/math/Erosion.hpp index 1d953968d3..3aa82a3ad4 100644 --- a/src/nupic/math/Erosion.hpp +++ b/src/nupic/math/Erosion.hpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Erosion/dilation */ @@ -37,15 +37,13 @@ using namespace std; /** * Erode or dilate an image. */ -template -struct Erosion -{ +template struct Erosion { typedef size_t size_type; typedef T value_type; size_type nrows_; size_type ncols_; - T* buffer_; + T *buffer_; inline void init(size_type nrows, size_type ncols) /* @@ -55,16 +53,13 @@ struct Erosion { nrows_ = nrows; ncols_ = ncols; - buffer_ = new T[nrows*ncols]; + buffer_ = new T[nrows * ncols]; } - inline Erosion() : buffer_(NULL) - { - } + inline Erosion() : buffer_(NULL) {} - inline ~Erosion() - { - delete [] buffer_; + inline ~Erosion() { + delete[] buffer_; buffer_ = NULL; } @@ -72,81 +67,77 @@ struct Erosion * Erodes (or dilates) the image by convolving with a 3x3 min (or max) filter. * Number of iterations is the radius of the erosion/dilation. * Does the convolution separably. - */ - inline void compute(T* data, T* eroded, size_type iterations, - bool dilate=false) - { - for (size_type iter = 0; iter != iterations; ++iter) { - T* in; - if (!iter) { - in = data; // First pass - read from the input buffer - } else { - in = eroded; // Subsequent pass - read from the output buffer - } - // Rows (ignoring the first and last column) - for (size_type i = 0; i != nrows_; ++i) { - T* b = buffer_ + i*ncols_ + 1; // Write to b in the buffer - T* d = in + i*ncols_; // Start reading from one to the left of b - while (b != (buffer_ + (i + 1)*ncols_ - 1)) { - if (dilate) { - *b = max(max(*d, *(d+1)), *(d+2)); - } else { - *b = min(min(*d, *(d+1)), *(d+2)); - } - b++; - d++; - } - } - if (dilate) { - // Need to fill the first and last column, which were ignored - for (size_type row = 0; row < nrows_; row++) { - buffer_[row * ncols_] = - max(in[row * ncols_], in[row * ncols_ + 1]); - buffer_[(row+1) * ncols_ - 1] = - max(in[(row+1) * ncols_ - 2], in[(row+1) * ncols_ - 1]); - } - } else { - // Zero out the first and last column (they are always eroded away) - for (size_type row = 0; row < nrows_; row++) { - buffer_[row * ncols_] = 0; - buffer_[(row + 1) * ncols_ - 1] = 0; - } - } - - // Columns (ignoring the first and last row) - for (size_type i = 0; i != ncols_; ++i) { - T* b = eroded + i + ncols_; // Write to b in the output - T* d = buffer_ + i; // Start reading from one above b - while (b != (eroded + i + ncols_*(nrows_ - 1))) { - if (dilate) { - *b = max(max(*d, *(d + ncols_)), *(d + ncols_*2)); - } else { - *b = min(min(*d, *(d + ncols_)), *(d + ncols_*2)); - } - b += ncols_; - d += ncols_; - } - } - if (dilate) { - // Need to fill the first and last row, which were ignored - for (size_type col = 0; col < ncols_; col++) { - eroded[col] = max(buffer_[col], buffer_[col + ncols_]); - eroded[(nrows_ - 1) * ncols_ + col] = - max(buffer_[(nrows_ - 1) * ncols_ + col], - buffer_[(nrows_ - 2) * ncols_ + col]); - } - } else { - // Zero out the first and last row (they are always eroded away) - for (size_type col = 0; col < ncols_; col++) { - eroded[col] = 0; - eroded[(nrows_ - 1) * ncols_ + col] = 0; - } - } - } - } + */ + inline void compute(T *data, T *eroded, size_type iterations, + bool dilate = false) { + for (size_type iter = 0; iter != iterations; ++iter) { + T *in; + if (!iter) { + in = data; // First pass - read from the input buffer + } else { + in = eroded; // Subsequent pass - read from the output buffer + } + // Rows (ignoring the first and last column) + for (size_type i = 0; i != nrows_; ++i) { + T *b = buffer_ + i * ncols_ + 1; // Write to b in the buffer + T *d = in + i * ncols_; // Start reading from one to the left of b + while (b != (buffer_ + (i + 1) * ncols_ - 1)) { + if (dilate) { + *b = max(max(*d, *(d + 1)), *(d + 2)); + } else { + *b = min(min(*d, *(d + 1)), *(d + 2)); + } + b++; + d++; + } + } + if (dilate) { + // Need to fill the first and last column, which were ignored + for (size_type row = 0; row < nrows_; row++) { + buffer_[row * ncols_] = max(in[row * ncols_], in[row * ncols_ + 1]); + buffer_[(row + 1) * ncols_ - 1] = + max(in[(row + 1) * ncols_ - 2], in[(row + 1) * ncols_ - 1]); + } + } else { + // Zero out the first and last column (they are always eroded away) + for (size_type row = 0; row < nrows_; row++) { + buffer_[row * ncols_] = 0; + buffer_[(row + 1) * ncols_ - 1] = 0; + } + } + // Columns (ignoring the first and last row) + for (size_type i = 0; i != ncols_; ++i) { + T *b = eroded + i + ncols_; // Write to b in the output + T *d = buffer_ + i; // Start reading from one above b + while (b != (eroded + i + ncols_ * (nrows_ - 1))) { + if (dilate) { + *b = max(max(*d, *(d + ncols_)), *(d + ncols_ * 2)); + } else { + *b = min(min(*d, *(d + ncols_)), *(d + ncols_ * 2)); + } + b += ncols_; + d += ncols_; + } + } + if (dilate) { + // Need to fill the first and last row, which were ignored + for (size_type col = 0; col < ncols_; col++) { + eroded[col] = max(buffer_[col], buffer_[col + ncols_]); + eroded[(nrows_ - 1) * ncols_ + col] = + max(buffer_[(nrows_ - 1) * ncols_ + col], + buffer_[(nrows_ - 2) * ncols_ + col]); + } + } else { + // Zero out the first and last row (they are always eroded away) + for (size_type col = 0; col < ncols_; col++) { + eroded[col] = 0; + eroded[(nrows_ - 1) * ncols_ + col] = 0; + } + } + } + } }; //-------------------------------------------------------------------------------- -#endif //NTA_EROSION_HPP - +#endif // NTA_EROSION_HPP diff --git a/src/nupic/math/Functions.hpp b/src/nupic/math/Functions.hpp index 749f5a41a9..e8656c7874 100644 --- a/src/nupic/math/Functions.hpp +++ b/src/nupic/math/Functions.hpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Declarations for math functions */ @@ -29,97 +29,79 @@ #include // For NTA_ASSERT -#include -#include -#include #include +#include #include +#include +#include namespace nupic { - // TODO: replace other functions by boost/math +// TODO: replace other functions by boost/math - static const double pi = 3.14159265358979311600e+00; +static const double pi = 3.14159265358979311600e+00; - //-------------------------------------------------------------------------------- - template - inline T lgamma(T x) - { - return boost::math::lgamma(x); - } +//-------------------------------------------------------------------------------- +template inline T lgamma(T x) { return boost::math::lgamma(x); } - //-------------------------------------------------------------------------------- - template - inline T digamma(T x) - { - return boost::math::digamma(x); - } +//-------------------------------------------------------------------------------- +template inline T digamma(T x) { return boost::math::digamma(x); } - //-------------------------------------------------------------------------------- - template - inline T beta(T x, T y) - { - return boost::math::beta(x, y); - } +//-------------------------------------------------------------------------------- +template inline T beta(T x, T y) { + return boost::math::beta(x, y); +} - //-------------------------------------------------------------------------------- - template - inline T erf(T x) - { - return boost::math::erf(x); - } +//-------------------------------------------------------------------------------- +template inline T erf(T x) { return boost::math::erf(x); } - //-------------------------------------------------------------------------------- - double fact(unsigned long n) - { - static double a[171]; - static bool init = true; - - if (init) { - init = false; - a[0] = 1.0; - for (size_t i = 1; i != 171; ++i) - a[i] = i * a[i-1]; - } - - if (n < 171) - return a[n]; - else - return exp(lgamma(n+1.0)); +//-------------------------------------------------------------------------------- +double fact(unsigned long n) { + static double a[171]; + static bool init = true; + + if (init) { + init = false; + a[0] = 1.0; + for (size_t i = 1; i != 171; ++i) + a[i] = i * a[i - 1]; } - //-------------------------------------------------------------------------------- - double lfact(unsigned long n) - { - static double a[2000]; - static bool init = true; - - if (init) { - for (size_t i = 0; i != 2000; ++i) - a[i] = lgamma(i+1.0); - } - - if (n < 2000) - return a[n]; - else - return lgamma(n+1.0); + if (n < 171) + return a[n]; + else + return exp(lgamma(n + 1.0)); +} + +//-------------------------------------------------------------------------------- +double lfact(unsigned long n) { + static double a[2000]; + static bool init = true; + + if (init) { + for (size_t i = 0; i != 2000; ++i) + a[i] = lgamma(i + 1.0); } - //-------------------------------------------------------------------------------- - double binomial(unsigned long n, unsigned long k) + if (n < 2000) + return a[n]; + else + return lgamma(n + 1.0); +} + +//-------------------------------------------------------------------------------- +double binomial(unsigned long n, unsigned long k) { { - { - NTA_ASSERT(k <= n) - << "binomial: Wrong arguments: n= " << n << " k= " << k; - } - - if (n < 171) - return floor(0.5 + fact(n) / (fact(k) * fact(n-k))); - else - return floor(0.5 + exp(lfact(n) - lfact(k) - lfact(n-k))); + NTA_ASSERT(k <= n) << "binomial: Wrong arguments: n= " << n << " k= " << k; } - //-------------------------------------------------------------------------------- -}; + if (n < 171) + return floor(0.5 + fact(n) / (fact(k) * fact(n - k))); + else + return floor(0.5 + exp(lfact(n) - lfact(k) - lfact(n - k))); +} + +//-------------------------------------------------------------------------------- +}; // namespace nupic -#endif //NTA_MATH_FUNCTIONS_HPP +#endif // NTA_MATH_FUNCTIONS_HPP diff --git a/src/nupic/math/GraphAlgorithms.hpp b/src/nupic/math/GraphAlgorithms.hpp index e124436d51..80c1461566 100644 --- a/src/nupic/math/GraphAlgorithms.hpp +++ b/src/nupic/math/GraphAlgorithms.hpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Definition and implementation of graph algorithms */ @@ -31,307 +31,305 @@ * Graph utilities. Not currently used in production code. */ -#include #include -#include -#include #include #include +#include +#include +#include namespace nupic { - //-------------------------------------------------------------------------------- - // GRAPH ALGORITHMS - //-------------------------------------------------------------------------------- - - typedef std::vector Sequence; - typedef std::list Sequences; - - //-------------------------------------------------------------------------------- - /** - * Enumerates all the sequences in this matrix by following edges that - * have a value greater than threshold th. - * - * row_sums = tam.rowSums(); col_sums = tam.colSums() - * ratios = [c/r if r != 0 else 0 for r,c in zip(row_sums,col_sums)] - * front = [[i] for i in range(tam.nRows()) if ratios[i] >= threshold] - * seqs = [] - * - * while front: - * subseq = front[0]; l = len(subseq); p = subseq[-1] - * cands = tam.colNonZeros(p)[0] - * next = [n for n in cands if tam[n,p] > min_count and n not in subseq] - * if next: - * for n in next: - * front.insert(1, subseq+[n]) - * elif l > 1 and is_sublist_of(subseq, seqs) == -1: - * seqs.append(subseq) - * front.pop(0) - */ - template - inline void - EnumerateSequences(typename SM::value_type th, const SM& g, Sequences& sequences, - int rowsOrCols=0, int noSubsequences=0) - { - using namespace std; - - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; - - const size_type N = rowsOrCols == 0 ? g.nCols() : g.nRows(); - Sequences front; - vector ind(N); - vector nz(N); - - for (size_type i = 0; i != N; ++i) { - Sequence s; s.push_back(i); front.push_back(s); - } +//-------------------------------------------------------------------------------- +// GRAPH ALGORITHMS +//-------------------------------------------------------------------------------- - while (!front.empty()) { - Sequence ss = front.front(); - bool more = false; - size_type l = ss.size(), p = ss[l-1]; - size_type n = 0; - if (rowsOrCols == 1) - n = g.getColToSparse(p, ind.begin(), nz.begin()); - else - n = g.getRowToSparse(p, ind.begin(), nz.begin()); - for (size_type i = 0; i != n; ++i) { - if (nz[i] > th && !contains(ss, ind[i])) { - more = true; - Sequence new_seq(ss); - new_seq.push_back(ind[i]); - front.insert(++front.begin(), new_seq); - } - } - if (!more && l > 1) { - if (noSubsequences && !is_subsequence_of(sequences, ss)) - sequences.push_back(ss); - else - sequences.push_back(ss); +typedef std::vector Sequence; +typedef std::list Sequences; + +//-------------------------------------------------------------------------------- +/** + * Enumerates all the sequences in this matrix by following edges that + * have a value greater than threshold th. + * + * row_sums = tam.rowSums(); col_sums = tam.colSums() + * ratios = [c/r if r != 0 else 0 for r,c in zip(row_sums,col_sums)] + * front = [[i] for i in range(tam.nRows()) if ratios[i] >= threshold] + * seqs = [] + * + * while front: + * subseq = front[0]; l = len(subseq); p = subseq[-1] + * cands = tam.colNonZeros(p)[0] + * next = [n for n in cands if tam[n,p] > min_count and n not in subseq] + * if next: + * for n in next: + * front.insert(1, subseq+[n]) + * elif l > 1 and is_sublist_of(subseq, seqs) == -1: + * seqs.append(subseq) + * front.pop(0) + */ +template +inline void EnumerateSequences(typename SM::value_type th, const SM &g, + Sequences &sequences, int rowsOrCols = 0, + int noSubsequences = 0) { + using namespace std; + + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; + + const size_type N = rowsOrCols == 0 ? g.nCols() : g.nRows(); + Sequences front; + vector ind(N); + vector nz(N); + + for (size_type i = 0; i != N; ++i) { + Sequence s; + s.push_back(i); + front.push_back(s); + } + + while (!front.empty()) { + Sequence ss = front.front(); + bool more = false; + size_type l = ss.size(), p = ss[l - 1]; + size_type n = 0; + if (rowsOrCols == 1) + n = g.getColToSparse(p, ind.begin(), nz.begin()); + else + n = g.getRowToSparse(p, ind.begin(), nz.begin()); + for (size_type i = 0; i != n; ++i) { + if (nz[i] > th && !contains(ss, ind[i])) { + more = true; + Sequence new_seq(ss); + new_seq.push_back(ind[i]); + front.insert(++front.begin(), new_seq); } - front.pop_front(); } + if (!more && l > 1) { + if (noSubsequences && !is_subsequence_of(sequences, ss)) + sequences.push_back(ss); + else + sequences.push_back(ss); + } + front.pop_front(); } +} + +//-------------------------------------------------------------------------------- +/** + * Finds connected components using a threshold. + * The returned components are not sorted. + * + * groups = [] + * cands = set(range(tam.nRows())) + * + * while cands: + * front = [cands.pop()] + * more = True + * while more: + * new_front = [] + * for x in front: + * cnz = tam.colNonZeros(x); rnz = tam.rowNonZeros(x) + * for n in [i for i,v in zip(cnz[0],cnz[1]) if v > th and i in cands]: + * new_front += [n]; cands.remove(n) + * for n in [i for i,v in zip(rnz[0],rnz[1]) if v > th and i in cands]: + * new_front += [n]; cands.remove(n) + * if len(new_front) > 0: + * front += new_front + * else: + * groups.append(front) + * more = False + * + */ +/* + groups = [] + cands = set(range(tam.nRows())) + ttam = copy.deepcopy(tam) + ttam.transpose() + + while cands: + + front = set([cands.pop()]) + groups.append(list(front)) + g = groups[-1] + + while front: + new_front = set([]) + for x in front: + rnz = tam.rowNonZeros(x); cnz = ttam.rowNonZeros(x) + l = zip(rnz[0],rnz[1]) + zip(cnz[0],cnz[1]) + for n,v in l: + if v > th and n in cands: + new_front.add(n); g += [n]; cands.remove(n) + front = new_front +*/ +template +inline void FindConnectedComponents(typename SM::value_type th, const SM &g, + Sequences &components) { + using namespace std; + using namespace boost; + + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; + typedef unordered_set Set; + + const size_type N = g.nRows(); + + vector ind(2 * N); + vector nz(2 * N); - //-------------------------------------------------------------------------------- - /** - * Finds connected components using a threshold. - * The returned components are not sorted. - * - * groups = [] - * cands = set(range(tam.nRows())) - * - * while cands: - * front = [cands.pop()] - * more = True - * while more: - * new_front = [] - * for x in front: - * cnz = tam.colNonZeros(x); rnz = tam.rowNonZeros(x) - * for n in [i for i,v in zip(cnz[0],cnz[1]) if v > th and i in cands]: - * new_front += [n]; cands.remove(n) - * for n in [i for i,v in zip(rnz[0],rnz[1]) if v > th and i in cands]: - * new_front += [n]; cands.remove(n) - * if len(new_front) > 0: - * front += new_front - * else: - * groups.append(front) - * more = False - * - */ - /* - groups = [] - cands = set(range(tam.nRows())) - ttam = copy.deepcopy(tam) - ttam.transpose() - - while cands: - - front = set([cands.pop()]) - groups.append(list(front)) - g = groups[-1] - - while front: - new_front = set([]) - for x in front: - rnz = tam.rowNonZeros(x); cnz = ttam.rowNonZeros(x) - l = zip(rnz[0],rnz[1]) + zip(cnz[0],cnz[1]) - for n,v in l: - if v > th and n in cands: - new_front.add(n); g += [n]; cands.remove(n) - front = new_front - */ - template - inline void - FindConnectedComponents(typename SM::value_type th, const SM& g, Sequences& components) - { - using namespace std; - using namespace boost; - - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; - typedef unordered_set Set; - - const size_type N = g.nRows(); - - vector ind(2*N); - vector nz(2*N); - - Set cands; - typename Set::iterator x, w; - - for (size_type i = 0; i != N; ++i) - cands.insert(i); - - SM tg; - g.transpose(tg); - - while (!cands.empty()) { - - size_type seed = *cands.begin(); - cands.erase(seed); - - Sequence group; - group.push_back(seed); - - Set front; - front.insert(seed); - - while (!front.empty()) { - - Set new_front; - - for (x = front.begin(); x != front.end(); ++x) { - - size_type n = g.getRowToSparse(*x, ind.begin(), nz.begin()); - n += tg.getRowToSparse(*x, ind.begin()+n, nz.begin()+n); - - for (size_type j = 0; j != n; ++j) { - size_type y = ind[j]; - if (nz[j] > th && (w = cands.find(y)) != cands.end()) { - new_front.insert(y); group.push_back(y); - cands.erase(w); - } + Set cands; + typename Set::iterator x, w; + + for (size_type i = 0; i != N; ++i) + cands.insert(i); + + SM tg; + g.transpose(tg); + + while (!cands.empty()) { + + size_type seed = *cands.begin(); + cands.erase(seed); + + Sequence group; + group.push_back(seed); + + Set front; + front.insert(seed); + + while (!front.empty()) { + + Set new_front; + + for (x = front.begin(); x != front.end(); ++x) { + + size_type n = g.getRowToSparse(*x, ind.begin(), nz.begin()); + n += tg.getRowToSparse(*x, ind.begin() + n, nz.begin() + n); + + for (size_type j = 0; j != n; ++j) { + size_type y = ind[j]; + if (nz[j] > th && (w = cands.find(y)) != cands.end()) { + new_front.insert(y); + group.push_back(y); + cands.erase(w); } } - front.swap(new_front); } - components.push_back(group); + front.swap(new_front); } + components.push_back(group); } - - //-------------------------------------------------------------------------------- - /** - * This for unit testing. - * The returned components are sorted. - */ - template - inline void FindConnectedComponents_boost(const SM& sm, Sequences& components) - { - using namespace std; - using namespace boost; - - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; - typedef adjacency_list Graph; - - Graph G(sm.nCols()); - size_type n = sm.nNonZeros(); - vector nz_i(n), nz_j(n); - vector nz_v(n); - - sm.getAllNonZeros(nz_i.begin(),nz_j.begin(),nz_v.begin()); - - for (size_type i = 0; i != n; ++i) - add_edge(nz_i[i],nz_j[i], G); - - std::vector component(num_vertices(G)); - int num = connected_components(G, &component[0]); - - vector c(num); - - for (size_type i = 0; i != component.size(); ++i) - c[component[i]].push_back(i); - - for (int i = 0; i != num; ++i) - components.push_back(c[i]); +} + +//-------------------------------------------------------------------------------- +/** + * This for unit testing. + * The returned components are sorted. + */ +template +inline void FindConnectedComponents_boost(const SM &sm, Sequences &components) { + using namespace std; + using namespace boost; + + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; + typedef adjacency_list Graph; + + Graph G(sm.nCols()); + size_type n = sm.nNonZeros(); + vector nz_i(n), nz_j(n); + vector nz_v(n); + + sm.getAllNonZeros(nz_i.begin(), nz_j.begin(), nz_v.begin()); + + for (size_type i = 0; i != n; ++i) + add_edge(nz_i[i], nz_j[i], G); + + std::vector component(num_vertices(G)); + int num = connected_components(G, &component[0]); + + vector c(num); + + for (size_type i = 0; i != component.size(); ++i) + c[component[i]].push_back(i); + + for (int i = 0; i != num; ++i) + components.push_back(c[i]); +} + +//-------------------------------------------------------------------------------- +//-------------------------------------------------------------------------------- +template +inline void CuthillMcKeeOrdering(const SM &sm, OutputIterator p, + OutputIterator rp) { + using namespace std; + using namespace boost; + + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; + + typedef adjacency_list>> + Graph; + typedef graph_traits::vertex_descriptor Vertex; + + const size_type nrows = sm.nRows(); + + Graph G(nrows); + + for (size_type i = 0; i != nrows; ++i) { + const size_type *ind = sm.row_nz_index_begin(i); + const size_type *ind_end = sm.row_nz_index_end(i); + for (; ind != ind_end; ++ind) + add_edge(i, *ind, G); } - - //-------------------------------------------------------------------------------- - //-------------------------------------------------------------------------------- - template - inline void CuthillMcKeeOrdering(const SM& sm, OutputIterator p, OutputIterator rp) - { - using namespace std; - using namespace boost; - - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; - - typedef adjacency_list > > Graph; - typedef graph_traits::vertex_descriptor Vertex; - - const size_type nrows = sm.nRows(); - - Graph G(nrows); - - for (size_type i = 0; i != nrows; ++i) { - const size_type* ind = sm.row_nz_index_begin(i); - const size_type* ind_end = sm.row_nz_index_end(i); - for (; ind != ind_end; ++ind) - add_edge(i, *ind, G); - } - graph_traits::vertex_iterator ui, ui_end; + graph_traits::vertex_iterator ui, ui_end; - property_map::type deg = get(vertex_degree, G); - for (boost::tie(ui, ui_end) = vertices(G); ui != ui_end; ++ui) - deg[*ui] = degree(*ui, G); + property_map::type deg = get(vertex_degree, G); + for (boost::tie(ui, ui_end) = vertices(G); ui != ui_end; ++ui) + deg[*ui] = degree(*ui, G); - property_map::type - index_map = get(vertex_index, G); + property_map::type index_map = get(vertex_index, G); - std::cout << "original bandwidth: " << bandwidth(G) << std::endl; + std::cout << "original bandwidth: " << bandwidth(G) << std::endl; - std::vector inv_perm(num_vertices(G)); - std::vector perm(num_vertices(G)); + std::vector inv_perm(num_vertices(G)); + std::vector perm(num_vertices(G)); - size_type best = nrows; + size_type best = nrows; - for (size_type i = 1; i != nrows; ++i) { + for (size_type i = 1; i != nrows; ++i) { - Vertex s = vertex(i, G); + Vertex s = vertex(i, G); - //reverse cuthill_mckee_ordering - cuthill_mckee_ordering(G, s, inv_perm.rbegin(), get(vertex_color, G), - get(vertex_degree, G)); - /* - cout << " "; - for (std::vector::const_iterator i = inv_perm.begin(); - i != inv_perm.end(); ++i) - cout << index_map[*i] << " "; - */ + // reverse cuthill_mckee_ordering + cuthill_mckee_ordering(G, s, inv_perm.rbegin(), get(vertex_color, G), + get(vertex_degree, G)); + /* + cout << " "; + for (std::vector::const_iterator i = inv_perm.begin(); + i != inv_perm.end(); ++i) + cout << index_map[*i] << " "; + */ - for (size_type c = 0; c != inv_perm.size(); ++c) - perm[index_map[inv_perm[c]]] = c; + for (size_type c = 0; c != inv_perm.size(); ++c) + perm[index_map[inv_perm[c]]] = c; - size_type bw = - bandwidth(G, make_iterator_property_map(&perm[0], index_map, perm[0])); + size_type bw = + bandwidth(G, make_iterator_property_map(&perm[0], index_map, perm[0])); - if (bw < best) { - best = bw; - std::cout << "bandwidth: " - << bw - << std::endl; - std::copy(perm.begin(), perm.end(), p); - std::copy(inv_perm.begin(), inv_perm.end(), rp); - } + if (bw < best) { + best = bw; + std::cout << "bandwidth: " << bw << std::endl; + std::copy(perm.begin(), perm.end(), p); + std::copy(inv_perm.begin(), inv_perm.end(), rp); } } +} - //-------------------------------------------------------------------------------- +//-------------------------------------------------------------------------------- } // end namespace nupic -#endif //GRAPH_ALGORITHMS +#endif // GRAPH_ALGORITHMS diff --git a/src/nupic/math/Index.hpp b/src/nupic/math/Index.hpp index bb570325ab..ac2e5b85f6 100644 --- a/src/nupic/math/Index.hpp +++ b/src/nupic/math/Index.hpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Definition and implementation for Index class */ @@ -32,682 +32,166 @@ #include #include // for memcpy in gcc 4.4 -#include #include #include +#include //---------------------------------------------------------------------- - namespace nupic { +/** + * @b Responsibility + * Index is a multi-dimensional index, consisting of a series of integers. + * It is a fixed size index, where the size is fixed at compile time. + * The size is the parameter NDims of the template. UInt is the type + * of integers stored in the index. + * + * @b Rationale + * Index is useful when working multi-dimensional SparseTensors. + * + * @b Notes + * NDims > 0 (that is, 0 is not allowed...) + */ +template class Index { +public: + typedef UInt value_type; + typedef UInt *iterator; + typedef const UInt *const_iterator; + /** - * @b Responsibility - * Index is a multi-dimensional index, consisting of a series of integers. - * It is a fixed size index, where the size is fixed at compile time. - * The size is the parameter NDims of the template. UInt is the type - * of integers stored in the index. - * - * @b Rationale - * Index is useful when working multi-dimensional SparseTensors. - * - * @b Notes - * NDims > 0 (that is, 0 is not allowed...) + * Default constructor. + * Creates an index initialized to zero. */ - template - class Index - { - public: - typedef UInt value_type; - typedef UInt* iterator; - typedef const UInt* const_iterator; - - /** - * Default constructor. - * Creates an index initialized to zero. - */ - inline Index() - { - memset(i_, 0, NDims*sizeof(value_type)); - } - - /** - * Constructor from an array. - * Creates an index that has the values in the given array. - * - * @param i [Uint[NDims] ] the values to initialize the index with - */ - explicit inline Index(const UInt i[NDims]) - { - memcpy(i_, i, NDims*sizeof(value_type)); - } - - /** - * Constructor from a list. - * Creates an index initialized with the values passed in the list. - * - * @param i0 [Uint...] the list of values to initialize the index with - */ - explicit inline Index(UInt i0, ...) - { - i_[0] = i0; - - va_list indices; - va_start(indices, i0); - for (UInt k = 1; k < NDims; ++k) - { - value_type v = va_arg(indices, value_type); - i_[k] = v; - } - va_end(indices); - } - - /** - * Constructor from bounds and an ordinal. - * This constructor builds the index that corresponds to the - * given ordinal, with the given bounds. - * - * @param bounds [Index] the bounds to use to compute the index - * @param ordinal [UInt] the ordinal that will correspond to - * this index - */ - explicit inline Index(const Index& bounds, const value_type& ordinal) - { - fromOrdinal(bounds, ordinal); - } - - inline Index(const std::vector idx) - { - for (UInt k = 0; k < NDims; ++k) - i_[k] = idx[k]; - } - - /** - * Copy constructor. - * - * @param from [Index] the index to copy - */ - inline Index(const Index& from) - { - if (&from != this) - memcpy(i_, from.i_, NDims*sizeof(value_type)); - } - - /** - * Assignment operator. - * - * @param from [Index] the index to copy - */ - inline Index& operator=(const Index& from) - { - if (&from != this) - memcpy(i_, from.i_, NDims*sizeof(UInt)); - return *this; - } - - inline iterator begin() { return i_; } - inline iterator end() { return begin() + NDims; } - inline const_iterator begin() const { return i_; } - inline const_iterator end() const { return begin() + NDims; } - inline UInt size() const { return NDims; } - - inline UInt max() const - { - UInt M = 0; - for (UInt i = 0; i < NDims; ++i) - if (i_[i] > M) - M = i_[i]; - return M; - } - - /** - * Indexing operator. - * - * @param idx [0 <= UInt < NDims] index - * @retval [UInt&] the value at index 'idx' - */ - inline UInt& operator[](const UInt idx) - { - { - NTA_ASSERT(idx >= 0 && idx < NDims) - << "Index::operator[] " - << "Invalid index: " << idx - << " - Should be in [0.." << NDims << ")"; - } - - return i_[idx]; - } - - /** - * Const indexing operator. - * - * @param idx [0 <= UInt < NDims] index - * @retval [const UInt] the value at index 'idx' - */ - inline UInt operator[](const UInt& idx) const - { - { - NTA_ASSERT(idx >= 0 && idx < NDims) - << "Index::operator[] const " - << "Invalid index: " << idx - << " - Should be in [0.." << NDims << ")"; - } - - return i_[idx]; - } - - /** - * Resets the whole index to all zeros. - */ - inline void setToZero() - { - memset(i_, 0, NDims*sizeof(UInt)); - } - - /** - * Returns whether the values in this index constitute a set or not - * (there are no duplicates). - */ - inline bool isSet() const - { - std::set s; - for (UInt i = 0; i < NDims; ++i) - s.insert(i_[i]); - return s.size() == NDims; - } - - /** - * Increments this index, using bounds as the upper bound - * for the iteration. - * - * @param bounds [Index] the upper bound - * @retval bool whether we've reached the end of the iteration or not - */ - inline bool increment(const Index& bounds) - { - int curr = NDims-1; - ++i_[curr]; - while (i_[curr] >= bounds[curr]) { - i_[curr] = 0; - --curr; - if (curr < 0) - return false; - else - ++i_[curr]; - } - return true; - } - - /** - * Increment this index, using lb and ub as lower and upper - * bounds for the iteration. - * - * @param lb [index] the lower bound - * @param ub [Index] the upper bound - * @retval bool whether we've reached the end of the iteration or not - */ - inline bool increment(const Index& lb, const Index& ub) - { - int curr = NDims-1; - ++i_[curr]; - while (i_[curr] >= ub[curr]) { - i_[curr] = lb[curr]; - --curr; - if (curr < 0) - return false; - else - ++i_[curr]; - } - return true; - } - - /** - * Computes the ordinal corresponding to the "natural" - * order for this index. For example, with bounds = [3, 2], - * [0, 0] -> 0, [0, 1] -> 1, [1, 0] -> 2, [1, 1] -> 3, - * [2, 0] -> 4, [2, 1] -> 5. - * - * @param bounds [Index] the upper bound to use to compute - * the ordinal - * @retval UInt the ordinal for this index - */ - inline UInt ordinal(const Index& bounds) const - { - { - NTA_ASSERT(indexGtZero(bounds)); - } - - if (NDims == 1) // do specialization - return i_[0]; - - UInt p = bounds[NDims-1], pos = i_[NDims-1]; - - for (int k = NDims-2; k >= 1; p *= bounds[k], --k) - pos += i_[k] * p; - pos += i_[0] * p; - - return pos; - } - - inline void fromOrdinal(const Index& bounds, const value_type& ordinal) - { - { - NTA_ASSERT(indexGtZero(bounds)); - } - - value_type o = ordinal, p = bounds.product() / bounds[0]; - //TODO optimize / - for (UInt k = 0; k < NDims-1; o %= p, p /= bounds[k+1], ++k) - i_[k] = o / p; - i_[NDims-1] = o; - } - - /** - * Computes the stride for dimension dim of this Index. - * If the Index is :[6, 7, 5, 4], then: - * stride(0) = 7*5*4, - * stride(1) = 5*4, - * stride(2) = 4, - * stride(3) = 1. - * - * @param dim [UInt] the dim for which we want the stride - * @retval UInt the stride - */ - inline UInt stride(const UInt& dim) const - { - if (dim == NDims-1) - return 1; - - UInt s = i_[dim+1]; - for (UInt i = dim+2; i < NDims; ++i) - s*= i_[i]; - return s; - } - - /** - * Computes the distance between two indices, with respect to - * a given upper bound. That is: - * distance = other.ordinal(bounds) - this->ordinal(bounds). - * - * @param bounds [Index] the upper bound - * @param other [Index] the second index - * @retval UInt the distance between this and the second index - */ - inline UInt distance(const Index& bounds, const Index& other) const - { - return other.ordinal(bounds) - ordinal(bounds); - } - - /** - * Computes the product of all the values in this index. - * The result can be zero, if at least one of the indices - * is zero. - * - * @retval UInt the product - */ - inline UInt product() const - { - UInt n = i_[0]; - for (UInt k = 1; k < NDims; ++k) - n *= i_[k]; - return n; - } - - /** - * Computes the complement of this index. - * - * For example: - * (*this) [0, 2, 4] -(N=6)-> (idx) [1, 3, 5] - * (*this) [0, 2] -(N=3)-> (idx) [1] - * (*this) [0] -(N=2)-> (idx) [1] - * (*this) [0, 1] -(N=3)-> (idx) [2] - * - * @param idx [Index] the complement of this index - */ - template - inline void complement(Index& idx) const - { - const UInt N = NDims + R; - UInt k = 0, k1 = 0, k2 = 0; - - for (k = 0; k < NDims; ++k) { - for (; k1 < i_[k]; ++k1) - idx[k2++] = k1; - k1 = i_[k]+1; - } - - while (k1 < N) - idx[k2++] = k1++; - } - - /** - * Computes the projection of this index on to the dimensions - * specified: - * idx2[k] = (*this)[dims[k]], for k in [0..R). - * - * @param dims [Index] the dimensions to project onto - * @param idx2 [Index] the projection of this index - */ - template - inline void project(const Index& dims, Index& idx2) const - { - { - NTA_ASSERT(R <= NDims) - << "Index::project(): " - << "Invalid number of dimensions to project on: " << R - << " - Should be less than: " << NDims; - - for (UInt k = 0; k < R-1; ++k) - NTA_ASSERT(dims[k] < dims[k+1]) - << "Index::project(): " - << "Dimensions need to be in strictly increasing order, " - << "passed: " << dims; - - NTA_ASSERT(0 <= dims[0] && dims[R-1] <= NDims) - << "Index::project(): " - << "Invalid dimensions: " << dims - << " when projecting in: [0.." << R << ")"; - } - - for (UInt k = 0; k < R; ++k) - idx2[k] = i_[dims[k]]; - } - - /** - * Embeds the current index into an index of higher dimension: - * idx2[dims[k]] = (*this)[k], for k in [0..R). - * - * @param dims [Index] the dimensions to embed into - * @param idx [Index] the embedding of this index - */ - template - inline void embed(const Index& dims, Index& idx2) const - { - { - NTA_ASSERT(R2 >= NDims) - << "Index::embed(): " - << "Invalid number of dimensions to embed into: " << R2 - << " - Should be >= " << NDims; - - for (UInt k = 0; k < R-1; ++k) - NTA_ASSERT(dims[k] < dims[k+1]) - << "Index::embed(): " - << "Dimensions need to be in strictly increasing order, " - << "passed: " << dims; - - NTA_ASSERT(0 <= dims[0] && dims[R-1] <= R2) - << "Index::embed(): " - << "Invalid dimensions: " << dims - << " when embedding in: [0.." << R2 << ")"; - } - - for (UInt k = 0; k < R; ++k) - idx2[dims[k]] = i_[k]; - } - - /** - * Permutes this index according to the order specified in ind. - * Examples: - * [1 2 3 4 5] becomes [2 3 4 5 1] with ind = [1 2 3 4 0] - * [1 2 3] becomes [3 1 2] with ind = [2 1 0] - * - * @param ind [Index] the new order - * @param perm [Index] the resulting permutation - */ - inline void permute(const Index& ind, Index& perm) const - { - { - checkPermutation_(ind); - } - - for (UInt k = 0; k < NDims; ++k) - perm[k] = i_[ind[k]]; - } - - inline Index permute(const Index& ind) const - { - { - checkPermutation_(ind); - } + inline Index() { memset(i_, 0, NDims * sizeof(value_type)); } - Index perm; - permute(ind, perm); - return perm; - } - - /** - * Finds the permutation that transforms this index into perm. - * If this index is [1 2 3 4 5] and perm is [2 3 4 5 1], - * the permutation is [1 2 3 4 0] - * Slow: O(NDims^2) - */ - inline void findPermutation(Index& ind, const Index& perm) const - { - for (UInt k = 0; k < NDims; ++k) - for (UInt k1 = 0; k1 < NDims; ++k1) - if (perm[k1] == i_[k]) - ind[k1] = k; - } - - inline Index findPermutation(const Index& perm) const - { - Index ind; - findPermutation(ind, perm); - return ind; - } + /** + * Constructor from an array. + * Creates an index that has the values in the given array. + * + * @param i [Uint[NDims] ] the values to initialize the index with + */ + explicit inline Index(const UInt i[NDims]) { + memcpy(i_, i, NDims * sizeof(value_type)); + } - /** - * Returns whether any of the values in this index is a zero. - * (That is, the "index" cannot be used to describe the dimensions - * of a tensor). - */ - inline bool hasZero() const - { - for (UInt i = 0; i < NDims; ++i) - if (i_[i] == 0) - return true; - return false; - } + /** + * Constructor from a list. + * Creates an index initialized with the values passed in the list. + * + * @param i0 [Uint...] the list of values to initialize the index with + */ + explicit inline Index(UInt i0, ...) { + i_[0] = i0; - /** - * Streaming operator (for debugging). - */ - //template - //NTA_HIDDEN friend std::ostream& operator<<(std::ostream&, const Index&); - UInt i_[NDims]; - private: - - /** - * This method becomes empty and optimized away when assertions are turned off. - */ - inline void checkPermutation_(const Index& ind) const - { -#ifdef NTA_ASSERTIONS_ON - std::set s; - for (UInt k = 0; k < NDims; ++k) { - NTA_ASSERT(ind[k] >= 0 && ind[k] < NDims); - s.insert(ind[k]); - } - NTA_ASSERT(s.size() == NDims); -#endif + va_list indices; + va_start(indices, i0); + for (UInt k = 1; k < NDims; ++k) { + value_type v = va_arg(indices, value_type); + i_[k] = v; } - }; - - //-------------------------------------------------------------------------------- - template - std::ostream& operator<<(std::ostream& outStream, const Index& i) - { - outStream << "["; - for (UInt k = 0; k < NDims; ++k) - outStream << i.i_[k] << (k < NDims-1 ? "," : ""); - return outStream << "]"; + va_end(indices); } - template - inline Index concatenate(const Index& i1, const Index& i2) - { - Index newIndex; - for (I k = 0; k < n1; ++k) - newIndex[k] = i1[k]; - for (I k = 0; k < n2; ++k) - newIndex[n1+k] = i2[k]; - return newIndex; - } - - template - inline std::vector concatenate(const std::vector& i1, - const std::vector& i2) - { - std::vector newIndex(i1.size() + i2.size(), 0); - for (UInt k = 0; k < i1.size(); ++k) - newIndex[k] = i1[k]; - for (UInt k = 0; k < i2.size(); ++k) - newIndex[i1.size()+k] = i2[k]; - return newIndex; + /** + * Constructor from bounds and an ordinal. + * This constructor builds the index that corresponds to the + * given ordinal, with the given bounds. + * + * @param bounds [Index] the bounds to use to compute the index + * @param ordinal [UInt] the ordinal that will correspond to + * this index + */ + explicit inline Index(const Index &bounds, const value_type &ordinal) { + fromOrdinal(bounds, ordinal); } - template - inline void setToZero(Index& idx) - { - const UInt NDims = (UInt)idx.size(); - for (UInt i = 0; i < NDims; ++i) - idx[i] = 0; + inline Index(const std::vector idx) { + for (UInt k = 0; k < NDims; ++k) + i_[k] = idx[k]; } /** - * Returns whether the values in this index constitute a set or not - * (there are no duplicates). + * Copy constructor. + * + * @param from [Index] the index to copy */ - template - inline bool isSet(const Index& idx) - { - std::set s; - const UInt NDims = (UInt)idx.size(); - for (UInt i = 0; i < NDims; ++i) - s.insert(idx[i]); - return s.size() == NDims; + inline Index(const Index &from) { + if (&from != this) + memcpy(i_, from.i_, NDims * sizeof(value_type)); } /** - * Returns whether any of the values in this index is a zero. - * (That is, the "index" cannot be used to describe the dimensions - * of a tensor). + * Assignment operator. + * + * @param from [Index] the index to copy */ - template - inline bool hasZero(const Index& idx) - { - const UInt NDims = idx.size(); - for (UInt i = 0; i < NDims; ++i) - if (idx[i] == 0) - return true; - return false; + inline Index &operator=(const Index &from) { + if (&from != this) + memcpy(i_, from.i_, NDims * sizeof(UInt)); + return *this; } - template - inline bool isZero(const Index& idx) - { - const UInt NDims = idx.size(); - for (UInt i = 0; i < NDims; ++i) - if (idx[i] != 0) - return false; - return true; - } - - template - inline bool indexGtZero(const Index& idx) - { - const UInt NDims = idx.size(); + inline iterator begin() { return i_; } + inline iterator end() { return begin() + NDims; } + inline const_iterator begin() const { return i_; } + inline const_iterator end() const { return begin() + NDims; } + inline UInt size() const { return NDims; } + + inline UInt max() const { + UInt M = 0; for (UInt i = 0; i < NDims; ++i) - if (idx[i] <= 0) - return false; - return true; - } - - /** - * This is not the same as positiveInBounds. - */ - template - inline bool indexLt(const Index1& i1, const Index2& i2) - { - { - NTA_ASSERT(i1.size() == i2.size()); - } - - const UInt NDims = (UInt)i1.size(); - for (UInt k = 0; k < NDims; ++k) - if (i1[k] < i2[k]) - return true; - else if (i1[k] > i2[k]) - return false; - return false; + if (i_[i] > M) + M = i_[i]; + return M; } /** - * This is not the same as positiveInBounds. + * Indexing operator. + * + * @param idx [0 <= UInt < NDims] index + * @retval [UInt&] the value at index 'idx' */ - template - inline bool indexLe(const Index1& i1, const Index2& i2) - { + inline UInt &operator[](const UInt idx) { { - NTA_ASSERT(i1.size() == i2.size()); - } - - const UInt NDims = i1.size(); - for (UInt k = 0; k < NDims; ++k) - if (i1[k] < i2[k]) - return true; - else if (i1[k] > i2[k]) - return false; - return true; - } - - template - inline bool indexEq(const Index1& i1, const Index2& i2) - { - { - NTA_ASSERT(i1.size() == i2.size()); + NTA_ASSERT(idx >= 0 && idx < NDims) + << "Index::operator[] " + << "Invalid index: " << idx << " - Should be in [0.." << NDims << ")"; } - const UInt NDims = (UInt)i1.size(); - for (UInt k = 0; k < NDims; ++k) - if (i1[k] != i2[k]) - return false; - return true; + return i_[idx]; } /** - * 0 is included, ub is excluded. + * Const indexing operator. + * + * @param idx [0 <= UInt < NDims] index + * @retval [const UInt] the value at index 'idx' */ - template - inline bool positiveInBounds(const Index1& idx, const Index2& ub) - { + inline UInt operator[](const UInt &idx) const { { - NTA_ASSERT(idx.size() == ub.size()); + NTA_ASSERT(idx >= 0 && idx < NDims) + << "Index::operator[] const " + << "Invalid index: " << idx << " - Should be in [0.." << NDims << ")"; } - const UInt NDims = idx.size(); - for (UInt k = 0; k < NDims; ++k) - if (idx[k] >= ub[k]) - return false; - return true; + return i_[idx]; } /** - * lb is included, ub is excluded. + * Resets the whole index to all zeros. */ - template - inline bool inBounds(const Index1& lb, const Index2& idx, const Index3& ub) - { - { - NTA_ASSERT(idx.size() == lb.size()); - NTA_ASSERT(idx.size() == ub.size()); - } + inline void setToZero() { memset(i_, 0, NDims * sizeof(UInt)); } - const UInt NDims = idx.size(); - for (UInt k = 0; k < NDims; ++k) - if (idx[k] < lb[k] || idx[k] >= ub[k]) - return false; - return true; + /** + * Returns whether the values in this index constitute a set or not + * (there are no duplicates). + */ + inline bool isSet() const { + std::set s; + for (UInt i = 0; i < NDims; ++i) + s.insert(i_[i]); + return s.size() == NDims; } /** @@ -717,23 +201,16 @@ namespace nupic { * @param bounds [Index] the upper bound * @retval bool whether we've reached the end of the iteration or not */ - template - inline bool increment(const Index1& bounds, Index2& idx) - { - { - NTA_ASSERT(bounds.size() == idx.size()); - NTA_ASSERT(positiveInBounds(idx, bounds)); - } - - int curr = (UInt)idx.size()-1; - ++idx[curr]; - while (idx[curr] >= bounds[curr]) { - idx[curr] = 0; + inline bool increment(const Index &bounds) { + int curr = NDims - 1; + ++i_[curr]; + while (i_[curr] >= bounds[curr]) { + i_[curr] = 0; --curr; if (curr < 0) return false; else - ++idx[curr]; + ++i_[curr]; } return true; } @@ -746,22 +223,16 @@ namespace nupic { * @param ub [Index] the upper bound * @retval bool whether we've reached the end of the iteration or not */ - template - inline bool increment(const Index1& lb, const Index2& ub, Index3& idx) - { - { - inBounds(lb, idx, ub); - } - - int curr = idx.size()-1; - ++idx[curr]; - while (idx[curr] >= ub[curr]) { - idx[curr] = lb[curr]; + inline bool increment(const Index &lb, const Index &ub) { + int curr = NDims - 1; + ++i_[curr]; + while (i_[curr] >= ub[curr]) { + i_[curr] = lb[curr]; --curr; if (curr < 0) return false; else - ++idx[curr]; + ++i_[curr]; } return true; } @@ -769,165 +240,172 @@ namespace nupic { /** * Computes the ordinal corresponding to the "natural" * order for this index. For example, with bounds = [3, 2], - * [0, 0] -> 0, [0, 1] -> 1, [1, 0] -> 2, [1, 1] -> 3, + * [0, 0] -> 0, [0, 1] -> 1, [1, 0] -> 2, [1, 1] -> 3, * [2, 0] -> 4, [2, 1] -> 5. * * @param bounds [Index] the upper bound to use to compute * the ordinal * @retval UInt the ordinal for this index */ - template - inline typename Index1::value_type ordinal(const Index1& bounds, const Index2& idx) - { - { - NTA_ASSERT(bounds.size() == idx.size()); - NTA_ASSERT(indexGtZero(bounds)); - NTA_ASSERT(positiveInBounds(idx, bounds)); - } + inline UInt ordinal(const Index &bounds) const { + { NTA_ASSERT(indexGtZero(bounds)); } + + if (NDims == 1) // do specialization + return i_[0]; + + UInt p = bounds[NDims - 1], pos = i_[NDims - 1]; - const UInt NDims = (UInt)idx.size(); - typename Index1::value_type p = bounds[NDims-1], pos = idx[NDims-1]; - - for (int k = NDims-2; k >= 1; p *= bounds[k], --k) - pos += idx[k] * p; - pos += idx[0] * p; + for (int k = NDims - 2; k >= 1; p *= bounds[k], --k) + pos += i_[k] * p; + pos += i_[0] * p; return pos; } - template - inline void setFromOrdinal(const Index1& bounds, - const typename Index1::value_type& ordinal, - Index2& idx) - { - { - NTA_ASSERT(bounds.size() == idx.size()); - NTA_ASSERT(indexGtZero(bounds)); - } - - const UInt NDims = (UInt)bounds.size(); - typename Index1::value_type o = ordinal, p = product(bounds) / bounds[0]; - //TODO optimize double / use (slow!) - for (UInt k = 0; k < NDims-1; ++k) { - o %= p; - p /= bounds[k]; - idx[k] = o / p; - } - idx[NDims-1] = o; + inline void fromOrdinal(const Index &bounds, const value_type &ordinal) { + { NTA_ASSERT(indexGtZero(bounds)); } + + value_type o = ordinal, p = bounds.product() / bounds[0]; + // TODO optimize / + for (UInt k = 0; k < NDims - 1; o %= p, p /= bounds[k + 1], ++k) + i_[k] = o / p; + i_[NDims - 1] = o; + } + + /** + * Computes the stride for dimension dim of this Index. + * If the Index is :[6, 7, 5, 4], then: + * stride(0) = 7*5*4, + * stride(1) = 5*4, + * stride(2) = 4, + * stride(3) = 1. + * + * @param dim [UInt] the dim for which we want the stride + * @retval UInt the stride + */ + inline UInt stride(const UInt &dim) const { + if (dim == NDims - 1) + return 1; + + UInt s = i_[dim + 1]; + for (UInt i = dim + 2; i < NDims; ++i) + s *= i_[i]; + return s; + } + + /** + * Computes the distance between two indices, with respect to + * a given upper bound. That is: + * distance = other.ordinal(bounds) - this->ordinal(bounds). + * + * @param bounds [Index] the upper bound + * @param other [Index] the second index + * @retval UInt the distance between this and the second index + */ + inline UInt distance(const Index &bounds, const Index &other) const { + return other.ordinal(bounds) - ordinal(bounds); + } + + /** + * Computes the product of all the values in this index. + * The result can be zero, if at least one of the indices + * is zero. + * + * @retval UInt the product + */ + inline UInt product() const { + UInt n = i_[0]; + for (UInt k = 1; k < NDims; ++k) + n *= i_[k]; + return n; } /** * Computes the complement of this index. + * * For example: - * (*this) [0, 2, 4] -(N=6)-> (idx) [1, 3, 5] + * (*this) [0, 2, 4] -(N=6)-> (idx) [1, 3, 5] * (*this) [0, 2] -(N=3)-> (idx) [1] * (*this) [0] -(N=2)-> (idx) [1] * (*this) [0, 1] -(N=3)-> (idx) [2] * * @param idx [Index] the complement of this index */ - template - inline void complement(const Index& idx, Index2& c_idx) - { - const UInt NDims = (UInt)idx.size(); - const UInt R = (UInt)c_idx.size(); + template inline void complement(Index &idx) const { const UInt N = NDims + R; - UInt k = 0, k1 = 0, k2 = 0; + UInt k = 0, k1 = 0, k2 = 0; for (k = 0; k < NDims; ++k) { - for (; k1 < idx[k]; ++k1) - c_idx[k2++] = k1; - k1 = idx[k]+1; + for (; k1 < i_[k]; ++k1) + idx[k2++] = k1; + k1 = i_[k] + 1; } while (k1 < N) - c_idx[k2++] = k1++; + idx[k2++] = k1++; } /** * Computes the projection of this index on to the dimensions - * specified: + * specified: * idx2[k] = (*this)[dims[k]], for k in [0..R). * * @param dims [Index] the dimensions to project onto * @param idx2 [Index] the projection of this index */ - template - inline void project(const Index1& dims, const Index2& idx, Index3& idx2) - { - const UInt NDims = (UInt)idx.size(); - const UInt R = (UInt)idx2.size(); - + template + inline void project(const Index &dims, Index &idx2) const { { - NTA_ASSERT(idx2.size() == dims.size()); - NTA_ASSERT(R <= NDims) - << "Index::project(): " - << "Invalid number of dimensions to project on: " << R - << " - Should be less than: " << NDims; - - for (UInt k = 0; k < R-1; ++k) - NTA_ASSERT(dims[k] < dims[k+1]) << "Index::project(): " - << "Dimensions need to be in strictly increasing order, " - << "passed: " << dims; + << "Invalid number of dimensions to project on: " << R + << " - Should be less than: " << NDims; - NTA_ASSERT(0 <= dims[0] && dims[R-1] <= NDims) - << "Index::project(): " - << "Invalid dimensions: " << dims - << " when projecting in: [0.." << R << ")"; + for (UInt k = 0; k < R - 1; ++k) + NTA_ASSERT(dims[k] < dims[k + 1]) + << "Index::project(): " + << "Dimensions need to be in strictly increasing order, " + << "passed: " << dims; + + NTA_ASSERT(0 <= dims[0] && dims[R - 1] <= NDims) + << "Index::project(): " + << "Invalid dimensions: " << dims << " when projecting in: [0.." << R + << ")"; } for (UInt k = 0; k < R; ++k) - idx2[k] = idx[dims[k]]; + idx2[k] = i_[dims[k]]; } /** * Embeds the current index into an index of higher dimension: * idx2[dims[k]] = (*this)[k], for k in [0..R). - * Note that if there are coordinates already set in idx2, - * they will stay there and not be reset to 0: - * I6 i6; - * I2 i2(2, 4), dims(1, 3); - * I4 i4(1, 3, 5, 6), compDims; - * embed(dims, i2, i6); - * Test("Index embed 3", i6, I6(0, 2, 0, 4, 0, 0)); - * dims.complement(compDims); - * embed(compDims, i4, i6); - * Test("Index embed 4", i6, I6(1, 2, 3, 4, 5, 6)); * * @param dims [Index] the dimensions to embed into * @param idx [Index] the embedding of this index */ - template - inline void embed(const Index1& dims, const Index2& idx, Index3& idx2) - { - const UInt R = dims.size(); - const UInt NDims = idx.size(); - const UInt R2 = idx2.size(); - + template + inline void embed(const Index &dims, Index &idx2) const { { - NTA_ASSERT(idx.size() == dims.size()); - - NTA_ASSERT(R2 >= NDims) - << "Index::embed(): " - << "Invalid number of dimensions to embed into: " << R2 - << " - Should be >= " << NDims; - - for (UInt k = 0; k < R-1; ++k) - NTA_ASSERT(dims[k] < dims[k+1]) + NTA_ASSERT(R2 >= NDims) << "Index::embed(): " - << "Dimensions need to be in strictly increasing order, " - << "passed: " << dims; + << "Invalid number of dimensions to embed into: " << R2 + << " - Should be >= " << NDims; - NTA_ASSERT(0 <= dims[0] && dims[R-1] <= R2) - << "Index::embed(): " - << "Invalid dimensions: " << dims - << " when embedding in: [0.." << R2 << ")"; + for (UInt k = 0; k < R - 1; ++k) + NTA_ASSERT(dims[k] < dims[k + 1]) + << "Index::embed(): " + << "Dimensions need to be in strictly increasing order, " + << "passed: " << dims; + + NTA_ASSERT(0 <= dims[0] && dims[R - 1] <= R2) + << "Index::embed(): " + << "Invalid dimensions: " << dims << " when embedding in: [0.." << R2 + << ")"; } for (UInt k = 0; k < R; ++k) - idx2[dims[k]] = idx[k]; + idx2[dims[k]] = i_[k]; } /** @@ -935,93 +413,527 @@ namespace nupic { * Examples: * [1 2 3 4 5] becomes [2 3 4 5 1] with ind = [1 2 3 4 0] * [1 2 3] becomes [3 1 2] with ind = [2 1 0] - * + * * @param ind [Index] the new order * @param perm [Index] the resulting permutation */ - template - inline void permute(const Index1& ind, const Index2& idx, Index3& perm) - { - { -#ifdef NTA_ASSERTIONS_ON - std::set s; - for (UInt k = 0; k < ind.size(); ++k) { - NTA_ASSERT(ind[k] >= 0 && ind[k] < ind.size()); - s.insert(ind[k]); - } - NTA_ASSERT(s.size() == ind.size()); -#endif - NTA_ASSERT(ind.size() == idx.size()); - NTA_ASSERT(ind.size() == perm.size()); - } + inline void permute(const Index &ind, Index &perm) const { + { checkPermutation_(ind); } - const UInt NDims = (UInt)idx.size(); for (UInt k = 0; k < NDims; ++k) - perm[k] = idx[ind[k]]; + perm[k] = i_[ind[k]]; } - //-------------------------------------------------------------------------------- - template - inline bool operator==(const Index& i1, const Index& i2) - { - return indexEq(i1, i2); + inline Index permute(const Index &ind) const { + { checkPermutation_(ind); } + + Index perm; + permute(ind, perm); + return perm; } - template - inline bool operator!=(const Index& i1, const Index& i2) - { - return ! indexEq(i1, i2); + /** + * Finds the permutation that transforms this index into perm. + * If this index is [1 2 3 4 5] and perm is [2 3 4 5 1], + * the permutation is [1 2 3 4 0] + * Slow: O(NDims^2) + */ + inline void findPermutation(Index &ind, const Index &perm) const { + for (UInt k = 0; k < NDims; ++k) + for (UInt k1 = 0; k1 < NDims; ++k1) + if (perm[k1] == i_[k]) + ind[k1] = k; } - template - inline bool operator<(const Index& i1, const Index& i2) - { - return indexLt(i1, i2); + inline Index findPermutation(const Index &perm) const { + Index ind; + findPermutation(ind, perm); + return ind; + } + + /** + * Returns whether any of the values in this index is a zero. + * (That is, the "index" cannot be used to describe the dimensions + * of a tensor). + */ + inline bool hasZero() const { + for (UInt i = 0; i < NDims; ++i) + if (i_[i] == 0) + return true; + return false; + } + + /** + * Streaming operator (for debugging). + */ + // template + // NTA_HIDDEN friend std::ostream& operator<<(std::ostream&, const Index&); + UInt i_[NDims]; + +private: + /** + * This method becomes empty and optimized away when assertions are turned + * off. + */ + inline void checkPermutation_(const Index &ind) const { +#ifdef NTA_ASSERTIONS_ON + std::set s; + for (UInt k = 0; k < NDims; ++k) { + NTA_ASSERT(ind[k] >= 0 && ind[k] < NDims); + s.insert(ind[k]); + } + NTA_ASSERT(s.size() == NDims); +#endif } +}; + +//-------------------------------------------------------------------------------- +template +std::ostream &operator<<(std::ostream &outStream, const Index &i) { + outStream << "["; + for (UInt k = 0; k < NDims; ++k) + outStream << i.i_[k] << (k < NDims - 1 ? "," : ""); + return outStream << "]"; +} + +template +inline Index concatenate(const Index &i1, + const Index &i2) { + Index newIndex; + for (I k = 0; k < n1; ++k) + newIndex[k] = i1[k]; + for (I k = 0; k < n2; ++k) + newIndex[n1 + k] = i2[k]; + return newIndex; +} + +template +inline std::vector concatenate(const std::vector &i1, + const std::vector &i2) { + std::vector newIndex(i1.size() + i2.size(), 0); + for (UInt k = 0; k < i1.size(); ++k) + newIndex[k] = i1[k]; + for (UInt k = 0; k < i2.size(); ++k) + newIndex[i1.size() + k] = i2[k]; + return newIndex; +} + +template inline void setToZero(Index &idx) { + const UInt NDims = (UInt)idx.size(); + for (UInt i = 0; i < NDims; ++i) + idx[i] = 0; +} + +/** + * Returns whether the values in this index constitute a set or not + * (there are no duplicates). + */ +template inline bool isSet(const Index &idx) { + std::set s; + const UInt NDims = (UInt)idx.size(); + for (UInt i = 0; i < NDims; ++i) + s.insert(idx[i]); + return s.size() == NDims; +} + +/** + * Returns whether any of the values in this index is a zero. + * (That is, the "index" cannot be used to describe the dimensions + * of a tensor). + */ +template inline bool hasZero(const Index &idx) { + const UInt NDims = idx.size(); + for (UInt i = 0; i < NDims; ++i) + if (idx[i] == 0) + return true; + return false; +} + +template inline bool isZero(const Index &idx) { + const UInt NDims = idx.size(); + for (UInt i = 0; i < NDims; ++i) + if (idx[i] != 0) + return false; + return true; +} + +template inline bool indexGtZero(const Index &idx) { + const UInt NDims = idx.size(); + for (UInt i = 0; i < NDims; ++i) + if (idx[i] <= 0) + return false; + return true; +} - template - inline bool operator<(const std::vector& i1, const std::vector& i2) +/** + * This is not the same as positiveInBounds. + */ +template +inline bool indexLt(const Index1 &i1, const Index2 &i2) { + { NTA_ASSERT(i1.size() == i2.size()); } + + const UInt NDims = (UInt)i1.size(); + for (UInt k = 0; k < NDims; ++k) + if (i1[k] < i2[k]) + return true; + else if (i1[k] > i2[k]) + return false; + return false; +} + +/** + * This is not the same as positiveInBounds. + */ +template +inline bool indexLe(const Index1 &i1, const Index2 &i2) { + { NTA_ASSERT(i1.size() == i2.size()); } + + const UInt NDims = i1.size(); + for (UInt k = 0; k < NDims; ++k) + if (i1[k] < i2[k]) + return true; + else if (i1[k] > i2[k]) + return false; + return true; +} + +template +inline bool indexEq(const Index1 &i1, const Index2 &i2) { + { NTA_ASSERT(i1.size() == i2.size()); } + + const UInt NDims = (UInt)i1.size(); + for (UInt k = 0; k < NDims; ++k) + if (i1[k] != i2[k]) + return false; + return true; +} + +/** + * 0 is included, ub is excluded. + */ +template +inline bool positiveInBounds(const Index1 &idx, const Index2 &ub) { + { NTA_ASSERT(idx.size() == ub.size()); } + + const UInt NDims = idx.size(); + for (UInt k = 0; k < NDims; ++k) + if (idx[k] >= ub[k]) + return false; + return true; +} + +/** + * lb is included, ub is excluded. + */ +template +inline bool inBounds(const Index1 &lb, const Index2 &idx, const Index3 &ub) { { - return indexLt(i1, i2); + NTA_ASSERT(idx.size() == lb.size()); + NTA_ASSERT(idx.size() == ub.size()); } - template - inline bool operator<(const Index& i1, const std::vector& i2) + const UInt NDims = idx.size(); + for (UInt k = 0; k < NDims; ++k) + if (idx[k] < lb[k] || idx[k] >= ub[k]) + return false; + return true; +} + +/** + * Increments this index, using bounds as the upper bound + * for the iteration. + * + * @param bounds [Index] the upper bound + * @retval bool whether we've reached the end of the iteration or not + */ +template +inline bool increment(const Index1 &bounds, Index2 &idx) { { - return indexLt(i1, i2); + NTA_ASSERT(bounds.size() == idx.size()); + NTA_ASSERT(positiveInBounds(idx, bounds)); } - template - inline bool operator<(const std::vector& i1, const Index& i2) + int curr = (UInt)idx.size() - 1; + ++idx[curr]; + while (idx[curr] >= bounds[curr]) { + idx[curr] = 0; + --curr; + if (curr < 0) + return false; + else + ++idx[curr]; + } + return true; +} + +/** + * Increment this index, using lb and ub as lower and upper + * bounds for the iteration. + * + * @param lb [index] the lower bound + * @param ub [Index] the upper bound + * @retval bool whether we've reached the end of the iteration or not + */ +template +inline bool increment(const Index1 &lb, const Index2 &ub, Index3 &idx) { + { inBounds(lb, idx, ub); } + + int curr = idx.size() - 1; + ++idx[curr]; + while (idx[curr] >= ub[curr]) { + idx[curr] = lb[curr]; + --curr; + if (curr < 0) + return false; + else + ++idx[curr]; + } + return true; +} + +/** + * Computes the ordinal corresponding to the "natural" + * order for this index. For example, with bounds = [3, 2], + * [0, 0] -> 0, [0, 1] -> 1, [1, 0] -> 2, [1, 1] -> 3, + * [2, 0] -> 4, [2, 1] -> 5. + * + * @param bounds [Index] the upper bound to use to compute + * the ordinal + * @retval UInt the ordinal for this index + */ +template +inline typename Index1::value_type ordinal(const Index1 &bounds, + const Index2 &idx) { { - return indexLt(i1, i2); + NTA_ASSERT(bounds.size() == idx.size()); + NTA_ASSERT(indexGtZero(bounds)); + NTA_ASSERT(positiveInBounds(idx, bounds)); } - template - inline bool operator<=(const Index& i1, const Index& i2) + const UInt NDims = (UInt)idx.size(); + typename Index1::value_type p = bounds[NDims - 1], pos = idx[NDims - 1]; + + for (int k = NDims - 2; k >= 1; p *= bounds[k], --k) + pos += idx[k] * p; + pos += idx[0] * p; + + return pos; +} + +template +inline void setFromOrdinal(const Index1 &bounds, + const typename Index1::value_type &ordinal, + Index2 &idx) { { - return indexLe(i1, i2); + NTA_ASSERT(bounds.size() == idx.size()); + NTA_ASSERT(indexGtZero(bounds)); } - template - inline bool operator<=(const std::vector& i1, const std::vector& i2) + const UInt NDims = (UInt)bounds.size(); + typename Index1::value_type o = ordinal, p = product(bounds) / bounds[0]; + // TODO optimize double / use (slow!) + for (UInt k = 0; k < NDims - 1; ++k) { + o %= p; + p /= bounds[k]; + idx[k] = o / p; + } + idx[NDims - 1] = o; +} + +/** + * Computes the complement of this index. + * For example: + * (*this) [0, 2, 4] -(N=6)-> (idx) [1, 3, 5] + * (*this) [0, 2] -(N=3)-> (idx) [1] + * (*this) [0] -(N=2)-> (idx) [1] + * (*this) [0, 1] -(N=3)-> (idx) [2] + * + * @param idx [Index] the complement of this index + */ +template +inline void complement(const Index &idx, Index2 &c_idx) { + const UInt NDims = (UInt)idx.size(); + const UInt R = (UInt)c_idx.size(); + const UInt N = NDims + R; + UInt k = 0, k1 = 0, k2 = 0; + + for (k = 0; k < NDims; ++k) { + for (; k1 < idx[k]; ++k1) + c_idx[k2++] = k1; + k1 = idx[k] + 1; + } + + while (k1 < N) + c_idx[k2++] = k1++; +} + +/** + * Computes the projection of this index on to the dimensions + * specified: + * idx2[k] = (*this)[dims[k]], for k in [0..R). + * + * @param dims [Index] the dimensions to project onto + * @param idx2 [Index] the projection of this index + */ +template +inline void project(const Index1 &dims, const Index2 &idx, Index3 &idx2) { + const UInt NDims = (UInt)idx.size(); + const UInt R = (UInt)idx2.size(); + { - return indexLe(i1, i2); + NTA_ASSERT(idx2.size() == dims.size()); + + NTA_ASSERT(R <= NDims) << "Index::project(): " + << "Invalid number of dimensions to project on: " + << R << " - Should be less than: " << NDims; + + for (UInt k = 0; k < R - 1; ++k) + NTA_ASSERT(dims[k] < dims[k + 1]) + << "Index::project(): " + << "Dimensions need to be in strictly increasing order, " + << "passed: " << dims; + + NTA_ASSERT(0 <= dims[0] && dims[R - 1] <= NDims) + << "Index::project(): " + << "Invalid dimensions: " << dims << " when projecting in: [0.." << R + << ")"; } - template - inline bool operator<=(const Index& i1, const std::vector& i2) + for (UInt k = 0; k < R; ++k) + idx2[k] = idx[dims[k]]; +} + +/** + * Embeds the current index into an index of higher dimension: + * idx2[dims[k]] = (*this)[k], for k in [0..R). + * Note that if there are coordinates already set in idx2, + * they will stay there and not be reset to 0: + * I6 i6; + * I2 i2(2, 4), dims(1, 3); + * I4 i4(1, 3, 5, 6), compDims; + * embed(dims, i2, i6); + * Test("Index embed 3", i6, I6(0, 2, 0, 4, 0, 0)); + * dims.complement(compDims); + * embed(compDims, i4, i6); + * Test("Index embed 4", i6, I6(1, 2, 3, 4, 5, 6)); + * + * @param dims [Index] the dimensions to embed into + * @param idx [Index] the embedding of this index + */ +template +inline void embed(const Index1 &dims, const Index2 &idx, Index3 &idx2) { + const UInt R = dims.size(); + const UInt NDims = idx.size(); + const UInt R2 = idx2.size(); + { - return indexLe(i1, i2); + NTA_ASSERT(idx.size() == dims.size()); + + NTA_ASSERT(R2 >= NDims) + << "Index::embed(): " + << "Invalid number of dimensions to embed into: " << R2 + << " - Should be >= " << NDims; + + for (UInt k = 0; k < R - 1; ++k) + NTA_ASSERT(dims[k] < dims[k + 1]) + << "Index::embed(): " + << "Dimensions need to be in strictly increasing order, " + << "passed: " << dims; + + NTA_ASSERT(0 <= dims[0] && dims[R - 1] <= R2) + << "Index::embed(): " + << "Invalid dimensions: " << dims << " when embedding in: [0.." << R2 + << ")"; } - template - inline bool operator<=(const std::vector& i1, const Index& i2) + for (UInt k = 0; k < R; ++k) + idx2[dims[k]] = idx[k]; +} + +/** + * Permutes this index according to the order specified in ind. + * Examples: + * [1 2 3 4 5] becomes [2 3 4 5 1] with ind = [1 2 3 4 0] + * [1 2 3] becomes [3 1 2] with ind = [2 1 0] + * + * @param ind [Index] the new order + * @param perm [Index] the resulting permutation + */ +template +inline void permute(const Index1 &ind, const Index2 &idx, Index3 &perm) { { - return indexLe(i1, i2); +#ifdef NTA_ASSERTIONS_ON + std::set s; + for (UInt k = 0; k < ind.size(); ++k) { + NTA_ASSERT(ind[k] >= 0 && ind[k] < ind.size()); + s.insert(ind[k]); + } + NTA_ASSERT(s.size() == ind.size()); +#endif + NTA_ASSERT(ind.size() == idx.size()); + NTA_ASSERT(ind.size() == perm.size()); } - //-------------------------------------------------------------------------------- + const UInt NDims = (UInt)idx.size(); + for (UInt k = 0; k < NDims; ++k) + perm[k] = idx[ind[k]]; +} + +//-------------------------------------------------------------------------------- +template +inline bool operator==(const Index &i1, const Index &i2) { + return indexEq(i1, i2); +} + +template +inline bool operator!=(const Index &i1, const Index &i2) { + return !indexEq(i1, i2); +} + +template +inline bool operator<(const Index &i1, const Index &i2) { + return indexLt(i1, i2); +} + +template +inline bool operator<(const std::vector &i1, + const std::vector &i2) { + return indexLt(i1, i2); +} + +template +inline bool operator<(const Index &i1, const std::vector &i2) { + return indexLt(i1, i2); +} + +template +inline bool operator<(const std::vector &i1, const Index &i2) { + return indexLt(i1, i2); +} + +template +inline bool operator<=(const Index &i1, const Index &i2) { + return indexLe(i1, i2); +} + +template +inline bool operator<=(const std::vector &i1, + const std::vector &i2) { + return indexLe(i1, i2); +} + +template +inline bool operator<=(const Index &i1, const std::vector &i2) { + return indexLe(i1, i2); +} + +template +inline bool operator<=(const std::vector &i1, const Index &i2) { + return indexLe(i1, i2); +} + +//-------------------------------------------------------------------------------- } // end namespace nupic diff --git a/src/nupic/math/Math.hpp b/src/nupic/math/Math.hpp index 8f585fe750..6743565eea 100644 --- a/src/nupic/math/Math.hpp +++ b/src/nupic/math/Math.hpp @@ -20,25 +20,25 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Declarations for maths routines */ #ifndef NTA_MATH_HPP #define NTA_MATH_HPP -#include #include -#include #include -#include #include #include +#include +#include +#include #include -#include #include +#include //-------------------------------------------------------------------------------- /** @@ -46,1021 +46,853 @@ */ // Assert that It obeys the STL forward iterator concept -#define ASSERT_INPUT_ITERATOR(It) \ - boost::function_requires >(); +#define ASSERT_INPUT_ITERATOR(It) \ + boost::function_requires>(); // Assert that It obeys the STL forward iterator concept -#define ASSERT_OUTPUT_ITERATOR(It, T) \ - boost::function_requires >(); +#define ASSERT_OUTPUT_ITERATOR(It, T) \ + boost::function_requires>(); // Assert that UnaryPredicate obeys the STL unary predicate concept -#define ASSERT_UNARY_PREDICATE(UnaryPredicate, Arg1) \ - boost::function_requires >(); +#define ASSERT_UNARY_PREDICATE(UnaryPredicate, Arg1) \ + boost::function_requires< \ + boost::UnaryPredicateConcept>(); // Assert that UnaryFunction obeys the STL unary function concept -#define ASSERT_UNARY_FUNCTION(UnaryFunction, Ret, Arg1) \ - boost::function_requires >(); +#define ASSERT_UNARY_FUNCTION(UnaryFunction, Ret, Arg1) \ + boost::function_requires< \ + boost::UnaryFunctionConcept>(); // Assert that BinaryFunction obeys the STL binary function concept -#define ASSERT_BINARY_FUNCTION(BinaryFunction, Ret, Arg1, Arg2) \ - boost::function_requires >(); +#define ASSERT_BINARY_FUNCTION(BinaryFunction, Ret, Arg1, Arg2) \ + boost::function_requires< \ + boost::BinaryFunctionConcept>(); //-------------------------------------------------------------------------------- namespace nupic { - //-------------------------------------------------------------------------------- - // ASSERTIONS - //-------------------------------------------------------------------------------- - /** - * This is used to check that a boolean condition holds, and to send a message - * on NTA_INFO if it doesn't. It's compiled in only when NTA_ASSERTIONS_ON is - * true. - */ - inline bool INVARIANT(bool cond, const char* msg) - { +//-------------------------------------------------------------------------------- +// ASSERTIONS +//-------------------------------------------------------------------------------- +/** + * This is used to check that a boolean condition holds, and to send a message + * on NTA_INFO if it doesn't. It's compiled in only when NTA_ASSERTIONS_ON is + * true. + */ +inline bool INVARIANT(bool cond, const char *msg) { #ifdef NTA_ASSERTIONS_ON - if (!(cond)) { - NTA_INFO << msg; - return false; - } -#endif - return true; + if (!(cond)) { + NTA_INFO << msg; + return false; } +#endif + return true; +} - //-------------------------------------------------------------------------------- - /** - * An assert used to validate that a range defined by two iterators is valid, - * that is: begin <= end. We could say that end < begin means the range is - * empty, but for debugging purposes, we want to know when the iterators - * are crossed. - */ - template - void ASSERT_VALID_RANGE(It begin, It end, const char* message) - { - NTA_ASSERT(begin <= end) << "Invalid iterators: " << message; - } +//-------------------------------------------------------------------------------- +/** + * An assert used to validate that a range defined by two iterators is valid, + * that is: begin <= end. We could say that end < begin means the range is + * empty, but for debugging purposes, we want to know when the iterators + * are crossed. + */ +template +void ASSERT_VALID_RANGE(It begin, It end, const char *message) { + NTA_ASSERT(begin <= end) << "Invalid iterators: " << message; +} - //-------------------------------------------------------------------------------- - /** - * Epsilon is defined for the whole math and algorithms of the Numenta Platform, - * independently of the concrete type chosen to handle - * floating point numbers. - * numeric_limits::epsilon() == 1.19209e-7 - * numeric_limits::epsilon() == 2.22045e-16 - */ - static const nupic::Real Epsilon = nupic::Real(1e-6); +//-------------------------------------------------------------------------------- +/** + * Epsilon is defined for the whole math and algorithms of the Numenta + * Platform, independently of the concrete type chosen to handle floating point + * numbers. numeric_limits::epsilon() == 1.19209e-7 + * numeric_limits::epsilon() == 2.22045e-16 + */ +static const nupic::Real Epsilon = nupic::Real(1e-6); - //-------------------------------------------------------------------------------- - /** - * Functions that test for positivity or negativity based on nupic::Epsilon. - */ - template - inline bool strictlyNegative(const T& a) - { - return a < -nupic::Epsilon; - } +//-------------------------------------------------------------------------------- +/** + * Functions that test for positivity or negativity based on nupic::Epsilon. + */ +template inline bool strictlyNegative(const T &a) { + return a < -nupic::Epsilon; +} - //-------------------------------------------------------------------------------- - template - inline bool strictlyPositive(const T& a) - { - return a > nupic::Epsilon; +//-------------------------------------------------------------------------------- +template inline bool strictlyPositive(const T &a) { + return a > nupic::Epsilon; +} + +//-------------------------------------------------------------------------------- +template inline bool negative(const T &a) { + return a <= nupic::Epsilon; +} + +//-------------------------------------------------------------------------------- +template inline bool positive(const T &a) { + return a >= -nupic::Epsilon; +} + +//-------------------------------------------------------------------------------- +/** + * A functions that implements the distance to zero function as a functor. + * Defining argument_type and result_type directly here instead of inheriting + * from std::unary_function so that we have an easier time in SWIG Python + * wrapping. + */ +template struct DistanceToZero { + typedef T argument_type; + typedef T result_type; + + inline T operator()(const T &x) const { return x >= 0 ? x : -x; } +}; + +//-------------------------------------------------------------------------------- +/** + * A specialization for UInts, where we only need one test (more efficient). + */ +template <> inline UInt DistanceToZero::operator()(const UInt &x) const { + return x; +} + +//-------------------------------------------------------------------------------- +/** + * Use this functor if T is guaranteed to be positive only. + */ +template +struct DistanceToZeroPositive : public std::unary_function { + inline T operator()(const T &x) const { return x; } +}; + +//-------------------------------------------------------------------------------- +/** + * This computes the distance to 1 rather than to 0. + */ +template struct DistanceToOne { + typedef T argument_type; + typedef T result_type; + + inline T operator()(const T &x) const { + return x > (T)1 ? x - (T)1 : (T)1 - x; } +}; - //-------------------------------------------------------------------------------- - template - inline bool negative(const T& a) - { - return a <= nupic::Epsilon; +//-------------------------------------------------------------------------------- +/** + * This functor decides whether a number is almost zero or not, using the + * platform-wide nupic::Epsilon. + */ +template struct IsNearlyZero { + typedef typename D::result_type value_type; + + D dist_; + + // In the case where D::result_type is integral + // we convert nupic::Epsilon to zero! + inline IsNearlyZero() : dist_() {} + + inline IsNearlyZero(const IsNearlyZero &other) : dist_(other.dist_) {} + + inline IsNearlyZero &operator=(const IsNearlyZero &other) { + if (this != &other) + dist_ = other.dist_; + + return *this; } - //-------------------------------------------------------------------------------- - template - inline bool positive(const T& a) - { - return a >= -nupic::Epsilon; + inline bool operator()(const typename D::argument_type &x) const { + return dist_(x) <= nupic::Epsilon; } +}; - //-------------------------------------------------------------------------------- - /** - * A functions that implements the distance to zero function as a functor. - * Defining argument_type and result_type directly here instead of inheriting from - * std::unary_function so that we have an easier time in SWIG Python wrapping. - */ - template - struct DistanceToZero - { - typedef T argument_type; - typedef T result_type; - - inline T operator()(const T& x) const - { - return x >= 0 ? x : -x; - } - }; +//-------------------------------------------------------------------------------- +/** + * @b Responsibility: + * Tell whether an arithmetic value is zero or not, within some precision, + * or whether two values are equal or not, within some precision. + * + * @b Parameters: + * epsilon: accuracy of the comparison + * + * @b Returns: + * true if |a| <= epsilon + * false otherwise + * + * @b Requirements: + * T arithmetic + * T comparable with operator<= AND operator >= + * + * @b Restrictions: + * Doesn't compile if T is not arithmetic. + * In debug mode, NTA_ASSERTs if |a| > 10 + * In debug mode, NTA_ASSERTs if a == infinity, quiet_NaN or signaling_NaN + * + * @b Notes: + * Comparing floating point numbers is a pretty tricky business. Knuth's got + * many pages devoted to it in Vol II. Boost::test has a special function to + * handle that. One of the problems is that when more bits are allocated to + * the integer part as the number gets bigger, there is inherently less + * precision in the decimals. But, for comparisons to zero, we can continue + * using an absolute epsilon (instead of multiplying epsilon by the number). + * In our application, we are anticipating numbers mostly between 0 and 1, + * because they represent probabilities. + * + * Not clear why this is namespace std instead of nupic , but FA says there was + * a "good, ugly" reason to do it this way that he can't remember. - WCS 0206 + */ +template +inline bool nearlyZero(const T &a, const T &epsilon = T(nupic::Epsilon)) { + return a >= -epsilon && a <= epsilon; +} - //-------------------------------------------------------------------------------- - /** - * A specialization for UInts, where we only need one test (more efficient). - */ - template<> - inline - UInt DistanceToZero::operator()(const UInt &x) const { return x; } +//-------------------------------------------------------------------------------- +template +inline bool nearlyEqual(const T &a, const T &b, + const T &epsilon = nupic::Epsilon) { + return nearlyZero((b - a), epsilon); +} - //-------------------------------------------------------------------------------- - /** - * Use this functor if T is guaranteed to be positive only. - */ - template - struct DistanceToZeroPositive : public std::unary_function - { - inline T operator()(const T& x) const - { - return x; - } - }; +//-------------------------------------------------------------------------------- +/** + * Euclidean modulo function. + * + * Returns x % m, but keeps the value positive + * (similar to Python's modulo function). + */ +inline int emod(int x, int m) { + int r = x % m; + if (r < 0) + return r + m; + else + return r; +} - //-------------------------------------------------------------------------------- - /** - * This computes the distance to 1 rather than to 0. - */ - template - struct DistanceToOne - { - typedef T argument_type; - typedef T result_type; - - inline T operator()(const T& x) const - { - return x > (T) 1 ? x - (T) 1 : (T) 1 - x; - } - }; +//-------------------------------------------------------------------------------- +/** + * A boolean functor that returns true if the element's selected value is found + * in the (associative) container (needs to support find). + * + * Example: + * ======= + * typedef std::pair IdxVal; + * std::list row; + * std::set alreadyGrouped; + * row.remove_if(SetIncludes, select1st + * >(alreadyGrouped)); + * + * Will remove from row (a list of pairs) the pairs whose first element is + * already contained in alreadyGrouped. + */ +template struct IsIncluded { + IsIncluded(const C1 &container) : sel_(), container_(container) {} + + template inline bool operator()(const T &p) const { + if (f) + return container_.find(sel_(p)) == container_.end(); + else + return container_.find(sel_(p)) != container_.end(); + } - //-------------------------------------------------------------------------------- - /** - * This functor decides whether a number is almost zero or not, using the - * platform-wide nupic::Epsilon. - */ - template - struct IsNearlyZero - { - typedef typename D::result_type value_type; - - D dist_; - - // In the case where D::result_type is integral - // we convert nupic::Epsilon to zero! - inline IsNearlyZero() - : dist_() - {} - - inline IsNearlyZero(const IsNearlyZero& other) - : dist_(other.dist_) - {} - - inline IsNearlyZero& operator=(const IsNearlyZero& other) - { - if (this != &other) - dist_ = other.dist_; - - return *this; - } + Selector sel_; + const C1 &container_; +}; - inline bool operator()(const typename D::argument_type& x) const - { - return dist_(x) <= nupic::Epsilon; - } - }; +//-------------------------------------------------------------------------------- +// PAIRS AND TRIPLETS +//-------------------------------------------------------------------------------- - //-------------------------------------------------------------------------------- - /** - * @b Responsibility: - * Tell whether an arithmetic value is zero or not, within some precision, - * or whether two values are equal or not, within some precision. - * - * @b Parameters: - * epsilon: accuracy of the comparison - * - * @b Returns: - * true if |a| <= epsilon - * false otherwise - * - * @b Requirements: - * T arithmetic - * T comparable with operator<= AND operator >= - * - * @b Restrictions: - * Doesn't compile if T is not arithmetic. - * In debug mode, NTA_ASSERTs if |a| > 10 - * In debug mode, NTA_ASSERTs if a == infinity, quiet_NaN or signaling_NaN - * - * @b Notes: - * Comparing floating point numbers is a pretty tricky business. Knuth's got - * many pages devoted to it in Vol II. Boost::test has a special function to - * handle that. One of the problems is that when more bits are allocated to - * the integer part as the number gets bigger, there is inherently less - * precision in the decimals. But, for comparisons to zero, we can continue - * using an absolute epsilon (instead of multiplying epsilon by the number). - * In our application, we are anticipating numbers mostly between 0 and 1, - * because they represent probabilities. - * - * Not clear why this is namespace std instead of nupic , but FA says there was - * a "good, ugly" reason to do it this way that he can't remember. - WCS 0206 - */ - template - inline bool nearlyZero(const T& a, const T& epsilon =T(nupic::Epsilon)) - { - return a >= -epsilon && a <= epsilon; +/** + * Lexicographic order: + * (1,1) < (1,2) < (1,10) < (2,5) < (3,6) < (3,7) ... + */ +template +struct lexicographic_2 + : public std::binary_function, std::pair> { + inline bool operator()(const std::pair &a, + const std::pair &b) const { + if (a.first < b.first) + return true; + else if (a.first == b.first) + if (a.second < b.second) + return true; + return false; } +}; - //-------------------------------------------------------------------------------- - template - inline bool nearlyEqual(const T& a, const T& b, const T& epsilon =nupic::Epsilon) - { - return nearlyZero((b-a), epsilon); +//-------------------------------------------------------------------------------- +/** + * Order based on the first member of a pair only: + * (1, 3.5) < (2, 5.6) < (10, 7.1) < (11, 8.5) + */ +template +struct less_1st + : public std::binary_function, std::pair> { + inline bool operator()(const std::pair &a, + const std::pair &b) const { + return a.first < b.first; } +}; - //-------------------------------------------------------------------------------- - /** - * Euclidean modulo function. - * - * Returns x % m, but keeps the value positive - * (similar to Python's modulo function). - */ - inline int emod(int x, int m) { - int r = x % m; - if (r < 0) return r + m; - else return r; +//-------------------------------------------------------------------------------- +/** + * Order based on the second member of a pair only: + * (10, 3.5) < (1, 5.6) < (2, 7.1) < (11, 8.5) + */ +template +struct less_2nd + : public std::binary_function, std::pair> { + inline bool operator()(const std::pair &a, + const std::pair &b) const { + return a.second < b.second; } - - //-------------------------------------------------------------------------------- - /** - * A boolean functor that returns true if the element's selected value is found - * in the (associative) container (needs to support find). - * - * Example: - * ======= - * typedef std::pair IdxVal; - * std::list row; - * std::set alreadyGrouped; - * row.remove_if(SetIncludes, select1st >(alreadyGrouped)); - * - * Will remove from row (a list of pairs) the pairs whose first element is - * already contained in alreadyGrouped. - */ - template - struct IsIncluded - { - IsIncluded(const C1& container) - : sel_(), - container_(container) - {} - - template - inline bool operator()(const T& p) const - { - if (f) - return container_.find(sel_(p)) == container_.end(); - else - return container_.find(sel_(p)) != container_.end(); - } - - Selector sel_; - const C1& container_; - }; +}; - //-------------------------------------------------------------------------------- - // PAIRS AND TRIPLETS - //-------------------------------------------------------------------------------- +//-------------------------------------------------------------------------------- +/** + * Order based on the first member of a pair only: + * (10, 3.5) > (8, 5.6) > (2, 7.1) > (1, 8.5) + */ +template +struct greater_1st + : public std::binary_function, std::pair> { + inline bool operator()(const std::pair &a, + const std::pair &b) const { + return a.first > b.first; + } +}; - /** - * Lexicographic order: - * (1,1) < (1,2) < (1,10) < (2,5) < (3,6) < (3,7) ... - */ - template - struct lexicographic_2 - : public std::binary_function, std::pair > - { - inline bool operator()(const std::pair& a, const std::pair& b) const - { +//-------------------------------------------------------------------------------- +/** + * Order based on the second member of a pair only: + * (10, 3.5) > (1, 5.6) > (2, 7.1) > (11, 8.5) + */ +template +struct greater_2nd + : public std::binary_function, std::pair> { + inline bool operator()(const std::pair &a, + const std::pair &b) const { + return a.second > b.second; + } +}; + +//-------------------------------------------------------------------------------- +template +struct greater_2nd_p : public std::binary_function, + std::pair> { + inline bool operator()(const std::pair &a, + const std::pair &b) const { + return *(a.second) > *(b.second); + } +}; + +//-------------------------------------------------------------------------------- +/** + * A greater 2nd order that breaks ties, useful for debugging. + */ +template +struct greater_2nd_no_ties + : public std::binary_function, std::pair> { + inline bool operator()(const std::pair &a, + const std::pair &b) const { + if (a.second > b.second) + return true; + else if (a.second == b.second) if (a.first < b.first) return true; - else if (a.first == b.first) - if (a.second < b.second) - return true; - return false; - } - }; + return false; + } +}; - //-------------------------------------------------------------------------------- - /** - * Order based on the first member of a pair only: - * (1, 3.5) < (2, 5.6) < (10, 7.1) < (11, 8.5) - */ - template - struct less_1st - : public std::binary_function, std::pair > - { - inline bool operator()(const std::pair& a, const std::pair& b) const - { - return a.first < b.first; - } - }; +//-------------------------------------------------------------------------------- +template +struct greater_2nd_rnd_ties + : public std::binary_function, std::pair> { + RND &rng; + + inline greater_2nd_rnd_ties(RND &_rng) : rng(_rng) {} + + inline bool operator()(const std::pair &a, + const std::pair &b) const { + T2 val_a = a.second, val_b = b.second; + if (val_a > val_b) + return true; + else if (val_a == val_b) + return rng.getReal64() >= .5; + return false; + } +}; - //-------------------------------------------------------------------------------- - /** - * Order based on the second member of a pair only: - * (10, 3.5) < (1, 5.6) < (2, 7.1) < (11, 8.5) - */ - template - struct less_2nd - : public std::binary_function, std::pair > - { - inline bool operator()(const std::pair& a, const std::pair& b) const - { - return a.second < b.second; - } - }; +//-------------------------------------------------------------------------------- +// A class used to work with lists of non-zeros represented in i,j,v format +//-------------------------------------------------------------------------------- +/** + * This class doesn't implement any algorithm, it just stores i,j and v. + */ +template class ijv { + typedef T1 size_type; + typedef T2 value_type; - //-------------------------------------------------------------------------------- - /** - * Order based on the first member of a pair only: - * (10, 3.5) > (8, 5.6) > (2, 7.1) > (1, 8.5) - */ - template - struct greater_1st - : public std::binary_function, std::pair > - { - inline bool operator()(const std::pair& a, const std::pair& b) const - { - return a.first > b.first; - } - }; +private: + size_type i_, j_; + value_type v_; - //-------------------------------------------------------------------------------- - /** - * Order based on the second member of a pair only: - * (10, 3.5) > (1, 5.6) > (2, 7.1) > (11, 8.5) - */ - template - struct greater_2nd - : public std::binary_function, std::pair > - { - inline bool operator()(const std::pair& a, const std::pair& b) const - { - return a.second > b.second; - } - }; +public: + inline ijv() : i_(0), j_(0), v_(0) {} - //-------------------------------------------------------------------------------- - template - struct greater_2nd_p - : public std::binary_function, std::pair > - { - inline bool operator()(const std::pair& a, const std::pair& b) const - { - return *(a.second) > *(b.second); - } - }; + inline ijv(size_type i, size_type j, value_type v) : i_(i), j_(j), v_(v) {} + + inline ijv(const ijv &o) : i_(o.i_), j_(o.j_), v_(o.v_) {} + + inline ijv &operator=(const ijv &o) { + i_ = o.i_; + j_ = o.j_; + v_ = o.v_; + return *this; + } + + inline size_type i() const { return i_; } + inline size_type j() const { return j_; } + inline value_type v() const { return v_; } + inline void i(size_type ii) { i_ = ii; } + inline void j(size_type jj) { j_ = jj; } + inline void v(value_type vv) { v_ = vv; } //-------------------------------------------------------------------------------- /** - * A greater 2nd order that breaks ties, useful for debugging. + * See just above for definition. */ - template - struct greater_2nd_no_ties - : public std::binary_function, std::pair > - { - inline bool operator()(const std::pair& a, const std::pair& b) const - { - if (a.second > b.second) + struct lexicographic : public std::binary_function { + inline bool operator()(const ijv &a, const ijv &b) const { + if (a.i() < b.i()) return true; - else if (a.second == b.second) - if (a.first < b.first) + else if (a.i() == b.i()) + if (a.j() < b.j()) return true; return false; } }; //-------------------------------------------------------------------------------- - template - struct greater_2nd_rnd_ties - : public std::binary_function, std::pair > - { - RND& rng; - - inline greater_2nd_rnd_ties(RND& _rng) : rng(_rng) {} - - inline bool operator()(const std::pair& a, const std::pair& b) const - { - T2 val_a = a.second, val_b = b.second; - if (val_a > val_b) - return true; - else if (val_a == val_b) - return rng.getReal64() >= .5; - return false; + /** + * See just above for definition. + */ + struct less_value : public std::binary_function { + inline bool operator()(const ijv &a, const ijv &b) const { + return a.v() < b.v(); } }; - //-------------------------------------------------------------------------------- - // A class used to work with lists of non-zeros represented in i,j,v format //-------------------------------------------------------------------------------- /** - * This class doesn't implement any algorithm, it just stores i,j and v. + * See just above for definition. */ - template - class ijv - { - typedef T1 size_type; - typedef T2 value_type; - - private: - size_type i_, j_; - value_type v_; - - public: - inline ijv() - : i_(0), j_(0), v_(0) {} - - inline ijv(size_type i, size_type j, value_type v) - : i_(i), j_(j), v_(v) {} - - inline ijv(const ijv& o) - : i_(o.i_), j_(o.j_), v_(o.v_) {} - - inline ijv& operator=(const ijv& o) - { - i_ = o.i_; - j_ = o.j_; - v_ = o.v_; - return *this; + struct greater_value : public std::binary_function { + inline bool operator()(const ijv &a, const ijv &b) const { + return a.v() > b.v(); } - - inline size_type i() const { return i_; } - inline size_type j() const { return j_; } - inline value_type v() const { return v_; } - inline void i(size_type ii) { i_ = ii; } - inline void j(size_type jj) { j_ = jj; } - inline void v(value_type vv) { v_ = vv; } - - //-------------------------------------------------------------------------------- - /** - * See just above for definition. - */ - struct lexicographic : public std::binary_function - { - inline bool operator()(const ijv& a, const ijv& b) const - { - if (a.i() < b.i()) - return true; - else if (a.i() == b.i()) - if (a.j() < b.j()) - return true; - return false; - } - }; - - //-------------------------------------------------------------------------------- - /** - * See just above for definition. - */ - struct less_value : public std::binary_function - { - inline bool operator()(const ijv& a, const ijv& b) const - { - return a.v() < b.v(); - } - }; - - //-------------------------------------------------------------------------------- - /** - * See just above for definition. - */ - struct greater_value : public std::binary_function - { - inline bool operator()(const ijv& a, const ijv& b) const - { - return a.v() > b.v(); - } - }; }; - - //-------------------------------------------------------------------------------- - /** - * These templates allow the implementation to use the right function depending on - * the data type, which yield significant speed-ups when working with floats. - * They also extend STL's arithmetic operations. - */ - //-------------------------------------------------------------------------------- - // Unary functions - //-------------------------------------------------------------------------------- +}; - template - struct Identity : public std::unary_function - { - inline T& operator()(T& x) const { return x; } - inline const T& operator()(const T& x) const { return x; } - }; +//-------------------------------------------------------------------------------- +/** + * These templates allow the implementation to use the right function depending + * on the data type, which yield significant speed-ups when working with floats. + * They also extend STL's arithmetic operations. + */ +//-------------------------------------------------------------------------------- +// Unary functions +//-------------------------------------------------------------------------------- - template - struct Negate : public std::unary_function - { - inline T operator()(const T& x) const { return -x; } - }; +template struct Identity : public std::unary_function { + inline T &operator()(T &x) const { return x; } + inline const T &operator()(const T &x) const { return x; } +}; - template - struct Abs : public std::unary_function - { - inline T operator()(const T& x) const { return x > 0.0 ? x : -x; } - }; - - template - struct Square : public std::unary_function - { - inline T operator()(const T& x) const { return x*x; } - }; +template struct Negate : public std::unary_function { + inline T operator()(const T &x) const { return -x; } +}; - template - struct Cube : public std::unary_function - { - inline T operator()(const T& x) const { return x*x*x; } - }; +template struct Abs : public std::unary_function { + inline T operator()(const T &x) const { return x > 0.0 ? x : -x; } +}; - template - struct Inverse : public std::unary_function - { - inline T operator()(const T& x) const { return 1.0/x; } - }; +template struct Square : public std::unary_function { + inline T operator()(const T &x) const { return x * x; } +}; - template - struct Sqrt : public std::unary_function - {}; +template struct Cube : public std::unary_function { + inline T operator()(const T &x) const { return x * x * x; } +}; - template <> - struct Sqrt : public std::unary_function - { - inline float operator()(const float& x) const { return sqrtf(x); } - }; +template struct Inverse : public std::unary_function { + inline T operator()(const T &x) const { return 1.0 / x; } +}; - template <> - struct Sqrt : public std::unary_function - { - inline double operator()(const double& x) const { return sqrt(x); } - }; +template struct Sqrt : public std::unary_function {}; - template - struct Exp : public std::unary_function - {}; - - template <> - struct Exp : public std::unary_function - { - // On x86_64, there is a bug in glibc that makes expf very slow - // (more than it should be), so we continue using exp on that - // platform as a workaround. - // https://bugzilla.redhat.com/show_bug.cgi?id=521190 - // To force the compiler to use exp instead of expf, the return - // type (and not the argument type!) needs to be double. - inline float operator()(const float& x) const { return expf(x); } - }; +template <> struct Sqrt : public std::unary_function { + inline float operator()(const float &x) const { return sqrtf(x); } +}; - template <> - struct Exp : public std::unary_function - { - inline double operator()(const double& x) const { return exp(x); } - }; +template <> struct Sqrt : public std::unary_function { + inline double operator()(const double &x) const { return sqrt(x); } +}; - template - struct Log : public std::unary_function - {}; +template struct Exp : public std::unary_function {}; - template <> - struct Log : public std::unary_function - { - inline float operator()(const float& x) const { return logf(x); } - }; +template <> struct Exp : public std::unary_function { + // On x86_64, there is a bug in glibc that makes expf very slow + // (more than it should be), so we continue using exp on that + // platform as a workaround. + // https://bugzilla.redhat.com/show_bug.cgi?id=521190 + // To force the compiler to use exp instead of expf, the return + // type (and not the argument type!) needs to be double. + inline float operator()(const float &x) const { return expf(x); } +}; - template <> - struct Log : public std::unary_function - { - inline double operator()(const double& x) const { return log(x); } - }; +template <> struct Exp : public std::unary_function { + inline double operator()(const double &x) const { return exp(x); } +}; + +template struct Log : public std::unary_function {}; - template - struct Log2 : public std::unary_function - {}; +template <> struct Log : public std::unary_function { + inline float operator()(const float &x) const { return logf(x); } +}; - template <> - struct Log2 : public std::unary_function - { - inline float operator()(const float& x) const - { +template <> struct Log : public std::unary_function { + inline double operator()(const double &x) const { return log(x); } +}; + +template struct Log2 : public std::unary_function {}; + +template <> struct Log2 : public std::unary_function { + inline float operator()(const float &x) const { #if defined(NTA_OS_WINDOWS) - return (float) (log(x) / log(2.0)); + return (float)(log(x) / log(2.0)); #else - return log2f(x); + return log2f(x); #endif - } - }; + } +}; - template <> - struct Log2 : public std::unary_function - { - inline double operator()(const double& x) const - { +template <> struct Log2 : public std::unary_function { + inline double operator()(const double &x) const { #if defined(NTA_OS_WINDOWS) - return log(x) / log(2.0); + return log(x) / log(2.0); #else - return log2(x); + return log2(x); #endif - } - }; + } +}; - template - struct Log10 : public std::unary_function - {}; +template struct Log10 : public std::unary_function {}; - template <> - struct Log10 : public std::unary_function - { - inline float operator()(const float& x) const - { +template <> struct Log10 : public std::unary_function { + inline float operator()(const float &x) const { #if defined(NTA_OS_WINDOWS) - return (float) (log(x) / log(10.0)); + return (float)(log(x) / log(10.0)); #else - return log10f(x); + return log10f(x); #endif - } - }; + } +}; - template <> - struct Log10 : public std::unary_function - { - inline double operator()(const double& x) const - { +template <> struct Log10 : public std::unary_function { + inline double operator()(const double &x) const { #if defined(NTA_OS_WINDOWS) - return log(x) / log(10.0); + return log(x) / log(10.0); #else - return log10(x); + return log10(x); #endif - } - }; + } +}; - template - struct Log1p : public std::unary_function - {}; +template struct Log1p : public std::unary_function {}; - template <> - struct Log1p : public std::unary_function - { - inline float operator()(const float& x) const - { +template <> struct Log1p : public std::unary_function { + inline float operator()(const float &x) const { #if defined(NTA_OS_WINDOWS) - return (float) log(1.0 + x); + return (float)log(1.0 + x); #else - return log1pf(x); + return log1pf(x); #endif - } - }; + } +}; - template <> - struct Log1p : public std::unary_function - { - inline double operator()(const double& x) const - { +template <> struct Log1p : public std::unary_function { + inline double operator()(const double &x) const { #if defined(NTA_OS_WINDOWS) - return log(1.0 + x); + return log(1.0 + x); #else - return log1p(x); + return log1p(x); #endif - } - }; + } +}; + +/** + * Numerical approximation of derivative. + * Error is h^4 y^5/30. + */ +template +struct Derivative : public std::unary_function { + Derivative(const F &f) : f_(f) {} + + F f_; /** - * Numerical approximation of derivative. - * Error is h^4 y^5/30. + * Approximates the derivative of F at x. */ - template - struct Derivative : public std::unary_function - { - Derivative(const F& f) : f_(f) {} - - F f_; - - /** - * Approximates the derivative of F at x. - */ - inline const Float operator()(const Float& x) const - { - const Float h = nupic::Epsilon; - return (-f_(x+2*h)+8*f_(x+h)-8*f_(x-h)+f_(x-2*h))/(12*h); - } - }; + inline const Float operator()(const Float &x) const { + const Float h = nupic::Epsilon; + return (-f_(x + 2 * h) + 8 * f_(x + h) - 8 * f_(x - h) + f_(x - 2 * h)) / + (12 * h); + } +}; - //-------------------------------------------------------------------------------- - // Binary functions - //-------------------------------------------------------------------------------- - template - struct Assign : public std::binary_function - { - inline T operator()(T& x, const T& y) const { x = y; return x; } - }; +//-------------------------------------------------------------------------------- +// Binary functions +//-------------------------------------------------------------------------------- +template struct Assign : public std::binary_function { + inline T operator()(T &x, const T &y) const { + x = y; + return x; + } +}; - template - struct Plus : public std::binary_function - { - inline T operator()(const T& x, const T& y) const { return x + y; } - }; +template struct Plus : public std::binary_function { + inline T operator()(const T &x, const T &y) const { return x + y; } +}; - template - struct Minus : public std::binary_function - { - inline T operator()(const T& x, const T& y) const { return x - y; } - }; +template struct Minus : public std::binary_function { + inline T operator()(const T &x, const T &y) const { return x - y; } +}; - template - struct Multiplies : public std::binary_function - { - inline T operator()(const T& x, const T& y) const { return x * y; } - }; +template struct Multiplies : public std::binary_function { + inline T operator()(const T &x, const T &y) const { return x * y; } +}; - template - struct Divides : public std::binary_function - { - inline T operator()(const T& x, const T& y) const { return x / y; } - }; +template struct Divides : public std::binary_function { + inline T operator()(const T &x, const T &y) const { return x / y; } +}; - template - struct Pow : public std::binary_function - {}; +template struct Pow : public std::binary_function {}; - template <> - struct Pow : public std::binary_function - { - inline float operator()(const float& x, const float& y) const - { - return powf(x,y); - } - }; +template <> +struct Pow : public std::binary_function { + inline float operator()(const float &x, const float &y) const { + return powf(x, y); + } +}; - template <> - struct Pow : public std::binary_function - { - inline double operator()(const double& x, const double& y) const - { - return pow(x,y); - } - }; +template <> +struct Pow : public std::binary_function { + inline double operator()(const double &x, const double &y) const { + return pow(x, y); + } +}; - template - struct Logk : public std::binary_function - {}; +template struct Logk : public std::binary_function {}; - template <> - struct Logk : public std::binary_function - { - inline float operator()(const float& x, const float& y) const - { - return logf(x)/logf(y); - } - }; +template <> +struct Logk : public std::binary_function { + inline float operator()(const float &x, const float &y) const { + return logf(x) / logf(y); + } +}; - template <> - struct Logk : public std::binary_function - { - inline double operator()(const double& x, const double& y) const - { - return log(x)/log(y); - } - }; +template <> +struct Logk : public std::binary_function { + inline double operator()(const double &x, const double &y) const { + return log(x) / log(y); + } +}; - template - struct Max : public std::binary_function - { - inline T operator()(const T& x, const T& y) const { return x > y ? x : y; } - }; +template struct Max : public std::binary_function { + inline T operator()(const T &x, const T &y) const { return x > y ? x : y; } +}; - template - struct Min : public std::binary_function - { - inline T operator()(const T& x, const T& y) const { return x < y ? x : y; } - }; +template struct Min : public std::binary_function { + inline T operator()(const T &x, const T &y) const { return x < y ? x : y; } +}; - //-------------------------------------------------------------------------------- - /** - * Gaussian: - * y = 1/(sigma * sqrt(2*pi)) * exp(-(x-mu)^2/(2*sigma^2)) as a functor. - */ - template - struct Gaussian : public std::unary_function - { - T k1, k2, mu; - - inline Gaussian(T m, T s) - : k1(0.0), - k2(0.0), - mu(m) - { - // For some reason, SWIG cannot parse 1 / (x), the parentheses in the - // denominator don't agree with it, so we have to initialize those - // constants here. - k1 = 1.0 / sqrt(2.0 * 3.1415926535); - k2 = -1.0 / (2.0 * s * s); +//-------------------------------------------------------------------------------- +/** + * Gaussian: + * y = 1/(sigma * sqrt(2*pi)) * exp(-(x-mu)^2/(2*sigma^2)) as a functor. + */ +template struct Gaussian : public std::unary_function { + T k1, k2, mu; + + inline Gaussian(T m, T s) : k1(0.0), k2(0.0), mu(m) { + // For some reason, SWIG cannot parse 1 / (x), the parentheses in the + // denominator don't agree with it, so we have to initialize those + // constants here. + k1 = 1.0 / sqrt(2.0 * 3.1415926535); + k2 = -1.0 / (2.0 * s * s); + } + + inline Gaussian(const Gaussian &o) : k1(o.k1), k2(o.k2), mu(o.mu) {} + + inline Gaussian &operator=(const Gaussian &o) { + if (&o != this) { + k1 = o.k1; + k2 = o.k2; + mu = o.mu; } - inline Gaussian(const Gaussian& o) - : k1(o.k1), k2(o.k2), mu(o.mu) - {} + return *this; + } - inline Gaussian& operator=(const Gaussian& o) - { - if (&o != this) { - k1 = o.k1; - k2 = o.k2; - mu = o.mu; - } + inline T operator()(T x) const { + T v = x - mu; + return k1 * exp(k2 * v * v); + } +}; - return *this; - } +//-------------------------------------------------------------------------------- +/** + * 2D Gaussian + */ +template +struct Gaussian2D // : public std::binary_function (SWIG pb) +{ + T c_x, c_y, s00, s01, s10, s11, s2, k1; + + inline Gaussian2D(T c_x_, T c_y_, T s00_, T s01_, T s10_, T s11_) + : c_x(c_x_), c_y(c_y_), s00(s00_), s01(s01_), s10(s10_), s11(s11_), + s2(s10 + s01), k1(0.0) { + // For some reason, SWIG cannot parse 1 / (x), the parentheses in the + // denominator don't agree with it, so we have to initialize those + // constants here. + k1 = 1.0 / (2.0 * 3.1415926535 * sqrt(s00 * s11 - s10 * s01)); + T d = -2.0 * (s00 * s11 - s10 * s01); + s00 /= d; + s01 /= d; + s10 /= d; + s11 /= d; + s2 /= d; + } - inline T operator()(T x) const - { - T v = x - mu; - return k1 * exp(k2 * v * v); - } - }; + inline Gaussian2D(const Gaussian2D &o) + : c_x(o.c_x), c_y(o.c_y), s00(o.s00), s01(o.s01), s10(o.s10), s11(o.s11), + s2(o.s2), k1(o.k1) {} - //-------------------------------------------------------------------------------- - /** - * 2D Gaussian - */ - template - struct Gaussian2D // : public std::binary_function (SWIG pb) - { - T c_x, c_y, s00, s01, s10, s11, s2, k1; - - inline Gaussian2D(T c_x_, T c_y_, T s00_, T s01_, T s10_, T s11_) - : c_x(c_x_), c_y(c_y_), - s00(s00_), s01(s01_), s10(s10_), s11(s11_), s2(s10 + s01), - k1(0.0) - { - // For some reason, SWIG cannot parse 1 / (x), the parentheses in the - // denominator don't agree with it, so we have to initialize those - // constants here. - k1 = 1.0/(2.0 * 3.1415926535 * sqrt(s00 * s11 - s10 * s01)); - T d = -2.0 * (s00 * s11 - s10 * s01); - s00 /= d; - s01 /= d; - s10 /= d; - s11 /= d; - s2 /= d; + inline Gaussian2D &operator=(const Gaussian2D &o) { + if (&o != this) { + c_x = o.c_x; + c_y = o.c_y; + s00 = o.s00; + s01 = o.s01; + s10 = o.s10; + s11 = o.s11; + s2 = o.s2; + k1 = o.k1; } - inline Gaussian2D(const Gaussian2D& o) - : c_x(o.c_x), c_y(o.c_y), - s00(o.s00), s01(o.s01), s10(o.s10), s11(o.s11), s2(o.s2), - k1(o.k1) - {} - - inline Gaussian2D& operator=(const Gaussian2D& o) - { - if (&o != this) { - c_x = o.c_x; c_y = o.c_y; - s00 = o.s00; s01 = o.s01; s10 = o.s10; s11 = o.s11; - s2 = o.s2; - k1 = o.k1; - } - - return *this; - } - - inline T operator()(T x, T y) const - { - T v0 = x - c_x, v1 = y - c_y; - return k1 * exp(s11 * v0 * v0 + s2 * v0 * v1 + s00 * v1 * v1); - } - }; + return *this; + } - //-------------------------------------------------------------------------------- - /** - * Compose two unary functions. - */ - template - struct unary_compose - : public std::unary_function - { - typedef typename F1::argument_type argument_type; - typedef typename F2::result_type result_type; - - F1 f1; - F2 f2; - - inline result_type operator()(const argument_type& x) const - { - return f2(f1(x)); - } - }; + inline T operator()(T x, T y) const { + T v0 = x - c_x, v1 = y - c_y; + return k1 * exp(s11 * v0 * v0 + s2 * v0 * v1 + s00 * v1 * v1); + } +}; - //-------------------------------------------------------------------------------- - /** - * Compose an order predicate and a binary selector, so that we can write: - * sort(x.begin(), x.end(), compose, select2nd > >()); - * to sort pairs in increasing order of their second element. - */ - template - struct predicate_compose - : public std::binary_function - { - typedef bool result_type; - typedef typename S::argument_type argument_type; - - O o; - S s; - - inline result_type operator()(const argument_type& x, const argument_type& y) const - { - return o(s(x), s(y)); - } - }; +//-------------------------------------------------------------------------------- +/** + * Compose two unary functions. + */ +template +struct unary_compose : public std::unary_function { + typedef typename F1::argument_type argument_type; + typedef typename F2::result_type result_type; - //-------------------------------------------------------------------------------- - /** - * When dividing by a value less than min_exponent10, inf will be generated. - * numeric_limits::min_exponent10 = -37 - * numeric_limits::min_exponent10 = -307 - */ - template - inline bool isSafeForDivision(const T& x) - { - Log log_f; - return log_f(x) >= std::numeric_limits::min_exponent10; + F1 f1; + F2 f2; + + inline result_type operator()(const argument_type &x) const { + return f2(f1(x)); } +}; - //-------------------------------------------------------------------------------- - /** - * Returns the value passed in or a threshold if the value is >= threshold. - */ - template - struct ClipAbove : public std::unary_function - { - inline ClipAbove(const T& val) - : val_(val) - {} - - inline ClipAbove(const ClipAbove& c) - : val_(c.val_) - {} - - inline ClipAbove& operator=(const ClipAbove& c) - { - if (this != &c) - val_ = c.val_; - - return *this; - } +//-------------------------------------------------------------------------------- +/** + * Compose an order predicate and a binary selector, so that we can write: + * sort(x.begin(), x.end(), compose, select2nd > + * >()); to sort pairs in increasing order of their second element. + */ +template +struct predicate_compose + : public std::binary_function { + typedef bool result_type; + typedef typename S::argument_type argument_type; + + O o; + S s; + + inline result_type operator()(const argument_type &x, + const argument_type &y) const { + return o(s(x), s(y)); + } +}; - inline T operator()(const T& x) const - { - return x >= val_ ? val_ : x; - } +//-------------------------------------------------------------------------------- +/** + * When dividing by a value less than min_exponent10, inf will be generated. + * numeric_limits::min_exponent10 = -37 + * numeric_limits::min_exponent10 = -307 + */ +template inline bool isSafeForDivision(const T &x) { + Log log_f; + return log_f(x) >= std::numeric_limits::min_exponent10; +} - T val_; - }; +//-------------------------------------------------------------------------------- +/** + * Returns the value passed in or a threshold if the value is >= threshold. + */ +template struct ClipAbove : public std::unary_function { + inline ClipAbove(const T &val) : val_(val) {} - //-------------------------------------------------------------------------------- - /** - * Returns the value passed in or a threshold if the value is < threshold. - */ - template - struct ClipBelow : public std::unary_function - { - inline ClipBelow(const T& val) - : val_(val) - {} - - inline ClipBelow(const ClipBelow& c) - : val_(c.val_) - {} - - inline ClipBelow& operator=(const ClipBelow& c) - { - if (this != &c) - val_ = c.val_; - - return *this; - } + inline ClipAbove(const ClipAbove &c) : val_(c.val_) {} - inline T operator()(const T& x) const - { - return x < val_ ? val_ : x; - } + inline ClipAbove &operator=(const ClipAbove &c) { + if (this != &c) + val_ = c.val_; - T val_; - }; + return *this; + } - //-------------------------------------------------------------------------------- + inline T operator()(const T &x) const { return x >= val_ ? val_ : x; } + + T val_; +}; + +//-------------------------------------------------------------------------------- +/** + * Returns the value passed in or a threshold if the value is < threshold. + */ +template struct ClipBelow : public std::unary_function { + inline ClipBelow(const T &val) : val_(val) {} + + inline ClipBelow(const ClipBelow &c) : val_(c.val_) {} + + inline ClipBelow &operator=(const ClipBelow &c) { + if (this != &c) + val_ = c.val_; + + return *this; + } + + inline T operator()(const T &x) const { return x < val_ ? val_ : x; } + + T val_; +}; + +//-------------------------------------------------------------------------------- }; // namespace nupic diff --git a/src/nupic/math/NearestNeighbor.hpp b/src/nupic/math/NearestNeighbor.hpp index 8da43eb0e7..ade3654f60 100644 --- a/src/nupic/math/NearestNeighbor.hpp +++ b/src/nupic/math/NearestNeighbor.hpp @@ -20,1174 +20,1096 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Definition for NearestNeighbor */ #ifndef NTA_NEAREST_NEIGHBOR_HPP #define NTA_NEAREST_NEIGHBOR_HPP -#include #include +#include //---------------------------------------------------------------------- namespace nupic { - template - class NearestNeighbor : public T - { - public: - typedef T parent_type; - typedef NearestNeighbor self_type; - - typedef typename parent_type::size_type size_type; - typedef typename parent_type::difference_type difference_type; - typedef typename parent_type::value_type value_type; - typedef typename parent_type::prec_value_type prec_value_type; - - //-------------------------------------------------------------------------------- - // CONSTRUCTORS - //-------------------------------------------------------------------------------- - inline NearestNeighbor() - : parent_type() - {} - - //-------------------------------------------------------------------------------- - /** - * Constructor with a number of columns and a hint for the number - * of rows. The SparseMatrix is empty. - * - * @param nrows [size_type >= 0] number of rows - * @param ncols [size_type >= 0] number of columns - * - * @b Exceptions: - * @li nrows < 0 (check) - * @li ncols < 0 (check) - * @li Not enough memory (error) - */ - inline NearestNeighbor(size_type nrows, size_type ncols) - : parent_type(nrows, ncols) - {} - - //-------------------------------------------------------------------------------- - /** - * Constructor from a dense matrix passed as an array of value_type. - * Uses the values in mat to initialize the SparseMatrix. - * - * @param nrows [size_type >= 0] number of rows - * @param ncols [size_type >= 0] number of columns - * @param dense [value_type** != NULL] initial array of values - * - * @b Exceptions: - * @li nrows <= 0 (check) - * @li ncols <= 0 (check) - * @li mat == NULL (check) - * @li NULL pointer in mat (check) - * @li Not enough memory (error) - */ - template - inline NearestNeighbor(size_type nrows, size_type ncols, InputIterator dense) - : parent_type(nrows, ncols, dense) - {} - - //-------------------------------------------------------------------------------- - /** - * Constructor from a stream in CSR format (don't forget number of bytes after - * 'csr' tag!). - */ - inline NearestNeighbor(std::istream& inStream) - : parent_type(inStream) - {} - - //-------------------------------------------------------------------------------- - /** - * Copy constructor. - * - * TODO copy part of a matrix? - * - * Copies the given NearestNeighbor into this one. - */ - inline NearestNeighbor(const NearestNeighbor& other) - : parent_type(other) - {} - - //-------------------------------------------------------------------------------- - /** - * Assignment operator. - */ - inline NearestNeighbor& operator=(const NearestNeighbor& other) - { - parent_type::operator=(other); - return *this; - } +template class NearestNeighbor : public T { +public: + typedef T parent_type; + typedef NearestNeighbor self_type; - //-------------------------------------------------------------------------------- - /** - * A method that computes the powers of x and their sum, according to f. - */ - private: - template - inline void - compute_powers_(value_type& Sp_x, value_type* p_x, InputIterator x, F f) - { - const size_type ncols = this->nCols(); - InputIterator end1 = x + 4*(ncols/4), end2 = x + ncols; - Sp_x = (value_type) 0; - - for (; x != end1; x += 4, p_x += 4) { - *p_x = f(Sp_x, *x); - *(p_x+1) = f(Sp_x, *(x+1)); - *(p_x+2) = f(Sp_x, *(x+2)); - *(p_x+3) = f(Sp_x, *(x+3)); - } + typedef typename parent_type::size_type size_type; + typedef typename parent_type::difference_type difference_type; + typedef typename parent_type::value_type value_type; + typedef typename parent_type::prec_value_type prec_value_type; - for (; x != end2; ++x) - *p_x++ = f(Sp_x, *x); - } + //-------------------------------------------------------------------------------- + // CONSTRUCTORS + //-------------------------------------------------------------------------------- + inline NearestNeighbor() : parent_type() {} - //-------------------------------------------------------------------------------- - /** - * A method that computes the sum of powers of the difference between x - * and a given row. - */ - private: - template - inline value_type - sum_of_p_diff_(size_type row, InputIterator x, value_type Sp_x, value_type *p_x, - F f) const - { - size_type nnzr = this->nnzr_[row], j, *ind = this->ind_[row]; - value_type *nz = this->nz_[row]; - value_type val = Sp_x, val1 = 0, val2 = 0; - size_type *end1 = ind + 4*(nnzr/4), *end2 = ind + nnzr; - - while (ind != end1) { - j = *ind++; - val1 = *nz++ - x[j]; - f(val, val1); - val -= p_x[j]; - j = *ind++; - val2 = *nz++ - x[j]; - f(val, val2); - val -= p_x[j]; - j = *ind++; - val1 = *nz++ - x[j]; - f(val, val1); - val -= p_x[j]; - j = *ind++; - val2 = *nz++ - x[j]; - f(val, val2); - val -= p_x[j]; - } - - while (ind != end2) { - j = *ind++; - val1 = *nz++ - x[j]; - f(val, val1); - val -= p_x[j]; - } + //-------------------------------------------------------------------------------- + /** + * Constructor with a number of columns and a hint for the number + * of rows. The SparseMatrix is empty. + * + * @param nrows [size_type >= 0] number of rows + * @param ncols [size_type >= 0] number of columns + * + * @b Exceptions: + * @li nrows < 0 (check) + * @li ncols < 0 (check) + * @li Not enough memory (error) + */ + inline NearestNeighbor(size_type nrows, size_type ncols) + : parent_type(nrows, ncols) {} - // Accuracy issues because of the subtractions, - // could return negative values - if (val <= (value_type) 0) - val = (value_type) 0; + //-------------------------------------------------------------------------------- + /** + * Constructor from a dense matrix passed as an array of value_type. + * Uses the values in mat to initialize the SparseMatrix. + * + * @param nrows [size_type >= 0] number of rows + * @param ncols [size_type >= 0] number of columns + * @param dense [value_type** != NULL] initial array of values + * + * @b Exceptions: + * @li nrows <= 0 (check) + * @li ncols <= 0 (check) + * @li mat == NULL (check) + * @li NULL pointer in mat (check) + * @li Not enough memory (error) + */ + template + inline NearestNeighbor(size_type nrows, size_type ncols, InputIterator dense) + : parent_type(nrows, ncols, dense) {} + + //-------------------------------------------------------------------------------- + /** + * Constructor from a stream in CSR format (don't forget number of bytes after + * 'csr' tag!). + */ + inline NearestNeighbor(std::istream &inStream) : parent_type(inStream) {} + + //-------------------------------------------------------------------------------- + /** + * Copy constructor. + * + * TODO copy part of a matrix? + * + * Copies the given NearestNeighbor into this one. + */ + inline NearestNeighbor(const NearestNeighbor &other) : parent_type(other) {} - return val; + //-------------------------------------------------------------------------------- + /** + * Assignment operator. + */ + inline NearestNeighbor &operator=(const NearestNeighbor &other) { + parent_type::operator=(other); + return *this; + } + + //-------------------------------------------------------------------------------- + /** + * A method that computes the powers of x and their sum, according to f. + */ +private: + template + inline void compute_powers_(value_type &Sp_x, value_type *p_x, + InputIterator x, F f) { + const size_type ncols = this->nCols(); + InputIterator end1 = x + 4 * (ncols / 4), end2 = x + ncols; + Sp_x = (value_type)0; + + for (; x != end1; x += 4, p_x += 4) { + *p_x = f(Sp_x, *x); + *(p_x + 1) = f(Sp_x, *(x + 1)); + *(p_x + 2) = f(Sp_x, *(x + 2)); + *(p_x + 3) = f(Sp_x, *(x + 3)); } - //-------------------------------------------------------------------------------- - /** - * A method that computes the distance between x and the specified row, - * parameterized on the norm function. Can be instantiated for L0, L1 and Lmax. - * Here, we don't need to worry about taking a root, and there are no powers of x - * to cache, as opposed to one_row_dist_2. The complexity is: nnzr*f(). - */ - private: - template - inline value_type - one_row_dist_1(size_type row, InputIterator x, F f) const - { - const size_type ncols = this->nCols(); - size_type *ind = this->ind_[row], *ind_end = ind + this->nnzr_[row], j = 0; - value_type *nz = this->nz_[row], d = (value_type) 0.0; - - while (ind != ind_end) { - size_type j_end = *ind++; - while (j != j_end) - f(d, x[j++]); - f(d, x[j++] - *nz++); - } - - if (j < ncols) - while (j != ncols) - f(d, x[j++]); + for (; x != end2; ++x) + *p_x++ = f(Sp_x, *x); + } - return d; + //-------------------------------------------------------------------------------- + /** + * A method that computes the sum of powers of the difference between x + * and a given row. + */ +private: + template + inline value_type sum_of_p_diff_(size_type row, InputIterator x, + value_type Sp_x, value_type *p_x, + F f) const { + size_type nnzr = this->nnzr_[row], j, *ind = this->ind_[row]; + value_type *nz = this->nz_[row]; + value_type val = Sp_x, val1 = 0, val2 = 0; + size_type *end1 = ind + 4 * (nnzr / 4), *end2 = ind + nnzr; + + while (ind != end1) { + j = *ind++; + val1 = *nz++ - x[j]; + f(val, val1); + val -= p_x[j]; + j = *ind++; + val2 = *nz++ - x[j]; + f(val, val2); + val -= p_x[j]; + j = *ind++; + val1 = *nz++ - x[j]; + f(val, val1); + val -= p_x[j]; + j = *ind++; + val2 = *nz++ - x[j]; + f(val, val2); + val -= p_x[j]; } - //-------------------------------------------------------------------------------- - /** - * A method that computes the distance between x and the specified row, - * parameterized on the norm function. Can be instantiated for L2 or Lp. - * Caches the powers of x in nzb_, so that we achieve complexity: - * nnzr*(2*diff+abs+f.power()) + ncols*f.power(), instead of: - * nnzr*(diff+abs+f.power()) + nrows*ncols*f.power(). - */ - private: - template - inline value_type - one_row_dist_2(size_type row, InputIterator x, F f, - bool take_root =false) const - { - value_type Sp_x = 0.0; + while (ind != end2) { + j = *ind++; + val1 = *nz++ - x[j]; + f(val, val1); + val -= p_x[j]; + } - const_cast(this)->compute_powers_(Sp_x, this->nzb_, x, f); - value_type val = sum_of_p_diff_(row, x, Sp_x, this->nzb_, f); + // Accuracy issues because of the subtractions, + // could return negative values + if (val <= (value_type)0) + val = (value_type)0; - if (take_root) - val = f.root(val); + return val; + } - return val; + //-------------------------------------------------------------------------------- + /** + * A method that computes the distance between x and the specified row, + * parameterized on the norm function. Can be instantiated for L0, L1 and + * Lmax. Here, we don't need to worry about taking a root, and there are no + * powers of x to cache, as opposed to one_row_dist_2. The complexity is: + * nnzr*f(). + */ +private: + template + inline value_type one_row_dist_1(size_type row, InputIterator x, F f) const { + const size_type ncols = this->nCols(); + size_type *ind = this->ind_[row], *ind_end = ind + this->nnzr_[row], j = 0; + value_type *nz = this->nz_[row], d = (value_type)0.0; + + while (ind != ind_end) { + size_type j_end = *ind++; + while (j != j_end) + f(d, x[j++]); + f(d, x[j++] - *nz++); } - //-------------------------------------------------------------------------------- - /** - * A method that computes the distances between x and all the rows in the matrix. - * The method is parameterized on the norm to use, F. - * Although it looks very similar to one_row_dist_1, the sums of powers are computed - * once for all the rows here, where they would be computed for each row if - * one_row_dist_1 was called. The complexity is: - * ncols*(pow) + nnz*(2*diff+f.power()) + nrows*f.root(). - */ - private: - template - inline void - all_rows_dist_(InputIterator x, OutputIterator y, F f, - bool take_root =false) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::all_rows_dist_(): " - << "No vector stored yet"; - } // End pre-conditions - - const size_type nrows = this->nRows(); - OutputIterator y_begin = y, y_end = y + nrows; - value_type Sp_x = 0.0; - - const_cast(this)->compute_powers_(Sp_x, this->nzb_, x, f); - - for (size_type i = 0; i != nrows; ++i, ++y) - *y = sum_of_p_diff_(i, x, Sp_x, this->nzb_, f); - - if (take_root) - for (y = y_begin; y != y_end; ++y) - *y = f.root(*y); - } + if (j < ncols) + while (j != ncols) + f(d, x[j++]); - //-------------------------------------------------------------------------------- - /** - * A method that finds the k top nearest neighbor, as defined by F. - */ - private: - template - inline void - k_nearest_(InputIterator x, OutputIterator nn, F f, - size_type k =1, bool take_root =false) const - { - { // Pre-conditions - NTA_ASSERT(k >= 1) - << "NearestNeighbor::k_nearest_(): " - << "Invalid number of nearest rows: " << k - << " - Should be >= 1, default value is 1"; - - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::k_nearest_(): " - << "No vector stored yet"; - } + return d; + } - std::vector b(this->nRows()); - all_rows_dist_(x, b.begin(), f, take_root); - partial_sort_2nd(k, b, nn, std::less()); - } + //-------------------------------------------------------------------------------- + /** + * A method that computes the distance between x and the specified row, + * parameterized on the norm function. Can be instantiated for L2 or Lp. + * Caches the powers of x in nzb_, so that we achieve complexity: + * nnzr*(2*diff+abs+f.power()) + ncols*f.power(), instead of: + * nnzr*(diff+abs+f.power()) + nrows*ncols*f.power(). + */ +private: + template + inline value_type one_row_dist_2(size_type row, InputIterator x, F f, + bool take_root = false) const { + value_type Sp_x = 0.0; + + const_cast(this)->compute_powers_(Sp_x, this->nzb_, x, f); + value_type val = sum_of_p_diff_(row, x, Sp_x, this->nzb_, f); + + if (take_root) + val = f.root(val); + + return val; + } - public: - //-------------------------------------------------------------------------------- - /** - * Computes the distance between vector x and a given row - * of this NearestNeighbor, using the L0 (Hamming) distance: - * - * dist(row, x) = sum(| row[i] - x[i] | > epsilon) - * - * Computations are performed on the non-zeros only. - * - * Non-mutating, O(nnzr) - * - * @param row [0 <= size_type < nrows] index of row to compute distance from - * @param x [InputIterator] x vector - * @retval [value_type] distance from x to row of index 'row' - * - * @b Exceptions: - * @li row < 0 || row >= nrows (check) - */ - template - inline value_type rowL0Dist(size_type row, InputIterator x) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::rowL0Dist(): " - << "No vector stored yet"; - - NTA_ASSERT(row >= 0 && row < this->nRows()) - << "NearestNeighbor::rowL0Dist(): " - << "Invalid row index: " << row - << " - Should be >= 0 and < nrows = " << this->nRows(); - } + //-------------------------------------------------------------------------------- + /** + * A method that computes the distances between x and all the rows in the + * matrix. The method is parameterized on the norm to use, F. Although it + * looks very similar to one_row_dist_1, the sums of powers are computed once + * for all the rows here, where they would be computed for each row if + * one_row_dist_1 was called. The complexity is: + * ncols*(pow) + nnz*(2*diff+f.power()) + nrows*f.root(). + */ +private: + template + inline void all_rows_dist_(InputIterator x, OutputIterator y, F f, + bool take_root = false) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::all_rows_dist_(): " + << "No vector stored yet"; + } // End pre-conditions + + const size_type nrows = this->nRows(); + OutputIterator y_begin = y, y_end = y + nrows; + value_type Sp_x = 0.0; + + const_cast(this)->compute_powers_(Sp_x, this->nzb_, x, f); + + for (size_type i = 0; i != nrows; ++i, ++y) + *y = sum_of_p_diff_(i, x, Sp_x, this->nzb_, f); + + if (take_root) + for (y = y_begin; y != y_end; ++y) + *y = f.root(*y); + } - return one_row_dist_1(row, x, Lp0()); + //-------------------------------------------------------------------------------- + /** + * A method that finds the k top nearest neighbor, as defined by F. + */ +private: + template + inline void k_nearest_(InputIterator x, OutputIterator nn, F f, + size_type k = 1, bool take_root = false) const { + { // Pre-conditions + NTA_ASSERT(k >= 1) << "NearestNeighbor::k_nearest_(): " + << "Invalid number of nearest rows: " << k + << " - Should be >= 1, default value is 1"; + + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::k_nearest_(): " + << "No vector stored yet"; } - //-------------------------------------------------------------------------------- - /** - * Computes the distance between vector x and a given row - * of this NearestNeighbor, using the L1 (Manhattan) distance: - * - * dist(row, x) = sum(| row[i] - x[i] |) - * - * Computations are performed on the non-zeros only. - * - * Non-mutating, O(nnzr) - * - * @param row [0 <= size_type < nrows] index of row to compute distance from - * @param x [InputIterator] x vector - * @retval [value_type] distance from x to row of index 'row' - * - * @b Exceptions: - * @li row < 0 || row >= nrows (check) - */ - template - inline value_type rowL1Dist(size_type row, InputIterator x) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::rowL1Dist(): " - << "No vector stored yet"; - - NTA_ASSERT(row >= 0 && row < this->nRows()) - << "NearestNeighbor::rowL1Dist(): " - << "Invalid row index: " << row - << " - Should be >= 0 and < nrows = " << this->nRows(); - } + std::vector b(this->nRows()); + all_rows_dist_(x, b.begin(), f, take_root); + partial_sort_2nd(k, b, nn, std::less()); + } - return one_row_dist_1(row, x, Lp1()); +public: + //-------------------------------------------------------------------------------- + /** + * Computes the distance between vector x and a given row + * of this NearestNeighbor, using the L0 (Hamming) distance: + * + * dist(row, x) = sum(| row[i] - x[i] | > epsilon) + * + * Computations are performed on the non-zeros only. + * + * Non-mutating, O(nnzr) + * + * @param row [0 <= size_type < nrows] index of row to compute distance from + * @param x [InputIterator] x vector + * @retval [value_type] distance from x to row of index 'row' + * + * @b Exceptions: + * @li row < 0 || row >= nrows (check) + */ + template + inline value_type rowL0Dist(size_type row, InputIterator x) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::rowL0Dist(): " + << "No vector stored yet"; + + NTA_ASSERT(row >= 0 && row < this->nRows()) + << "NearestNeighbor::rowL0Dist(): " + << "Invalid row index: " << row + << " - Should be >= 0 and < nrows = " << this->nRows(); } - //-------------------------------------------------------------------------------- - /** - * Computes the distance between vector x and a given row - * of this NearestNeighbor, using the Euclidean distance: - * - * dist(row, x) = [ sum((row[i] - x[i])^2) ] ^ 1/2 - * - * Computations are performed on the non-zeros only. - * The square root is optional, controlled by parameter take_root. - * - * Non-mutating, O(ncols + nnzr) - * - * @param row [0 <= size_type < nrows] index of row to compute distance from - * @param x [InputIterator] vector of the squared distances to x - * @param take_root [bool (false)] whether to return the square root of the distance - * or the exact value (the square root of the sum of the sqaures). Default is to - * return the square of the distance. - * @retval [value_type] distance from x to row of index 'row' - * - * @b Exceptions: - * @li row < 0 || row >= nrows (check) - */ - template - inline value_type - rowL2Dist(size_type row, InputIterator x, bool take_root =false) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::rowL2Dist(): " - << "No vector stored yet"; - - NTA_ASSERT(row >= 0 && row < this->nRows()) - << "NearestNeighbor::rowL2Dist(): " - << "Invalid row index: " << row - << " - Should be >= 0 and < nrows = " << this->nRows(); - } + return one_row_dist_1(row, x, Lp0()); + } - return one_row_dist_2(row, x, Lp2(), take_root); + //-------------------------------------------------------------------------------- + /** + * Computes the distance between vector x and a given row + * of this NearestNeighbor, using the L1 (Manhattan) distance: + * + * dist(row, x) = sum(| row[i] - x[i] |) + * + * Computations are performed on the non-zeros only. + * + * Non-mutating, O(nnzr) + * + * @param row [0 <= size_type < nrows] index of row to compute distance from + * @param x [InputIterator] x vector + * @retval [value_type] distance from x to row of index 'row' + * + * @b Exceptions: + * @li row < 0 || row >= nrows (check) + */ + template + inline value_type rowL1Dist(size_type row, InputIterator x) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::rowL1Dist(): " + << "No vector stored yet"; + + NTA_ASSERT(row >= 0 && row < this->nRows()) + << "NearestNeighbor::rowL1Dist(): " + << "Invalid row index: " << row + << " - Should be >= 0 and < nrows = " << this->nRows(); } - //-------------------------------------------------------------------------------- - /** - * Computes the distance between vector x and a given row - * of this NearestNeighbor, using the Lmax distance: - * - * dist(row, x) = max(| row[i] - x[i] |) - * - * Computations are performed on the non-zeros only. - * - * Non-mutating, O(nnzr) - * - * @param row [0 <= size_type < nrows] index of row to compute distance from - * @param x [InputIterator] x vector - * @retval [value_type] distance from x to row of index 'row' - * - * @b Exceptions: - * @li row < 0 || row >= nrows (check) - */ - template - inline value_type rowLMaxDist(size_type row, InputIterator x) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::rowLMaxDist(): " - << "No vector stored yet"; + return one_row_dist_1(row, x, Lp1()); + } - assert_valid_row_(row, "rowLMaxDist"); - } // End pre-conditions - - return one_row_dist_1(row, x, LpMax()); + //-------------------------------------------------------------------------------- + /** + * Computes the distance between vector x and a given row + * of this NearestNeighbor, using the Euclidean distance: + * + * dist(row, x) = [ sum((row[i] - x[i])^2) ] ^ 1/2 + * + * Computations are performed on the non-zeros only. + * The square root is optional, controlled by parameter take_root. + * + * Non-mutating, O(ncols + nnzr) + * + * @param row [0 <= size_type < nrows] index of row to compute distance from + * @param x [InputIterator] vector of the squared distances to x + * @param take_root [bool (false)] whether to return the square root of the + * distance or the exact value (the square root of the sum of the sqaures). + * Default is to return the square of the distance. + * @retval [value_type] distance from x to row of index 'row' + * + * @b Exceptions: + * @li row < 0 || row >= nrows (check) + */ + template + inline value_type rowL2Dist(size_type row, InputIterator x, + bool take_root = false) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::rowL2Dist(): " + << "No vector stored yet"; + + NTA_ASSERT(row >= 0 && row < this->nRows()) + << "NearestNeighbor::rowL2Dist(): " + << "Invalid row index: " << row + << " - Should be >= 0 and < nrows = " << this->nRows(); } - //-------------------------------------------------------------------------------- - /** - * Computes the distance between vector x and a given row - * of this NearestNeighbor, using the Lp distance: - * - * dist(row, x) = [ sum(|row[i] - x[i]|^p) ] ^ 1/p - * - * Computations are performed on the non-zeros only. - * The square root is optional, controlled by parameter take_root. - * - * Non-mutating. - * - * @param row [0 <= size_type < nrows] index of row to compute distance from - * @param x [InputIterator] vector of the squared distances to x - * @param take_root [bool (false)] whether to return the p-th power of the distance - * or the exact value (the p-th root of the sum of the p-powers). Default is to - * return the p-th power of the distance. - * @retval [value_type] distance from x to row of index 'row' - * - * @b Exceptions: - * @li row < 0 || row >= nrows (check) - */ - template - inline value_type - rowLpDist(value_type p, size_type row, InputIterator x, - bool take_root =false) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::rowLpDist(): " - << "No vector stored yet"; + return one_row_dist_2(row, x, Lp2(), take_root); + } - assert_valid_row_(row, "rowLpDist"); + //-------------------------------------------------------------------------------- + /** + * Computes the distance between vector x and a given row + * of this NearestNeighbor, using the Lmax distance: + * + * dist(row, x) = max(| row[i] - x[i] |) + * + * Computations are performed on the non-zeros only. + * + * Non-mutating, O(nnzr) + * + * @param row [0 <= size_type < nrows] index of row to compute distance from + * @param x [InputIterator] x vector + * @retval [value_type] distance from x to row of index 'row' + * + * @b Exceptions: + * @li row < 0 || row >= nrows (check) + */ + template + inline value_type rowLMaxDist(size_type row, InputIterator x) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::rowLMaxDist(): " + << "No vector stored yet"; + + assert_valid_row_(row, "rowLMaxDist"); + } // End pre-conditions + + return one_row_dist_1(row, x, LpMax()); + } - NTA_ASSERT(p >= (value_type)0.0) - << "NearestNeighbor::rowLpDist():" - << "Invalid value for parameter p: " << p - << " - Only positive values (p >= 0) are supported"; - } // End pre-conditions + //-------------------------------------------------------------------------------- + /** + * Computes the distance between vector x and a given row + * of this NearestNeighbor, using the Lp distance: + * + * dist(row, x) = [ sum(|row[i] - x[i]|^p) ] ^ 1/p + * + * Computations are performed on the non-zeros only. + * The square root is optional, controlled by parameter take_root. + * + * Non-mutating. + * + * @param row [0 <= size_type < nrows] index of row to compute distance from + * @param x [InputIterator] vector of the squared distances to x + * @param take_root [bool (false)] whether to return the p-th power of the + * distance or the exact value (the p-th root of the sum of the p-powers). + * Default is to return the p-th power of the distance. + * @retval [value_type] distance from x to row of index 'row' + * + * @b Exceptions: + * @li row < 0 || row >= nrows (check) + */ + template + inline value_type rowLpDist(value_type p, size_type row, InputIterator x, + bool take_root = false) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::rowLpDist(): " + << "No vector stored yet"; + + assert_valid_row_(row, "rowLpDist"); + + NTA_ASSERT(p >= (value_type)0.0) + << "NearestNeighbor::rowLpDist():" + << "Invalid value for parameter p: " << p + << " - Only positive values (p >= 0) are supported"; + } // End pre-conditions + + if (p == (value_type)0.0) + return rowL0Dist(row, x); + + if (p == (value_type)1.0) + return rowL1Dist(row, x); + + if (p == (value_type)2.0) + return rowL2Dist(row, x, take_root); + + return one_row_dist_2(row, x, Lp(p), take_root); + } - if (p == (value_type)0.0) - return rowL0Dist(row, x); + //-------------------------------------------------------------------------------- + /** + * Computes the distance between vector x and all the rows + * of this NearestNeighbor, using the L0 (Hamming) distance: + * + * dist(row, x) = sum(| row[i] - x[i] | > Epsilon) + * + * Computations are performed on the non-zeros only. + * + * Non-mutating, O(nnzr) + * + * @param x [InputIterator] x vector + * @param y [OutputIterator] vector of distances of x to each row + * + * @b Exceptions: + * @li None + */ + template + inline void L0Dist(InputIterator x, OutputIterator y) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::L0Dist(): " + << "No vector stored yet"; + } - if (p == (value_type)1.0) - return rowL1Dist(row, x); + const size_type nrows = this->nRows(); + Lp0 f; - if (p == (value_type)2.0) - return rowL2Dist(row, x, take_root); + for (size_type i = 0; i != nrows; ++i, ++y) + *y = one_row_dist_1(i, x, f); + } - return one_row_dist_2(row, x, Lp(p), take_root); + //-------------------------------------------------------------------------------- + /** + * Computes the distance between vector x and all the rows + * of this NearestNeighbor, using the L1 (Manhattan) distance: + * + * dist(row, x) = sum(| row[i] - x[i] |) + * + * Computations are performed on the non-zeros only. + * + * Non-mutating, O(nnzr) + * + * @param x [InputIterator] x vector + * @param y [OutputIterator] vector of distances of x to each row + * + * @b Exceptions: + * @li None + */ + template + inline void L1Dist(InputIterator x, OutputIterator y) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::L1Dist(): " + << "No vector stored yet"; } - - //-------------------------------------------------------------------------------- - /** - * Computes the distance between vector x and all the rows - * of this NearestNeighbor, using the L0 (Hamming) distance: - * - * dist(row, x) = sum(| row[i] - x[i] | > Epsilon) - * - * Computations are performed on the non-zeros only. - * - * Non-mutating, O(nnzr) - * - * @param x [InputIterator] x vector - * @param y [OutputIterator] vector of distances of x to each row - * - * @b Exceptions: - * @li None - */ - template - inline void L0Dist(InputIterator x, OutputIterator y) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::L0Dist(): " - << "No vector stored yet"; - } - const size_type nrows = this->nRows(); - Lp0 f; + const size_type nrows = this->nRows(), ncols = this->nCols(); + value_type s = 0.0; + Lp1 f; + + InputIterator x_ptr = x; + for (size_type j = 0; j != ncols; ++j, ++x_ptr) + this->nzb_[j] = f(s, *x_ptr); + + for (size_type i = 0; i != nrows; ++i, ++y) { + size_type *ind = this->ind_[i], *ind_end = ind + this->nnzr_[i]; + value_type *nz = this->nz_[i], d = s; + for (; ind != ind_end; ++ind, ++nz) { + size_type j = *ind; + f(d, x[j] - *nz); + d -= this->nzb_[j]; + } + if (d <= (value_type)0) + d = (value_type)0; + *y = d; + } + } - for (size_type i = 0; i != nrows; ++i, ++y) - *y = one_row_dist_1(i, x, f); + //-------------------------------------------------------------------------------- + /** + * Computes the Euclidean distance between vector x + * and each row of this NearestNeighbor. + * + * Non-mutating, O(nnz) + * + * @param x [InputIterator] vector to compute the distance from + * @param y [OutputIterator] vector of the distances to x + * @param take_root [bool (false)] whether to return the square root of the + * distances + * or their exact value (the square root of the sum of the squares). Default + * is to return the square of the distances. + * + * @b Exceptions: + * @li None + */ + template + inline void L2Dist(InputIterator x, OutputIterator y, + bool take_root = false) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::L2Dist(): " + << "No vector stored yet"; } - //-------------------------------------------------------------------------------- - /** - * Computes the distance between vector x and all the rows - * of this NearestNeighbor, using the L1 (Manhattan) distance: - * - * dist(row, x) = sum(| row[i] - x[i] |) - * - * Computations are performed on the non-zeros only. - * - * Non-mutating, O(nnzr) - * - * @param x [InputIterator] x vector - * @param y [OutputIterator] vector of distances of x to each row - * - * @b Exceptions: - * @li None - */ - template - inline void L1Dist(InputIterator x, OutputIterator y) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::L1Dist(): " - << "No vector stored yet"; - } + all_rows_dist_(x, y, Lp2(), take_root); + } - const size_type nrows = this->nRows(), ncols = this->nCols(); - value_type s = 0.0; - Lp1 f; - - InputIterator x_ptr = x; - for (size_type j = 0; j != ncols; ++j, ++x_ptr) - this->nzb_[j] = f(s, *x_ptr); - - for (size_type i = 0; i != nrows; ++i, ++y) { - size_type *ind = this->ind_[i], *ind_end = ind + this->nnzr_[i]; - value_type *nz = this->nz_[i], d = s; - for (; ind != ind_end; ++ind, ++nz) { - size_type j = *ind; - f(d, x[j] - *nz); - d -= this->nzb_[j]; - } - if (d <= (value_type) 0) - d = (value_type) 0; - *y = d; - } + //-------------------------------------------------------------------------------- + /** + * Computes the Lmax distance between vector x and each row of this + * NearestNeighbor. + * + * Non-mutating, O(nrows*ncols) + * + * @param x [InputIterator] vector to compute the distance from + * @param y [OutputIterator] vector of the distances to x + * + * @b Exceptions: + * @li None + */ + template + inline void LMaxDist(InputIterator x, OutputIterator y) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::LMaxDist(): " + << "No vector stored yet"; } - //-------------------------------------------------------------------------------- - /** - * Computes the Euclidean distance between vector x - * and each row of this NearestNeighbor. - * - * Non-mutating, O(nnz) - * - * @param x [InputIterator] vector to compute the distance from - * @param y [OutputIterator] vector of the distances to x - * @param take_root [bool (false)] whether to return the square root of the - * distances - * or their exact value (the square root of the sum of the squares). Default is - * to - * return the square of the distances. - * - * @b Exceptions: - * @li None - */ - template - inline void L2Dist(InputIterator x, OutputIterator y, bool take_root =false) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::L2Dist(): " - << "No vector stored yet"; - } + const size_type nrows = this->nRows(); + LpMax f; - all_rows_dist_(x, y, Lp2(), take_root); + for (size_type i = 0; i != nrows; ++i, ++y) + *y = one_row_dist_1(i, x, f); + } + + //-------------------------------------------------------------------------------- + /** + * Computes the p-th power of the Lp distance between vector x + * and each row of this NearestNeighbor. Puts the result + * in vector y. If take_root is true, we take the p-th root of the sums, + * if not, y will contain the sum of the p-th powers only. + * + * Non-mutating, O(nnz) + * + * @param x [InputIterator] vector to compute the distance from + * @param y [OutputIterator] vector of the squared distances to x + * @param take_root [bool (false)] whether to return the p-th power of the + * distances or their exact value (the p-th root of the sum of the p-powers). + * Default is to return the p-th power of the distances. + * + * @b Exceptions: + * @li None + */ + template + inline void LpDist(value_type p, InputIterator x, OutputIterator y, + bool take_root = false) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::LpDist(): " + << "No vector stored yet"; + + NTA_ASSERT(p >= (value_type)0.0) + << "NearestNeighbor::LpDist():" + << "Invalid value for parameter p: " << p + << " - Only positive values (p >= 0) are supported"; } - - //-------------------------------------------------------------------------------- - /** - * Computes the Lmax distance between vector x and each row of this NearestNeighbor. - * - * Non-mutating, O(nrows*ncols) - * - * @param x [InputIterator] vector to compute the distance from - * @param y [OutputIterator] vector of the distances to x - * - * @b Exceptions: - * @li None - */ - template - inline void LMaxDist(InputIterator x, OutputIterator y) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::LMaxDist(): " - << "No vector stored yet"; - } - const size_type nrows = this->nRows(); - LpMax f; - - for (size_type i = 0; i != nrows; ++i, ++y) - *y = one_row_dist_1(i, x, f); + if (p == (value_type)0.0) { + L0Dist(x, y); + return; } - //-------------------------------------------------------------------------------- - /** - * Computes the p-th power of the Lp distance between vector x - * and each row of this NearestNeighbor. Puts the result - * in vector y. If take_root is true, we take the p-th root of the sums, - * if not, y will contain the sum of the p-th powers only. - * - * Non-mutating, O(nnz) - * - * @param x [InputIterator] vector to compute the distance from - * @param y [OutputIterator] vector of the squared distances to x - * @param take_root [bool (false)] whether to return the p-th power of the distances - * or their exact value (the p-th root of the sum of the p-powers). Default is to - * return the p-th power of the distances. - * - * @b Exceptions: - * @li None - */ - template - inline void - LpDist(value_type p, - InputIterator x, OutputIterator y, bool take_root=false) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::LpDist(): " - << "No vector stored yet"; - - NTA_ASSERT(p >= (value_type)0.0) - << "NearestNeighbor::LpDist():" - << "Invalid value for parameter p: " << p - << " - Only positive values (p >= 0) are supported"; - } - - if (p == (value_type)0.0) { - L0Dist(x, y); - return; - } + if (p == (value_type)1.0) { + L1Dist(x, y); + return; + } - if (p == (value_type)1.0) { - L1Dist(x, y); - return; - } + if (p == (value_type)2.0) { + L2Dist(x, y, take_root); + return; + } - if (p == (value_type)2.0) { - L2Dist(x, y, take_root); - return; - } + all_rows_dist_(x, y, Lp(p), take_root); + } - all_rows_dist_(x, y, Lp(p), take_root); + //-------------------------------------------------------------------------------- + /** + * Finds the row nearest to x, where nearest is defined as the row which has + * the smallest L0 (Hamming) distance to x. If k > 1, finds the k nearest rows + * to x. + * + * Non-mutating, O(nnz) + complexity of partial sort up to k if k > 1. + * + * @param x [InputIterator] vector to compute the distance from + * @param nn [OutputIterator] the indices and distances of the nearest rows + * (pairs) + * @param k [size_type > 0, (1)] the number of nearest rows to retrieve + * + * @b Exceptions: + * @li If k < 1. + */ + template + inline void L0Nearest(InputIterator x, OutputIterator nn, + size_type k = 1) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::L0Nearest(): " + << "No vector stored yet"; + + NTA_ASSERT(k >= 1) << "NearestNeighbor::L0Nearest():" + << "Invalid number of nearest rows: " << k + << " - Should be >= 1, default is 1"; } - //-------------------------------------------------------------------------------- - /** - * Finds the row nearest to x, where nearest is defined as the row which has the - * smallest L0 (Hamming) distance to x. If k > 1, finds the k nearest rows to x. - * - * Non-mutating, O(nnz) + complexity of partial sort up to k if k > 1. - * - * @param x [InputIterator] vector to compute the distance from - * @param nn [OutputIterator] the indices and distances of the nearest rows (pairs) - * @param k [size_type > 0, (1)] the number of nearest rows to retrieve - * - * @b Exceptions: - * @li If k < 1. - */ - template - inline void - L0Nearest(InputIterator x, OutputIterator nn, size_type k =1) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::L0Nearest(): " - << "No vector stored yet"; - - NTA_ASSERT(k >= 1) - << "NearestNeighbor::L0Nearest():" - << "Invalid number of nearest rows: " << k - << " - Should be >= 1, default is 1"; - } + k_nearest_(x, nn, Lp0(), k); + } - k_nearest_(x, nn, Lp0(), k); + //-------------------------------------------------------------------------------- + /** + * Finds the row nearest to x, where nearest is defined as the row which has + * the smallest L1 (Manhattan) distance to x. If k > 1, finds the k nearest + * rows to x. + * + * Non-mutating, O(nnz) + complexity of partial sort up to k if k > 1. + * + * @param x [InputIterator] vector to compute the distance from + * @param nn [OutputIterator] the indices and distances of the nearest rows + * (pairs) + * @param k [size_type > 0, (1)] the number of nearest rows to retrieve + * + * @b Exceptions: + * @li If k < 1. + */ + template + inline void L1Nearest(InputIterator x, OutputIterator nn, + size_type k = 1) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::L1Nearest(): " + << "No vector stored yet"; + + NTA_ASSERT(k >= 1) << "NearestNeighbor::L1Nearest():" + << "Invalid number of nearest rows: " << k + << " - Should be >= 1, default is 1"; } - //-------------------------------------------------------------------------------- - /** - * Finds the row nearest to x, where nearest is defined as the row which has the - * smallest L1 (Manhattan) distance to x. If k > 1, finds the k nearest rows to x. - * - * Non-mutating, O(nnz) + complexity of partial sort up to k if k > 1. - * - * @param x [InputIterator] vector to compute the distance from - * @param nn [OutputIterator] the indices and distances of the nearest rows (pairs) - * @param k [size_type > 0, (1)] the number of nearest rows to retrieve - * - * @b Exceptions: - * @li If k < 1. - */ - template - inline void - L1Nearest(InputIterator x, OutputIterator nn, size_type k =1) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::L1Nearest(): " - << "No vector stored yet"; - - NTA_ASSERT(k >= 1) - << "NearestNeighbor::L1Nearest():" - << "Invalid number of nearest rows: " << k - << " - Should be >= 1, default is 1"; - } + k_nearest_(x, nn, Lp1(), k); + } - k_nearest_(x, nn, Lp1(), k); + //-------------------------------------------------------------------------------- + /** + * Finds the row nearest to x, where nearest is defined as the row which has + * the smallest L2 (Euclidean) distance to x. If k > 1, finds the k nearest + * rows to x. + * + * Non-mutating, O(nnz) + complexity of partial sort up to k if k > 1. + * + * @param x [InputIterator] vector to compute the distance from + * @param nn [OutputIterator] the indices and distances of the nearest rows + * (pairs) + * @param k [size_type > 0, (1)] the number of nearest rows to retrieve + * @param take_root [bool (false)] whether to return the square root of the + * distances or their exact value (the square root of the sum of the squares). + * Default is to return the square of the distances. + * + * @b Exceptions: + * @li If k < 1. + */ + template + inline void L2Nearest(InputIterator x, OutputIterator nn, size_type k = 1, + bool take_root = false) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::L2Nearest(): " + << "No vector stored yet"; + + NTA_ASSERT(k >= 1) << "NearestNeighbor::L2Nearest():" + << "Invalid number of nearest rows: " << k + << " - Should be >= 1, default is 1"; } - //-------------------------------------------------------------------------------- - /** - * Finds the row nearest to x, where nearest is defined as the row which has the - * smallest L2 (Euclidean) distance to x. If k > 1, finds the k nearest rows to x. - * - * Non-mutating, O(nnz) + complexity of partial sort up to k if k > 1. - * - * @param x [InputIterator] vector to compute the distance from - * @param nn [OutputIterator] the indices and distances of the nearest rows (pairs) - * @param k [size_type > 0, (1)] the number of nearest rows to retrieve - * @param take_root [bool (false)] whether to return the square root of the distances - * or their exact value (the square root of the sum of the squares). Default is to - * return the square of the distances. - * - * @b Exceptions: - * @li If k < 1. - */ - template - inline void - L2Nearest(InputIterator x, OutputIterator nn, size_type k =1, - bool take_root =false) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::L2Nearest(): " - << "No vector stored yet"; - - NTA_ASSERT(k >= 1) - << "NearestNeighbor::L2Nearest():" - << "Invalid number of nearest rows: " << k - << " - Should be >= 1, default is 1"; - } + k_nearest_(x, nn, Lp2(), k, take_root); + } - k_nearest_(x, nn, Lp2(), k, take_root); + //-------------------------------------------------------------------------------- + /** + * Finds the row nearest to x, where nearest is defined as the row which has + * the smallest Lmax distance to x. If k > 1, finds the k nearest rows to x. + * + * Non-mutating, O(nnz) + complexity of partial sort up to k if k > 1. + * + * @param x [InputIterator] vector to compute the distance from + * @param nn [OutputIterator] the indices and distances of the nearest rows + * (pairs) + * @param k [size_type > 0, (1)] the number of nearest rows to retrieve + * + * @b Exceptions: + * @li If k < 1. + */ + template + inline void LMaxNearest(InputIterator x, OutputIterator nn, + size_type k = 1) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::LMaxNearest(): " + << "No vector stored yet"; + + NTA_ASSERT(k >= 1) << "NearestNeighbor::LMaxNearest():" + << "Invalid number of nearest rows: " << k + << " - Should be >= 1, default is 1"; } - //-------------------------------------------------------------------------------- - /** - * Finds the row nearest to x, where nearest is defined as the row which has the - * smallest Lmax distance to x. If k > 1, finds the k nearest rows to x. - * - * Non-mutating, O(nnz) + complexity of partial sort up to k if k > 1. - * - * @param x [InputIterator] vector to compute the distance from - * @param nn [OutputIterator] the indices and distances of the nearest rows (pairs) - * @param k [size_type > 0, (1)] the number of nearest rows to retrieve - * - * @b Exceptions: - * @li If k < 1. - */ - template - inline void - LMaxNearest(InputIterator x, OutputIterator nn, size_type k =1) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::LMaxNearest(): " - << "No vector stored yet"; - - NTA_ASSERT(k >= 1) - << "NearestNeighbor::LMaxNearest():" - << "Invalid number of nearest rows: " << k - << " - Should be >= 1, default is 1"; - } - - std::vector b(this->nRows()); - LMaxDist(x, b.begin()); - partial_sort_2nd(k, b, nn, std::less()); + std::vector b(this->nRows()); + LMaxDist(x, b.begin()); + partial_sort_2nd(k, b, nn, std::less()); + } + + //-------------------------------------------------------------------------------- + /** + * Finds the row nearest to x, where nearest is defined as the row which has + * the smallest Lp distance to x. If k > 1, finds the k nearest rows to x. + * + * Non-mutating, O(nnz) + complexity of partial sort up to k if k > 1. + * + * @param x [InputIterator] vector to compute the distance from + * @param nn [OutputIterator1] the indices and distances of the nearest rows + * (pairs) + * @param k [size_type > 0, (1)] the number of nearest rows to retrieve + * @param take_root [bool (false)] whether to return the p-th power of the + * distances or their exact value (the p-th root of the sum of the p-powers). + * Default is to return the p-th power of the distances. + * + * @b Exceptions: + * @li If p < 0. + * @li If k < 1. + */ + template + inline void LpNearest(value_type p, InputIterator x, OutputIterator nn, + size_type k = 1, bool take_root = false) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::LpNearest(): " + << "No vector stored yet"; + + NTA_ASSERT(p >= (value_type)0.0) + << "NearestNeighbor::LpNearest():" + << "Invalid value for parameter p: " << p + << " - Only positive values (p >= 0) are supported"; + + NTA_ASSERT(k >= 1) << "NearestNeighbor::LpNearest():" + << "Invalid number of nearest rows: " << k + << " - Should be >= 1, default is 1"; } - //-------------------------------------------------------------------------------- - /** - * Finds the row nearest to x, where nearest is defined as the row which has the - * smallest Lp distance to x. If k > 1, finds the k nearest rows to x. - * - * Non-mutating, O(nnz) + complexity of partial sort up to k if k > 1. - * - * @param x [InputIterator] vector to compute the distance from - * @param nn [OutputIterator1] the indices and distances of the nearest rows (pairs) - * @param k [size_type > 0, (1)] the number of nearest rows to retrieve - * @param take_root [bool (false)] whether to return the p-th power of the distances - * or their exact value (the p-th root of the sum of the p-powers). Default is to - * return the p-th power of the distances. - * - * @b Exceptions: - * @li If p < 0. - * @li If k < 1. - */ - template - inline void - LpNearest(value_type p, InputIterator x, OutputIterator nn, - size_type k =1, bool take_root =false) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::LpNearest(): " - << "No vector stored yet"; - - NTA_ASSERT(p >= (value_type)0.0) - << "NearestNeighbor::LpNearest():" - << "Invalid value for parameter p: " << p - << " - Only positive values (p >= 0) are supported"; - - NTA_ASSERT(k >= 1) - << "NearestNeighbor::LpNearest():" - << "Invalid number of nearest rows: " << k - << " - Should be >= 1, default is 1"; - } + if (p == (value_type)0.0) { + L0Nearest(x, nn, k); + return; + } - if (p == (value_type)0.0) { - L0Nearest(x, nn, k); - return; - } + if (p == (value_type)1.0) { + L1Nearest(x, nn, k); + return; + } - if (p == (value_type)1.0) { - L1Nearest(x, nn, k); - return; - } + if (p == (value_type)2.0) { + L2Nearest(x, nn, k, take_root); + return; + } - if (p == (value_type)2.0) { - L2Nearest(x, nn, k, take_root); - return; - } + k_nearest_(x, nn, Lp(p), k, take_root); + } - k_nearest_(x, nn, Lp(p), k, take_root); - } + //-------------------------------------------------------------------------------- + template + inline void LpNearest(value_type p, InputIterator1 ind, + InputIterator1 ind_end, InputIterator2 nz, + OutputIterator nn, size_type k = 1, + bool take_root = false) const { + std::vector x(this->nCols()); + to_dense(ind, ind_end, nz, nz + (ind_end - ind), x.begin(), x.end()); + LpNearest(p, x.begin(), nn, k, take_root); + } - //-------------------------------------------------------------------------------- - template - inline void - LpNearest(value_type p, InputIterator1 ind, InputIterator1 ind_end, - InputIterator2 nz, OutputIterator nn, - size_type k =1, bool take_root =false) const + //-------------------------------------------------------------------------------- + /** + * Computes the "nearest-dot" distance between vector x + * and each row in this NearestNeighbor. Returns the index of + * the row that maximizes the dot product as well as the + * value of this dot-product. + * + * Note that this equivalent to L2Nearest if all the vectors are + * normalized. + * + * Non-mutating, O(nnz) + * + * @param x [InputIterator] vector to compute the distance from + * @retval [std::pair] index of the row nearest + * to x, and value of the distance between x and that row + * + * @b Exceptions: + * @li None + */ + template + inline std::pair dotNearest(InputIterator x) const { { - std::vector x(this->nCols()); - to_dense(ind, ind_end, nz, nz + (ind_end - ind), x.begin(), x.end()); - LpNearest(p, x.begin(), nn, k, take_root); + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::dotNearest(): " + << "No vector stored yet"; } - //-------------------------------------------------------------------------------- - /** - * Computes the "nearest-dot" distance between vector x - * and each row in this NearestNeighbor. Returns the index of - * the row that maximizes the dot product as well as the - * value of this dot-product. - * - * Note that this equivalent to L2Nearest if all the vectors are - * normalized. - * - * Non-mutating, O(nnz) - * - * @param x [InputIterator] vector to compute the distance from - * @retval [std::pair] index of the row nearest - * to x, and value of the distance between x and that row - * - * @b Exceptions: - * @li None - */ - template - inline std::pair dotNearest(InputIterator x) const - { - { - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::dotNearest(): " - << "No vector stored yet"; - } - - size_type i, k, nnzr, *ind, end, nrows = this->nRows(); - value_type val, *nz; + size_type i, k, nnzr, *ind, end, nrows = this->nRows(); + value_type val, *nz; - size_type arg_i = 0; - value_type max_v = - std::numeric_limits::max(); + size_type arg_i = 0; + value_type max_v = -std::numeric_limits::max(); - for (i = 0; i != nrows; ++i) { + for (i = 0; i != nrows; ++i) { - val = 0; - nnzr = this->nnzr_[i]; - ind = this->ind_[i]; - nz = this->nz_[i]; - end = 4 * (nnzr / 4); + val = 0; + nnzr = this->nnzr_[i]; + ind = this->ind_[i]; + nz = this->nz_[i]; + end = 4 * (nnzr / 4); - for (k = 0; k != end; k += 4) - val += nz[k] * x[ind[k]] + nz[k+1] * x[ind[k+1]] - + nz[k+2] * x[ind[k+2]] + nz[k+3] * x[ind[k+3]]; + for (k = 0; k != end; k += 4) + val += nz[k] * x[ind[k]] + nz[k + 1] * x[ind[k + 1]] + + nz[k + 2] * x[ind[k + 2]] + nz[k + 3] * x[ind[k + 3]]; - for (k = end; k != nnzr; ++k) - val += nz[k] * x[ind[k]]; + for (k = end; k != nnzr; ++k) + val += nz[k] * x[ind[k]]; - if (val > max_v) { - arg_i = i; - max_v = val; - } + if (val > max_v) { + arg_i = i; + max_v = val; } - - return std::make_pair(arg_i, max_v); } - //-------------------------------------------------------------------------------- - /** - * EXPERIMENTAL - * This method computes the std dev of each component of the vectors, and - * scales them by that standard deviation before computing the norms. - * Distance values are distorted by the standard deviation. - */ - std::vector stddev_; - - template - inline void - LpNearest_w(value_type p, InputIterator x, OutputIterator nn, - size_type k =1, bool take_root =false) - { - { // Pre-conditions - NTA_ASSERT(p >= (value_type)0.0) - << "NearestNeighbor::LpNearest_w():" - << "Invalid value for parameter p: " << p - << " - Only positive values (p >= 0) are supported"; - - NTA_ASSERT(k >= 1) - << "NearestNeighbor::LpNearest_w():" - << "Invalid number of nearest rows: " << k - << " - Should be >= 1, default is 1"; - } + return std::make_pair(arg_i, max_v); + } - const size_type nrows = this->nRows(), ncols = this->nCols(); - std::vector e(ncols, 0), e2(ncols, 0); + //-------------------------------------------------------------------------------- + /** + * EXPERIMENTAL + * This method computes the std dev of each component of the vectors, and + * scales them by that standard deviation before computing the norms. + * Distance values are distorted by the standard deviation. + */ + std::vector stddev_; + + template + inline void LpNearest_w(value_type p, InputIterator x, OutputIterator nn, + size_type k = 1, bool take_root = false) { + { // Pre-conditions + NTA_ASSERT(p >= (value_type)0.0) + << "NearestNeighbor::LpNearest_w():" + << "Invalid value for parameter p: " << p + << " - Only positive values (p >= 0) are supported"; + + NTA_ASSERT(k >= 1) << "NearestNeighbor::LpNearest_w():" + << "Invalid number of nearest rows: " << k + << " - Should be >= 1, default is 1"; + } - if (stddev_.empty()) { + const size_type nrows = this->nRows(), ncols = this->nCols(); + std::vector e(ncols, 0), e2(ncols, 0); - stddev_.resize(ncols, 0); - - for (size_type i = 0; i != nrows; ++i) { - size_type *ind = this->ind_[i], *ind_end = ind + this->nnzr_[i]; - value_type *nz = this->nz_[i]; - while (ind != ind_end) { - size_type idx = *ind++; - value_type val = *nz++; - e[idx] += val; - e2[idx] += val * val; - } - } + if (stddev_.empty()) { - nupic::Sqrt sf; + stddev_.resize(ncols, 0); - for (size_type j = 0; j != ncols; ++j) - stddev_[j] = sf((e2[j] - e[j]*e[j]/nrows) / (nrows-1)); - } - - Lp f(p); - value_type Sp_x = 0; - for (size_type j = 0; j != ncols; ++j) - this->nzb_[j] = f(Sp_x, x[j]/stddev_[j]); - - std::vector b(nrows); - for (size_type i = 0; i != nrows; ++i) { - size_type *ind = this->ind_[i], *ind_end = ind + this->nnzr_[i]; - value_type *nz = this->nz_[i], d = Sp_x; - while (ind != ind_end) { - size_type j = *ind++; - f(d, (*nz++ - x[j])/stddev_[j]); - d -= this->nzb_[j]; - } - if (d <= (value_type) 0) - d = (value_type) 0; - b[i] = d; + size_type *ind = this->ind_[i], *ind_end = ind + this->nnzr_[i]; + value_type *nz = this->nz_[i]; + while (ind != ind_end) { + size_type idx = *ind++; + value_type val = *nz++; + e[idx] += val; + e2[idx] += val * val; + } } - partial_sort_2nd(k, b, nn, std::less()); - } + nupic::Sqrt sf; - //-------------------------------------------------------------------------------- - // RBF - //-------------------------------------------------------------------------------- - template - inline void rbf(value_type p, value_type k, - InputIterator in_begin, OutputIterator out_begin) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::rbf(): " - << "No vector stored yet"; - - NTA_ASSERT(p >= (value_type)0.0) - << "NearestNeighbor::rbf():" - << "Invalid value for parameter p: " << p - << " - Only positive values (p >= 0) are supported"; - } // End pre-conditions - - LpDist(p, in_begin, out_begin, false); - - range_exp(k, out_begin, out_begin + this->nRows()); + for (size_type j = 0; j != ncols; ++j) + stddev_[j] = sf((e2[j] - e[j] * e[j] / nrows) / (nrows - 1)); } - //-------------------------------------------------------------------------------- - // Proj nearest - //-------------------------------------------------------------------------------- - private: - template - inline void - proj_all_rows_dist_(InputIterator x, OutputIterator y, F f, - bool take_root =false) const - { - const size_type nrows = this->nRows(); - OutputIterator y_begin = y, y_end = y_begin + nrows; - - for (size_type row = 0; row != nrows; ++row, ++y) { - size_type *ind = this->ind_[row]; - size_type *ind_end = ind + this->nNonZerosOnRow(row); - value_type *nz = this->nz_[row], val = 0; - for (; ind != ind_end; ++ind, ++nz) - f(val, *nz - *(x + *ind)); - *y = val; - } + Lp f(p); + value_type Sp_x = 0; + for (size_type j = 0; j != ncols; ++j) + this->nzb_[j] = f(Sp_x, x[j] / stddev_[j]); + + std::vector b(nrows); - if (take_root) { - for (y = y_begin; y != y_end; ++y) - *y = f.root(*y); + for (size_type i = 0; i != nrows; ++i) { + size_type *ind = this->ind_[i], *ind_end = ind + this->nnzr_[i]; + value_type *nz = this->nz_[i], d = Sp_x; + while (ind != ind_end) { + size_type j = *ind++; + f(d, (*nz++ - x[j]) / stddev_[j]); + d -= this->nzb_[j]; } + if (d <= (value_type)0) + d = (value_type)0; + b[i] = d; } - //-------------------------------------------------------------------------------- - public: - template - inline void - projLpDist(value_type p, InputIterator x, OutputIterator y, - bool take_root =false) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::projLpDist(): " - << "No vector stored yet"; - - NTA_ASSERT(p >= (value_type)0.0) - << "NearestNeighbor::projLpDist():" - << "Invalid value for parameter p: " << p - << " - Only positive values (p >= 0) are supported"; - } // End pre-conditions - - if (p == (value_type) 0.0) { - proj_all_rows_dist_(x, y, Lp0(), take_root); - - } else if (p == (value_type) 1.0) { - proj_all_rows_dist_(x, y, Lp1(), take_root); - - } else if (p == (value_type) 2.0) { - proj_all_rows_dist_(x, y, Lp2(), take_root); - - } else { - proj_all_rows_dist_(x, y, Lp(p), take_root); - } - } - - //-------------------------------------------------------------------------------- - /** - * Finds the k-nearest neighbors to x, ignoring the zeros of each vector - * stored in this matrix. - */ - template - inline void - projLpNearest(value_type p, InputIterator x, OutputIterator nn, - size_type k =1, bool take_root =false) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::projLpNearest(): " - << "No vector stored yet"; - - NTA_ASSERT(p >= (value_type)0.0) - << "NearestNeighbor::projLpNearest():" - << "Invalid value for parameter p: " << p - << " - Only positive values (p >= 0) are supported"; - - NTA_ASSERT(k >= 1) - << "NearestNeighbor::projLpNearest():" - << "Invalid number of nearest rows: " << k - << " - Should be >= 1, default is 1"; - } // End pre-conditions - - std::vector b(this->nRows()); - projLpDist(p, x, b.begin(), take_root); - partial_sort_2nd(k, b, nn, std::less()); - } + partial_sort_2nd(k, b, nn, std::less()); + } - //-------------------------------------------------------------------------------- - template - inline void - projLpNearest(value_type p, InputIterator1 ind, InputIterator1 ind_end, - InputIterator2 nz, OutputIterator nn, - size_type k =1, bool take_root =false) const - { - std::vector x(this->nCols()); - to_dense(ind, ind_end, nz, nz + (ind_end - ind), x.begin(), x.end()); - projLpNearest(p, x.begin(), nn, k, take_root); - } + //-------------------------------------------------------------------------------- + // RBF + //-------------------------------------------------------------------------------- + template + inline void rbf(value_type p, value_type k, InputIterator in_begin, + OutputIterator out_begin) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::rbf(): " + << "No vector stored yet"; - //-------------------------------------------------------------------------------- - template - inline void projRbf(value_type p, value_type k, - InputIterator in_begin, OutputIterator out_begin) const - { - { // Pre-conditions - NTA_ASSERT(this->nRows() > 0) - << "NearestNeighbor::projRbf(): " - << "No vector stored yet"; + NTA_ASSERT(p >= (value_type)0.0) + << "NearestNeighbor::rbf():" + << "Invalid value for parameter p: " << p + << " - Only positive values (p >= 0) are supported"; + } // End pre-conditions - NTA_ASSERT(p >= (value_type)0.0) - << "NearestNeighbor::projRbf():" - << "Invalid value for parameter p: " << p - << " - Only positive values (p >= 0) are supported"; - } // End pre-conditions + LpDist(p, in_begin, out_begin, false); - projLpDist(p, in_begin, out_begin, false); + range_exp(k, out_begin, out_begin + this->nRows()); + } - range_exp(k, out_begin, out_begin + this->nRows()); + //-------------------------------------------------------------------------------- + // Proj nearest + //-------------------------------------------------------------------------------- +private: + template + inline void proj_all_rows_dist_(InputIterator x, OutputIterator y, F f, + bool take_root = false) const { + const size_type nrows = this->nRows(); + OutputIterator y_begin = y, y_end = y_begin + nrows; + + for (size_type row = 0; row != nrows; ++row, ++y) { + size_type *ind = this->ind_[row]; + size_type *ind_end = ind + this->nNonZerosOnRow(row); + value_type *nz = this->nz_[row], val = 0; + for (; ind != ind_end; ++ind, ++nz) + f(val, *nz - *(x + *ind)); + *y = val; } - }; // end class NearestNeighbor + if (take_root) { + for (y = y_begin; y != y_end; ++y) + *y = f.root(*y); + } + } //-------------------------------------------------------------------------------- +public: + template + inline void projLpDist(value_type p, InputIterator x, OutputIterator y, + bool take_root = false) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::projLpDist(): " + << "No vector stored yet"; + + NTA_ASSERT(p >= (value_type)0.0) + << "NearestNeighbor::projLpDist():" + << "Invalid value for parameter p: " << p + << " - Only positive values (p >= 0) are supported"; + } // End pre-conditions + + if (p == (value_type)0.0) { + proj_all_rows_dist_(x, y, Lp0(), take_root); + + } else if (p == (value_type)1.0) { + proj_all_rows_dist_(x, y, Lp1(), take_root); + + } else if (p == (value_type)2.0) { + proj_all_rows_dist_(x, y, Lp2(), take_root); + + } else { + proj_all_rows_dist_(x, y, Lp(p), take_root); + } + } -} // end namespace nupic + //-------------------------------------------------------------------------------- + /** + * Finds the k-nearest neighbors to x, ignoring the zeros of each vector + * stored in this matrix. + */ + template + inline void projLpNearest(value_type p, InputIterator x, OutputIterator nn, + size_type k = 1, bool take_root = false) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::projLpNearest(): " + << "No vector stored yet"; + + NTA_ASSERT(p >= (value_type)0.0) + << "NearestNeighbor::projLpNearest():" + << "Invalid value for parameter p: " << p + << " - Only positive values (p >= 0) are supported"; + + NTA_ASSERT(k >= 1) << "NearestNeighbor::projLpNearest():" + << "Invalid number of nearest rows: " << k + << " - Should be >= 1, default is 1"; + } // End pre-conditions + + std::vector b(this->nRows()); + projLpDist(p, x, b.begin(), take_root); + partial_sort_2nd(k, b, nn, std::less()); + } -#endif // NTA_NEAREST_NEIGHBOR_HPP + //-------------------------------------------------------------------------------- + template + inline void projLpNearest(value_type p, InputIterator1 ind, + InputIterator1 ind_end, InputIterator2 nz, + OutputIterator nn, size_type k = 1, + bool take_root = false) const { + std::vector x(this->nCols()); + to_dense(ind, ind_end, nz, nz + (ind_end - ind), x.begin(), x.end()); + projLpNearest(p, x.begin(), nn, k, take_root); + } + //-------------------------------------------------------------------------------- + template + inline void projRbf(value_type p, value_type k, InputIterator in_begin, + OutputIterator out_begin) const { + { // Pre-conditions + NTA_ASSERT(this->nRows() > 0) << "NearestNeighbor::projRbf(): " + << "No vector stored yet"; + + NTA_ASSERT(p >= (value_type)0.0) + << "NearestNeighbor::projRbf():" + << "Invalid value for parameter p: " << p + << " - Only positive values (p >= 0) are supported"; + } // End pre-conditions + + projLpDist(p, in_begin, out_begin, false); + + range_exp(k, out_begin, out_begin + this->nRows()); + } +}; // end class NearestNeighbor +//-------------------------------------------------------------------------------- + +} // end namespace nupic + +#endif // NTA_NEAREST_NEIGHBOR_HPP diff --git a/src/nupic/math/Rotation.hpp b/src/nupic/math/Rotation.hpp index 7570836ae1..e5827c1fc7 100644 --- a/src/nupic/math/Rotation.hpp +++ b/src/nupic/math/Rotation.hpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Declarations for 2D matrix rotation by 45 degrees. */ @@ -31,52 +31,48 @@ * Used in GaborFilter */ -#define cos45 0.70710678118654746f // cos(pi/4) = 1/sqrt(2) +#define cos45 0.70710678118654746f // cos(pi/4) = 1/sqrt(2) -template -struct Rotation45 -{ +template struct Rotation45 { typedef size_t size_type; typedef T value_type; int srow_; int scol_; size_t offset_; - - inline T round(T x) { - return floor(x + 0.5); - } - /** + inline T round(T x) { return floor(x + 0.5); } + + /** * Rotate counter-clockwise by 45 degrees. * Fill in pixels in the larger, rotated version of the image. */ - inline void rotate(T* original, T* rotated, size_t nrows, size_t ncols, - size_t z) - { - offset_ = size_t(T(ncols) * cos45); // Vertical offset + inline void rotate(T *original, T *rotated, size_t nrows, size_t ncols, + size_t z) { + offset_ = size_t(T(ncols) * cos45); // Vertical offset for (int j = -1 * offset_; j != int(z - offset_); j++) { for (int i = 0; i != int(z); i++) { // Compute the nearest source pixel for this destination pixel // Multiply the destination pixel by the rotation matrix srow_ = int(round(cos45 * T(j) + cos45 * T(i))); scol_ = int(round(-1 * cos45 * T(j) + cos45 * T(i))); - if (0 <= srow_ && srow_ < int(nrows) && 0 <= scol_ && scol_ < int(ncols)) { + if (0 <= srow_ && srow_ < int(nrows) && 0 <= scol_ && + scol_ < int(ncols)) { // Copy the source pixel to the destination pixel - rotated[size_t(j + offset_) * z + i] = original[srow_ * ncols + scol_]; + rotated[size_t(j + offset_) * z + i] = + original[srow_ * ncols + scol_]; } } } } - /** + /** * Rotate clockwise by 45 degrees. * Start with the larger, rotated image, and fill in the smaller image * of the original size. */ - inline void unrotate(T* unrotated, T* rotated, size_t nrows, size_t ncols, - size_t z) - { - offset_ = size_t(T(ncols) * cos45); // Vertical offset + inline void unrotate(T *unrotated, T *rotated, size_t nrows, size_t ncols, + size_t z) { + offset_ = size_t(T(ncols) * cos45); // Vertical offset for (size_t j = 0; j != nrows; j++) { for (size_t i = 0; i != ncols; i++) { // Compute the nearest source pixel for this destination pixel @@ -91,7 +87,6 @@ struct Rotation45 } } } - }; -#endif //NTA_ROTATION_HPP +#endif // NTA_ROTATION_HPP diff --git a/src/nupic/math/SegmentMatrixAdapter.hpp b/src/nupic/math/SegmentMatrixAdapter.hpp index ee420379c6..5174b848aa 100644 --- a/src/nupic/math/SegmentMatrixAdapter.hpp +++ b/src/nupic/math/SegmentMatrixAdapter.hpp @@ -28,435 +28,382 @@ #define NTA_SEGMENT_MATRIX_ADAPTER_HPP #include -#include #include +#include namespace nupic { +/** + * A data structure that stores dendrite segments as rows in a matrix. + * The matrix itself is part of this class's public API. This class stores the + * segments for each cell, and it can get the cell for each segment. + * + * This class is focused on Python consumers. C++ consumers could easily + * accomplish all of this directly with a matrix class, but Python consumers + * need a fast way of doing segment reads and writes in batches. This class + * makes it possible to add rows in batch, maintaining mappings between cells + * and segments, and providing batch lookups on those mappings. + */ +template class SegmentMatrixAdapter { +public: + typedef typename Matrix::size_type size_type; + +public: + SegmentMatrixAdapter(size_type nCells, size_type nCols) + : matrix(0, nCols), segmentsForCell_(nCells) {} + + /** + * Get the number of cells. + */ + size_type nCells() const { return segmentsForCell_.size(); } + + /** + * Get the number of segments. + */ + size_type nSegments() const { + return cellForSegment_.size() - destroyedSegments_.size(); + } + /** - * A data structure that stores dendrite segments as rows in a matrix. - * The matrix itself is part of this class's public API. This class stores the - * segments for each cell, and it can get the cell for each segment. + * Create a segment. * - * This class is focused on Python consumers. C++ consumers could easily - * accomplish all of this directly with a matrix class, but Python consumers - * need a fast way of doing segment reads and writes in batches. This class - * makes it possible to add rows in batch, maintaining mappings between cells - * and segments, and providing batch lookups on those mappings. + * @param cell + * The cell that gets a new segment */ - template - class SegmentMatrixAdapter { - public: - typedef typename Matrix::size_type size_type; - - public: - SegmentMatrixAdapter(size_type nCells, size_type nCols) - : matrix(0, nCols), - segmentsForCell_(nCells) - { + size_type createSegment(size_type cell) { + assert_valid_cell_(cell, "createSegment"); + + if (destroyedSegments_.size() > 0) { + const size_type segment = destroyedSegments_.back(); + destroyedSegments_.pop_back(); + segmentsForCell_[cell].push_back(segment); + cellForSegment_[segment] = cell; + return segment; + } else { + const size_type segment = matrix.nRows(); + matrix.resize(matrix.nRows() + 1, matrix.nCols()); + segmentsForCell_[cell].push_back(segment); + cellForSegment_.push_back(cell); + return segment; } + } - /** - * Get the number of cells. - */ - size_type nCells() const - { - return segmentsForCell_.size(); - } + /** + * Create one segment on each of the specified cells. + * + * @param cells + * The cells that each get a new segment + * + * @param segments + * An output array with the same size as 'cells' + */ + template + void createSegments(InputIterator cells_begin, InputIterator cells_end, + OutputIterator segments_begin) { + assert_valid_cell_range_(cells_begin, cells_end, "createSegments"); + + InputIterator cell = cells_begin; + OutputIterator out = segments_begin; + + const size_type reclaimCount = std::min( + destroyedSegments_.size(), (size_t)std::distance(cell, cells_end)); + if (reclaimCount > 0) { + for (auto segment = destroyedSegments_.end() - reclaimCount; + segment != destroyedSegments_.end(); ++cell, ++out, ++segment) { + segmentsForCell_[*cell].push_back(*segment); + cellForSegment_[*segment] = *cell; + *out = *segment; + } - /** - * Get the number of segments. - */ - size_type nSegments() const - { - return cellForSegment_.size() - destroyedSegments_.size(); + destroyedSegments_.resize(destroyedSegments_.size() - reclaimCount); } - /** - * Create a segment. - * - * @param cell - * The cell that gets a new segment - */ - size_type createSegment(size_type cell) - { - assert_valid_cell_(cell, "createSegment"); - - if (destroyedSegments_.size() > 0) - { - const size_type segment = destroyedSegments_.back(); - destroyedSegments_.pop_back(); - segmentsForCell_[cell].push_back(segment); - cellForSegment_[segment] = cell; - return segment; - } - else - { - const size_type segment = matrix.nRows(); - matrix.resize(matrix.nRows() + 1, matrix.nCols()); - segmentsForCell_[cell].push_back(segment); - cellForSegment_.push_back(cell); - return segment; + const size_type newCount = std::distance(cell, cells_end); + if (newCount > 0) { + const size_type firstNewRow = matrix.nRows(); + matrix.resize(matrix.nRows() + newCount, matrix.nCols()); + cellForSegment_.reserve(cellForSegment_.size() + newCount); + + for (size_type segment = firstNewRow; cell != cells_end; + ++cell, ++out, ++segment) { + segmentsForCell_[*cell].push_back(segment); + cellForSegment_.push_back(*cell); + *out = segment; } } + } - /** - * Create one segment on each of the specified cells. - * - * @param cells - * The cells that each get a new segment - * - * @param segments - * An output array with the same size as 'cells' - */ - template - void createSegments(InputIterator cells_begin, InputIterator cells_end, - OutputIterator segments_begin) - { - assert_valid_cell_range_(cells_begin, cells_end, "createSegments"); - - InputIterator cell = cells_begin; - OutputIterator out = segments_begin; - - const size_type reclaimCount = - std::min(destroyedSegments_.size(), - (size_t)std::distance(cell, cells_end)); - if (reclaimCount > 0) - { - for (auto segment = destroyedSegments_.end() - reclaimCount; - segment != destroyedSegments_.end(); - ++cell, ++out, ++segment) - { - segmentsForCell_[*cell].push_back(*segment); - cellForSegment_[*segment] = *cell; - *out = *segment; - } - - destroyedSegments_.resize(destroyedSegments_.size() - reclaimCount); - } + /** + * Destroy a segment. Remove it from its cell and remove all of its synapses + * in the Matrix. + * + * This doesn't remove the segment's row from the Matrix, so the other + * segments' row numbers are unaffected. + * + * @param segment + * The segment to destroy + */ + void destroySegment(size_type segment) { + assert_valid_segment_(segment, "destroySegment"); - const size_type newCount = std::distance(cell, cells_end); - if (newCount > 0) - { - const size_type firstNewRow = matrix.nRows(); - matrix.resize(matrix.nRows() + newCount, matrix.nCols()); - cellForSegment_.reserve(cellForSegment_.size() + newCount); - - for (size_type segment = firstNewRow; - cell != cells_end; - ++cell, ++out, ++segment) - { - segmentsForCell_[*cell].push_back(segment); - cellForSegment_.push_back(*cell); - *out = segment; - } - } - } + matrix.setRowToZero(segment); - /** - * Destroy a segment. Remove it from its cell and remove all of its synapses - * in the Matrix. - * - * This doesn't remove the segment's row from the Matrix, so the other - * segments' row numbers are unaffected. - * - * @param segment - * The segment to destroy - */ - void destroySegment(size_type segment) - { - assert_valid_segment_(segment, "destroySegment"); - - matrix.setRowToZero(segment); - - std::vector& ownerList = + std::vector &ownerList = segmentsForCell_[cellForSegment_[segment]]; - ownerList.erase(std::find(ownerList.begin(), ownerList.end(), - segment)); + ownerList.erase(std::find(ownerList.begin(), ownerList.end(), segment)); - cellForSegment_[segment] = (size_type)-1; + cellForSegment_[segment] = (size_type)-1; - destroyedSegments_.push_back(segment); - } + destroyedSegments_.push_back(segment); + } - /** - * Destroy multiple segments. - * - * @param segments - * The segments to destroy - */ - template - void destroySegments(InputIterator segments_begin, InputIterator segments_end) - { - assert_valid_segment_range_(segments_begin, segments_end, "destroySegments"); - - destroyedSegments_.reserve(destroyedSegments_.size() + - std::distance(segments_begin, segments_end)); - - for (InputIterator segment = segments_begin; - segment != segments_end; - ++segment) - { - destroySegment(*segment); - } + /** + * Destroy multiple segments. + * + * @param segments + * The segments to destroy + */ + template + void destroySegments(InputIterator segments_begin, + InputIterator segments_end) { + assert_valid_segment_range_(segments_begin, segments_end, + "destroySegments"); + + destroyedSegments_.reserve(destroyedSegments_.size() + + std::distance(segments_begin, segments_end)); + + for (InputIterator segment = segments_begin; segment != segments_end; + ++segment) { + destroySegment(*segment); } + } - /** - * Get the number of segments on each of the provided cells. - * - * @param cells - * The cells to check - * - * @param counts - * Output array with the same length as 'cells' - */ - template - void getSegmentCounts(InputIterator cells_begin, InputIterator cells_end, - OutputIterator counts_begin) const - { - assert_valid_cell_range_(cells_begin, cells_end, "getSegmentCounts"); - - OutputIterator out = counts_begin; - - for (InputIterator cell = cells_begin; - cell != cells_end; - ++cell, ++out) - { - *out = segmentsForCell_[*cell].size(); - } - } + /** + * Get the number of segments on each of the provided cells. + * + * @param cells + * The cells to check + * + * @param counts + * Output array with the same length as 'cells' + */ + template + void getSegmentCounts(InputIterator cells_begin, InputIterator cells_end, + OutputIterator counts_begin) const { + assert_valid_cell_range_(cells_begin, cells_end, "getSegmentCounts"); - /** - * Get the segments for a cell. - * - * @param cell - * The cell - */ - const std::vector& getSegmentsForCell(size_type cell) const - { - assert_valid_cell_(cell, "getSegmentsForCell"); - - return segmentsForCell_[cell]; - } + OutputIterator out = counts_begin; - /** - * Sort an array of segments by cell in increasing order. - * - * @param segments - * The segment array. It's sorted in-place. - */ - template - void sortSegmentsByCell(InputIterator segments_begin, - InputIterator segments_end) const - { - assert_valid_segment_range_(segments_begin, segments_end, - "sortSegmentsByCell"); - - std::sort(segments_begin, segments_end, - [&](size_type a, size_type b) - { - return cellForSegment_[a] < cellForSegment_[b]; - }); + for (InputIterator cell = cells_begin; cell != cells_end; ++cell, ++out) { + *out = segmentsForCell_[*cell].size(); } + } - /** - * Return the subset of segments that are on the provided cells. - * - * @param segments - * The segments to filter. Must be sorted by cell. - * - * @param cells - * The cells whose segments we want to keep. Must be sorted. - */ - template - std::vector filterSegmentsByCell( - InputIterator1 segments_begin, InputIterator1 segments_end, - InputIterator2 cells_begin, InputIterator2 cells_end) const - { - assert_valid_sorted_segment_range_(segments_begin, segments_end, - "filterSegmentsByCell"); - assert_valid_sorted_cell_range_(cells_begin, cells_end, - "filterSegmentsByCell"); - - std::vector filteredSegments; - - InputIterator1 segment = segments_begin; - InputIterator2 cell = cells_begin; - - bool finished = (segment == segments_end) || (cell == cells_end); - - while (!finished) - { - while (cellForSegment_[*segment] < *cell) - { - finished = (++segment == segments_end); - if (finished) break; - } - - if (finished) break; - - if (cellForSegment_[*segment] == *cell) - { - filteredSegments.push_back(*segment); - finished = (++segment == segments_end); - if (finished) break; - } - - while (*cell < cellForSegment_[*segment]) - { - finished = (++cell == cells_end); - if (finished) break; - } + /** + * Get the segments for a cell. + * + * @param cell + * The cell + */ + const std::vector &getSegmentsForCell(size_type cell) const { + assert_valid_cell_(cell, "getSegmentsForCell"); + + return segmentsForCell_[cell]; + } + + /** + * Sort an array of segments by cell in increasing order. + * + * @param segments + * The segment array. It's sorted in-place. + */ + template + void sortSegmentsByCell(InputIterator segments_begin, + InputIterator segments_end) const { + assert_valid_segment_range_(segments_begin, segments_end, + "sortSegmentsByCell"); + + std::sort(segments_begin, segments_end, [&](size_type a, size_type b) { + return cellForSegment_[a] < cellForSegment_[b]; + }); + } + + /** + * Return the subset of segments that are on the provided cells. + * + * @param segments + * The segments to filter. Must be sorted by cell. + * + * @param cells + * The cells whose segments we want to keep. Must be sorted. + */ + template + std::vector filterSegmentsByCell(InputIterator1 segments_begin, + InputIterator1 segments_end, + InputIterator2 cells_begin, + InputIterator2 cells_end) const { + assert_valid_sorted_segment_range_(segments_begin, segments_end, + "filterSegmentsByCell"); + assert_valid_sorted_cell_range_(cells_begin, cells_end, + "filterSegmentsByCell"); + + std::vector filteredSegments; + + InputIterator1 segment = segments_begin; + InputIterator2 cell = cells_begin; + + bool finished = (segment == segments_end) || (cell == cells_end); + + while (!finished) { + while (cellForSegment_[*segment] < *cell) { + finished = (++segment == segments_end); + if (finished) + break; } - return filteredSegments; - } + if (finished) + break; - /** - * Get the cell for each provided segment. - * - * @param segments - * The segments to query - * - * @param cells - * Output array with the same length as 'segments' - */ - template - void mapSegmentsToCells( - InputIterator segments_begin, InputIterator segments_end, - OutputIterator cells_begin) const - { - assert_valid_segment_range_(segments_begin, segments_end, - "mapSegmentsToCells"); - - OutputIterator out = cells_begin; - - for (InputIterator segment = segments_begin; - segment != segments_end; - ++segment, ++out) - { - *out = cellForSegment_[*segment]; + if (cellForSegment_[*segment] == *cell) { + filteredSegments.push_back(*segment); + finished = (++segment == segments_end); + if (finished) + break; + } + + while (*cell < cellForSegment_[*segment]) { + finished = (++cell == cells_end); + if (finished) + break; } } - public: + return filteredSegments; + } - /** - * The underlying Matrix. Each row is a segment. - * - * Don't add or remove rows directly. Use createSegment / destroySegment. - */ - Matrix matrix; + /** + * Get the cell for each provided segment. + * + * @param segments + * The segments to query + * + * @param cells + * Output array with the same length as 'segments' + */ + template + void mapSegmentsToCells(InputIterator segments_begin, + InputIterator segments_end, + OutputIterator cells_begin) const { + assert_valid_segment_range_(segments_begin, segments_end, + "mapSegmentsToCells"); + + OutputIterator out = cells_begin; + + for (InputIterator segment = segments_begin; segment != segments_end; + ++segment, ++out) { + *out = cellForSegment_[*segment]; + } + } - private: +public: + /** + * The underlying Matrix. Each row is a segment. + * + * Don't add or remove rows directly. Use createSegment / destroySegment. + */ + Matrix matrix; - void assert_valid_segment_(size_type segment, const char *where) const - { +private: + void assert_valid_segment_(size_type segment, const char *where) const { #ifdef NTA_ASSERTIONS_ON - NTA_ASSERT(segment < matrix.nRows()) + NTA_ASSERT(segment < matrix.nRows()) << "SegmentMatrixAdapter " << where << ": Invalid segment: " << segment << " - Should be < " << matrix.nRows(); - NTA_ASSERT(cellForSegment_[segment] != (size_type)-1) + NTA_ASSERT(cellForSegment_[segment] != (size_type)-1) << "SegmentMatrixAdapter " << where << ": Invalid segment: " << segment << " -- This segment has been destroyed."; #endif - } + } - template - void assert_valid_segment_range_(Iterator segments_begin, - Iterator segments_end, - const char *where) const - { + template + void assert_valid_segment_range_(Iterator segments_begin, + Iterator segments_end, + const char *where) const { #ifdef NTA_ASSERTIONS_ON - for (Iterator segment = segments_begin; - segment != segments_end; - ++segment) - { - assert_valid_segment_(*segment, where); - } -#endif + for (Iterator segment = segments_begin; segment != segments_end; + ++segment) { + assert_valid_segment_(*segment, where); } +#endif + } - template - void assert_valid_sorted_segment_range_(Iterator segments_begin, - Iterator segments_end, - const char *where) const - { + template + void assert_valid_sorted_segment_range_(Iterator segments_begin, + Iterator segments_end, + const char *where) const { #ifdef NTA_ASSERTIONS_ON - for (Iterator segment = segments_begin; - segment != segments_end; - ++segment) - { - assert_valid_segment_(*segment, where); - - if (segment != segments_begin) - { - NTA_ASSERT(cellForSegment_[*(segment - 1)] <= - cellForSegment_[*segment]) + for (Iterator segment = segments_begin; segment != segments_end; + ++segment) { + assert_valid_segment_(*segment, where); + + if (segment != segments_begin) { + NTA_ASSERT(cellForSegment_[*(segment - 1)] <= cellForSegment_[*segment]) << "SegmentMatrixAdapter " << where << ": Segments must be sorted " << "by cell. Found cell " << cellForSegment_[*(segment - 1)] << " before cell " << cellForSegment_[*segment]; - } } -#endif } +#endif + } - void assert_valid_cell_(size_type cell, const char *where) const { + void assert_valid_cell_(size_type cell, const char *where) const { #ifdef NTA_ASSERTIONS_ON - NTA_ASSERT(cell < nCells()) + NTA_ASSERT(cell < nCells()) << "SegmentMatrixAdapter " << where << ": Invalid cell: " << cell << " - Should be < " << nCells(); #endif - } + } - template - void assert_valid_cell_range_(Iterator cells_begin, - Iterator cells_end, - const char *where) const - { + template + void assert_valid_cell_range_(Iterator cells_begin, Iterator cells_end, + const char *where) const { #ifdef NTA_ASSERTIONS_ON - for (Iterator cell = cells_begin; cell != cells_end; ++cell) - { - assert_valid_cell_(*cell, where); - } -#endif + for (Iterator cell = cells_begin; cell != cells_end; ++cell) { + assert_valid_cell_(*cell, where); } +#endif + } - template - void assert_valid_sorted_cell_range_(Iterator cells_begin, - Iterator cells_end, - const char *where) const - { + template + void assert_valid_sorted_cell_range_(Iterator cells_begin, Iterator cells_end, + const char *where) const { #ifdef NTA_ASSERTIONS_ON - for (Iterator cell = cells_begin; cell != cells_end; ++cell) - { - assert_valid_cell_(*cell, where); + for (Iterator cell = cells_begin; cell != cells_end; ++cell) { + assert_valid_cell_(*cell, where); - if (cell != cells_begin) - { - NTA_ASSERT(*(cell - 1) <= *cell) + if (cell != cells_begin) { + NTA_ASSERT(*(cell - 1) <= *cell) << "SegmentMatrixAdapter " << where << ": Cells must be sorted. " << "Found cell " << *(cell - 1) << " before cell " << *cell; - } } -#endif } +#endif + } - private: - - // One-to-one mapping: segment -> cell - std::vector cellForSegment_; +private: + // One-to-one mapping: segment -> cell + std::vector cellForSegment_; - // One-to-many mapping: cell -> segments - std::vector > segmentsForCell_; + // One-to-many mapping: cell -> segments + std::vector> segmentsForCell_; - // Rather that deleting rows from the matrix, keep a list of rows that can - // be reused. Otherwise the segment numbers in the 'cellForSegment' and - // 'segmentsForCell' vectors would be invalidated every time a segment gets - // destroyed. - std::vector destroyedSegments_; - }; + // Rather that deleting rows from the matrix, keep a list of rows that can + // be reused. Otherwise the segment numbers in the 'cellForSegment' and + // 'segmentsForCell' vectors would be invalidated every time a segment gets + // destroyed. + std::vector destroyedSegments_; +}; } // end namespace nupic diff --git a/src/nupic/math/Set.hpp b/src/nupic/math/Set.hpp index 057261f508..59b1048766 100644 --- a/src/nupic/math/Set.hpp +++ b/src/nupic/math/Set.hpp @@ -22,8 +22,9 @@ /** @file * Our own set object, to beat Python, at least when computing intersection. - * TODO: this file is currently superceded by built-in python set(), keeping as a reference, - * and we should test which is faster for intersection workload, which is heavily used. + * TODO: this file is currently superceded by built-in python set(), keeping as + * a reference, and we should test which is faster for intersection workload, + * which is heavily used. */ #ifndef NTA_MATH_SET_HPP @@ -33,107 +34,97 @@ namespace nupic { - //-------------------------------------------------------------------------------- - // SET - // - // Represents the set with an indicator function stored in a bit array. - // - // T is an unsigned integral type. - // T_byte has the size of a byte. - // - // Test from Python: - // Mac PowerBook 2.8 GHz Core 2 Duo, 10.6.3, -O3 -DNDEBUG, gcc 4.2.1 (Apple 5659) - // m = 50000, n1 = 40, n2 = 10000: 0.00274658203125 0.00162267684937 1.69262415516 - // m = 50000, n1 = 80, n2 = 10000: 0.00458002090454 0.00179862976074 2.54639448568 - // m = 50000, n1 = 200, n2 = 10000: 0.0124213695526 0.00241708755493 5.13898204774 - // m = 50000, n1 = 500, n2 = 10000: 0.0339875221252 0.00330281257629 10.2904785967 - // m = 50000, n1 = 1000, n2 = 10000: 0.0573344230652 0.00443959236145 12.9143440202 - // m = 50000, n1 = 2500, n2 = 10000: 0.155576944351 0.00838160514832 18.5617124164 - // m = 50000, n1 = 5000, n2 = 10000: 0.256726026535 0.0143656730652 17.8707969595 - //-------------------------------------------------------------------------------- - template - class Set - { - private: - T m; // max value of non-zero indices - T n; // number of non-zeros in s - std::vector s; // indicator of the non-zeros - - public: - // For Python binding - inline Set() - {} - - /** - * Constructs from a list of n element indices ss, each element being - * in the interval [0,m[. - */ - inline Set(T _m, T _n, T* ss) - : m(_m), - n(_n), - s(m/8 + (m % 8 == 0 ? 0 : 1)) - { - construct(m, n, ss); - } +//-------------------------------------------------------------------------------- +// SET +// +// Represents the set with an indicator function stored in a bit array. +// +// T is an unsigned integral type. +// T_byte has the size of a byte. +// +// Test from Python: +// Mac PowerBook 2.8 GHz Core 2 Duo, 10.6.3, -O3 -DNDEBUG, gcc 4.2.1 (Apple +// 5659) m = 50000, n1 = 40, n2 = 10000: 0.00274658203125 +// 0.00162267684937 1.69262415516 m = 50000, n1 = 80, n2 = 10000: +// 0.00458002090454 0.00179862976074 2.54639448568 m = 50000, n1 = 200, n2 = +// 10000: 0.0124213695526 0.00241708755493 5.13898204774 m = 50000, n1 = 500, n2 +// = 10000: 0.0339875221252 0.00330281257629 10.2904785967 m = 50000, n1 = 1000, +// n2 = 10000: 0.0573344230652 0.00443959236145 12.9143440202 m = 50000, n1 = +// 2500, n2 = 10000: 0.155576944351 0.00838160514832 18.5617124164 m = 50000, n1 +// = 5000, n2 = 10000: 0.256726026535 0.0143656730652 17.8707969595 +//-------------------------------------------------------------------------------- +template class Set { +private: + T m; // max value of non-zero indices + T n; // number of non-zeros in s + std::vector s; // indicator of the non-zeros - inline void construct(T _m, T _n, T* ss) - { - m = _m; - n = _n; - s.resize(m/8 + (m % 8 == 0 ? 0 : 1)); +public: + // For Python binding + inline Set() {} - for (T i = 0; i != n; ++i) - s[ss[i] / 8] |= 1 << (ss[i] % 8); - } + /** + * Constructs from a list of n element indices ss, each element being + * in the interval [0,m[. + */ + inline Set(T _m, T _n, T *ss) + : m(_m), n(_n), s(m / 8 + (m % 8 == 0 ? 0 : 1)) { + construct(m, n, ss); + } + + inline void construct(T _m, T _n, T *ss) { + m = _m; + n = _n; + s.resize(m / 8 + (m % 8 == 0 ? 0 : 1)); + + for (T i = 0; i != n; ++i) + s[ss[i] / 8] |= 1 << (ss[i] % 8); + } + + inline Set(const Set &o) : s(o.s) {} + + inline Set &operator=(const Set &o) { + s = o.s; + return *this; + } - inline Set(const Set& o) - : s(o.s) - {} + inline T n_elements() const { return n; } + inline T max_index() const { return m; } + inline T n_bytes() const { return s.size(); } - inline Set& operator=(const Set& o) - { - s = o.s; - return *this; + /** + * Computes the intersection between this and another set (n2, s2). + * n2 is the number of element in the second s, s2 is a pointer to + * the first element, which needs to be stored contiguously. s2 needs + * to store the indices of the elements: (2,7,11) is the set of those + * 3 elements. + * r is the result set, and is also a list of element indices. + * This method also returns an integer, which is the number of elements + * in the intersection (so that r can be allocated once, and its first + * few positions reused over and over again). + * + * NOTE: for best performance, have n2 << n + */ + inline T intersection(T n2, T *s2, T *r) const { + /* + if (n < n2) { + std::cout << "Calling nupic::Set::intersection " + << "with small set: for best performance, " + << "call with smaller set as argument" + << std::endl; } + */ - inline T n_elements() const { return n; } - inline T max_index() const { return m; } - inline T n_bytes() const { return s.size(); } - - /** - * Computes the intersection between this and another set (n2, s2). - * n2 is the number of element in the second s, s2 is a pointer to - * the first element, which needs to be stored contiguously. s2 needs - * to store the indices of the elements: (2,7,11) is the set of those - * 3 elements. - * r is the result set, and is also a list of element indices. - * This method also returns an integer, which is the number of elements - * in the intersection (so that r can be allocated once, and its first - * few positions reused over and over again). - * - * NOTE: for best performance, have n2 << n - */ - inline T intersection(T n2, T* s2, T* r) const - { - /* - if (n < n2) { - std::cout << "Calling nupic::Set::intersection " - << "with small set: for best performance, " - << "call with smaller set as argument" - << std::endl; - } - */ - - T* rr = r; - for (T* i = s2, *i_end = s2 + n2; i != i_end; ++i) { - *r = *i; - r += (s[*i >> 3] & (1 << (*i % 8))) / (1 << (*i % 8)); - } - return (T) (r - rr); + T *rr = r; + for (T *i = s2, *i_end = s2 + n2; i != i_end; ++i) { + *r = *i; + r += (s[*i >> 3] & (1 << (*i % 8))) / (1 << (*i % 8)); } - }; + return (T)(r - rr); + } +}; } // end namespace nupic //-------------------------------------------------------------------------------- -#endif //NTA_MATH_SET_HPP +#endif // NTA_MATH_SET_HPP diff --git a/src/nupic/math/SparseBinaryMatrix.hpp b/src/nupic/math/SparseBinaryMatrix.hpp index ee41df7107..68e75d240b 100644 --- a/src/nupic/math/SparseBinaryMatrix.hpp +++ b/src/nupic/math/SparseBinaryMatrix.hpp @@ -30,9 +30,9 @@ #include #include +#include #include #include -#include #include #include @@ -511,8 +511,8 @@ class SparseBinaryMatrix : public Serializable { ASSERT_INPUT_ITERATOR(InputIterator1); ASSERT_INPUT_ITERATOR(InputIterator2); - NTA_ASSERT(nz_j_end - nz_j == nz_i_end - nz_i) << where - << "Invalid range"; + NTA_ASSERT(nz_j_end - nz_j == nz_i_end - nz_i) + << where << "Invalid range"; #ifdef NTA_ASSERTIONS_ON @@ -1119,7 +1119,7 @@ class SparseBinaryMatrix : public Serializable { n += sprintf(buffer, "%ld ", (long)ind_[row][j]); } return n; - } + } inline void fromCSR(std::istream &inStream) { const std::string where = "SparseBinaryMatrix::readState: "; @@ -1253,8 +1253,8 @@ class SparseBinaryMatrix : public Serializable { std::string version; inStream >> version; - NTA_CHECK(version == getVersion(true)) << where - << "Unknown format: " << version; + NTA_CHECK(version == getVersion(true)) + << where << "Unknown format: " << version; size_type nrows = 0; inStream >> nrows; diff --git a/src/nupic/math/SparseMatrix.hpp b/src/nupic/math/SparseMatrix.hpp index b0fd15eeba..38004ec334 100644 --- a/src/nupic/math/SparseMatrix.hpp +++ b/src/nupic/math/SparseMatrix.hpp @@ -192,8 +192,8 @@ class SparseMatrix : public Serializable { inline void assert_not_zero_value_(const value_type &val, const char *where) const { #ifdef NTA_ASSERTIONS_ON - NTA_ASSERT(!isZero_(val)) << "SparseMatrix " << where - << ": Zero value should be != 0"; + NTA_ASSERT(!isZero_(val)) + << "SparseMatrix " << where << ": Zero value should be != 0"; #endif } @@ -235,10 +235,10 @@ class SparseMatrix : public Serializable { assert_valid_row_(row_begin, where); if (row_begin < row_end) assert_valid_row_(row_end - 1, where); - NTA_ASSERT(row_begin <= row_end) << "SparseMatrix " << where - << ": Invalid row range: [" << row_begin - << ".." << row_end << "): " - << "- Beginning should be <= end of range"; + NTA_ASSERT(row_begin <= row_end) + << "SparseMatrix " << where << ": Invalid row range: [" << row_begin + << ".." << row_end << "): " + << "- Beginning should be <= end of range"; #endif } @@ -248,10 +248,10 @@ class SparseMatrix : public Serializable { assert_valid_col_(col_begin, where); if (col_begin < col_end) assert_valid_col_(col_end - 1, where); - NTA_ASSERT(col_begin <= col_end) << "SparseMatrix " << where - << ": Invalid col range: [" << col_begin - << ".." << col_end << "): " - << "- Beginning should be <= end of range"; + NTA_ASSERT(col_begin <= col_end) + << "SparseMatrix " << where << ": Invalid col range: [" << col_begin + << ".." << col_end << "): " + << "- Beginning should be <= end of range"; #endif } @@ -299,16 +299,16 @@ class SparseMatrix : public Serializable { ASSERT_INPUT_ITERATOR(InputIterator1); - NTA_ASSERT(ind_end - ind_it >= 0) << "SparseMatrix " << where - << ": Invalid iterators"; + NTA_ASSERT(ind_end - ind_it >= 0) + << "SparseMatrix " << where << ": Invalid iterators"; for (size_type j = 0, prev = 0; ind_it != ind_end; ++ind_it, ++j) { size_type index = *ind_it; - NTA_ASSERT(0 <= index && index < m) << "SparseMatrix " << where - << ": Invalid index: " << index - << " - Should be >= 0 and < " << m; + NTA_ASSERT(0 <= index && index < m) + << "SparseMatrix " << where << ": Invalid index: " << index + << " - Should be >= 0 and < " << m; if (j > 0) { NTA_ASSERT(prev < index) @@ -333,20 +333,20 @@ class SparseMatrix : public Serializable { ASSERT_INPUT_ITERATOR(InputIterator1); ASSERT_INPUT_ITERATOR(InputIterator2); - NTA_ASSERT(ind_end - ind_it >= 0) << "SparseMatrix " << where - << ": Invalid iterators"; + NTA_ASSERT(ind_end - ind_it >= 0) + << "SparseMatrix " << where << ": Invalid iterators"; for (size_type j = 0, prev = 0; ind_it != ind_end; ++ind_it, ++nz_it, ++j) { size_type index = *ind_it; - NTA_ASSERT(0 <= index && index < m) << "SparseMatrix " << where - << ": Invalid index: " << index - << " - Should be >= 0 and < " << m; + NTA_ASSERT(0 <= index && index < m) + << "SparseMatrix " << where << ": Invalid index: " << index + << " - Should be >= 0 and < " << m; - NTA_ASSERT(!isZero_(*nz_it)) << "SparseMatrix " << where - << ": Passed zero at index: " << j - << " - Should pass non-zeros only"; + NTA_ASSERT(!isZero_(*nz_it)) + << "SparseMatrix " << where << ": Passed zero at index: " << j + << " - Should pass non-zeros only"; if (j > 0) { NTA_ASSERT(prev < index) @@ -370,10 +370,10 @@ class SparseMatrix : public Serializable { NTA_ASSERT(!isZero_(*nz)) << where << "Near zero value: " << *nz << " at (" << row << ", " << *ind << ") " << "nupic::Epsilon= " << nupic::Epsilon; - NTA_ASSERT(row < nRows()) << where << "Invalid row index: " << row - << " nRows= " << nRows(); - NTA_ASSERT(*ind < nCols()) << where << "Invalid col index: " << *ind - << " nCols= " << nCols(); + NTA_ASSERT(row < nRows()) + << where << "Invalid row index: " << row << " nRows= " << nRows(); + NTA_ASSERT(*ind < nCols()) + << where << "Invalid col index: " << *ind << " nCols= " << nCols(); } assert_valid_sorted_index_range_(nCols(), ind_begin_(row), ind_end_(row), @@ -438,8 +438,8 @@ class SparseMatrix : public Serializable { */ inline void allocate_(size_type nrows_max, size_type ncols) { { // Pre-conditions - NTA_ASSERT(0 <= nrows_max) << "SparseMatrix allocate_: Bad nrows_max: " - << nrows_max; + NTA_ASSERT(0 <= nrows_max) + << "SparseMatrix allocate_: Bad nrows_max: " << nrows_max; NTA_ASSERT(0 <= ncols) << "SparseMatrix allocate_: Bad ncols: " << ncols; } // End pre-conditions @@ -929,10 +929,8 @@ class SparseMatrix : public Serializable { * @li Columns must be sorted */ template - inline size_type nZerosInRowOnColumns_( - size_type row, - InputIterator col_begin, InputIterator col_end) - { + inline size_type nZerosInRowOnColumns_(size_type row, InputIterator col_begin, + InputIterator col_end) { { // Pre-conditions ASSERT_INPUT_ITERATOR(InputIterator); assert_valid_row_(row, "nZerosInRowOnColumns_"); @@ -981,12 +979,11 @@ class SparseMatrix : public Serializable { * @li The caller must correctly count the number of zeros in the selection */ template - inline void insertRandomNonZerosIntoColumns_( - size_type row, - InputIterator col_begin, InputIterator col_end, - size_type numToInsert, size_type numZerosAvailable, value_type value, - Random& rng) - { + inline void + insertRandomNonZerosIntoColumns_(size_type row, InputIterator col_begin, + InputIterator col_end, size_type numToInsert, + size_type numZerosAvailable, + value_type value, Random &rng) { { // Pre-conditions ASSERT_INPUT_ITERATOR(InputIterator); NTA_ASSERT(numToInsert <= numZerosAvailable); @@ -1010,10 +1007,12 @@ class SparseMatrix : public Serializable { InputIterator selected_col = col_begin; size_type *prev_it = prev_ind_begin; - size_type nextNonzeroCol = prev_it != prev_ind_end ? - *prev_it : std::numeric_limits::max(); - size_type nextSelectedCol = selected_col != col_end ? - *selected_col : std::numeric_limits::max(); + size_type nextNonzeroCol = prev_it != prev_ind_end + ? *prev_it + : std::numeric_limits::max(); + size_type nextSelectedCol = selected_col != col_end + ? *selected_col + : std::numeric_limits::max(); for (size_type pos = 0; pos < nnzr; ++pos) { while (true) { @@ -1024,8 +1023,9 @@ class SparseMatrix : public Serializable { row_nz[pos] = nz_[row][prev_it - prev_ind_begin]; ++prev_it; - nextNonzeroCol = prev_it != prev_ind_end ? - *prev_it : std::numeric_limits::max(); + nextNonzeroCol = prev_it != prev_ind_end + ? *prev_it + : std::numeric_limits::max(); break; } else if (nextNonzeroCol == nextSelectedCol) { // The next selected column is nonzero. @@ -1035,17 +1035,19 @@ class SparseMatrix : public Serializable { ++prev_it; ++selected_col; - nextNonzeroCol = prev_it != prev_ind_end ? - *prev_it : std::numeric_limits::max(); - nextSelectedCol = selected_col != col_end ? - *selected_col : std::numeric_limits::max(); + nextNonzeroCol = prev_it != prev_ind_end + ? *prev_it + : std::numeric_limits::max(); + nextSelectedCol = selected_col != col_end + ? *selected_col + : std::numeric_limits::max(); break; } else { // The next selected column is a zero. // Maybe insert a nonzero. NTA_ASSERT(numRemainingAvailable > 0); - const bool insertNonzero = (rng.getUInt32(numRemainingAvailable) < - numRemainingToChoose); + const bool insertNonzero = + (rng.getUInt32(numRemainingAvailable) < numRemainingToChoose); if (insertNonzero) { row_ind[pos] = *selected_col; row_nz[pos] = value; @@ -1056,8 +1058,9 @@ class SparseMatrix : public Serializable { --numRemainingAvailable; ++selected_col; - nextSelectedCol = selected_col != col_end ? - *selected_col : std::numeric_limits::max(); + nextSelectedCol = selected_col != col_end + ? *selected_col + : std::numeric_limits::max(); if (insertNonzero) { break; @@ -1107,13 +1110,13 @@ class SparseMatrix : public Serializable { nz_mem_(nullptr), ind_(nullptr), nz_(nullptr), indb_(nullptr), nzb_(nullptr), isZero_() { { // Pre-conditions - NTA_CHECK(nrows >= 0) << "SparseMatrix::SparseMatrix(nrows, ncols): " - << "Invalid number of rows: " << nrows - << " - Should be >= 0"; + NTA_CHECK(nrows >= 0) + << "SparseMatrix::SparseMatrix(nrows, ncols): " + << "Invalid number of rows: " << nrows << " - Should be >= 0"; - NTA_CHECK(ncols >= 0) << "SparseMatrix::SparseMatrix(nrows, ncols): " - << "Invalid number of columns: " << ncols - << " - Should be >= 0"; + NTA_CHECK(ncols >= 0) + << "SparseMatrix::SparseMatrix(nrows, ncols): " + << "Invalid number of columns: " << ncols << " - Should be >= 0"; } // End pre-conditions allocate_(nrows, ncols); @@ -1208,17 +1211,8 @@ class SparseMatrix : public Serializable { inline SparseMatrix(const SparseMatrix &other, InputIterator take, InputIterator take_end, int rowCol = 1) // cols - : nrows_(0), - nrows_max_(0), - ncols_(0), - nnzr_(0), - ind_mem_(0), - nz_mem_(0), - ind_(0), - nz_(0), - indb_(0), - nzb_(0), - isZero_() { + : nrows_(0), nrows_max_(0), ncols_(0), nnzr_(0), ind_mem_(0), nz_mem_(0), + ind_(0), nz_(0), indb_(0), nzb_(0), isZero_() { { // Pre-conditions NTA_ASSERT(rowCol == 0 || rowCol == 1) << "SparseMatrix: constructor from set of rows/cols: " @@ -1266,7 +1260,7 @@ class SparseMatrix : public Serializable { /** * Deallocates this instance and initializes its non-zeros only it with the - * non-zeros of specified cols from other. The number of rows and cols is + * non-zeros of specified cols from other. The number of rows and cols is * the same as other. * The vector passed is a binary vector with 1 at the indices of the * cols that need to be copied and 0 elsewhere. @@ -1631,20 +1625,16 @@ class SparseMatrix : public Serializable { * @li O(rows_end - rows_begin) */ template - inline void nNonZerosPerRow(InputIterator rows_begin, - InputIterator rows_end, + inline void nNonZerosPerRow(InputIterator rows_begin, InputIterator rows_end, OutputIterator out_begin) const { { // Pre-conditions - assert_valid_row_it_range_(rows_begin, rows_end, - "nNonZerosPerRow"); + assert_valid_row_it_range_(rows_begin, rows_end, "nNonZerosPerRow"); ASSERT_OUTPUT_ITERATOR(OutputIterator, size_type); } // End pre-conditions InputIterator row; OutputIterator out; - for (row = rows_begin, out = out_begin; - row != rows_end; - ++row, ++out) { + for (row = rows_begin, out = out_begin; row != rows_end; ++row, ++out) { *out = nNonZerosOnRow(*row); } } @@ -1666,14 +1656,12 @@ class SparseMatrix : public Serializable { * @li O(number of nonzeros on the specified rows) */ template - inline void nNonZerosPerRowOnCols(InputIterator rows_begin, - InputIterator rows_end, - InputIterator cols_begin, - InputIterator cols_end, - OutputIterator out_begin) const { + inline void + nNonZerosPerRowOnCols(InputIterator rows_begin, InputIterator rows_end, + InputIterator cols_begin, InputIterator cols_end, + OutputIterator out_begin) const { { // Pre-conditions - assert_valid_row_it_range_(rows_begin, rows_end, - "nNonZerosPerRowOnCols"); + assert_valid_row_it_range_(rows_begin, rows_end, "nNonZerosPerRowOnCols"); assert_valid_sorted_index_range_(nCols(), cols_begin, cols_end, "nNonZerosPerRowOnCols"); ASSERT_OUTPUT_ITERATOR(OutputIterator, size_type); @@ -1681,9 +1669,7 @@ class SparseMatrix : public Serializable { InputIterator row; OutputIterator out; - for (row = rows_begin, out = out_begin; - row != rows_end; - ++row, ++out) { + for (row = rows_begin, out = out_begin; row != rows_end; ++row, ++out) { size_type numNonZeros = 0; @@ -1694,13 +1680,11 @@ class SparseMatrix : public Serializable { while (ind != ind_end && col != cols_end) { if (*ind < *col) { ++ind; - } - else if (*ind == *col) { + } else if (*ind == *col) { ++numNonZeros; ++ind; ++col; - } - else { + } else { ++col; } } @@ -2684,13 +2668,13 @@ class SparseMatrix : public Serializable { { // Pre-conditions ASSERT_INPUT_ITERATOR(InputIterator); - NTA_CHECK(nrows >= 0) << "SparseMatrix::fromDense(): " - << "Invalid number of rows: " << nrows - << " - Should be >= 0"; + NTA_CHECK(nrows >= 0) + << "SparseMatrix::fromDense(): " + << "Invalid number of rows: " << nrows << " - Should be >= 0"; - NTA_CHECK(ncols >= 0) << "SparseMatrix::fromDense(): " - << "Invalid number of columns: " << ncols - << " - Should be > 0"; + NTA_CHECK(ncols >= 0) + << "SparseMatrix::fromDense(): " + << "Invalid number of columns: " << ncols << " - Should be > 0"; } // End pre-conditions deallocate_(); @@ -3078,11 +3062,11 @@ class SparseMatrix : public Serializable { NTA_CHECK(s2 == sizeof(value_type)) << where << "Bad value_type: " << s2; - NTA_CHECK(s3 == sizeof(difference_type)) << where - << "Bad difference_type: " << s3; + NTA_CHECK(s3 == sizeof(difference_type)) + << where << "Bad difference_type: " << s3; - NTA_CHECK(s4 == sizeof(prec_value_type)) << where - << "Bad prec_value_type: " << s4; + NTA_CHECK(s4 == sizeof(prec_value_type)) + << where << "Bad prec_value_type: " << s4; } size_type nrows, nrows_max, ncols, nnz; @@ -3092,8 +3076,8 @@ class SparseMatrix : public Serializable { { NTA_CHECK(0 <= nrows) << where << "Bad number of rows: " << nrows; - NTA_CHECK(0 <= nrows_max) << where - << "Bad max number of rows: " << nrows_max; + NTA_CHECK(0 <= nrows_max) + << where << "Bad max number of rows: " << nrows_max; NTA_CHECK(nrows <= nrows_max) << where << "Number of rows: " << nrows << " should be less than max number of rows: " << nrows_max; @@ -3188,13 +3172,13 @@ class SparseMatrix : public Serializable { inline void resize(size_type new_nrows, size_type new_ncols, bool setToZero = false) { { // Pre-conditions - NTA_ASSERT(0 <= new_nrows) << "SparseMatrix resize: " - << "New number of rows: " << new_nrows - << " should be positive"; + NTA_ASSERT(0 <= new_nrows) + << "SparseMatrix resize: " + << "New number of rows: " << new_nrows << " should be positive"; - NTA_ASSERT(0 <= new_ncols) << "SparseMatrix resize: " - << "New number of columns: " << new_ncols - << " should be positive"; + NTA_ASSERT(0 <= new_ncols) + << "SparseMatrix resize: " + << "New number of columns: " << new_ncols << " should be positive"; } // End pre-conditions const size_type nrows = nRows(); @@ -3248,13 +3232,13 @@ class SparseMatrix : public Serializable { */ inline void reshape(size_type new_nrows, size_type new_ncols) { { // Pre-conditions - NTA_ASSERT(0 <= new_nrows) << "SparseMatrix reshape: " - << "New number of rows: " << new_nrows - << " should be positive"; + NTA_ASSERT(0 <= new_nrows) + << "SparseMatrix reshape: " + << "New number of rows: " << new_nrows << " should be positive"; - NTA_ASSERT(0 <= new_ncols) << "SparseMatrix reshape: " - << "New number of columns: " << new_ncols - << " should be positive"; + NTA_ASSERT(0 <= new_ncols) + << "SparseMatrix reshape: " + << "New number of columns: " << new_ncols << " should be positive"; NTA_ASSERT((double)new_nrows * new_ncols == (double)nRows() * nCols()) << "SparseMatrix reshape: " @@ -3415,11 +3399,11 @@ class SparseMatrix : public Serializable { << "SparseMatrix::deleteRows(): " << "Invalid row index: " << *d << " - Row indices should be between 0 and " << nRows(); - NTA_CHECK(*d < *d_next) << "SparseMatrix::deleteRows(): " - << "Invalid row indices " << *d << " and " - << *d_next - << " - Row indices need to be passed " - << "in strictly increasing order"; + NTA_CHECK(*d < *d_next) + << "SparseMatrix::deleteRows(): " + << "Invalid row indices " << *d << " and " << *d_next + << " - Row indices need to be passed " + << "in strictly increasing order"; ++d; ++d_next; } @@ -3577,11 +3561,11 @@ class SparseMatrix : public Serializable { << "SparseMatrix::deleteCols(): " << "Invalid column index: " << *d << " - Col indices should be between 0 and " << nCols(); - NTA_ASSERT(*d < *d_next) << "SparseMatrix::deleteCols(): " - << "Invalid column indices " << *d << " and " - << *d_next - << " - Col indices need to be passed " - << "in strictly increasing order"; + NTA_ASSERT(*d < *d_next) + << "SparseMatrix::deleteCols(): " + << "Invalid column indices " << *d << " and " << *d_next + << " - Col indices need to be passed " + << "in strictly increasing order"; ++d; ++d_next; } @@ -4302,10 +4286,10 @@ class SparseMatrix : public Serializable { * Amount to add to each selected nonzero */ template - inline void incrementNonZerosOnOuter( - InputIterator1 row_begin, InputIterator1 row_end, - InputIterator2 col_begin, InputIterator2 col_end, - value_type delta) { + inline void + incrementNonZerosOnOuter(InputIterator1 row_begin, InputIterator1 row_end, + InputIterator2 col_begin, InputIterator2 col_end, + value_type delta) { { // Pre-conditions ASSERT_INPUT_ITERATOR(InputIterator1); ASSERT_INPUT_ITERATOR(InputIterator2); @@ -4324,10 +4308,8 @@ class SparseMatrix : public Serializable { size_type *ind_begin = ind_begin_(*row); size_type *ind_end = ind_end_(*row); - for (size_type *ind = ind_begin; ind != ind_end; ++ind) - { - if (indb_[*ind] == 1) - { + for (size_type *ind = ind_begin; ind != ind_end; ++ind) { + if (indb_[*ind] == 1) { nz_[*row][ind - ind_begin] += delta; } } @@ -4357,10 +4339,11 @@ class SparseMatrix : public Serializable { * Amount to add to each selected nonzero */ template - inline void incrementNonZerosOnRowsExcludingCols( - InputIterator1 row_begin, InputIterator1 row_end, - InputIterator2 col_begin, InputIterator2 col_end, - value_type delta) { + inline void incrementNonZerosOnRowsExcludingCols(InputIterator1 row_begin, + InputIterator1 row_end, + InputIterator2 col_begin, + InputIterator2 col_end, + value_type delta) { { // Pre-conditions ASSERT_INPUT_ITERATOR(InputIterator1); ASSERT_INPUT_ITERATOR(InputIterator2); @@ -4379,10 +4362,8 @@ class SparseMatrix : public Serializable { size_type *ind_begin = ind_begin_(*row); size_type *ind_end = ind_end_(*row); - for (size_type *ind = ind_begin; ind != ind_end; ++ind) - { - if (indb_[*ind] != 1) - { + for (size_type *ind = ind_begin; ind != ind_end; ++ind) { + if (indb_[*ind] != 1) { nz_[*row][ind - ind_begin] += delta; } } @@ -4416,9 +4397,9 @@ class SparseMatrix : public Serializable { * @li a <= b */ template - inline void clipRowsBelowAndAbove( - InputIterator row_begin, InputIterator row_end, - value_type a, value_type b) { + inline void clipRowsBelowAndAbove(InputIterator row_begin, + InputIterator row_end, value_type a, + value_type b) { { // Pre-conditions ASSERT_INPUT_ITERATOR(InputIterator); assert_valid_row_it_range_(row_begin, row_end, "clipRowsBelowAndAbove"); @@ -4456,16 +4437,13 @@ class SparseMatrix : public Serializable { * @li Columns must be sorted */ template - inline void setZerosOnOuter( - InputIterator1 row_begin, InputIterator1 row_end, - InputIterator2 col_begin, InputIterator2 col_end, - value_type value) - { + inline void setZerosOnOuter(InputIterator1 row_begin, InputIterator1 row_end, + InputIterator2 col_begin, InputIterator2 col_end, + value_type value) { { // Pre-conditions ASSERT_INPUT_ITERATOR(InputIterator1); ASSERT_INPUT_ITERATOR(InputIterator2); - assert_valid_row_it_range_(row_begin, row_end, - "setZerosOnOuter"); + assert_valid_row_it_range_(row_begin, row_end, "setZerosOnOuter"); assert_valid_sorted_index_range_(nCols(), col_begin, col_end, "setZerosOnOuter"); } // End pre-conditions @@ -4478,10 +4456,11 @@ class SparseMatrix : public Serializable { InputIterator2 selected_col = col_begin; size_type *it = ind_begin; - size_type nextNonzeroCol = it != ind_end ? - *it : std::numeric_limits::max(); - size_type nextSelectedCol = selected_col != col_end ? - *selected_col : std::numeric_limits::max(); + size_type nextNonzeroCol = + it != ind_end ? *it : std::numeric_limits::max(); + size_type nextSelectedCol = selected_col != col_end + ? *selected_col + : std::numeric_limits::max(); size_type *indb_it = indb_; value_type *nzb_it = nzb_; @@ -4493,8 +4472,8 @@ class SparseMatrix : public Serializable { *nzb_it = nz_[*row][it - ind_begin]; ++it; - nextNonzeroCol = it != ind_end ? - *it : std::numeric_limits::max(); + nextNonzeroCol = + it != ind_end ? *it : std::numeric_limits::max(); } else if (nextNonzeroCol == nextSelectedCol) { // The next selected column is nonzero. // Copy it. @@ -4503,10 +4482,11 @@ class SparseMatrix : public Serializable { ++it; ++selected_col; - nextNonzeroCol = it != ind_end ? - *it : std::numeric_limits::max(); - nextSelectedCol = selected_col != col_end ? - *selected_col : std::numeric_limits::max(); + nextNonzeroCol = + it != ind_end ? *it : std::numeric_limits::max(); + nextSelectedCol = selected_col != col_end + ? *selected_col + : std::numeric_limits::max(); } else { // The next selected column is a zero. // Insert a nonzero. @@ -4514,8 +4494,9 @@ class SparseMatrix : public Serializable { *nzb_it = value; ++selected_col; - nextSelectedCol = selected_col != col_end ? - *selected_col : std::numeric_limits::max(); + nextSelectedCol = selected_col != col_end + ? *selected_col + : std::numeric_limits::max(); } } @@ -4574,28 +4555,26 @@ class SparseMatrix : public Serializable { * @li Columns must be sorted */ template - inline void setRandomZerosOnOuter( - InputIterator1 row_begin, InputIterator1 row_end, - InputIterator2 col_begin, InputIterator2 col_end, - difference_type numNewNonZerosPerRow, value_type value, Random& rng) { + inline void + setRandomZerosOnOuter(InputIterator1 row_begin, InputIterator1 row_end, + InputIterator2 col_begin, InputIterator2 col_end, + difference_type numNewNonZerosPerRow, value_type value, + Random &rng) { { // Pre-conditions ASSERT_INPUT_ITERATOR(InputIterator1); ASSERT_INPUT_ITERATOR(InputIterator2); - assert_valid_row_it_range_(row_begin, row_end, - "setRandomZerosOnOuter"); + assert_valid_row_it_range_(row_begin, row_end, "setRandomZerosOnOuter"); assert_valid_sorted_index_range_(nCols(), col_begin, col_end, "setRandomZerosOnOuter"); } // End pre-conditions for (InputIterator1 row = row_begin; row != row_end; ++row) { size_type numZeros = nZerosInRowOnColumns_(*row, col_begin, col_end); - difference_type numNewNonZeros = std::min(numNewNonZerosPerRow, - (difference_type)numZeros); - if (numNewNonZeros > 0) - { + difference_type numNewNonZeros = + std::min(numNewNonZerosPerRow, (difference_type)numZeros); + if (numNewNonZeros > 0) { insertRandomNonZerosIntoColumns_(*row, col_begin, col_end, - numNewNonZeros, numZeros, value, - rng); + numNewNonZeros, numZeros, value, rng); } } } @@ -4635,17 +4614,16 @@ class SparseMatrix : public Serializable { */ template - inline void setRandomZerosOnOuter( - InputIterator1 row_begin, InputIterator1 row_end, - InputIterator2 col_begin, InputIterator2 col_end, - InputIterator3 numNew_begin, InputIterator3 numNew_end, - value_type value, Random& rng) { + inline void + setRandomZerosOnOuter(InputIterator1 row_begin, InputIterator1 row_end, + InputIterator2 col_begin, InputIterator2 col_end, + InputIterator3 numNew_begin, InputIterator3 numNew_end, + value_type value, Random &rng) { { // Pre-conditions ASSERT_INPUT_ITERATOR(InputIterator1); ASSERT_INPUT_ITERATOR(InputIterator2); ASSERT_INPUT_ITERATOR(InputIterator3); - assert_valid_row_it_range_(row_begin, row_end, - "setRandomZerosOnOuter"); + assert_valid_row_it_range_(row_begin, row_end, "setRandomZerosOnOuter"); assert_valid_sorted_index_range_(nCols(), col_begin, col_end, "setRandomZerosOnOuter"); @@ -4655,18 +4633,15 @@ class SparseMatrix : public Serializable { InputIterator1 row; InputIterator3 numNew; - for (row = row_begin, numNew = numNew_begin; - row != row_end; + for (row = row_begin, numNew = numNew_begin; row != row_end; ++row, ++numNew) { if (*numNew > 0) { - size_type numZeros = nZerosInRowOnColumns_(*row, col_begin, col_end); - difference_type numNewNonZeros = std::min(*numNew, - (difference_type)numZeros); - if (numNewNonZeros > 0) - { - insertRandomNonZerosIntoColumns_(*row, col_begin, col_end, - numNewNonZeros, numZeros, value, - rng); + size_type numZeros = nZerosInRowOnColumns_(*row, col_begin, col_end); + difference_type numNewNonZeros = + std::min(*numNew, (difference_type)numZeros); + if (numNewNonZeros > 0) { + insertRandomNonZerosIntoColumns_( + *row, col_begin, col_end, numNewNonZeros, numZeros, value, rng); } } } @@ -4709,10 +4684,10 @@ class SparseMatrix : public Serializable { */ template inline void increaseRowNonZeroCountsOnOuterTo( - InputIterator1 row_begin, InputIterator1 row_end, - InputIterator2 col_begin, InputIterator2 col_end, - difference_type numDesiredNonzeros, value_type initialValue, Random& rng) - { + InputIterator1 row_begin, InputIterator1 row_end, + InputIterator2 col_begin, InputIterator2 col_end, + difference_type numDesiredNonzeros, value_type initialValue, + Random &rng) { { // Pre-conditions ASSERT_INPUT_ITERATOR(InputIterator1); ASSERT_INPUT_ITERATOR(InputIterator2); @@ -4725,18 +4700,15 @@ class SparseMatrix : public Serializable { for (InputIterator1 row = row_begin; row != row_end; ++row) { size_type numZeros = nZerosInRowOnColumns_(*row, col_begin, col_end); difference_type numNonZeros = (col_end - col_begin) - numZeros; - size_type numDesiredNewNonZeros = - (size_type) std::max((difference_type)0, - (difference_type)(numDesiredNonzeros - - numNonZeros)); + size_type numDesiredNewNonZeros = (size_type)std::max( + (difference_type)0, + (difference_type)(numDesiredNonzeros - numNonZeros)); size_type numActualNewNonZeros = - std::min(numDesiredNewNonZeros, numZeros); - if (numActualNewNonZeros > 0) - { + std::min(numDesiredNewNonZeros, numZeros); + if (numActualNewNonZeros > 0) { insertRandomNonZerosIntoColumns_(*row, col_begin, col_end, (size_type)numActualNewNonZeros, - numZeros, - initialValue, rng); + numZeros, initialValue, rng); } } } @@ -4833,10 +4805,10 @@ class SparseMatrix : public Serializable { if (*ind >= col_begin && *ind < col_end) { ijv_iterator->i(row); ijv_iterator->j(*ind); - NTA_ASSERT(!isZero_(*nz)) << "SparseMatrix::getAllNonZeros (rect): " - << "Zero at " << row << ", " << *ind - << ": " << *nz - << " epsilon= " << nupic::Epsilon; + NTA_ASSERT(!isZero_(*nz)) + << "SparseMatrix::getAllNonZeros (rect): " + << "Zero at " << row << ", " << *ind << ": " << *nz + << " epsilon= " << nupic::Epsilon; ijv_iterator->v(*nz); ++ijv_iterator; } @@ -7512,8 +7484,8 @@ class SparseMatrix : public Serializable { * @li None. */ inline void threshold(const value_type &threshold = nupic::Epsilon) { - filter(std::bind(std::greater_equal(), - std::placeholders::_1, threshold)); + filter(std::bind(std::greater_equal(), std::placeholders::_1, + threshold)); } template @@ -7801,9 +7773,9 @@ class SparseMatrix : public Serializable { { // Pre-conditions } // End pre-conditions - return countWhere(begin_row, end_row, begin_col, end_col, - std::bind(std::equal_to(), - std::placeholders::_1, value)); + return countWhere( + begin_row, end_row, begin_col, end_col, + std::bind(std::equal_to(), std::placeholders::_1, value)); } /** @@ -7837,10 +7809,10 @@ class SparseMatrix : public Serializable { ASSERT_OUTPUT_ITERATOR(OutputIterator1, size_type); } // End pre-conditions - findIndices(begin_row, end_row, begin_col, end_col, - std::bind(std::equal_to(), - std::placeholders::_1, value), - row_it, col_it); + findIndices( + begin_row, end_row, begin_col, end_col, + std::bind(std::equal_to(), std::placeholders::_1, value), + row_it, col_it); } /** @@ -7868,9 +7840,9 @@ class SparseMatrix : public Serializable { { // Pre-conditions } // End pre-conditions - return countWhere(begin_row, end_row, begin_col, end_col, - std::bind(std::greater(), - std::placeholders::_1, value)); + return countWhere( + begin_row, end_row, begin_col, end_col, + std::bind(std::greater(), std::placeholders::_1, value)); } /** @@ -7904,10 +7876,10 @@ class SparseMatrix : public Serializable { ASSERT_OUTPUT_ITERATOR(OutputIterator1, size_type); } // End pre-conditions - findIndices(begin_row, end_row, begin_col, end_col, - std::bind(std::greater(), - std::placeholders::_1, value), - row_it, col_it); + findIndices( + begin_row, end_row, begin_col, end_col, + std::bind(std::greater(), std::placeholders::_1, value), + row_it, col_it); } /** @@ -7992,20 +7964,20 @@ class SparseMatrix : public Serializable { NTA_ASSERT(nnzr >= 0) << "SparseMatrix::findRow(): " << "Passed in " << nnzr << " non-zeros"; - NTA_ASSERT(nnzr <= nCols()) << "SparseMatrix::findRow(): " - << "Passed in " << nnzr << " non-zeros " - << "but there are only " << nCols() - << " columns"; + NTA_ASSERT(nnzr <= nCols()) + << "SparseMatrix::findRow(): " + << "Passed in " << nnzr << " non-zeros " + << "but there are only " << nCols() << " columns"; #ifdef NTA_ASSERTIONS_ON // to avoid compilation of loop IndIt jj = ind_it; NzIt nn = nz_it; size_type j = 0, prev = 0; for (j = 0; j != nnzr; ++j, ++jj, ++nn) { - NTA_ASSERT(0 <= *jj && *jj < nCols()) << "SparseMatrix::findRow(): " - << "Invalid column index" - << " - Should be >= 0 and < " - << nCols(); + NTA_ASSERT(0 <= *jj && *jj < nCols()) + << "SparseMatrix::findRow(): " + << "Invalid column index" + << " - Should be >= 0 and < " << nCols(); NTA_ASSERT(!isZero_(*nn)) << "SparseMatrix::findRow(): " << "Passed zero at index: " << *jj << " - Should pass non-zeros only"; @@ -10331,10 +10303,8 @@ class SparseMatrix : public Serializable { size_type *ind_end = ind_begin + nnzr_[row]; difference_type sum = 0; - for (size_type *ind = ind_begin; ind != ind_end; ++ind) - { - if (indb_[*ind]) - { + for (size_type *ind = ind_begin; ind != ind_end; ++ind) { + if (indb_[*ind]) { ++sum; } } @@ -10343,7 +10313,6 @@ class SparseMatrix : public Serializable { } } - /** * Like rightVecSumAtNZ, except that we add to the sum only if the value of * the non-zero is > threshold. @@ -10373,9 +10342,10 @@ class SparseMatrix : public Serializable { * than a float array. */ template - inline void rightVecSumAtNZGtThresholdSparse( - InputIterator x_ones_begin, InputIterator x_ones_end, - OutputIterator out_begin, value_type threshold) const { + inline void rightVecSumAtNZGtThresholdSparse(InputIterator x_ones_begin, + InputIterator x_ones_end, + OutputIterator out_begin, + value_type threshold) const { { // Pre-conditions ASSERT_INPUT_ITERATOR(InputIterator); ASSERT_OUTPUT_ITERATOR(OutputIterator, difference_type); @@ -10397,11 +10367,8 @@ class SparseMatrix : public Serializable { value_type *nz = nz_[row]; difference_type sum = 0; - for (size_type *ind = ind_begin; ind != ind_end; ++ind) - { - if (indb_[*ind] && - nz[ind - ind_begin] > threshold) - { + for (size_type *ind = ind_begin; ind != ind_end; ++ind) { + if (indb_[*ind] && nz[ind - ind_begin] > threshold) { ++sum; } } @@ -10438,9 +10405,10 @@ class SparseMatrix : public Serializable { * than a float array. */ template - inline void rightVecSumAtNZGteThresholdSparse( - InputIterator x_ones_begin, InputIterator x_ones_end, - OutputIterator out_begin, value_type threshold) const { + inline void rightVecSumAtNZGteThresholdSparse(InputIterator x_ones_begin, + InputIterator x_ones_end, + OutputIterator out_begin, + value_type threshold) const { { // Pre-conditions ASSERT_INPUT_ITERATOR(InputIterator); ASSERT_OUTPUT_ITERATOR(OutputIterator, difference_type); @@ -10462,11 +10430,8 @@ class SparseMatrix : public Serializable { value_type *nz = nz_[row]; difference_type sum = 0; - for (size_type *ind = ind_begin; ind != ind_end; ++ind) - { - if (indb_[*ind] && - nz[ind - ind_begin] >= threshold) - { + for (size_type *ind = ind_begin; ind != ind_end; ++ind) { + if (indb_[*ind] && nz[ind - ind_begin] >= threshold) { ++sum; } } @@ -10915,9 +10880,9 @@ class SparseMatrix : public Serializable { ASSERT_INPUT_ITERATOR(InputIterator); } // End pre-conditions - this->applyOuter(row_begin, row_end, col_begin, col_end, - std::bind(nupic::Plus(), - std::placeholders::_1, val)); + this->applyOuter( + row_begin, row_end, col_begin, col_end, + std::bind(nupic::Plus(), std::placeholders::_1, val)); } /** @@ -11028,8 +10993,8 @@ class SparseMatrix : public Serializable { * Replaces the non-zeros by the specified value. */ inline void replaceNZ(const value_type &val = 1.0) { - elementNZApply(std::bind(nupic::Assign(), - std::placeholders::_1, val)); + elementNZApply( + std::bind(nupic::Assign(), std::placeholders::_1, val)); } /** @@ -11239,8 +11204,8 @@ class SparseMatrix : public Serializable { } inline void multiply(const value_type &val) { - elementNZApply(std::bind(nupic::Multiplies(), - std::placeholders::_1, val)); + elementNZApply( + std::bind(nupic::Multiplies(), std::placeholders::_1, val)); } inline void elementRowDivide(size_type idx, const value_type &val) { @@ -11258,38 +11223,38 @@ class SparseMatrix : public Serializable { NTA_ASSERT(!isZero_(val)) << "divide: Division by zero"; } // End pre-conditions - elementNZApply(std::bind(nupic::Divides(), - std::placeholders::_1, val)); + elementNZApply( + std::bind(nupic::Divides(), std::placeholders::_1, val)); } inline void elementRowNZPow(size_type idx, const value_type &val) { - elementRowNZApply(idx, std::bind(nupic::Pow(), - std::placeholders::_1, val)); + elementRowNZApply( + idx, std::bind(nupic::Pow(), std::placeholders::_1, val)); } inline void elementColNZPow(size_type idx, const value_type &val) { - elementColNZApply(idx, std::bind(nupic::Pow(), - std::placeholders::_1, val)); + elementColNZApply( + idx, std::bind(nupic::Pow(), std::placeholders::_1, val)); } inline void elementNZPow(const value_type &val) { - elementNZApply(std::bind(nupic::Pow(), - std::placeholders::_1, val)); + elementNZApply( + std::bind(nupic::Pow(), std::placeholders::_1, val)); } inline void elementRowNZLogk(size_type idx, const value_type &val) { - elementRowNZApply(idx, std::bind(nupic::Logk(), - std::placeholders::_1, val)); + elementRowNZApply( + idx, std::bind(nupic::Logk(), std::placeholders::_1, val)); } inline void elementColNZLogk(size_type idx, const value_type &val) { - elementColNZApply(idx, std::bind(nupic::Logk(), - std::placeholders::_1, val)); + elementColNZApply( + idx, std::bind(nupic::Logk(), std::placeholders::_1, val)); } inline void elementNZLogk(const value_type &val) { - elementNZApply(std::bind(nupic::Logk(), - std::placeholders::_1, val)); + elementNZApply( + std::bind(nupic::Logk(), std::placeholders::_1, val)); } template @@ -11323,38 +11288,38 @@ class SparseMatrix : public Serializable { } inline void rowAdd(size_type idx, const value_type &val) { - elementRowApply(idx, std::bind(nupic::Plus(), - std::placeholders::_1, val)); + elementRowApply( + idx, std::bind(nupic::Plus(), std::placeholders::_1, val)); } inline void colAdd(size_type idx, const value_type &val) { - elementColApply(idx, std::bind(nupic::Plus(), - std::placeholders::_1, val)); + elementColApply( + idx, std::bind(nupic::Plus(), std::placeholders::_1, val)); } inline void add(const value_type &val) { - elementApply(std::bind(nupic::Plus(), - std::placeholders::_1, val)); + elementApply( + std::bind(nupic::Plus(), std::placeholders::_1, val)); } inline void elementNZAdd(const value_type &val) { - elementNZApply(std::bind(nupic::Plus(), - std::placeholders::_1, val)); + elementNZApply( + std::bind(nupic::Plus(), std::placeholders::_1, val)); } inline void rowSubtract(size_type idx, const value_type &val) { - elementRowApply(idx, std::bind(nupic::Minus(), - std::placeholders::_1, val)); + elementRowApply( + idx, std::bind(nupic::Minus(), std::placeholders::_1, val)); } inline void colSubtract(size_type idx, const value_type &val) { - elementColApply(idx, std::bind(nupic::Minus(), - std::placeholders::_1, val)); + elementColApply( + idx, std::bind(nupic::Minus(), std::placeholders::_1, val)); } inline void subtract(const value_type &val) { - elementApply(std::bind(nupic::Minus(), - std::placeholders::_1, val)); + elementApply( + std::bind(nupic::Minus(), std::placeholders::_1, val)); } inline void elementNZMultiply(const SparseMatrix &other) { diff --git a/src/nupic/math/SparseMatrix01.hpp b/src/nupic/math/SparseMatrix01.hpp index 641b8eda1b..8416a7a8d0 100644 --- a/src/nupic/math/SparseMatrix01.hpp +++ b/src/nupic/math/SparseMatrix01.hpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Definition and implementation for SparseMatrix01 class */ @@ -29,2159 +29,2066 @@ //---------------------------------------------------------------------- -#include +#include #include #include -#include +#include -#include // for memset in gcc 4.4 -#include #include +#include +#include // for memset in gcc 4.4 //---------------------------------------------------------------------- - namespace nupic { - //-------------------------------------------------------------------------------- - template - struct RowCompare - { - RowCompare(Int rowSize_) - : rowSize(rowSize_) - {} +//-------------------------------------------------------------------------------- +template struct RowCompare { + RowCompare(Int rowSize_) : rowSize(rowSize_) {} + + inline bool operator()(Int *row1, Int *row2) const { + for (Int i = 0; i < rowSize; ++i) + if (row1[i] > row2[i]) + return true; + else if (row1[i] < row2[i]) + return false; + return false; + } - inline bool operator()(Int* row1, Int* row2) const - { - for (Int i = 0; i < rowSize; ++i) - if (row1[i] > row2[i]) - return true; - else if (row1[i] < row2[i]) - return false; - return false; - } + Int rowSize; +}; + +//-------------------------------------------------------------------------------- +/** + * @b Responsibility: + * A sparse matrix class dedicated to supporting Numenta's algorithms. + * This is not a general sparse matrix. It's tuned specifically for + * speed, and to support Numenta's algorithms. + * + * @b Rationale: + * It is not a fully general sparse matrix class. Instead, it is intended + * to support Numenta's algorithms as efficiently as possible. + * + * @b Resource/Ownerships: + * This class manages its own memory. + * + * @b Invariants: + * + * @b Notes: + * Note 1: SparseMatrix has a limitation to max columns, or + * rows or non-zeros. + */ +template class SparseMatrix01 { +public: + typedef Int size_type; + typedef Float value_type; + +private: + /// number of rows >= 0 + size_type nrows_; + + /// max number of rows allocated >= 8 + size_type nrows_max_; + + /// number of colums > 0 + size_type ncols_; + + /// array of number of non-zeros per row + size_type *nzr_; + + /// array of arrays of indices of non-zeros, per row + size_type **ind_; + + /// buffer array of indices + size_type *indb_; + + /// buffer array of non-zeros + value_type *nzb_; + + /// whether this matrix is allocated contiguously or not + bool compact_; - Int rowSize; - }; + // Number of non-zeros per row, used when working with unique rows. + // This is used to block the arity of the row comparison functor, + // for speed. (could be retrieved from the map, that stores an instance + // of RowCompare that contains this data). + // This is used as a flag that we are working with unique rows and that + // counts_ has been initialized properly. If nnzr_ > 0, we are working + // with unique rows. + size_type nnzr_; + + // Map that contains a pair for each row. + // The count is the number of times a row has been seen, when working + // with unique rows. + typedef std::pair Row_Count; + typedef std::map> Counts; + Counts counts_; //-------------------------------------------------------------------------------- /** - * @b Responsibility: - * A sparse matrix class dedicated to supporting Numenta's algorithms. - * This is not a general sparse matrix. It's tuned specifically for - * speed, and to support Numenta's algorithms. - * - * @b Rationale: - * It is not a fully general sparse matrix class. Instead, it is intended - * to support Numenta's algorithms as efficiently as possible. - * - * @b Resource/Ownerships: - * This class manages its own memory. - * - * @b Invariants: - * - * @b Notes: - * Note 1: SparseMatrix has a limitation to max columns, or - * rows or non-zeros. + * We try to limit reallocations of SparseMatrix01 instances + * by forbidding the copy constructor and assingment operators. + * We also forbid the default constructor, which would create + * empty matrices of little use. */ - template - class SparseMatrix01 - { - public: - typedef Int size_type; - typedef Float value_type; - - private: - /// number of rows >= 0 - size_type nrows_; - - /// max number of rows allocated >= 8 - size_type nrows_max_; - - /// number of colums > 0 - size_type ncols_; - - /// array of number of non-zeros per row - size_type* nzr_; - - /// array of arrays of indices of non-zeros, per row - size_type** ind_; - - /// buffer array of indices - size_type* indb_; - - /// buffer array of non-zeros - value_type* nzb_; - - /// whether this matrix is allocated contiguously or not - bool compact_; - - // Number of non-zeros per row, used when working with unique rows. - // This is used to block the arity of the row comparison functor, - // for speed. (could be retrieved from the map, that stores an instance - // of RowCompare that contains this data). - // This is used as a flag that we are working with unique rows and that - // counts_ has been initialized properly. If nnzr_ > 0, we are working - // with unique rows. - size_type nnzr_; - - // Map that contains a pair for each row. - // The count is the number of times a row has been seen, when working - // with unique rows. - typedef std::pair Row_Count; - typedef std::map > Counts; - Counts counts_; - - //-------------------------------------------------------------------------------- - /** - * We try to limit reallocations of SparseMatrix01 instances - * by forbidding the copy constructor and assingment operators. - * We also forbid the default constructor, which would create - * empty matrices of little use. - */ - NO_DEFAULTS(SparseMatrix01) - - //-------------------------------------------------------------------------------- - /** - * Decides whether val is not zero or not, by testing if val is - * outside closed ball [-nupic::Epsilon .. +nupic::Epsilon]. - * - * @param val [value_type] the value to test - * @retval [bool] whether val is different from zero or not - */ - inline bool isNotZero_(const value_type& val) const - { - return !nupic::nearlyZero(val); - } + NO_DEFAULTS(SparseMatrix01) - //-------------------------------------------------------------------------------- - /** - * Decides whether val is zero or not, by testing if val is - * inside open ball (-nupic::Epsilon .. +nupic::Epsilon). - * - * @param val [value_type] the value to test - * @retval [bool] whether val is zero or not - */ - inline bool isZero_(const value_type& val) const - { - return nupic::nearlyZero(val); - } + //-------------------------------------------------------------------------------- + /** + * Decides whether val is not zero or not, by testing if val is + * outside closed ball [-nupic::Epsilon .. +nupic::Epsilon]. + * + * @param val [value_type] the value to test + * @retval [bool] whether val is different from zero or not + */ + inline bool isNotZero_(const value_type &val) const { + return !nupic::nearlyZero(val); + } - //-------------------------------------------------------------------------------- - /** - * Allocates data structures of the SparseMatrix01. - * Mutating, allocates memory. - * - * @param nrows_max [size_type >= 0] max number of rows to allocate - * @param ncols [size_type > 0] number of columns for this matrix - * - * @b Exceptions: - * @li nrows_max < 0 (assert) - * @li ncols <= 0 (assert) - * @li Not enough memory (error) - */ - inline void allocate_(const size_type& nrows_max, const size_type& ncols) - { - { // Pre-conditions - NTA_ASSERT(nrows_max >= 0) + //-------------------------------------------------------------------------------- + /** + * Decides whether val is zero or not, by testing if val is + * inside open ball (-nupic::Epsilon .. +nupic::Epsilon). + * + * @param val [value_type] the value to test + * @retval [bool] whether val is zero or not + */ + inline bool isZero_(const value_type &val) const { + return nupic::nearlyZero(val); + } + + //-------------------------------------------------------------------------------- + /** + * Allocates data structures of the SparseMatrix01. + * Mutating, allocates memory. + * + * @param nrows_max [size_type >= 0] max number of rows to allocate + * @param ncols [size_type > 0] number of columns for this matrix + * + * @b Exceptions: + * @li nrows_max < 0 (assert) + * @li ncols <= 0 (assert) + * @li Not enough memory (error) + */ + inline void allocate_(const size_type &nrows_max, const size_type &ncols) { + { // Pre-conditions + NTA_ASSERT(nrows_max >= 0) << "SparseMatrix01::allocate_(): " - << "Invalid nrows_max = " << nrows_max - << " - Should be >= 0"; - - NTA_ASSERT(ncols > 0) + << "Invalid nrows_max = " << nrows_max << " - Should be >= 0"; + + NTA_ASSERT(ncols > 0) << "SparseMatrix01::allocate_(): " - << "Invalid ncols = " << ncols - << " - Should be > 0"; - } - - nrows_max_ = std::max(8, nrows_max); - ncols_ = ncols; - - try { - nzr_ = new size_type[nrows_max_]; - ind_ = new size_type*[nrows_max_]; - indb_ = new size_type[ncols_]; - nzb_ = new value_type[ncols_]; - - } catch (std::exception&) { - - NTA_THROW << "SparseMatrix01::allocate_(): " - << "Could not allocate enough memory:" - << " nrows_max = " << nrows_max - << " ncols = " << ncols; - } - - memset(nzr_, 0, nrows_max_ * sizeof(size_type)); - memset(ind_, 0, nrows_max_ * sizeof(size_type*)); - memset(indb_, 0, ncols_ * sizeof(size_type)); - memset(nzb_, 0, ncols_ * sizeof(value_type)); + << "Invalid ncols = " << ncols << " - Should be > 0"; } - //-------------------------------------------------------------------------------- - /** - * Deallocates data structures of the SparseMatrix01. - */ - inline void deallocate_() - { - if (!nzr_) - return; - - if (!isCompact()) { - size_type **ind = ind_, **ind_end = ind_ + nRows(); - while (ind != ind_end) - delete [] *ind++; - } else { - delete [] ind_[0]; - } + nrows_max_ = std::max(8, nrows_max); + ncols_ = ncols; - delete [] ind_; - ind_ = nullptr; - delete [] nzr_; - nzr_ = nullptr; - delete [] indb_; - indb_ = nullptr; - delete [] nzb_; - nzb_ = nullptr; + try { + nzr_ = new size_type[nrows_max_]; + ind_ = new size_type *[nrows_max_]; + indb_ = new size_type[ncols_]; + nzb_ = new value_type[ncols_]; - nrows_ = ncols_ = nrows_max_ = 0; + } catch (std::exception &) { + + NTA_THROW << "SparseMatrix01::allocate_(): " + << "Could not allocate enough memory:" + << " nrows_max = " << nrows_max << " ncols = " << ncols; } - //-------------------------------------------------------------------------------- - inline void swapCountKeys_(size_type* old_ptr, size_type* new_ptr) - { - typename Counts::iterator it = counts_.find(old_ptr); - const Row_Count row_val = it->second; - counts_.erase(it); - counts_[new_ptr] = row_val; + memset(nzr_, 0, nrows_max_ * sizeof(size_type)); + memset(ind_, 0, nrows_max_ * sizeof(size_type *)); + memset(indb_, 0, ncols_ * sizeof(size_type)); + memset(nzb_, 0, ncols_ * sizeof(value_type)); + } + + //-------------------------------------------------------------------------------- + /** + * Deallocates data structures of the SparseMatrix01. + */ + inline void deallocate_() { + if (!nzr_) + return; + + if (!isCompact()) { + size_type **ind = ind_, **ind_end = ind_ + nRows(); + while (ind != ind_end) + delete[] * ind++; + } else { + delete[] ind_[0]; } - //-------------------------------------------------------------------------------- - /** - * Adds a row stored in ind. - * Returns the index of the newly added row. - * - * WARNING: send in values without duplicates, in increasing order of indices - * - * ind needs to be a contiguous array of memory - */ - template - inline size_type addRow_(const size_type nnzr, InIter ind) - { - size_type row_num = nRows(); - - if (isCompact()) - decompact(); - - // Allocate possibly one more row in all cases, - // even if nnzr == 0 (identically zero row) - if (row_num == nrows_max_-1) { - - size_type *nzr_new = nullptr, **ind_new = nullptr; - nrows_max_ *= 2; - nzr_new = new size_type[nrows_max_]; - ind_new = new size_type*[nrows_max_]; - memcpy(nzr_new, nzr_, row_num * sizeof(size_type)); - memcpy(ind_new, ind_, row_num * sizeof(size_type*)); - delete [] nzr_; - delete [] ind_; - nzr_ = nzr_new; - ind_ = ind_new; - } - - // Can be a row of zeros, in which case nzr_[row_num] == 0 - nzr_[row_num] = nnzr; - ind_[row_num] = new size_type[nnzr]; - - InIter ind_it = ind, ind_end = ind + nnzr; - size_type *target_ind = ind_[row_num]; - while (ind_it != ind_end) { - *target_ind = *ind_it; - ++target_ind; ++ind_it; - } + delete[] ind_; + ind_ = nullptr; + delete[] nzr_; + nzr_ = nullptr; + delete[] indb_; + indb_ = nullptr; + delete[] nzb_; + nzb_ = nullptr; - ++nrows_; - return row_num; + nrows_ = ncols_ = nrows_max_ = 0; + } + + //-------------------------------------------------------------------------------- + inline void swapCountKeys_(size_type *old_ptr, size_type *new_ptr) { + typename Counts::iterator it = counts_.find(old_ptr); + const Row_Count row_val = it->second; + counts_.erase(it); + counts_[new_ptr] = row_val; + } + + //-------------------------------------------------------------------------------- + /** + * Adds a row stored in ind. + * Returns the index of the newly added row. + * + * WARNING: send in values without duplicates, in increasing order of indices + * + * ind needs to be a contiguous array of memory + */ + template + inline size_type addRow_(const size_type nnzr, InIter ind) { + size_type row_num = nRows(); + + if (isCompact()) + decompact(); + + // Allocate possibly one more row in all cases, + // even if nnzr == 0 (identically zero row) + if (row_num == nrows_max_ - 1) { + + size_type *nzr_new = nullptr, **ind_new = nullptr; + nrows_max_ *= 2; + nzr_new = new size_type[nrows_max_]; + ind_new = new size_type *[nrows_max_]; + memcpy(nzr_new, nzr_, row_num * sizeof(size_type)); + memcpy(ind_new, ind_, row_num * sizeof(size_type *)); + delete[] nzr_; + delete[] ind_; + nzr_ = nzr_new; + ind_ = ind_new; } - //-------------------------------------------------------------------------------- - /** - * Compacts a row from nzb_ to (nzr_[r], ind_[r], nz_[r]). - * Mutating, allocates memory. - * - * @param r [0 <= size_type < nrows] row index - * - * @b Exceptions: - * @li r < 0 || r >= nrows_ (assert) - * @li Not enough memory (error) - */ - inline void compactRow_(const size_type& r) - { - { // Pre-conditions - // assert because private (likely to be checked already) - // and for speed - NTA_ASSERT(r >= 0 && r < nRows()) + // Can be a row of zeros, in which case nzr_[row_num] == 0 + nzr_[row_num] = nnzr; + ind_[row_num] = new size_type[nnzr]; + + InIter ind_it = ind, ind_end = ind + nnzr; + size_type *target_ind = ind_[row_num]; + while (ind_it != ind_end) { + *target_ind = *ind_it; + ++target_ind; + ++ind_it; + } + + ++nrows_; + return row_num; + } + + //-------------------------------------------------------------------------------- + /** + * Compacts a row from nzb_ to (nzr_[r], ind_[r], nz_[r]). + * Mutating, allocates memory. + * + * @param r [0 <= size_type < nrows] row index + * + * @b Exceptions: + * @li r < 0 || r >= nrows_ (assert) + * @li Not enough memory (error) + */ + inline void compactRow_(const size_type &r) { + { // Pre-conditions + // assert because private (likely to be checked already) + // and for speed + NTA_ASSERT(r >= 0 && r < nRows()) << "SparseMatrix01::compactRow_(): " - << "Invalid row index: " << r - << " - Should be >= 0 and < " << nRows(); - } + << "Invalid row index: " << r << " - Should be >= 0 and < " + << nRows(); + } - size_type *old_row_ptr = NULL, *new_row_ptr = NULL; - size_type* indb_it = indb_; - value_type* nzb_it = nzb_, nzb_end = nzb_ + nCols(); + size_type *old_row_ptr = NULL, *new_row_ptr = NULL; + size_type *indb_it = indb_; + value_type *nzb_it = nzb_, nzb_end = nzb_ + nCols(); - while (nzb_it != nzb_end) - if (isNotZero_(*nzb_it++)) - *indb_it++ = nzb_it - nzb_; + while (nzb_it != nzb_end) + if (isNotZero_(*nzb_it++)) + *indb_it++ = nzb_it - nzb_; - size_type nnzr = indb_it - indb_; - - if (nnzr > nzr_[r]) { // more non-zeros, need more memory - - if (isCompact()) // as late as possible, but... - decompact(); // changes ind_[r] and nz_[r]!!! but not nzr_[r] - - old_row_ptr = ind_[r]; - new_row_ptr = new size_type[nnzr]; + size_type nnzr = indb_it - indb_; - if (hasUniqueRows()) - swapCountKeys_(old_row_ptr, new_row_ptr); + if (nnzr > nzr_[r]) { // more non-zeros, need more memory - delete [] ind_[r]; // recycle, or delete later - } - - // We are paying here, with two copies each call - // to axby, because we don't have nnzr_max per row - memcpy(ind_[r], indb_, nnzr * sizeof(size_type)); - nzr_[r] = nnzr; // maybe a different value + if (isCompact()) // as late as possible, but... + decompact(); // changes ind_[r] and nz_[r]!!! but not nzr_[r] + + old_row_ptr = ind_[r]; + new_row_ptr = new size_type[nnzr]; + + if (hasUniqueRows()) + swapCountKeys_(old_row_ptr, new_row_ptr); + + delete[] ind_[r]; // recycle, or delete later } - public: - //-------------------------------------------------------------------------------- - /** - * Constructor with a number of columns and a hint for the number - * of rows. The SparseMatrix01 is empty. - * - * @param ncols [size_type > 0] number of columns - * @param hint [size_type (16) >= 0] hint for the initial number of rows - * - * @b Exceptions: - * @li ncols <= 0 (check) - * @li hint < 0 (check) - * @li Not enough memory (error) - */ - SparseMatrix01(const size_type& ncols, - const size_type& hint =16, - const size_type& nnzr =0) - - : nrows_(0), nrows_max_(0), ncols_(0), - nzr_(nullptr), ind_(nullptr), indb_(nullptr), nzb_(nullptr), - compact_(false), - nnzr_(nnzr), - counts_(RowCompare(nnzr)) - { - { // Pre-conditions - NTA_CHECK(ncols > 0) - << "SparseMatrix01::SparseMatrix01(ncols, hint): " - << "Invalid number of columns: " << ncols - << " - Should be > 0"; - - NTA_CHECK(hint >= 0) - << "SparseMatrix01::SparseMatrix01(ncols, hint): " - << "Invalid hint: " << hint - << " - Should be >= 0"; - } + // We are paying here, with two copies each call + // to axby, because we don't have nnzr_max per row + memcpy(ind_[r], indb_, nnzr * sizeof(size_type)); + nzr_[r] = nnzr; // maybe a different value + } - allocate_(hint, ncols); +public: + //-------------------------------------------------------------------------------- + /** + * Constructor with a number of columns and a hint for the number + * of rows. The SparseMatrix01 is empty. + * + * @param ncols [size_type > 0] number of columns + * @param hint [size_type (16) >= 0] hint for the initial number of rows + * + * @b Exceptions: + * @li ncols <= 0 (check) + * @li hint < 0 (check) + * @li Not enough memory (error) + */ + SparseMatrix01(const size_type &ncols, const size_type &hint = 16, + const size_type &nnzr = 0) + + : nrows_(0), nrows_max_(0), ncols_(0), nzr_(nullptr), ind_(nullptr), + indb_(nullptr), nzb_(nullptr), compact_(false), nnzr_(nnzr), + counts_(RowCompare(nnzr)) { + { // Pre-conditions + NTA_CHECK(ncols > 0) << "SparseMatrix01::SparseMatrix01(ncols, hint): " + << "Invalid number of columns: " << ncols + << " - Should be > 0"; + + NTA_CHECK(hint >= 0) << "SparseMatrix01::SparseMatrix01(ncols, hint): " + << "Invalid hint: " << hint << " - Should be >= 0"; } - //-------------------------------------------------------------------------------- - /** - * Constructor from a dense matrix passed as an array of value_type. - * Uses the values in mat to initialize the SparseMatrix01. - * - * @param nrows [size_type > 0] number of rows - * @param ncols [size_type > 0] number of columns - * @param mat [value_type** != NULL] initial array of values - * - * @b Exceptions: - * @li nrows <= 0 (check) - * @li ncols <= 0 (check) - * @li mat == NULL (check) - * @li NULL pointer in mat (check) - * @li Not enough memory (error) - */ - template - SparseMatrix01(const size_type& nrows, - const size_type& ncols, - InIter mat, - const size_type& nnzr) - - : nrows_(0), nrows_max_(0), ncols_(0), - nzr_(0), ind_(0), indb_(0), nzb_(0), - compact_(false), - nnzr_(nnzr), - counts_(RowCompare(nnzr)) - { - { // Pre-conditions - NTA_CHECK(nrows >= 0) - << "SparseMatrix01::SparseMatrix01(nrows, ncols, mat): " - << "Invalid number of rows: " << nrows - << " - Should be >= 0"; - - NTA_CHECK(ncols > 0) + allocate_(hint, ncols); + } + + //-------------------------------------------------------------------------------- + /** + * Constructor from a dense matrix passed as an array of value_type. + * Uses the values in mat to initialize the SparseMatrix01. + * + * @param nrows [size_type > 0] number of rows + * @param ncols [size_type > 0] number of columns + * @param mat [value_type** != NULL] initial array of values + * + * @b Exceptions: + * @li nrows <= 0 (check) + * @li ncols <= 0 (check) + * @li mat == NULL (check) + * @li NULL pointer in mat (check) + * @li Not enough memory (error) + */ + template + SparseMatrix01(const size_type &nrows, const size_type &ncols, InIter mat, + const size_type &nnzr) + + : nrows_(0), nrows_max_(0), ncols_(0), nzr_(0), ind_(0), indb_(0), + nzb_(0), compact_(false), nnzr_(nnzr), + counts_(RowCompare(nnzr)) { + { // Pre-conditions + NTA_CHECK(nrows >= 0) << "SparseMatrix01::SparseMatrix01(nrows, ncols, mat): " - << "Invalid number of columns: " << ncols - << " - Should be > 0"; - } + << "Invalid number of rows: " << nrows << " - Should be >= 0"; - fromDense(nrows, ncols, mat); + NTA_CHECK(ncols > 0) + << "SparseMatrix01::SparseMatrix01(nrows, ncols, mat): " + << "Invalid number of columns: " << ncols << " - Should be > 0"; } - //-------------------------------------------------------------------------------- - /** - * Destructor. - */ - ~SparseMatrix01() - { - deallocate_(); - } + fromDense(nrows, ncols, mat); + } - //-------------------------------------------------------------------------------- - /** - * Whether this matrix is zero or not. - * Non-mutating, O(nrows). - * This is computed, rather than maintained incrementally. - * - * @retval [bool] whether the matrix is zero or not - */ - inline bool isZero() const { return nNonZeros() == 0; } - - //-------------------------------------------------------------------------------- - /** - * Returns the number of rows in this SparseMatrix01. - * Non-mutating, O(1). - * - * @retval [size_type >= 0] number of rows - */ - inline const size_type nRows() const { return nrows_; } - - //-------------------------------------------------------------------------------- - /** - * Returns the number of colums in this SparseMatrix01. - * Non-mutating, O(1). - * - * @retval [size_type >= 0] number of columns - */ - inline const size_type nCols() const { return ncols_; } - - //-------------------------------------------------------------------------------- - /** - * Returns the number of non-zeros in this SparseMatrix01. - * Non-mutating, O(nnz). - * This is computed rather than stored and maintained incrementally. - * This is slow, but we can't add the nzr_[i] because there might - * be less non-zeros... So far, nobody has seen this method - * on the critical path in a profile. - * - * @retval [size_type >= 0] number of non-zeros - */ - inline const size_type nNonZeros() const - { - size_type nnz = 0, nrows = nRows(); - for (size_type i = 0; i < nrows; ++i) - nnz += nzr_[i]; - return nnz; - } + //-------------------------------------------------------------------------------- + /** + * Destructor. + */ + ~SparseMatrix01() { deallocate_(); } - //-------------------------------------------------------------------------------- - /** - * Returns the number of non-zeros on 'row'-th row. - * Non-mutating, O(nnzr). - * This is slow, we rescan the line rather than returning nzr_[row], - * because there might be small discrepancies. nzr_[row] really - * represents how many elements are allocated, but there might be - * some elements that fall below epsilon in some cases. - * - * @param row [SizeType >= 0 < nrows] index of the row to access - * @retval [size_type >= 0] number of non-zeros on 'row'-th row - * - * @b Exceptions: - * @li If row < 0 || row >= nRows() (assert) - */ - inline const size_type nNonZerosRow(const size_type& row) const - { - { - NTA_ASSERT(0 <= row && row < nRows()) - << "SparseMatrix01::nNonZerosRow(): " - << "Invalid row index: " << row - << " - Should be >= 0 and < " << nRows(); - } - - return nzr_[row]; - } + //-------------------------------------------------------------------------------- + /** + * Whether this matrix is zero or not. + * Non-mutating, O(nrows). + * This is computed, rather than maintained incrementally. + * + * @retval [bool] whether the matrix is zero or not + */ + inline bool isZero() const { return nNonZeros() == 0; } - //-------------------------------------------------------------------------------- - inline bool hasUniqueRows() const - { - return nnzr_ > 0; - } + //-------------------------------------------------------------------------------- + /** + * Returns the number of rows in this SparseMatrix01. + * Non-mutating, O(1). + * + * @retval [size_type >= 0] number of rows + */ + inline const size_type nRows() const { return nrows_; } - //-------------------------------------------------------------------------------- - inline bool isCompact() const + //-------------------------------------------------------------------------------- + /** + * Returns the number of colums in this SparseMatrix01. + * Non-mutating, O(1). + * + * @retval [size_type >= 0] number of columns + */ + inline const size_type nCols() const { return ncols_; } + + //-------------------------------------------------------------------------------- + /** + * Returns the number of non-zeros in this SparseMatrix01. + * Non-mutating, O(nnz). + * This is computed rather than stored and maintained incrementally. + * This is slow, but we can't add the nzr_[i] because there might + * be less non-zeros... So far, nobody has seen this method + * on the critical path in a profile. + * + * @retval [size_type >= 0] number of non-zeros + */ + inline const size_type nNonZeros() const { + size_type nnz = 0, nrows = nRows(); + for (size_type i = 0; i < nrows; ++i) + nnz += nzr_[i]; + return nnz; + } + + //-------------------------------------------------------------------------------- + /** + * Returns the number of non-zeros on 'row'-th row. + * Non-mutating, O(nnzr). + * This is slow, we rescan the line rather than returning nzr_[row], + * because there might be small discrepancies. nzr_[row] really + * represents how many elements are allocated, but there might be + * some elements that fall below epsilon in some cases. + * + * @param row [SizeType >= 0 < nrows] index of the row to access + * @retval [size_type >= 0] number of non-zeros on 'row'-th row + * + * @b Exceptions: + * @li If row < 0 || row >= nRows() (assert) + */ + inline const size_type nNonZerosRow(const size_type &row) const { { - return compact_; + NTA_ASSERT(0 <= row && row < nRows()) + << "SparseMatrix01::nNonZerosRow(): " + << "Invalid row index: " << row << " - Should be >= 0 and < " + << nRows(); } - //-------------------------------------------------------------------------------- - /** - * Exports this SparseMatrix01 to pre-allocated dense array of value_types. - * Non-mutating, O(nnz). - * - * WARNING: does not update nnzr_, needed when counting rows - * - * @param dense [value_type** != NULL] array in which to put the values - * - * @b Exceptions: - * @li dense == NULL (check) - * @li dense has NULL pointer (check) - */ - template - inline void toDense(OutIter dense) const - { - size_type nrows = nRows(), ncols = nCols(); + return nzr_[row]; + } - std::fill(dense, dense + nrows*ncols, (value_type)0); + //-------------------------------------------------------------------------------- + inline bool hasUniqueRows() const { return nnzr_ > 0; } - ITER_2(nrows, nzr_[i]) - *(dense + i*ncols + ind_[i][j]) = 1; - } + //-------------------------------------------------------------------------------- + inline bool isCompact() const { return compact_; } - //-------------------------------------------------------------------------------- - /** - * Populates this SparseMatrix01 from a dense array of value_types. - * Mutating, discards the previous state of this SparseMatrix01. - * - * @param nrows [size_type > 0] number of rows in dense array - * @param ncols [size_type > 0] number of columns in dense array - * @param dense [value_type** != NULL] dense array of values - * - * WARNING: does not update nnzr_, needed when counting rows - * - * @b Exceptions: - * @li nrows <= 0 (check) - * @li ncols <= 0 (check) - * @li dense == NULL (check) - * @li dense has NULL pointer (check) - * @li Not enough memory (error) - */ - template - inline void fromDense(const size_type& nrows, const size_type& ncols, InIter dense) - { - { // Pre-conditions - NTA_CHECK(nrows >= 0) - << "SparseMatrix01::fromDense(): " - << "Invalid number of rows: " << nrows - << " - Should be >= 0"; + //-------------------------------------------------------------------------------- + /** + * Exports this SparseMatrix01 to pre-allocated dense array of value_types. + * Non-mutating, O(nnz). + * + * WARNING: does not update nnzr_, needed when counting rows + * + * @param dense [value_type** != NULL] array in which to put the values + * + * @b Exceptions: + * @li dense == NULL (check) + * @li dense has NULL pointer (check) + */ + template inline void toDense(OutIter dense) const { + size_type nrows = nRows(), ncols = nCols(); + + std::fill(dense, dense + nrows * ncols, (value_type)0); - NTA_CHECK(ncols > 0) + ITER_2(nrows, nzr_[i]) + *(dense + i * ncols + ind_[i][j]) = 1; + } + + //-------------------------------------------------------------------------------- + /** + * Populates this SparseMatrix01 from a dense array of value_types. + * Mutating, discards the previous state of this SparseMatrix01. + * + * @param nrows [size_type > 0] number of rows in dense array + * @param ncols [size_type > 0] number of columns in dense array + * @param dense [value_type** != NULL] dense array of values + * + * WARNING: does not update nnzr_, needed when counting rows + * + * @b Exceptions: + * @li nrows <= 0 (check) + * @li ncols <= 0 (check) + * @li dense == NULL (check) + * @li dense has NULL pointer (check) + * @li Not enough memory (error) + */ + template + inline void fromDense(const size_type &nrows, const size_type &ncols, + InIter dense) { + { // Pre-conditions + NTA_CHECK(nrows >= 0) << "SparseMatrix01::fromDense(): " - << "Invalid number of columns: " << ncols - << " - Should be > 0"; - } + << "Invalid number of rows: " << nrows << " - Should be >= 0"; - if (nzr_) - deallocate_(); + NTA_CHECK(ncols > 0) << "SparseMatrix01::fromDense(): " + << "Invalid number of columns: " << ncols + << " - Should be > 0"; + } - allocate_(nrows, ncols); - nrows_ = 0; + if (nzr_) + deallocate_(); + + allocate_(nrows, ncols); + nrows_ = 0; + + for (size_type i = 0; i < nrows; ++i) + addRow(dense + i * ncols); + } - for (size_type i = 0; i < nrows; ++i) - addRow(dense + i*ncols); + //-------------------------------------------------------------------------------- + /** + * Populates this SparseMatrix01 from a stream in csr format. + * The pairs (index, value) can be in any order for each row. + * Mutating, discards the previous state of this SparseMatrix01. + * Can handle large sparse matrices. + * + * @b Format: + * 'csr' nrows ncols nnz + * nnzr1 j1 val1 j2 val2 ... + * nnzr2 ... + * + * The order of the (j, val) tuples doesn't matter. + * + * @param inStream [std::istream] the stream to initialize from + * @retval [std::istream] the stream after the matrix has been read + * + * @b Exceptions: + * @li Bad stream (check) + * @li Stream does not start with 'csr' tag (check) + * @li nrows < 0 in stream (check) + * @li ncols <= 0 in stream (check) + * @li nnz < 0 || nnz > nrows * ncols in stream (check) + * @li nnzr < 0 || nnzr > ncols for any row (check) + * @li column index j < 0 || >= ncols for any row (check) + * @li Not enough memory (error) + */ + inline std::istream &fromCSR(std::istream &inStreamParam) { + const char *where = "SparseMatrix01::fromCSR(): "; + + { // Pre-conditions + NTA_CHECK(inStreamParam.good()) << where << "Bad stream"; } - //-------------------------------------------------------------------------------- - /** - * Populates this SparseMatrix01 from a stream in csr format. - * The pairs (index, value) can be in any order for each row. - * Mutating, discards the previous state of this SparseMatrix01. - * Can handle large sparse matrices. - * - * @b Format: - * 'csr' nrows ncols nnz - * nnzr1 j1 val1 j2 val2 ... - * nnzr2 ... - * - * The order of the (j, val) tuples doesn't matter. - * - * @param inStream [std::istream] the stream to initialize from - * @retval [std::istream] the stream after the matrix has been read - * - * @b Exceptions: - * @li Bad stream (check) - * @li Stream does not start with 'csr' tag (check) - * @li nrows < 0 in stream (check) - * @li ncols <= 0 in stream (check) - * @li nnz < 0 || nnz > nrows * ncols in stream (check) - * @li nnzr < 0 || nnzr > ncols for any row (check) - * @li column index j < 0 || >= ncols for any row (check) - * @li Not enough memory (error) - */ - inline std::istream& fromCSR(std::istream& inStreamParam) - { - const char* where = "SparseMatrix01::fromCSR(): "; + std::string tag; + inStreamParam >> tag; + NTA_CHECK(tag == "csr01") << where << "Stream is not in csr format" + << " - Should start with 'csr01' tag"; - { // Pre-conditions - NTA_CHECK(inStreamParam.good()) - << where << "Bad stream"; - } + // Read our stream data into a MemParser object for faster parsing. + long totalBytes; + inStreamParam >> totalBytes; + if (totalBytes < 0) + totalBytes = 0; + MemParser inStream(inStreamParam, totalBytes); - std::string tag; - inStreamParam >> tag; - NTA_CHECK(tag == "csr01") - << where - << "Stream is not in csr format" - << " - Should start with 'csr01' tag"; + size_type i, j, k, nrows, ncols, nnz, nnzr; - // Read our stream data into a MemParser object for faster parsing. - long totalBytes; - inStreamParam >> totalBytes; - if (totalBytes < 0) - totalBytes = 0; - MemParser inStream(inStreamParam, totalBytes); + inStream >> nrows >> ncols >> nnz >> nnzr; - size_type i, j, k, nrows, ncols, nnz, nnzr; + { + NTA_CHECK(nrows >= 0) << where << "Invalid number of rows: " << nrows + << " - Should be >= 0"; - inStream >> nrows >> ncols >> nnz >> nnzr; + NTA_CHECK(ncols > 0) << where << "Invalid number of columns: " << ncols + << " - Should be > 0"; - { - NTA_CHECK(nrows >= 0) - << where - << "Invalid number of rows: " << nrows - << " - Should be >= 0"; - - NTA_CHECK(ncols > 0) - << where - << "Invalid number of columns: " << ncols - << " - Should be > 0"; - - NTA_CHECK(nnz >= 0 && nnz <= nrows * ncols) - << where - << "Invalid number of non-zeros: " << nnz + NTA_CHECK(nnz >= 0 && nnz <= nrows * ncols) + << where << "Invalid number of non-zeros: " << nnz << " - Should be >= 0 && nrows * ncols = " << nrows * ncols; - NTA_CHECK(nnzr >= 0 && nnzr <= ncols) - << where - << "Invalid number of non-zeros per row: " << nnzr + NTA_CHECK(nnzr >= 0 && nnzr <= ncols) + << where << "Invalid number of non-zeros per row: " << nnzr << " - Should be >= 0 && ncols = " << ncols; - } - - if (nzr_) - deallocate_(); - - allocate_(nrows, ncols); - nrows_ = 0; - nnzr_ = nnzr; - - std::vector counts(nrows, 1); - if (hasUniqueRows()) { - size_type count = 1; - for (i = 0; i < nrows; ++i) { - inStream >> count; - counts[i] = count; - } - } + } + if (nzr_) + deallocate_(); + + allocate_(nrows, ncols); + nrows_ = 0; + nnzr_ = nnzr; + + std::vector counts(nrows, 1); + if (hasUniqueRows()) { + size_type count = 1; for (i = 0; i < nrows; ++i) { + inStream >> count; + counts[i] = count; + } + } - size_type *indb_it = indb_; - inStream >> nzr_[i]; + for (i = 0; i < nrows; ++i) { - { - NTA_CHECK(nzr_[i] >= 0 && nzr_[i] <= ncols) - << where - << "Invalid number of non-zeros: " << nzr_[i] + size_type *indb_it = indb_; + inStream >> nzr_[i]; + + { + NTA_CHECK(nzr_[i] >= 0 && nzr_[i] <= ncols) + << where << "Invalid number of non-zeros: " << nzr_[i] << " - Should be >= 0 && < ncols = " << ncols; - } + } - for (k = 0; k < nzr_[i]; ++k) { - - inStream >> j; - - { - NTA_CHECK(j >= 0 && j < ncols) - << where - << "Invalid index: " << j - << " - Should be >= 0 and < ncols = " << ncols; - } + for (k = 0; k < nzr_[i]; ++k) { + + inStream >> j; - *indb_it++ = j; + { + NTA_CHECK(j >= 0 && j < ncols) + << where << "Invalid index: " << j + << " - Should be >= 0 and < ncols = " << ncols; } - addRow(size_type(indb_it - indb_), indb_); + *indb_it++ = j; } - typename Counts::iterator it; - for (it = counts_.begin(); it != counts_.end(); ++it) - it->second.second = counts[it->second.first]; + addRow(size_type(indb_it - indb_), indb_); + } + + typename Counts::iterator it; + for (it = counts_.begin(); it != counts_.end(); ++it) + it->second.second = counts[it->second.first]; + + compact(); - compact(); + return inStreamParam; + } - return inStreamParam; + //-------------------------------------------------------------------------------- + /** + * Exports this SparseMatrix01 to a stream in csr format. + * + * @param out [std::ostream] the stream to write this matrix to + * @retval [std::ostream] the stream with the matrix written to it + * + * @b Exceptions: + * @li Bad stream (check) + */ + inline std::ostream &toCSR(std::ostream &out) const { + { // Pre-conditions + NTA_CHECK(out.good()) << "SparseMatrix01:toCSR(): " + << "Bad stream"; } - //-------------------------------------------------------------------------------- - /** - * Exports this SparseMatrix01 to a stream in csr format. - * - * @param out [std::ostream] the stream to write this matrix to - * @retval [std::ostream] the stream with the matrix written to it - * - * @b Exceptions: - * @li Bad stream (check) - */ - inline std::ostream& toCSR(std::ostream& out) const - { - { // Pre-conditions - NTA_CHECK(out.good()) - << "SparseMatrix01:toCSR(): " - << "Bad stream"; - } - - out << "csr01 "; - - OMemStream outStream; - outStream << nRows() << " " - << nCols() << " " - << nNonZeros() << " " - << nnzr_ << " "; - - if (hasUniqueRows()) { - std::vector counts(counts_.size()); - typename Counts::const_iterator it; - for (it = counts_.begin(); it != counts_.end(); ++it) - counts[it->second.first] = it->second.second; - typename std::vector::const_iterator it_v; - for (it_v = counts.begin(); it_v != counts.end(); ++it_v) - outStream << *it_v << " "; - //outStream << it->second.first << " " << it->second.second << " "; - } - - size_type i, nrows = nRows(); - for (i = 0; i < nrows; ++i) { - outStream << nzr_[i] << " "; - size_type* ind_it = ind_[i], *ind_end = ind_it + nzr_[i]; - while (ind_it != ind_end) - outStream << *ind_it++ << " "; - } + out << "csr01 "; + + OMemStream outStream; + outStream << nRows() << " " << nCols() << " " << nNonZeros() << " " << nnzr_ + << " "; - // Write total # of bytes, followed by data. - // This facilitates faster parsing of the - // data directly from a memory buffer in fromCSR() - out << outStream.pcount() << " "; - out.write(outStream.str(), UInt(outStream.pcount())); - return out; + if (hasUniqueRows()) { + std::vector counts(counts_.size()); + typename Counts::const_iterator it; + for (it = counts_.begin(); it != counts_.end(); ++it) + counts[it->second.first] = it->second.second; + typename std::vector::const_iterator it_v; + for (it_v = counts.begin(); it_v != counts.end(); ++it_v) + outStream << *it_v << " "; + // outStream << it->second.first << " " << it->second.second << " "; } - //-------------------------------------------------------------------------------- - /** - * Compatible with SparseMatrix from CSR. - */ - inline std::ostream& toCSRFull(std::ostream& out) const - { - { // Pre-conditions - NTA_CHECK(out.good()) - << "SparseMatrix01:toCSR(): " - << "Bad stream"; - } - - out << "csr "; - - OMemStream buf; - buf << nRows() << " " << nCols() << " " << nNonZeros() << " "; + size_type i, nrows = nRows(); + for (i = 0; i < nrows; ++i) { + outStream << nzr_[i] << " "; + size_type *ind_it = ind_[i], *ind_end = ind_it + nzr_[i]; + while (ind_it != ind_end) + outStream << *ind_it++ << " "; + } - size_type i, nrows = nRows(); - for (i = 0; i < nrows; ++i) { - buf << nzr_[i] << " "; - size_type* ind_it = ind_[i], *ind_end = ind_it + nzr_[i]; - while (ind_it != ind_end) - buf << *ind_it++ << " 1 "; - } + // Write total # of bytes, followed by data. + // This facilitates faster parsing of the + // data directly from a memory buffer in fromCSR() + out << outStream.pcount() << " "; + out.write(outStream.str(), UInt(outStream.pcount())); + return out; + } - // Write total # of bytes, followed by data. - // This facilitates faster parsing of the - // data directly from a memory buffer in fromCSR() - out << buf.pcount() << " "; - out.write(buf.str(), UInt(buf.pcount())); - return out; + //-------------------------------------------------------------------------------- + /** + * Compatible with SparseMatrix from CSR. + */ + inline std::ostream &toCSRFull(std::ostream &out) const { + { // Pre-conditions + NTA_CHECK(out.good()) << "SparseMatrix01:toCSR(): " + << "Bad stream"; } - //-------------------------------------------------------------------------------- - /** - * Compacts the memory for this SparseMatrix01. - * This reduces the number of cache misses, and can - * make a sizable runtime difference (up to 30% on shona, - * depending on the operation). - * All the non-zeros are allocated contiguously. - * Non-mutating algorithms can run on the compact representation. - * - * Mutating, O(nnz) - * - * @b Exceptions: - * @li Not eneough memory (error) - */ - inline void compact() - { - if (nRows() == 0) { - compact_ = true; - return; - } - - size_type i, nnz = nNonZeros(), top, *indp = NULL, nrows = nRows(); - size_type* old_ind = ind_[0]; - - // Allocate contiguous storage for the whole - // sparse matrix - indp = new size_type[nnz]; - - // Copy the old row to the new location ... - // ... and delete the old location if we - // were allocated separately (not compact) - for (top = 0, i = 0; i < nrows; ++i) { - - memcpy(indp + top, ind_[i], nzr_[i] * sizeof(size_type)); - - if (hasUniqueRows()) - swapCountKeys_(ind_[i], indp + top); - - if (!isCompact()) - delete [] ind_[i]; - - ind_[i] = indp + top; - top += nzr_[i]; - } + out << "csr "; + + OMemStream buf; + buf << nRows() << " " << nCols() << " " << nNonZeros() << " "; + + size_type i, nrows = nRows(); + for (i = 0; i < nrows; ++i) { + buf << nzr_[i] << " "; + size_type *ind_it = ind_[i], *ind_end = ind_it + nzr_[i]; + while (ind_it != ind_end) + buf << *ind_it++ << " 1 "; + } - if (compact_) - delete [] old_ind; - + // Write total # of bytes, followed by data. + // This facilitates faster parsing of the + // data directly from a memory buffer in fromCSR() + out << buf.pcount() << " "; + out.write(buf.str(), UInt(buf.pcount())); + return out; + } + + //-------------------------------------------------------------------------------- + /** + * Compacts the memory for this SparseMatrix01. + * This reduces the number of cache misses, and can + * make a sizable runtime difference (up to 30% on shona, + * depending on the operation). + * All the non-zeros are allocated contiguously. + * Non-mutating algorithms can run on the compact representation. + * + * Mutating, O(nnz) + * + * @b Exceptions: + * @li Not eneough memory (error) + */ + inline void compact() { + if (nRows() == 0) { compact_ = true; + return; } - - //-------------------------------------------------------------------------------- - /** - * "De-compacts" this SparseMatrix01, that is, each row - * is allocated separately. All the non-zeros inside a given row - * are still allocated contiguously. This is more efficient - * when changing the number of non-zeros on each row (rather than - * reallocating the whole contiguous array of all the non-zeros - * in the SparseMatrix01. We decompact before mutating the non-zeros, - * and we recompact once the non-zeros don't change anymore. - * - * Mutating, O(nnz) - * - * @b Exceptions: - * @li Not enough memory (error) - */ - inline void decompact() - { - size_type i, nnzr, nrows = nRows(); - size_type* old_ind = ind_[0]; - - for (i = 0; i < nrows; ++i) { - - nnzr = nzr_[i]; - auto ind = new size_type[nnzr]; - memcpy(ind, ind_[i], nnzr*sizeof(size_type)); - - if (nnzr_ > 0) { - Row_Count row_val = counts_[ind_[i]]; - counts_.erase(ind_[i]); - counts_[ind] = row_val; - } - if (!isCompact()) - delete [] ind_[i]; - - ind_[i] = ind; - } - - if (isCompact()) - delete [] old_ind; - - compact_ = false; + size_type i, nnz = nNonZeros(), top, *indp = NULL, nrows = nRows(); + size_type *old_ind = ind_[0]; + + // Allocate contiguous storage for the whole + // sparse matrix + indp = new size_type[nnz]; + + // Copy the old row to the new location ... + // ... and delete the old location if we + // were allocated separately (not compact) + for (top = 0, i = 0; i < nrows; ++i) { + + memcpy(indp + top, ind_[i], nzr_[i] * sizeof(size_type)); + + if (hasUniqueRows()) + swapCountKeys_(ind_[i], indp + top); + + if (!isCompact()) + delete[] ind_[i]; + + ind_[i] = indp + top; + top += nzr_[i]; } - //-------------------------------------------------------------------------------- - /** - * Adds a row to this SparseMatrix01. The iterator - * iterates over the values in increasing - * order of indices. The iterator needs to span - * ncols values. - * - * Mutating, can increase the number of non-zeros, O(nnzr + K) - * - * @param x [InIter] input iterator for row values. - * x needs to be a contiguous array in memory, like std::vector. - * - * @b Exceptions: - * @li Not enough memory (error) - */ - template - inline size_type addRow(InIter x) - { - size_type *indb_it = indb_; - InIter x_it = x, x_end = x + nCols(); + if (compact_) + delete[] old_ind; + + compact_ = true; + } + + //-------------------------------------------------------------------------------- + /** + * "De-compacts" this SparseMatrix01, that is, each row + * is allocated separately. All the non-zeros inside a given row + * are still allocated contiguously. This is more efficient + * when changing the number of non-zeros on each row (rather than + * reallocating the whole contiguous array of all the non-zeros + * in the SparseMatrix01. We decompact before mutating the non-zeros, + * and we recompact once the non-zeros don't change anymore. + * + * Mutating, O(nnz) + * + * @b Exceptions: + * @li Not enough memory (error) + */ + inline void decompact() { + size_type i, nnzr, nrows = nRows(); + size_type *old_ind = ind_[0]; - // TODO: what if numbers are negative? - while (x_it != x_end) { - if (*x_it > 0) - *indb_it++ = size_type(x_it - x); - ++x_it; + for (i = 0; i < nrows; ++i) { + + nnzr = nzr_[i]; + auto ind = new size_type[nnzr]; + memcpy(ind, ind_[i], nnzr * sizeof(size_type)); + + if (nnzr_ > 0) { + Row_Count row_val = counts_[ind_[i]]; + counts_.erase(ind_[i]); + counts_[ind] = row_val; } - return addRow(size_type(indb_it - indb_), indb_); + if (!isCompact()) + delete[] ind_[i]; + + ind_[i] = ind; } - //-------------------------------------------------------------------------------- - /** - * Adds a row to this SparseMatrix01, from an iterator on indices of - * non-zeros. If x is already a row in this SparseMatrix01, - * and this sparse matrix is set up to work with unique rows, it is not - * added, but its count is incremented. Otherwise, x is added to - * this sparse matrix. - * - * x needs to be a contiguous array in memory, like std::vector. - * - * @param nnzr [size_type >= 0] the number of non-zeros in x - * @param x [InIter] itertor to the beginning of a contiguous array - * containing the indices of the non-zeros. - * @retval size_type the index of the row - */ - template - inline size_type addRow(const size_type nnzr, InIter x_begin) - { - size_type j = 0, ncols = nCols(), prev = 0, row_index = 0; - InIter jj = x_begin; + if (isCompact()) + delete[] old_ind; + + compact_ = false; + } + + //-------------------------------------------------------------------------------- + /** + * Adds a row to this SparseMatrix01. The iterator + * iterates over the values in increasing + * order of indices. The iterator needs to span + * ncols values. + * + * Mutating, can increase the number of non-zeros, O(nnzr + K) + * + * @param x [InIter] input iterator for row values. + * x needs to be a contiguous array in memory, like std::vector. + * + * @b Exceptions: + * @li Not enough memory (error) + */ + template inline size_type addRow(InIter x) { + size_type *indb_it = indb_; + InIter x_it = x, x_end = x + nCols(); + + // TODO: what if numbers are negative? + while (x_it != x_end) { + if (*x_it > 0) + *indb_it++ = size_type(x_it - x); + ++x_it; + } - { // Pre-conditions - NTA_ASSERT(nnzr >= 0) + return addRow(size_type(indb_it - indb_), indb_); + } + + //-------------------------------------------------------------------------------- + /** + * Adds a row to this SparseMatrix01, from an iterator on indices of + * non-zeros. If x is already a row in this SparseMatrix01, + * and this sparse matrix is set up to work with unique rows, it is not + * added, but its count is incremented. Otherwise, x is added to + * this sparse matrix. + * + * x needs to be a contiguous array in memory, like std::vector. + * + * @param nnzr [size_type >= 0] the number of non-zeros in x + * @param x [InIter] itertor to the beginning of a contiguous array + * containing the indices of the non-zeros. + * @retval size_type the index of the row + */ + template + inline size_type addRow(const size_type nnzr, InIter x_begin) { + size_type j = 0, ncols = nCols(), prev = 0, row_index = 0; + InIter jj = x_begin; + + { // Pre-conditions + NTA_ASSERT(nnzr >= 0) << "SparseMatrix01::addRow(): " + << "Passed nnzr = " << nnzr << " - Should be >= 0"; + + NTA_ASSERT(nnzr <= nCols()) << "SparseMatrix01::addRow(): " - << "Passed nnzr = " << nnzr - << " - Should be >= 0"; + << "Passed nnzr = " << nnzr << " but there are only " << nCols() + << " columns"; - NTA_ASSERT(nnzr <= nCols()) + /* Too noisy + if (nnzr == 0) + NTA_WARN << "SparseMatrix01::addRow(): " - << "Passed nnzr = " << nnzr - << " but there are only " << nCols() << " columns"; + << "Passed nnzr = 0 - Won't do anything"; + */ - /* Too noisy - if (nnzr == 0) - NTA_WARN + for (j = 0; j < nnzr; ++j, ++jj) { + NTA_ASSERT(0 <= *jj && *jj < ncols) << "SparseMatrix01::addRow(): " - << "Passed nnzr = 0 - Won't do anything"; - */ + << "Invalid column index: " << *jj << " - Should be >= 0 and < " + << ncols; - for (j = 0; j < nnzr; ++j, ++jj) { - NTA_ASSERT(0 <= *jj && *jj < ncols) - << "SparseMatrix01::addRow(): " - << "Invalid column index: " << *jj - << " - Should be >= 0 and < " << ncols; - - if (j > 0) { - NTA_ASSERT(prev < *jj) + if (j > 0) { + NTA_ASSERT(prev < *jj) << "SparseMatrix01::addRow(): " << "Indices need to be in strictly increasing order, " << "found: " << prev << " followed by: " << *jj; - } - prev = *jj; } - } // End pre-conditions + prev = *jj; + } + } // End pre-conditions - if (nnzr_ == 0) { - - row_index = addRow_(nnzr, x_begin); + if (nnzr_ == 0) { - } else { // unique, counted rows - - // TODO: speed up by inserting and looking at returned iterator? - auto it = counts_.find(&*x_begin); - - if (it != counts_.end()) { - ++(it->second.second); - row_index = it->second.first; - } else { - row_index = addRow_(nnzr, x_begin); - counts_[ind_[row_index]] = std::make_pair(row_index, 1); - } - } // end unique, counted rows + row_index = addRow_(nnzr, x_begin); - return row_index; - } + } else { // unique, counted rows - //-------------------------------------------------------------------------------- - /** - * Finds the max value inside each segment defined by the boundaries, - * records a "1" in indb_ at the corresponding position, e.g.: - * - * boundaries = 3 6 9 - * x = .7 .3 .4 .2 .8 .1 .5 .5 .7 - * gives: - * binarized x = 1 0 0 0 1 0 0 0 1 (not actually computed or stored) - * and indb_ = 0 4 8 (the indices of the 1s) - * - * WARNING: this works if this SparseMatrix01 is set up to work - * with unique rows only. - * - * @param boundaries [InIter1] iterator to beginning of a contiguous - * array that contains the boundaries for the filtering operation - * @param x [InIter2] iterator to beginning of input vector - * - * @b Exceptions: - * @li If matrix is not set up to handle unique rows. - * @li If first boundary is zero. - * @li If boundaries are not in strictly increasing order. - * @li If last boundary is not equal to the number of columns. - */ - template - inline void winnerTakesAll(InIter1 boundaries, InIter2 x) - { - { // Pre-conditions - const char* where = "SparseMatrix01::winnerTakesAll(): "; - - // This works only with unique, counted rows (nnzr_ != 0) - NTA_ASSERT(nnzr_ != 0) - << where - << "Attempting to call this method on a SparseMatrix01 " + // TODO: speed up by inserting and looking at returned iterator? + auto it = counts_.find(&*x_begin); + + if (it != counts_.end()) { + ++(it->second.second); + row_index = it->second.first; + } else { + row_index = addRow_(nnzr, x_begin); + counts_[ind_[row_index]] = std::make_pair(row_index, 1); + } + } // end unique, counted rows + + return row_index; + } + + //-------------------------------------------------------------------------------- + /** + * Finds the max value inside each segment defined by the boundaries, + * records a "1" in indb_ at the corresponding position, e.g.: + * + * boundaries = 3 6 9 + * x = .7 .3 .4 .2 .8 .1 .5 .5 .7 + * gives: + * binarized x = 1 0 0 0 1 0 0 0 1 (not actually computed or stored) + * and indb_ = 0 4 8 (the indices of the 1s) + * + * WARNING: this works if this SparseMatrix01 is set up to work + * with unique rows only. + * + * @param boundaries [InIter1] iterator to beginning of a contiguous + * array that contains the boundaries for the filtering operation + * @param x [InIter2] iterator to beginning of input vector + * + * @b Exceptions: + * @li If matrix is not set up to handle unique rows. + * @li If first boundary is zero. + * @li If boundaries are not in strictly increasing order. + * @li If last boundary is not equal to the number of columns. + */ + template + inline void winnerTakesAll(InIter1 boundaries, InIter2 x) { + { // Pre-conditions + const char *where = "SparseMatrix01::winnerTakesAll(): "; + + // This works only with unique, counted rows (nnzr_ != 0) + NTA_ASSERT(nnzr_ != 0) + << where << "Attempting to call this method on a SparseMatrix01 " << "that was not set up to work with unique rows"; - - // First boundary cannot be zero - NTA_ASSERT(*boundaries > 0) - << where - << "Zero is not allowed for first boundary"; - - // Boundaries need to be passed in strictly increasing - // order - for (size_type i = 1; i < nnzr_; ++i) - NTA_ASSERT(boundaries[i-1] < boundaries[i]) - << where - << "Passed invalid boundaries: " << boundaries[i-1] - << " and " << boundaries[i] - << " at " << i-1 << " and " << i + + // First boundary cannot be zero + NTA_ASSERT(*boundaries > 0) + << where << "Zero is not allowed for first boundary"; + + // Boundaries need to be passed in strictly increasing + // order + for (size_type i = 1; i < nnzr_; ++i) + NTA_ASSERT(boundaries[i - 1] < boundaries[i]) + << where << "Passed invalid boundaries: " << boundaries[i - 1] + << " and " << boundaries[i] << " at " << i - 1 << " and " << i << " out of " << nCols() << " - Boundaries need to be passed in strictly increasing order"; - // The last boundary is the number of columns - NTA_ASSERT(nCols() == boundaries[nnzr_-1]) - << where - << "Wrong boundaries passed in, last boundary " + // The last boundary is the number of columns + NTA_ASSERT(nCols() == boundaries[nnzr_ - 1]) + << where << "Wrong boundaries passed in, last boundary " << "should be number of columns (" << nCols() << ") " - << "but found: " << boundaries[nnzr_-1]; - } // End pre-conditions - - /* - size_type i, k, row_index = 0; - value_type val, max_v = 0; //- std::numeric_limits::max(); - - for (i = 0, k = 0; i < nnzr_; ++i) { - indb_[i] = i == 0 ? 0 : boundaries[i-1]; - for (max_v = 0; k < boundaries[i]; ++k, ++x) { - val = *x; - if (val > max_v) { - indb_[i] = k; - max_v = val; - } + << "but found: " << boundaries[nnzr_ - 1]; + } // End pre-conditions + + /* + size_type i, k, row_index = 0; + value_type val, max_v = 0; //- std::numeric_limits::max(); + + for (i = 0, k = 0; i < nnzr_; ++i) { + indb_[i] = i == 0 ? 0 : boundaries[i-1]; + for (max_v = 0; k < boundaries[i]; ++k, ++x) { + val = *x; + if (val > max_v) { + indb_[i] = k; + max_v = val; } } - */ + } + */ - value_type val, max_v = 0;//- std::numeric_limits::max(); - InIter1 b_it = boundaries, b_end = boundaries + nnzr_; - InIter2 it_x = x, x_end = x; - size_type* indb_it = indb_; - - for (; b_it != b_end; ++b_it, ++indb_it) { - max_v = 0; - x_end = x + *b_it; - *indb_it = size_type(it_x - x); - for (; it_x != x_end; ++it_x) { - val = *it_x; - if (val > max_v) { - *indb_it = size_type(it_x - x); - max_v = val; - } + value_type val, max_v = 0; //- std::numeric_limits::max(); + InIter1 b_it = boundaries, b_end = boundaries + nnzr_; + InIter2 it_x = x, x_end = x; + size_type *indb_it = indb_; + + for (; b_it != b_end; ++b_it, ++indb_it) { + max_v = 0; + x_end = x + *b_it; + *indb_it = size_type(it_x - x); + for (; it_x != x_end; ++it_x) { + val = *it_x; + if (val > max_v) { + *indb_it = size_type(it_x - x); + max_v = val; } } } + } - //-------------------------------------------------------------------------------- - /** - * Filters a row using the given boundaries and adds a new row - * based on that filtered vector if it was not seen before. - * The filtering operation consists in replacing the max value of x - * by 1 and the other values by 0 in the each segment defined by - * the boundaries. - * That is, if x = [.5 .7 .3 .2 .6 .3 0 .9] and boundaries = [4 8], - * the result of filtering is: [0 1 0 0 0 0 0 1]: - * .7 is the max in the segment [0..4) , - * and .9 is the max in in the segment [4..8) - * - * WARNING: this works if this SparseMatrix01 is set up to work - * with unique rows only. - * - * @param boundaries [InIter1] iterator to beginning of a contiguous - * array that contains the boundaries for the filtering operation - * @param x [InIter2] iterator to beginning of input vector - * - * @b Exceptions: - * @li If matrix is not set up to handle unique rows. - * @li If first boundary is zero. - * @li If boundaries are not in strictly increasing order. - * @li If last boundary is not equal to the number of columns. - */ - template - inline size_type addUniqueFilteredRow(InIter1 boundaries, InIter2 x) - { - size_type row_index = 0; + //-------------------------------------------------------------------------------- + /** + * Filters a row using the given boundaries and adds a new row + * based on that filtered vector if it was not seen before. + * The filtering operation consists in replacing the max value of x + * by 1 and the other values by 0 in the each segment defined by + * the boundaries. + * That is, if x = [.5 .7 .3 .2 .6 .3 0 .9] and boundaries = [4 8], + * the result of filtering is: [0 1 0 0 0 0 0 1]: + * .7 is the max in the segment [0..4) , + * and .9 is the max in in the segment [4..8) + * + * WARNING: this works if this SparseMatrix01 is set up to work + * with unique rows only. + * + * @param boundaries [InIter1] iterator to beginning of a contiguous + * array that contains the boundaries for the filtering operation + * @param x [InIter2] iterator to beginning of input vector + * + * @b Exceptions: + * @li If matrix is not set up to handle unique rows. + * @li If first boundary is zero. + * @li If boundaries are not in strictly increasing order. + * @li If last boundary is not equal to the number of columns. + */ + template + inline size_type addUniqueFilteredRow(InIter1 boundaries, InIter2 x) { + size_type row_index = 0; + + winnerTakesAll(boundaries, x); + + // TODO: make sure we have really found nnzr non-zeros: + // a vector of all zeros will yield for indices of the non-zeros + // the indices of the first elements of each child, + // which would be indistinguishable from the vectors were the + // first positions for each child are really the maxima! + // And we don't want to remember indices of actual zeros. + + // TODO: speed up by inserting and looking at returned iterator? + typename Counts::iterator it = counts_.find(indb_); + + if (it != counts_.end()) { + ++(it->second.second); + row_index = it->second.first; + } else { + row_index = addRow_(nnzr_, indb_); + counts_[ind_[row_index]] = std::make_pair(row_index, 1); + } - winnerTakesAll(boundaries, x); + return row_index; + } - // TODO: make sure we have really found nnzr non-zeros: - // a vector of all zeros will yield for indices of the non-zeros - // the indices of the first elements of each child, - // which would be indistinguishable from the vectors were the - // first positions for each child are really the maxima! - // And we don't want to remember indices of actual zeros. - - // TODO: speed up by inserting and looking at returned iterator? - typename Counts::iterator it = counts_.find(indb_); - - if (it != counts_.end()) { - ++ (it->second.second); - row_index = it->second.first; - } else { - row_index = addRow_(nnzr_, indb_); - counts_[ind_[row_index]] = std::make_pair(row_index, 1); + //-------------------------------------------------------------------------------- + /** + * Finds the closest coincidence to x according to Hamming distance. + * If the closest distance is less than maxDistance, the count of the + * corresponding coincidence is incremented by one, otherwise, + * a new coincidence is inserted in the matrix, with a count of 1. + * + * WARNING: this works if this SparseMatrix01 is set up to work + * with unique rows only. + * + * @param boundaries [InIter1] iterator to beginning of a contiguous + * array that contains the boundaries for the filtering operation + * @param x [InIter2] iterator to beginning of input vector + * @param maxDistance max distance to closest coincidence that will trigger + * the insertion of a new coincidence + * + * @b Exceptions: + * @li If matrix is not set up to handle unique rows. + * @li If first boundary is zero. + * @li If boundaries are not in strictly increasing order. + * @li If last boundary is not equal to the number of columns. + */ + template + inline size_type addMinHamming(InIter1 boundaries, InIter2 x, + const value_type &maxDistance) { + size_type row_index = 0; + size_type hamming, min_hamming, *ind, *ind_end, *indb; + typename Counts::iterator it, it_end, arg_it; + + // Binarize x into indb_ + winnerTakesAll(boundaries, x); + + // Find Hamming-closest row + min_hamming = nnzr_; // std::numeric_limits::max(); + arg_it = it = counts_.begin(); + it_end = counts_.end(); + + while (it != it_end) { + hamming = 0; + ind = it->first; + ind_end = ind + nnzr_; + indb = indb_; + while (ind != ind_end && hamming < min_hamming) { + hamming += *ind != *indb; // this works because nnzr_ = constant + ++ind; + ++indb; + } + if (hamming < min_hamming) { + arg_it = it; + min_hamming = hamming; } - - return row_index; + ++it; } - //-------------------------------------------------------------------------------- - /** - * Finds the closest coincidence to x according to Hamming distance. - * If the closest distance is less than maxDistance, the count of the - * corresponding coincidence is incremented by one, otherwise, - * a new coincidence is inserted in the matrix, with a count of 1. - * - * WARNING: this works if this SparseMatrix01 is set up to work - * with unique rows only. - * - * @param boundaries [InIter1] iterator to beginning of a contiguous - * array that contains the boundaries for the filtering operation - * @param x [InIter2] iterator to beginning of input vector - * @param maxDistance max distance to closest coincidence that will trigger - * the insertion of a new coincidence - * - * @b Exceptions: - * @li If matrix is not set up to handle unique rows. - * @li If first boundary is zero. - * @li If boundaries are not in strictly increasing order. - * @li If last boundary is not equal to the number of columns. - */ - template - inline size_type addMinHamming(InIter1 boundaries, InIter2 x, - const value_type& maxDistance) - { - size_type row_index = 0; - size_type hamming, min_hamming, *ind, *ind_end, *indb; - typename Counts::iterator it, it_end, arg_it; - - // Binarize x into indb_ - winnerTakesAll(boundaries, x); - - // Find Hamming-closest row - min_hamming = nnzr_; //std::numeric_limits::max(); - arg_it = it = counts_.begin(); it_end = counts_.end(); - - while (it != it_end) { - hamming = 0; - ind = it->first; ind_end = ind + nnzr_; indb = indb_; - while (ind != ind_end && hamming < min_hamming) { - hamming += *ind != *indb; // this works because nnzr_ = constant - ++ind; ++indb; - } - if (hamming < min_hamming) { - arg_it = it; - min_hamming = hamming; - } - ++it; - } + // So far, we have counted the mismatching segments + // the Hamming distance is twice that number + if (2 * min_hamming <= maxDistance) { + ++(arg_it->second.second); + row_index = arg_it->second.first; + } else { + row_index = addRow_(nnzr_, indb_); + counts_[ind_[row_index]] = std::make_pair(row_index, 1); + } - // So far, we have counted the mismatching segments - // the Hamming distance is twice that number - if (2*min_hamming <= maxDistance) { - ++ (arg_it->second.second); - row_index = arg_it->second.first; - } else { - row_index = addRow_(nnzr_, indb_); - counts_[ind_[row_index]] = std::make_pair(row_index, 1); - } + return row_index; + } - return row_index; - } + //-------------------------------------------------------------------------------- + /* + template + inline size_type addWithThreshold(InIter x, const value_type& threshold) + { + size_type row_index = 0, *indb = indb_; - //-------------------------------------------------------------------------------- - /* - template - inline size_type addWithThreshold(InIter x, const value_type& threshold) - { - size_type row_index = 0, *indb = indb_; - - for (InIter x_end = x + nCols(); x != x_end; ++x) { - value_type val = *x; - if (val > threshold) { - *indb = val; - ++indb; - } + for (InIter x_end = x + nCols(); x != x_end; ++x) { + value_type val = *x; + if (val > threshold) { + *indb = val; + ++indb; } - - typename Counts::iterator it = counts_.find(indb_); - - if (it != counts_.end()) { - ++ (it->second.second); - row_index = it->second.first; - } else { - row_index = addRow_(nnzr_, indb_); - counts_[ind_[row_index]] = std::make_pair(row_index, 1); - } - - return row_index; } - */ - //-------------------------------------------------------------------------------- - /** - * Deletes specified rows. - * The indices of the rows are passed in a range [del..del_end). - * The range can be contiguous (std::vector) or not (std::list, std::map). - * The matrix can end up empty if all the rows are removed. - * If the list of rows to remove is empty, the matrix is unchanged. - * - * WARNING: the row indices need to be passed with duplicates, - * in strictly increasing order. - * - * @param del [InIter] iterator to the beginning of the range - * that contains the indices of the rows to be deleted - * @param del_end [InIter] iterator to one past the end of the - * range that contains the indices of the rows to be deleted - * - * @b Exceptions: - * @li If a row index < 0 || >= nRows(). - * @li If row indices are not passed in strictly increasing order. - * @li If del_end - del < 0 || del_end - del > nRows(). - */ - template - inline void deleteRows(InIter del_it, InIter del_end) - { - ptrdiff_t n_del = del_end - del_it; - - // Here because pre-conditions will fail if nRows == 0 - if (n_del <= 0 || nRows() == 0) - return; - - { // Pre-conditions - if (n_del > 0) { - - NTA_ASSERT(n_del <= (ptrdiff_t)nRows()) + typename Counts::iterator it = counts_.find(indb_); + + if (it != counts_.end()) { + ++ (it->second.second); + row_index = it->second.first; + } else { + row_index = addRow_(nnzr_, indb_); + counts_[ind_[row_index]] = std::make_pair(row_index, 1); + } + + return row_index; + } + */ + + //-------------------------------------------------------------------------------- + /** + * Deletes specified rows. + * The indices of the rows are passed in a range [del..del_end). + * The range can be contiguous (std::vector) or not (std::list, std::map). + * The matrix can end up empty if all the rows are removed. + * If the list of rows to remove is empty, the matrix is unchanged. + * + * WARNING: the row indices need to be passed with duplicates, + * in strictly increasing order. + * + * @param del [InIter] iterator to the beginning of the range + * that contains the indices of the rows to be deleted + * @param del_end [InIter] iterator to one past the end of the + * range that contains the indices of the rows to be deleted + * + * @b Exceptions: + * @li If a row index < 0 || >= nRows(). + * @li If row indices are not passed in strictly increasing order. + * @li If del_end - del < 0 || del_end - del > nRows(). + */ + template + inline void deleteRows(InIter del_it, InIter del_end) { + ptrdiff_t n_del = del_end - del_it; + + // Here because pre-conditions will fail if nRows == 0 + if (n_del <= 0 || nRows() == 0) + return; + + { // Pre-conditions + if (n_del > 0) { + + NTA_ASSERT(n_del <= (ptrdiff_t)nRows()) << "SparseMatrix01::deleteRows(): " << " Passed more indices of rows to delete" << " than there are rows"; - InIter d = del_it, d_next = del_it + 1; - while (d < del_end - 1) { - NTA_ASSERT(0 <= *d && *d < nRows()) + InIter d = del_it, d_next = del_it + 1; + while (d < del_end - 1) { + NTA_ASSERT(0 <= *d && *d < nRows()) << "SparseMatrix01::deleteRows(): " << "Invalid row index: " << *d << " - Row indices should be between 0 and " << nRows(); - NTA_ASSERT(*d < *d_next) + NTA_ASSERT(*d < *d_next) << "SparseMatrix01::deleteRows(): " << "Invalid row indices " << *d << " and " << *d_next << " - Row indices need to be passed " << "in strictly increasing order"; - ++d; ++d_next; - } - - NTA_ASSERT(0 <= *d && *d < nRows()) + ++d; + ++d_next; + } + + NTA_ASSERT(0 <= *d && *d < nRows()) << "SparseMatrix01::deleteRows(): " << "Invalid row index: " << *d << " - Row indices should be between 0 and " << nRows(); - } else if (n_del == 0) { - - /* Too noisy - NTA_WARN - << "SparseMatrix01::deleteRows(): " - << "Nothing to delete"; - */ + } else if (n_del == 0) { - } else if (n_del < 0) { + /* Too noisy + NTA_WARN + << "SparseMatrix01::deleteRows(): " + << "Nothing to delete"; + */ - /* Too noisy - NTA_WARN - << "SparseMatrix01::deleteRows(): " - << "Invalid pointers - Won't do anything"; - */ - } + } else if (n_del < 0) { + + /* Too noisy + NTA_WARN + << "SparseMatrix01::deleteRows(): " + << "Invalid pointers - Won't do anything"; + */ } + } - if (isCompact()) - decompact(); - - size_type *nzr_old = nzr_, *nzr_it = nzr_; - size_type **ind_old = ind_, **ind_it = ind_; - - for (size_type i_old = 0; i_old < nrows_; ++i_old) { - if (del_it != del_end && i_old == *del_it) { - if (hasUniqueRows()) { - counts_.erase(*ind_old); - } - // DON'T delete here: it would require updating nrows_max_ - // and we don't have the time anyway. - //delete [] *ind_old++; - ++ind_old; - ++nzr_old; - ++del_it; - } else { - *nzr_it++ = *nzr_old++; - *ind_it++ = *ind_old++; + if (isCompact()) + decompact(); + + size_type *nzr_old = nzr_, *nzr_it = nzr_; + size_type **ind_old = ind_, **ind_it = ind_; + + for (size_type i_old = 0; i_old < nrows_; ++i_old) { + if (del_it != del_end && i_old == *del_it) { + if (hasUniqueRows()) { + counts_.erase(*ind_old); } + // DON'T delete here: it would require updating nrows_max_ + // and we don't have the time anyway. + // delete [] *ind_old++; + ++ind_old; + ++nzr_old; + ++del_it; + } else { + *nzr_it++ = *nzr_old++; + *ind_it++ = *ind_old++; } - - nrows_ = size_type(nzr_it - nzr_); } - //-------------------------------------------------------------------------------- - /** - * Deletes rows whose count is < threshold. - * - * @b Exceptions: - */ - template - inline void deleteRows(const size_type& threshold, OutIter del_it) + nrows_ = size_type(nzr_it - nzr_); + } + + //-------------------------------------------------------------------------------- + /** + * Deletes rows whose count is < threshold. + * + * @b Exceptions: + */ + template + inline void deleteRows(const size_type &threshold, OutIter del_it) { { - { - NTA_ASSERT(hasUniqueRows()) + NTA_ASSERT(hasUniqueRows()) << "SparseMatrix01::deleteRows(threshold): " << "Sparse matrix needs to be in unique rows mode"; - } + } - size_type offset = 0; - std::vector to_del, row_counts; - row_counts = getRowCountsSorted(); + size_type offset = 0; + std::vector to_del, row_counts; + row_counts = getRowCountsSorted(); - for (size_type i = 0; i < row_counts.size(); ++i) { - if (row_counts[i] < threshold) { - to_del.push_back(i); - *del_it++ = std::make_pair(i, row_counts[i]); - ++offset; - } else { - counts_[ind_[i]].first -= offset; - } + for (size_type i = 0; i < row_counts.size(); ++i) { + if (row_counts[i] < threshold) { + to_del.push_back(i); + *del_it++ = std::make_pair(i, row_counts[i]); + ++offset; + } else { + counts_[ind_[i]].first -= offset; } - - deleteRows(to_del.begin(), to_del.end()); } - //-------------------------------------------------------------------------------- - /** - * Deletes specified columns. - * The indices of the columns are passed in a range [del..del_end). - * The range can be contiguous (std::vector) or not (std::list, std::map). - * The matrix can end up empty if all the columns are removed. - * If the list of columns to remove is empty, the matrix is unchanged. - * - * WARNING: the columns indices need to be passed with duplicates, - * in strictly increasing order. - * - * @param del [InIter] iterator to the beginning of the range - * that contains the indices of the columns to be deleted - * @param del_end [InIter] iterator to one past the end of the - * range that contains the indices of the columns to be deleted - * - * @b Exceptions: - * @li If a column index < 0 || >= nCols(). - * @li If column indices are not passed in strictly increasing order. - * @li If del_end - del < 0 || del_end - del > nCols(). - */ - template - inline void deleteColumns(InIter del_it, InIter del_end) - { - ptrdiff_t n_del = del_end - del_it; - - if (n_del <= 0 || nCols() == 0) - return; - - { // Pre-conditions - if (n_del > 0) { - - NTA_ASSERT(n_del <= (ptrdiff_t)nCols()) + deleteRows(to_del.begin(), to_del.end()); + } + + //-------------------------------------------------------------------------------- + /** + * Deletes specified columns. + * The indices of the columns are passed in a range [del..del_end). + * The range can be contiguous (std::vector) or not (std::list, std::map). + * The matrix can end up empty if all the columns are removed. + * If the list of columns to remove is empty, the matrix is unchanged. + * + * WARNING: the columns indices need to be passed with duplicates, + * in strictly increasing order. + * + * @param del [InIter] iterator to the beginning of the range + * that contains the indices of the columns to be deleted + * @param del_end [InIter] iterator to one past the end of the + * range that contains the indices of the columns to be deleted + * + * @b Exceptions: + * @li If a column index < 0 || >= nCols(). + * @li If column indices are not passed in strictly increasing order. + * @li If del_end - del < 0 || del_end - del > nCols(). + */ + template + inline void deleteColumns(InIter del_it, InIter del_end) { + ptrdiff_t n_del = del_end - del_it; + + if (n_del <= 0 || nCols() == 0) + return; + + { // Pre-conditions + if (n_del > 0) { + + NTA_ASSERT(n_del <= (ptrdiff_t)nCols()) << "SparseMatrix01::deleteColumns(): " << " Passed more indices of rows to delete" << " than there are columns"; - InIter d = del_it, d_next = del_it + 1; - while (d < del_end - 1) { - NTA_ASSERT(0 <= *d && *d < nCols()) + InIter d = del_it, d_next = del_it + 1; + while (d < del_end - 1) { + NTA_ASSERT(0 <= *d && *d < nCols()) << "SparseMatrix01::deleteColumns(): " << "Invalid column index: " << *d << " - Column indices should be between 0 and " << nCols(); - NTA_ASSERT(*d < *d_next) + NTA_ASSERT(*d < *d_next) << "SparseMatrix01::deleteColumns(): " << "Invalid column indices " << *d << " and " << *d_next << " - Column indices need to be passed " << "in strictly increasing order"; - ++d; ++d_next; - } - - NTA_ASSERT(0 <= *d && *d < nCols()) + ++d; + ++d_next; + } + + NTA_ASSERT(0 <= *d && *d < nCols()) << "SparseMatrix01::deleteColumns(): " << "Invalid column index: " << *d << " - Column indices should be between 0 and " << nCols(); - } else if (n_del == 0) { - - /* Too noisy - NTA_WARN - << "SparseMatrix01::deleteColumns(): " - << "Nothing to delete"; - */ + } else if (n_del == 0) { + + /* Too noisy + NTA_WARN + << "SparseMatrix01::deleteColumns(): " + << "Nothing to delete"; + */ - } else if (n_del < 0) { + } else if (n_del < 0) { - /* Too noisy - NTA_WARN - << "SparseMatrix01::deleteColumns(): " - << "Invalid pointers - Won't do anything"; - */ - } + /* Too noisy + NTA_WARN + << "SparseMatrix01::deleteColumns(): " + << "Invalid pointers - Won't do anything"; + */ } + } - InIter d; - size_type i, j, *ind, *ind_old, *ind_end; - - for (i = 0; i < nRows(); ++i) { - - j = 0; - d = del_it; - ind = ind_[i]; ind_old = ind; ind_end = ind + nzr_[i]; - - while (ind_old != ind_end && d != del_end) { - if (*d == *ind_old) { - ++d; ++j; - ++ind_old; - } else if (*d < *ind_old) { - ++d; ++j; - } else { - *ind++ = *ind_old++ - j; - } - } - - while (ind_old != ind_end) + InIter d; + size_type i, j, *ind, *ind_old, *ind_end; + + for (i = 0; i < nRows(); ++i) { + + j = 0; + d = del_it; + ind = ind_[i]; + ind_old = ind; + ind_end = ind + nzr_[i]; + + while (ind_old != ind_end && d != del_end) { + if (*d == *ind_old) { + ++d; + ++j; + ++ind_old; + } else if (*d < *ind_old) { + ++d; + ++j; + } else { *ind++ = *ind_old++ - j; - - nzr_[i] = size_type(ind - ind_[i]); + } } - ncols_ -= std::max(size_type(0), size_type(n_del)); + while (ind_old != ind_end) + *ind++ = *ind_old++ - j; + + nzr_[i] = size_type(ind - ind_[i]); } - //-------------------------------------------------------------------------------- - /** - * The data structure returned by getRowCounts(), that contains row indices and - * counts. The first integer is the index of the row, the second is the number - * of times that row was inserted in the matrix. In "unique rows" mode, rows - * are inserted once if they are not yet in the matrix, and then a running count - * is incremented each the same row is presented for insertion afterwards. The count - * starts with a value of 1 the first time the row is encountered. The index is the - * number of rows of the matrix at the time the row is inserted. - */ - typedef std::vector RowCounts; - - //-------------------------------------------------------------------------------- - /** - * This function can be called only if the matrix has been set up to work - * with unique rows (nnzr_ > 0, the matrix was constructed with a fixed - * number of non-zeros per row). - * - * WARNING: the row counts are not in any particular order! - * - * Non-mutating. - * - * @b Exceptions: - * @li If calling but matrix was not initialized properly by declaring - * the number of non-zeros per row (error). - */ - inline RowCounts getRowCounts() const + ncols_ -= std::max(size_type(0), size_type(n_del)); + } + + //-------------------------------------------------------------------------------- + /** + * The data structure returned by getRowCounts(), that contains row indices + * and counts. The first integer is the index of the row, the second is the + * number of times that row was inserted in the matrix. In "unique rows" mode, + * rows are inserted once if they are not yet in the matrix, and then a + * running count is incremented each the same row is presented for insertion + * afterwards. The count starts with a value of 1 the first time the row is + * encountered. The index is the number of rows of the matrix at the time the + * row is inserted. + */ + typedef std::vector RowCounts; + + //-------------------------------------------------------------------------------- + /** + * This function can be called only if the matrix has been set up to work + * with unique rows (nnzr_ > 0, the matrix was constructed with a fixed + * number of non-zeros per row). + * + * WARNING: the row counts are not in any particular order! + * + * Non-mutating. + * + * @b Exceptions: + * @li If calling but matrix was not initialized properly by declaring + * the number of non-zeros per row (error). + */ + inline RowCounts getRowCounts() const { { - { - NTA_ASSERT(nnzr_ > 0) + NTA_ASSERT(nnzr_ > 0) << "SparseMatrix01::getRowCounts(): " << "Called for unique rows, but matrix is not set up to work" << " with unique rows"; - } + } - RowCounts rc; - typename Counts::const_iterator it; + RowCounts rc; + typename Counts::const_iterator it; - for (it = counts_.begin(); it != counts_.end(); ++it) - rc.push_back(it->second); + for (it = counts_.begin(); it != counts_.end(); ++it) + rc.push_back(it->second); - return rc; - } + return rc; + } - //-------------------------------------------------------------------------------- - /** - * This function can be called only if the matrix has been set up to work - * with unique rows (nnzr_ > 0, the matrix was constructed with a fixed - * number of non-zeros per row). - * - * Non-mutating. - * - * @b Exceptions: - * @li If calling but matrix was not initialized properly by declaring - * the number of non-zeros per row (error). - */ - inline std::vector getRowCountsSorted() const + //-------------------------------------------------------------------------------- + /** + * This function can be called only if the matrix has been set up to work + * with unique rows (nnzr_ > 0, the matrix was constructed with a fixed + * number of non-zeros per row). + * + * Non-mutating. + * + * @b Exceptions: + * @li If calling but matrix was not initialized properly by declaring + * the number of non-zeros per row (error). + */ + inline std::vector getRowCountsSorted() const { { - { - NTA_ASSERT(nnzr_ > 0) + NTA_ASSERT(nnzr_ > 0) << "SparseMatrix01::getRowCountsSorted(): " << "Called for unique rows, but matrix is not set up to work" << " with unique rows"; - } - - std::vector rc(counts_.size(), 0); - typename Counts::const_iterator it; + } - for (it = counts_.begin(); it != counts_.end(); ++it) - rc[it->second.first] = it->second.second; + std::vector rc(counts_.size(), 0); + typename Counts::const_iterator it; - return rc; - } + for (it = counts_.begin(); it != counts_.end(); ++it) + rc[it->second.first] = it->second.second; - //-------------------------------------------------------------------------------- - /** - * Stores r-th row of this sparse matrix in provided iterators. - * The iterators need to point to enough storage. - * Non-mutating, O(nnzr) - * - * @param r [0 <= size_type < nrows] row index - * @param indIt [OutIter1] output iterator for indices - * @param nzIt [OutIter2] output iterator for non-zeros - * - * @b Exceptions: - * @li r < 0 || r >= nrows (check) - */ - template - inline void getRowSparse(const size_type& r, OutIter1 indIt) const - { - { // Pre-conditions - NTA_ASSERT(r >= 0 && r < nRows()) - << "SparseMatrix01::getRowSparse(): " - << "Invalid row index: " << r - << " - Should be >= 0 and < " << nRows(); - } - - size_type *ind = ind_[r], *ind_end = ind + nzr_[r]; + return rc; + } - // Won't do anything to the iterators - // if row r has only zeros - while (ind != ind_end) - *indIt++ = *ind++; + //-------------------------------------------------------------------------------- + /** + * Stores r-th row of this sparse matrix in provided iterators. + * The iterators need to point to enough storage. + * Non-mutating, O(nnzr) + * + * @param r [0 <= size_type < nrows] row index + * @param indIt [OutIter1] output iterator for indices + * @param nzIt [OutIter2] output iterator for non-zeros + * + * @b Exceptions: + * @li r < 0 || r >= nrows (check) + */ + template + inline void getRowSparse(const size_type &r, OutIter1 indIt) const { + { // Pre-conditions + NTA_ASSERT(r >= 0 && r < nRows()) + << "SparseMatrix01::getRowSparse(): " + << "Invalid row index: " << r << " - Should be >= 0 and < " + << nRows(); } - //-------------------------------------------------------------------------------- - /** - */ - template - inline void getRow(const size_type& r, OutIter x_begin) const - { - { // Pre-conditions - NTA_ASSERT(r >= 0 && r < nRows()) + size_type *ind = ind_[r], *ind_end = ind + nzr_[r]; + + // Won't do anything to the iterators + // if row r has only zeros + while (ind != ind_end) + *indIt++ = *ind++; + } + + //-------------------------------------------------------------------------------- + /** + */ + template + inline void getRow(const size_type &r, OutIter x_begin) const { + { // Pre-conditions + NTA_ASSERT(r >= 0 && r < nRows()) << "SparseMatrix01::getRow(): " - << "Invalid row index: " << r - << " - Should be >= 0 and < " << nRows(); - } + << "Invalid row index: " << r << " - Should be >= 0 and < " + << nRows(); + } - OutIter it = x_begin, it_end = x_begin + nCols(); - size_type *ind = ind_[r], *ind_end = ind + nzr_[r]; + OutIter it = x_begin, it_end = x_begin + nCols(); + size_type *ind = ind_[r], *ind_end = ind + nzr_[r]; - while (it != it_end) - *it++ = 0; + while (it != it_end) + *it++ = 0; - while (ind != ind_end) - *(x_begin + *ind++) = 1; - } + while (ind != ind_end) + *(x_begin + *ind++) = 1; + } - //-------------------------------------------------------------------------------- - /** - * Computes the square of the distance between vector x - * and each row of this SparseMatrix01. Puts the result - * in vector y. - * - * Non-mutating, O(nnz) - * - * @param x [InIter] vector to compute the distance from - * @param y [OutIter] vector of the squared distances to x - * - * @b Exceptions: - * @li None - */ - template - inline void vecDistSquared(InIter x, OutIter y) const - { - size_type i, j, k, nnzr, *ind, nrows = nRows(), ncols = nCols(); - value_type val1, val; - value_type *sq_x = nzb_, Ssq_x = 0; + //-------------------------------------------------------------------------------- + /** + * Computes the square of the distance between vector x + * and each row of this SparseMatrix01. Puts the result + * in vector y. + * + * Non-mutating, O(nnz) + * + * @param x [InIter] vector to compute the distance from + * @param y [OutIter] vector of the squared distances to x + * + * @b Exceptions: + * @li None + */ + template + inline void vecDistSquared(InIter x, OutIter y) const { + size_type i, j, k, nnzr, *ind, nrows = nRows(), ncols = nCols(); + value_type val1, val; + value_type *sq_x = nzb_, Ssq_x = 0; - for (j = 0; j < ncols; ++j) - Ssq_x += sq_x[j] = x[j] * x[j]; + for (j = 0; j < ncols; ++j) + Ssq_x += sq_x[j] = x[j] * x[j]; - for (i = 0; i < nrows; ++i) { + for (i = 0; i < nrows; ++i) { - val = Ssq_x; - nnzr = nzr_[i]; - ind = ind_[i]; - - for (k = 0; k < nnzr; ++k) { - j = ind[k]; - val1 = value_type(1.0 - x[j]); - val += val1 * val1 - sq_x[j]; - } + val = Ssq_x; + nnzr = nzr_[i]; + ind = ind_[i]; - // Accuracy issues because of the subtractions, - // could return negative values - if (val <= nupic::Epsilon) - val = 0; - - { // Post-condition - NTA_ASSERT(val >= 0) - << "SparseMatrix01::vecDistSquare(): " - << "Negative value in post-condition"; - } + for (k = 0; k < nnzr; ++k) { + j = ind[k]; + val1 = value_type(1.0 - x[j]); + val += val1 * val1 - sq_x[j]; + } + + // Accuracy issues because of the subtractions, + // could return negative values + if (val <= nupic::Epsilon) + val = 0; - *y++ = val; + { // Post-condition + NTA_ASSERT(val >= 0) << "SparseMatrix01::vecDistSquare(): " + << "Negative value in post-condition"; } + + *y++ = val; } + } - //-------------------------------------------------------------------------------- - /** - * Computes the Euclidean distance between vector x - * and each row of this SparseMatrix01. Puts the result - * in vector y. - * - * Non-mutating, O(nnz) - * - * @param x [InIter] vector to compute the distance from - * @param y [OutIter] vector of the squared distances to x - * - * @b Exceptions: - * @li None - */ - template - inline void vecDist(InIter x, OutIter y) const - { - nupic::Sqrt s; + //-------------------------------------------------------------------------------- + /** + * Computes the Euclidean distance between vector x + * and each row of this SparseMatrix01. Puts the result + * in vector y. + * + * Non-mutating, O(nnz) + * + * @param x [InIter] vector to compute the distance from + * @param y [OutIter] vector of the squared distances to x + * + * @b Exceptions: + * @li None + */ + template + inline void vecDist(InIter x, OutIter y) const { + nupic::Sqrt s; - vecDistSquared(x, y); - - ITERATE_ON_ALL_ROWS { - *y = s(*y); - ++y; - } + vecDistSquared(x, y); + + ITERATE_ON_ALL_ROWS { + *y = s(*y); + ++y; } + } - //-------------------------------------------------------------------------------- - /** - * Computes the distance between vector x and a given row - * of this SparseMatrix01. Returns the result as a value_type. - * - * Non-mutating, O(ncols + nnzr) !!! - * - * @param row [0 <= size_type < nrows] index of row to compute distance from - * @param x [InIter] vector of the squared distances to x - * @retval [value_type] distance from x to row of index 'row' - * - * @b Exceptions: - * @li row < 0 || row >= nrows (check) - */ - template - inline value_type rowDistSquared(const size_type& row, InIter x) const - { - { // Pre-conditions - NTA_ASSERT(row >= 0 && row < nRows()) + //-------------------------------------------------------------------------------- + /** + * Computes the distance between vector x and a given row + * of this SparseMatrix01. Returns the result as a value_type. + * + * Non-mutating, O(ncols + nnzr) !!! + * + * @param row [0 <= size_type < nrows] index of row to compute distance from + * @param x [InIter] vector of the squared distances to x + * @retval [value_type] distance from x to row of index 'row' + * + * @b Exceptions: + * @li row < 0 || row >= nrows (check) + */ + template + inline value_type rowDistSquared(const size_type &row, InIter x) const { + { // Pre-conditions + NTA_ASSERT(row >= 0 && row < nRows()) << "SparseMatrix01::rowDistSquared(): " << "Invalid row index: " << row << " - Should be >= 0 and < nrows = " << nRows(); - } + } + + size_type j, k, nnzr, *ind, ncols = nCols(); + value_type val1, val; + value_type *sq_x = nzb_, Ssq_x = 0; + + for (j = 0; j < ncols; ++j) + Ssq_x += sq_x[j] = x[j] * x[j]; + + val = Ssq_x; + nnzr = nzr_[row]; + ind = ind_[row]; + + for (k = 0; k < nnzr; ++k) { + j = ind[k]; + val1 = value_type(1.0 - x[j]); + val += val1 * val1 - sq_x[j]; + } + + // Accuracy issues because of the subtractions, + // could return negative values + if (val <= nupic::Epsilon) + val = 0; + + { // Post-condition + NTA_ASSERT(val >= 0) << "SparseMatrix01::rowDistSquared(): " + << "Negative value in post-condition"; + } + + return val; + } + + //-------------------------------------------------------------------------------- + /** + * Computes the Euclidean distance between vector x and + * each row of this SparseMatrix01. Returns the index of the row + * that minimizes the Euclidean distance and the value of this + * distance. + * + * Non-mutating, O(nnz) + * + * @param x [InIter] vector to compute the distance from + * @retval [std::pair] index of the row closest + * to x, and value of the distance between x and that row + * + * @b Exceptions: + * @li None + */ + template + inline std::pair closestEuclidean(InIter x) const { + size_type i, j, k, arg_i, nnzr, *ind, nrows = nRows(), ncols = nCols(); + value_type val, val1, min_v; + value_type *sq_x = nzb_, Ssq_x = 0; - size_type j, k, nnzr, *ind, ncols = nCols(); - value_type val1, val; - value_type *sq_x = nzb_, Ssq_x = 0; + // Pre-computing the sum of the squares of x and + // modifying it with only the non-zeros of each row + // is a huge performance improvement, but shows some + // floating-point accuracy problems when working in float. + // It's ok when working in double. + for (i = 0; i < ncols; ++i) + Ssq_x += sq_x[i] = x[i] * x[i]; - for (j = 0; j < ncols; ++j) - Ssq_x += sq_x[j] = x[j] * x[j]; + arg_i = 0; + min_v = std::numeric_limits::max(); + + for (i = 0; i < nrows; ++i) { val = Ssq_x; - nnzr = nzr_[row]; - ind = ind_[row]; - + nnzr = nzr_[i]; + ind = ind_[i]; + for (k = 0; k < nnzr; ++k) { j = ind[k]; val1 = value_type(1.0 - x[j]); val += val1 * val1 - sq_x[j]; } - // Accuracy issues because of the subtractions, - // could return negative values - if (val <= nupic::Epsilon) - val = 0; - - { // Post-condition - NTA_ASSERT(val >= 0) - << "SparseMatrix01::rowDistSquared(): " - << "Negative value in post-condition"; + if (val < min_v) { + arg_i = i; + min_v = val; } - - return val; } - //-------------------------------------------------------------------------------- - /** - * Computes the Euclidean distance between vector x and - * each row of this SparseMatrix01. Returns the index of the row - * that minimizes the Euclidean distance and the value of this - * distance. - * - * Non-mutating, O(nnz) - * - * @param x [InIter] vector to compute the distance from - * @retval [std::pair] index of the row closest - * to x, and value of the distance between x and that row - * - * @b Exceptions: - * @li None - */ - template - inline std::pair closestEuclidean(InIter x) const - { - size_type i, j, k, arg_i, nnzr, *ind, nrows = nRows(), ncols = nCols(); - value_type val, val1, min_v; - value_type *sq_x = nzb_, Ssq_x = 0; + // Accuracy issues because of the subtractions, + // could return negative values + if (min_v <= nupic::Epsilon) + min_v = 0; - // Pre-computing the sum of the squares of x and - // modifying it with only the non-zeros of each row - // is a huge performance improvement, but shows some - // floating-point accuracy problems when working in float. - // It's ok when working in double. - for (i = 0; i < ncols; ++i) - Ssq_x += sq_x[i] = x[i] * x[i]; + { // Post-condition + NTA_ASSERT(min_v >= 0) << "SparseMatrix01::closestEuclidean(): " + << "Negative value in post-condition"; + } - arg_i = 0; - min_v = std::numeric_limits::max(); + nupic::Sqrt s; + return std::make_pair(arg_i, s(min_v)); + } - for (i = 0; i < nrows; ++i) { + //-------------------------------------------------------------------------------- + template + inline std::pair closest01(InIter x) const { + size_type i, j, k, arg_i, nnzr, *ind, nrows = nRows(), ncols = nCols(); + value_type val, min_v; + value_type *sq_x = nzb_, Ssq_x = 0; - val = Ssq_x; - nnzr = nzr_[i]; - ind = ind_[i]; + // Pre-computing the sum of the squares of x and + // modifying it with only the non-zeros of each row + // is a huge performance improvement, but shows some + // floating-point accuracy problems when working in float. + // It's ok when working in double. + for (i = 0; i < ncols; ++i) + Ssq_x += sq_x[i] = (x[i] > 0); - for (k = 0; k < nnzr; ++k) { - j = ind[k]; - val1 = value_type(1.0 - x[j]); - val += val1 * val1 - sq_x[j]; - } + arg_i = 0; + min_v = std::numeric_limits::max(); - if (val < min_v) { - arg_i = i; - min_v = val; - } - } + for (i = 0; i < nrows; ++i) { - // Accuracy issues because of the subtractions, - // could return negative values - if (min_v <= nupic::Epsilon) - min_v = 0; + val = Ssq_x; + nnzr = nzr_[i]; + ind = ind_[i]; - { // Post-condition - NTA_ASSERT(min_v >= 0) - << "SparseMatrix01::closestEuclidean(): " - << "Negative value in post-condition"; + for (k = 0; k < nnzr; ++k) { + j = ind[k]; + val += ((x[j] > 0) ? 0 : 1) - sq_x[j]; } - nupic::Sqrt s; - return std::make_pair(arg_i, s(min_v)); + if (val < min_v) { + arg_i = i; + min_v = val; + } } - //-------------------------------------------------------------------------------- - template - inline std::pair closest01(InIter x) const - { - size_type i, j, k, arg_i, nnzr, *ind, nrows = nRows(), ncols = nCols(); - value_type val, min_v; - value_type *sq_x = nzb_, Ssq_x = 0; - - // Pre-computing the sum of the squares of x and - // modifying it with only the non-zeros of each row - // is a huge performance improvement, but shows some - // floating-point accuracy problems when working in float. - // It's ok when working in double. - for (i = 0; i < ncols; ++i) - Ssq_x += sq_x[i] = (x[i] > 0); - - arg_i = 0; - min_v = std::numeric_limits::max(); - - for (i = 0; i < nrows; ++i) { + { // Post-condition + NTA_ASSERT(min_v >= 0) << "SparseMatrix01::closest01(): " + << "Negative value in post-condition"; + } - val = Ssq_x; - nnzr = nzr_[i]; - ind = ind_[i]; + nupic::Sqrt s; + return std::make_pair(arg_i, s(min_v)); + } - for (k = 0; k < nnzr; ++k) { - j = ind[k]; - val += ((x[j] > 0) ? 0 : 1) - sq_x[j]; - } + //-------------------------------------------------------------------------------- + /** + * Computes the "closest-dot" distance between vector x + * and each row in this SparseMatrix01. Returns the index of + * the row that maximizes the dot product as well as the + * value of this dot-product. + * + * Non-mutating, O(nnz) + * + * @param x [InIter] vector to compute the distance from + * @retval [std::pair] index of the row closest + * to x, and value of the distance between x and that row + * + * @b Exceptions: + * @li None + */ + template + inline std::pair closestDot(InIter x) const { + size_type i, k, arg_i, nnzr, *ind, end, nrows = nRows(); + value_type val, max_v; - if (val < min_v) { - arg_i = i; - min_v = val; - } - } + // cache prod of the nz? + // compute prod of all the x only once? - { // Post-condition - NTA_ASSERT(min_v >= 0) - << "SparseMatrix01::closest01(): " - << "Negative value in post-condition"; - } + arg_i = 0; + max_v = -std::numeric_limits::max(); - nupic::Sqrt s; - return std::make_pair(arg_i, s(min_v)); - } + for (i = 0; i < nrows; ++i) { - //-------------------------------------------------------------------------------- - /** - * Computes the "closest-dot" distance between vector x - * and each row in this SparseMatrix01. Returns the index of - * the row that maximizes the dot product as well as the - * value of this dot-product. - * - * Non-mutating, O(nnz) - * - * @param x [InIter] vector to compute the distance from - * @retval [std::pair] index of the row closest - * to x, and value of the distance between x and that row - * - * @b Exceptions: - * @li None - */ - template - inline std::pair closestDot(InIter x) const - { - size_type i, k, arg_i, nnzr, *ind, end, nrows = nRows(); - value_type val, max_v; - - // cache prod of the nz? - // compute prod of all the x only once? - - arg_i = 0; - max_v = - std::numeric_limits::max(); + val = 0; + nnzr = nzr_[i]; + ind = ind_[i]; + end = 4 * (nnzr / 4); - for (i = 0; i < nrows; ++i) { + for (k = 0; k < end; k += 4) + val += x[ind[k]] + x[ind[k + 1]] + x[ind[k + 2]] + x[ind[k + 3]]; - val = 0; - nnzr = nzr_[i]; - ind = ind_[i]; - end = 4 * (nnzr / 4); + for (k = end; k < nnzr; ++k) + val += x[ind[k]]; - for (k = 0; k < end; k += 4) - val += x[ind[k]] + x[ind[k+1]] + x[ind[k+2]] + x[ind[k+3]]; + if (val > max_v) { + arg_i = i; + max_v = val; + } + } - for (k = end; k < nnzr; ++k) - val += x[ind[k]]; + return std::make_pair(arg_i, max_v); + } - if (val > max_v) { - arg_i = i; - max_v = val; - } + //-------------------------------------------------------------------------------- + /** + * Computes the product of vector x by this SparseMatrix01 + * on the right side, and puts the result in vector y. + * + * Non-mutating, O(nnz) + * + * @param x [InIter] input vector + * @param y [OutIter] result of the multiplication + * + * @b Exceptions: + * @li None + */ + template + inline void rightVecProd(InIter x, OutIter y) const { + size_type i, nnzr, *ind, *end1, *end2, nrows = nRows(); + value_type val, a, b; + + // memcpy(nzb_, &*x, nCols()*sizeof(value_type)); + + for (i = 0; i < nrows; ++i, ++y) { + + val = 0; + nnzr = nzr_[i]; + ind = ind_[i]; + end1 = ind + 4 * (nnzr / 4); + end2 = ind + nnzr; + + while (ind != end1) { + a = x[*ind++]; + b = x[*ind++]; + val += a + b; + a = x[*ind++]; + b = x[*ind++]; + val += a + b; } - return std::make_pair(arg_i, max_v); + while (ind != end2) + val += x[*ind++]; + + *y = val; } + } - //-------------------------------------------------------------------------------- - /** - * Computes the product of vector x by this SparseMatrix01 - * on the right side, and puts the result in vector y. - * - * Non-mutating, O(nnz) - * - * @param x [InIter] input vector - * @param y [OutIter] result of the multiplication - * - * @b Exceptions: - * @li None - */ - template - inline void rightVecProd(InIter x, OutIter y) const - { - size_type i, nnzr, *ind, *end1, *end2, nrows = nRows(); - value_type val, a, b; + //-------------------------------------------------------------------------------- + /** + * Computes the max prod between each element of vector x + * and each non-zero of each row of this SparseMatrix01. + * Puts the max for each row in vector y. + * + * Non-mutating, O(nnz) + * + * @param x [InIter] input vector + * @param y [OutIter] result + * + * @b Exceptions: + * @li None + */ + template + inline void vecMaxProd(InIter x, OutIter y) const { + size_type i, *ind, *ind_end, nrows = nRows(); + value_type max_v, p; - //memcpy(nzb_, &*x, nCols()*sizeof(value_type)); + // memcpy(nzb_, &*x, nCols()*sizeof(value_type)); - for (i = 0; i < nrows; ++i, ++y) { + for (i = 0; i < nrows; ++i, ++y) { - val = 0; - nnzr = nzr_[i]; - ind = ind_[i]; - end1 = ind + 4*(nnzr/4); - end2 = ind + nnzr; - - while (ind != end1) { - a = x[*ind++]; - b = x[*ind++]; - val += a + b; - a = x[*ind++]; - b = x[*ind++]; - val += a + b; - } + ind = ind_[i]; + ind_end = ind + nzr_[i]; + max_v = 0; // nnzr == 0 ? 0 : x[ind[0]]; - while (ind != end2) - val += x[*ind++]; - - *y = val; + while (ind != ind_end) { + p = x[*ind++]; + if (p > max_v) + max_v = p; } + + *y = max_v; } + } - //-------------------------------------------------------------------------------- - /** - * Computes the max prod between each element of vector x - * and each non-zero of each row of this SparseMatrix01. - * Puts the max for each row in vector y. - * - * Non-mutating, O(nnz) - * - * @param x [InIter] input vector - * @param y [OutIter] result - * - * @b Exceptions: - * @li None - */ - template - inline void vecMaxProd(InIter x, OutIter y) const - { - size_type i, *ind, *ind_end, nrows = nRows(); - value_type max_v, p; + //-------------------------------------------------------------------------------- + /** + * Computes the index of the max value of x, where that + * value is at the index of a non-zero. Does that for each + * row, stores the resulting index in y for each row. + * + * Non-mutating, O(nnz) + * + * @param x [InIter] input vector + * @param y [OutIter] output vector + * + * @b Exceptions: + * @li None + */ + template + inline void rowMax(InIter x, OutIter y) const { + size_type i, j, nnzr, *ind, *end, arg_j, nrows = nRows(); + value_type val, max_val; - //memcpy(nzb_, &*x, nCols()*sizeof(value_type)); - - for (i = 0; i < nrows; ++i, ++y) { + for (arg_j = 0, i = 0; i < nrows; ++i, ++y) { - ind = ind_[i]; - ind_end = ind + nzr_[i]; - max_v = 0; //nnzr == 0 ? 0 : x[ind[0]]; + nnzr = nzr_[i]; + ind = ind_[i]; + end = ind + nnzr; - while (ind != ind_end) { - p = x[*ind++]; - if (p > max_v) - max_v = p; + max_val = 0; //- std::numeric_limits::max(); + + while (ind != end) { + j = *ind; + val = x[j]; + if (val > max_val) { + arg_j = j; + max_val = val; } - - *y = max_v; + ++ind; } + *y = value_type(arg_j); } + } - //-------------------------------------------------------------------------------- - /** - * Computes the index of the max value of x, where that - * value is at the index of a non-zero. Does that for each - * row, stores the resulting index in y for each row. - * - * Non-mutating, O(nnz) - * - * @param x [InIter] input vector - * @param y [OutIter] output vector - * - * @b Exceptions: - * @li None - */ - template - inline void rowMax(InIter x, OutIter y) const - { - size_type i, j, nnzr, *ind, *end, arg_j, nrows = nRows(); - value_type val, max_val; - - for (arg_j = 0, i = 0; i < nrows; ++i, ++y) { - - nnzr = nzr_[i]; - ind = ind_[i]; - end = ind + nnzr; - - max_val = 0; //- std::numeric_limits::max(); - - while (ind != end) { - j = *ind; - val = x[j]; - if (val > max_val) { - arg_j = j; - max_val = val; - } - ++ind; - } - *y = value_type(arg_j); + //-------------------------------------------------------------------------------- + /** + * Computes the product of the values in x corresponding to the non-zeros + * for each row. + * Stores the result in y. + * + * Non-mutating, O(nnz) + * + * @param x [InIter] input vector + * @param y [OutIter] output vector + * + * @b Exceptions: + * @li None + */ + template + inline void rowProd(InIter x, OutIter y) const { + size_type i, *end1, *end2, nnzr, *ind, nrows = nRows(); + value_type val, a, b; + + // Pre-fetch x into nzb_ to minimize cache misses + // Also, copying x into nzb_ on the stack frees a register + // that can be used below for a and b + // memcpy(nzb_, &*x, nCols()*sizeof(value_type)); + + for (i = 0; i < nrows; ++i, ++y) { + + nnzr = nzr_[i]; + ind = ind_[i]; + + end1 = ind + 4 * (nnzr / 4); + end2 = ind + nnzr; + + val = 1.0; + + // Using a and b to use more registers + while (ind != end1) { // != faster than ind < end1 + a = x[*ind++]; + b = x[*ind++]; + val *= a * b; + a = x[*ind++]; + b = x[*ind++]; + val *= a * b; } - } - //-------------------------------------------------------------------------------- - /** - * Computes the product of the values in x corresponding to the non-zeros - * for each row. - * Stores the result in y. - * - * Non-mutating, O(nnz) - * - * @param x [InIter] input vector - * @param y [OutIter] output vector - * - * @b Exceptions: - * @li None - */ - template - inline void rowProd(InIter x, OutIter y) const - { - size_type i, *end1, *end2, nnzr, *ind, nrows = nRows(); - value_type val, a, b; - - // Pre-fetch x into nzb_ to minimize cache misses - // Also, copying x into nzb_ on the stack frees a register - // that can be used below for a and b - //memcpy(nzb_, &*x, nCols()*sizeof(value_type)); - - for (i = 0; i < nrows; ++i, ++y) { - - nnzr = nzr_[i]; - ind = ind_[i]; - - end1 = ind + 4*(nnzr/4); - end2 = ind + nnzr; - - val = 1.0; - - // Using a and b to use more registers - while (ind != end1) { // != faster than ind < end1 - a = x[*ind++]; - b = x[*ind++]; - val *= a * b; - a = x[*ind++]; - b = x[*ind++]; - val *= a * b; - } + while (ind != end2) + val *= x[*ind++]; - while (ind != end2) - val *= x[*ind++]; + *y = val; + } - *y = val; - } + /* + for (k = 0; k < end1; k += 4) + val *= x[ind[k]] * x[ind[k+1]] * x[ind[k+2]] * x[ind[k+3]]; - /* - for (k = 0; k < end1; k += 4) - val *= x[ind[k]] * x[ind[k+1]] * x[ind[k+2]] * x[ind[k+3]]; - - for (k = end1; k < nnzr; ++k) - val *= x[ind[k]]; - */ + for (k = end1; k < nnzr; ++k) + val *= x[ind[k]]; + */ - /* Slow - while (ind < end1) { - val *= x[*ind]; ++ind; - val *= x[*ind]; ++ind; - val *= x[*ind]; ++ind; - val *= x[*ind]; ++ind; - } - */ - - /* 30% faster, but wrong result, because of order - while (ind < end1) - val *= x[*ind++] * x[*ind++] * x[*ind++] * x[*ind++]; - */ - } + /* Slow + while (ind < end1) { + val *= x[*ind]; ++ind; + val *= x[*ind]; ++ind; + val *= x[*ind]; ++ind; + val *= x[*ind]; ++ind; + } + */ - //-------------------------------------------------------------------------------- - /** - * This is used in prediction to clamp the results to lb in case of underflow. - */ - template - inline void rowProd(InIter x, OutIter y, const value_type& lb) const - { - size_type i, k, nnzr, end, *ind, nrows = nRows(); - double val; + /* 30% faster, but wrong result, because of order + while (ind < end1) + val *= x[*ind++] * x[*ind++] * x[*ind++] * x[*ind++]; + */ + } - for (i = 0; i < nrows; ++i) { - - nnzr = nzr_[i]; - ind = ind_[i]; - val = 1; - end = 4*(nnzr / 4); - - k = 0; - while (k < end && val > lb) { - val *= x[ind[k]] * x[ind[k+1]] * x[ind[k+2]] * x[ind[k+3]]; - k += 4; - } - - if (val > lb) - for (k = end; k < nnzr; ++k) - val *= x[ind[k]]; - - if (val > lb) - *y++ = (value_type) val; - else - *y++ = lb; + //-------------------------------------------------------------------------------- + /** + * This is used in prediction to clamp the results to lb in case of underflow. + */ + template + inline void rowProd(InIter x, OutIter y, const value_type &lb) const { + size_type i, k, nnzr, end, *ind, nrows = nRows(); + double val; + + for (i = 0; i < nrows; ++i) { + + nnzr = nzr_[i]; + ind = ind_[i]; + val = 1; + end = 4 * (nnzr / 4); + + k = 0; + while (k < end && val > lb) { + val *= x[ind[k]] * x[ind[k + 1]] * x[ind[k + 2]] * x[ind[k + 3]]; + k += 4; } + + if (val > lb) + for (k = end; k < nnzr; ++k) + val *= x[ind[k]]; + + if (val > lb) + *y++ = (value_type)val; + else + *y++ = lb; } + } - //-------------------------------------------------------------------------------- - inline void print(std::ostream& outStream) const - { - size_type i, j, k; - for (i = 0; i < nRows(); ++i) { - if (nzr_[i] > 0) { - for (j = 0, k = 0; j < nCols(); ++j) { - if (k < nzr_[i] && ind_[i][k] == j) { - outStream << "1 "; - ++k; - } else { - outStream << "0 "; - } + //-------------------------------------------------------------------------------- + inline void print(std::ostream &outStream) const { + size_type i, j, k; + for (i = 0; i < nRows(); ++i) { + if (nzr_[i] > 0) { + for (j = 0, k = 0; j < nCols(); ++j) { + if (k < nzr_[i] && ind_[i][k] == j) { + outStream << "1 "; + ++k; + } else { + outStream << "0 "; } } - outStream << std::endl; } + outStream << std::endl; } - }; - - //-------------------------------------------------------------------------------- - template - inline std::ostream& operator<<(std::ostream& outStream, - const SparseMatrix01& A) - { - A.print(outStream); - return outStream; } +}; - //-------------------------------------------------------------------------------- +//-------------------------------------------------------------------------------- +template +inline std::ostream &operator<<(std::ostream &outStream, + const SparseMatrix01 &A) { + A.print(outStream); + return outStream; +} + +//-------------------------------------------------------------------------------- } // end namespace nupic #endif // NTA_SPARSE_MATRIX01_HPP - diff --git a/src/nupic/math/SparseMatrixAlgorithms.cpp b/src/nupic/math/SparseMatrixAlgorithms.cpp index 6f03b5ea15..3cf79d71fe 100644 --- a/src/nupic/math/SparseMatrixAlgorithms.cpp +++ b/src/nupic/math/SparseMatrixAlgorithms.cpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * External algorithms for operating on a sparse matrix. */ @@ -34,9 +34,8 @@ namespace nupic { - // The two tables used when approximating logSum and logDiff. - std::vector LogSumApprox::table; - std::vector LogDiffApprox::table; +// The two tables used when approximating logSum and logDiff. +std::vector LogSumApprox::table; +std::vector LogDiffApprox::table; } // end namespace nupic - diff --git a/src/nupic/math/SparseMatrixAlgorithms.hpp b/src/nupic/math/SparseMatrixAlgorithms.hpp index 6fbdd4a6e2..8ceee0b81f 100644 --- a/src/nupic/math/SparseMatrixAlgorithms.hpp +++ b/src/nupic/math/SparseMatrixAlgorithms.hpp @@ -20,2227 +20,2247 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * External algorithms that operate on SparseMatrix. */ #ifndef NTA_SM_ALGORITHMS_HPP #define NTA_SM_ALGORITHMS_HPP -#include #include +#include //-------------------------------------------------------------------------------- namespace nupic { +/** + * A collection of algorithms that operate on SparseMatrix. They are put here + * instead of directly in the SparseMatrix because they are not as general + * as the SparseMatrix methods. They are usually tailored for a specific, + * sometimes experimental, algorithm. This struct is a friend of SparseMatrix, + * so that it can access iterators on the indices and values of the non-zeros + * that are not made public in SparseMatrix. In the following methods, template + * parameter "SM" stands for a SparseMatrix type. + */ +struct SparseMatrixAlgorithms { + //-------------------------------------------------------------------------------- /** - * A collection of algorithms that operate on SparseMatrix. They are put here - * instead of directly in the SparseMatrix because they are not as general - * as the SparseMatrix methods. They are usually tailored for a specific, sometimes - * experimental, algorithm. This struct is a friend of SparseMatrix, so that it - * can access iterators on the indices and values of the non-zeros that are not - * made public in SparseMatrix. In the following methods, template parameter "SM" - * stands for a SparseMatrix type. + * Computes the entropy rate of a sparse matrix, along the rows or the + columns. + * This is defined as: + * sum(-nz[i,j] * log2(nz[i,j]) * sum_of_row[i], for all i,j), i.e. + * the usual definition of entropy, but weighted by the probability of the + rows + * or column, i.e. the probability of the conditional distributions. + * + * A copy of the matrix passed in is performed, and the copy is normalized to + * give it the meaning of a joint distribution. This is pretty slow. + * + * TODO: + * + * I don't think the matrix needs to be normalized (which means rowwise + normalization). You've already computed the "row sums", which are the per-row + normalization factor, and you can use this in the entropy calculation. i.e. + if n is the norm of a single row and x[i] is the original value and xn[i] = + x[i]/n is the normalized value then the partial contribution for that row is: + - n * sum xn[i] * ln xn[i] + = - n * sum x[i]/n * ln x[i]/n + = - sum x[i] *( ln x[i] - ln[n]) + = sum x[i] ln[n] - sum x[i] ln x[i] + = n ln [n] - sum x[i] ln x[i] + + In other words, you compute the entropy based on the non-normalized matrix + and then add n ln[n] (There may be an error in this calculation, but in any + case, I'm pretty sure you don't actually need to normalize the matrix) */ - struct SparseMatrixAlgorithms - { - //-------------------------------------------------------------------------------- - /** - * Computes the entropy rate of a sparse matrix, along the rows or the columns. - * This is defined as: - * sum(-nz[i,j] * log2(nz[i,j]) * sum_of_row[i], for all i,j), i.e. - * the usual definition of entropy, but weighted by the probability of the rows - * or column, i.e. the probability of the conditional distributions. - * - * A copy of the matrix passed in is performed, and the copy is normalized to - * give it the meaning of a joint distribution. This is pretty slow. - * - * TODO: - * - * I don't think the matrix needs to be normalized (which means rowwise normalization). - You've already computed the "row sums", which are the per-row normalization factor, - and you can use this in the entropy calculation. i.e. if n is the norm of a single row - and x[i] is the original value and xn[i] = x[i]/n is the normalized value then - the partial contribution for that row is: - - n * sum xn[i] * ln xn[i] - = - n * sum x[i]/n * ln x[i]/n - = - sum x[i] *( ln x[i] - ln[n]) - = sum x[i] ln[n] - sum x[i] ln x[i] - = n ln [n] - sum x[i] ln x[i] - - In other words, you compute the entropy based on the non-normalized matrix and then - add n ln[n] (There may be an error in this calculation, but in any case, I'm pretty - sure you don't actually need to normalize the matrix) - */ - template - static typename SM::value_type entropy_rate(const SM& sm) - { - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; - - SM m(sm); - - std::vector s(m.nRows()); - - m.rowSums(s); - nupic::normalize(s); - m.normalizeRows(); - - value_type e = 0; - - Log2 log2_f; - - for (size_type i = 0; i != m.nRows(); ++i) { - value_type ee = 0; - const value_type* nz = m.nz_begin_(i); - const value_type* nz_end = m.nz_end_(i); - for (; nz != nz_end; ++nz) - ee += *nz * log2_f(*nz); - e -= s[i] * ee; - } - - return e; + template + static typename SM::value_type entropy_rate(const SM &sm) { + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; + + SM m(sm); + + std::vector s(m.nRows()); + + m.rowSums(s); + nupic::normalize(s); + m.normalizeRows(); + + value_type e = 0; + + Log2 log2_f; + + for (size_type i = 0; i != m.nRows(); ++i) { + value_type ee = 0; + const value_type *nz = m.nz_begin_(i); + const value_type *nz_end = m.nz_end_(i); + for (; nz != nz_end; ++nz) + ee += *nz * log2_f(*nz); + e -= s[i] * ee; } - //-------------------------------------------------------------------------------- - /** - * Computes an entropy on a "smoothed" SM, for each row and for each column. - * Smoothes by simply adding 1 to each count as the entropy is calculated. - */ - template - static void matrix_entropy(const SM& sm, OutputIter row_out, OutputIter row_out_end, - OutputIter col_out, OutputIter col_out_end, - typename SM::value_type s = 1.0) - { - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; + return e; + } - { // Pre-conditions - NTA_CHECK((size_type)(row_out_end - row_out) == sm.nRows()) - << "entropy_smooth: Invalid size for output vector: " + //-------------------------------------------------------------------------------- + /** + * Computes an entropy on a "smoothed" SM, for each row and for each column. + * Smoothes by simply adding 1 to each count as the entropy is calculated. + */ + template + static void matrix_entropy(const SM &sm, OutputIter row_out, + OutputIter row_out_end, OutputIter col_out, + OutputIter col_out_end, + typename SM::value_type s = 1.0) { + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; + + { // Pre-conditions + NTA_CHECK((size_type)(row_out_end - row_out) == sm.nRows()) + << "entropy_smooth: Invalid size for output vector: " << (size_type)(row_out_end - row_out) << " - Should be number of rows: " << sm.nRows(); - NTA_CHECK((size_type)(col_out_end - col_out) == sm.nCols()) + NTA_CHECK((size_type)(col_out_end - col_out) == sm.nCols()) << "entropy_smooth: Invalid size for output vector: " << (size_type)(col_out_end - col_out) << " - Should be number of columns: " << sm.nCols(); - } // End pre-conditions - - size_type m = sm.nRows(), n = sm.nCols(); + } // End pre-conditions - std::vector row_sums(m, (value_type) n * s); - std::fill(sm.indb_, sm.indb_ + n, (size_type) 0); - std::fill(sm.nzb_, sm.nzb_ + n, (value_type) m * s); - - for (size_type row = 0; row != m; ++row) { - size_type *ind = sm.ind_[row], *ind_end = ind + sm.nnzr_[row]; - value_type *nz = sm.nz_[row]; - for (; ind != ind_end; ++ind, ++nz) { - row_sums[row] += *nz; - sm.nzb_[*ind] += *nz; - sm.indb_[*ind] += (size_type) 1; - } - } + size_type m = sm.nRows(), n = sm.nCols(); - Log2 log2_f; - - for (size_type c = 0; c != n; ++c) { - value_type v = s / sm.nzb_[c]; - *(col_out + c) = - ((value_type)(m - sm.indb_[c]) * v * log2_f(v)); - } + std::vector row_sums(m, (value_type)n * s); + std::fill(sm.indb_, sm.indb_ + n, (size_type)0); + std::fill(sm.nzb_, sm.nzb_ + n, (value_type)m * s); - for (size_type row = 0; row != m; ++row, ++row_out) { - size_type *ind = sm.ind_[row], *ind_end = ind + sm.nnzr_[row]; - value_type *nz = sm.nz_[row]; - value_type v = s / row_sums[row]; - *row_out = - ((value_type)(n - sm.nnzr_[row]) * v * log2_f(v)); - for (; ind != ind_end; ++ind, ++nz) { - v = *nz + s; - value_type val_row = v / row_sums[row]; - *row_out -= val_row * log2_f(val_row); - value_type val_col = v / sm.nzb_[*ind]; - *(col_out + *ind) -= val_col * log2_f(val_col); - } + for (size_type row = 0; row != m; ++row) { + size_type *ind = sm.ind_[row], *ind_end = ind + sm.nnzr_[row]; + value_type *nz = sm.nz_[row]; + for (; ind != ind_end; ++ind, ++nz) { + row_sums[row] += *nz; + sm.nzb_[*ind] += *nz; + sm.indb_[*ind] += (size_type)1; } } - //-------------------------------------------------------------------------------- - /** - * Multiplies the 'X' matrix by the constant 'a', - * then adds 'b * X * Y' to it, in-place. - * for row in [0,nrows): - * for col in [0,ncols): - * X[row,col] = X[row,col] * (a + b * Y[row,col]) - * - * @param a [value_type] a coefficient - * @param b [value_type] b coefficient - * @param B [const SparseMatrix] Y matrix - * - */ - template - static void - aX_plus_bX_elementMultiply_Y(const typename SM1::value_type &a, SM1 &Xoutput, - const typename SM1::value_type &b, const SM2& Y) - { - typedef typename SM1::size_type size_type; - typedef typename SM1::value_type value_type; - - const size_type nrows = Xoutput.nRows(); - - for (size_type row = 0; row != nrows; ++row) { + Log2 log2_f; - value_type *nz_write = Xoutput.nz_begin_(row); - size_type *ind_write = Xoutput.ind_begin_(row); - const size_type *ind_x_begin = ind_write; - - const value_type *nz_x = Xoutput.nz_begin_(row); - const size_type *ind_x = ind_x_begin; - const size_type *ind_x_end = Xoutput.ind_end_(row); - - const typename SM2::value_type *nz_y = Y.nz_begin_(row); - const typename SM2::size_type *ind_y = Y.ind_begin_(row); - const typename SM2::size_type *ind_y_begin = ind_y; - const typename SM2::size_type *ind_y_end = Y.ind_end_(row); - - while (ind_x != ind_x_end) { - - const size_type column = *(ind_x++); - const value_type vx = *(nz_x++); - value_type val = vx * a; - ind_y = std::lower_bound(ind_y, ind_y_end, column); + for (size_type c = 0; c != n; ++c) { + value_type v = s / sm.nzb_[c]; + *(col_out + c) = -((value_type)(m - sm.indb_[c]) * v * log2_f(v)); + } - if (ind_y != ind_y_end && column == *ind_y) { - const value_type vy = (value_type) nz_y[ind_y - ind_y_begin]; - val += vx * vy * b; - ++ind_y; - } + for (size_type row = 0; row != m; ++row, ++row_out) { + size_type *ind = sm.ind_[row], *ind_end = ind + sm.nnzr_[row]; + value_type *nz = sm.nz_[row]; + value_type v = s / row_sums[row]; + *row_out = -((value_type)(n - sm.nnzr_[row]) * v * log2_f(v)); + for (; ind != ind_end; ++ind, ++nz) { + v = *nz + s; + value_type val_row = v / row_sums[row]; + *row_out -= val_row * log2_f(val_row); + value_type val_col = v / sm.nzb_[*ind]; + *(col_out + *ind) -= val_col * log2_f(val_col); + } + } + } - // Could save this check, but it should usually - // be predictable. - if (!Xoutput.isZero_(val)) { - *ind_write++ = column; - *nz_write++ = val; - } + //-------------------------------------------------------------------------------- + /** + * Multiplies the 'X' matrix by the constant 'a', + * then adds 'b * X * Y' to it, in-place. + * for row in [0,nrows): + * for col in [0,ncols): + * X[row,col] = X[row,col] * (a + b * Y[row,col]) + * + * @param a [value_type] a coefficient + * @param b [value_type] b coefficient + * @param B [const SparseMatrix] Y matrix + * + */ + template + static void aX_plus_bX_elementMultiply_Y(const typename SM1::value_type &a, + SM1 &Xoutput, + const typename SM1::value_type &b, + const SM2 &Y) { + typedef typename SM1::size_type size_type; + typedef typename SM1::value_type value_type; + + const size_type nrows = Xoutput.nRows(); + + for (size_type row = 0; row != nrows; ++row) { + + value_type *nz_write = Xoutput.nz_begin_(row); + size_type *ind_write = Xoutput.ind_begin_(row); + const size_type *ind_x_begin = ind_write; + + const value_type *nz_x = Xoutput.nz_begin_(row); + const size_type *ind_x = ind_x_begin; + const size_type *ind_x_end = Xoutput.ind_end_(row); + + const typename SM2::value_type *nz_y = Y.nz_begin_(row); + const typename SM2::size_type *ind_y = Y.ind_begin_(row); + const typename SM2::size_type *ind_y_begin = ind_y; + const typename SM2::size_type *ind_y_end = Y.ind_end_(row); + + while (ind_x != ind_x_end) { + + const size_type column = *(ind_x++); + const value_type vx = *(nz_x++); + value_type val = vx * a; + ind_y = std::lower_bound(ind_y, ind_y_end, column); + + if (ind_y != ind_y_end && column == *ind_y) { + const value_type vy = (value_type)nz_y[ind_y - ind_y_begin]; + val += vx * vy * b; + ++ind_y; } - Xoutput.nnzr_[row] = (size_type)(ind_write - ind_x_begin); + // Could save this check, but it should usually + // be predictable. + if (!Xoutput.isZero_(val)) { + *ind_write++ = column; + *nz_write++ = val; + } } + + Xoutput.nnzr_[row] = (size_type)(ind_write - ind_x_begin); } + } - //-------------------------------------------------------------------------------- - /** - * Used to speed up sparse pooler algorithm. - * - * TODO: describe algo. - */ - template - static void - kthroot_product(const SM& sm, - typename SM::size_type ss, InIter1 x, OutIter y, - const typename SM::value_type& min_input) - { - using namespace std; + //-------------------------------------------------------------------------------- + /** + * Used to speed up sparse pooler algorithm. + * + * TODO: describe algo. + */ + template + static void kthroot_product(const SM &sm, typename SM::size_type ss, + InIter1 x, OutIter y, + const typename SM::value_type &min_input) { + using namespace std; + + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; - - { - NTA_ASSERT(sm.nCols() % ss == 0) + { + NTA_ASSERT(sm.nCols() % ss == 0) << "SparseMatrix kthroot_product: " << "Invalid segment size: " << ss << "Needs to be a divisor of nCols() = " << sm.nCols(); - } - - Log log_f; - Exp exp_f; - - const size_type k = sm.nCols() / ss; - const value_type log_min_input = logf(min_input); - - OutIter y_begin = y, y_end = y_begin + sm.nCols(); - - for (size_type row = 0; row != sm.nRows(); ++row) { + } - value_type sum = (value_type) 0.0f; - size_type seg_begin = 0, seg_end = ss; - size_type *ind = sm.ind_begin_(row), *ind_end = sm.ind_end_(row); - for (; seg_begin != sm.nCols(); seg_begin += ss, seg_end += ss) { - if (ind < ind_end && seg_begin <= *ind && *ind < seg_end) { - size_type *c2 = seg_end == sm.nCols() ? ind_end : sm.pos_(row, seg_end); - for (; ind != c2; ++ind) { - value_type val = x[*ind]; - if (sm.isZero_(val)) - sum += log_min_input; - else - sum += log_f(val); - } - } else { - value_type max_value = - std::numeric_limits::max(); - for (size_type i = seg_begin; i != seg_end; ++i) - max_value = std::max(x[i], max_value); - sum += log_f(std::max((value_type)1.0f - max_value, min_input)); - ind = seg_end == sm.nCols() ? ind_end : sm.pos_(row, seg_end); + Log log_f; + Exp exp_f; + + const size_type k = sm.nCols() / ss; + const value_type log_min_input = logf(min_input); + + OutIter y_begin = y, y_end = y_begin + sm.nCols(); + + for (size_type row = 0; row != sm.nRows(); ++row) { + + value_type sum = (value_type)0.0f; + size_type seg_begin = 0, seg_end = ss; + size_type *ind = sm.ind_begin_(row), *ind_end = sm.ind_end_(row); + for (; seg_begin != sm.nCols(); seg_begin += ss, seg_end += ss) { + if (ind < ind_end && seg_begin <= *ind && *ind < seg_end) { + size_type *c2 = + seg_end == sm.nCols() ? ind_end : sm.pos_(row, seg_end); + for (; ind != c2; ++ind) { + value_type val = x[*ind]; + if (sm.isZero_(val)) + sum += log_min_input; + else + sum += log_f(val); } + } else { + value_type max_value = -std::numeric_limits::max(); + for (size_type i = seg_begin; i != seg_end; ++i) + max_value = std::max(x[i], max_value); + sum += log_f(std::max((value_type)1.0f - max_value, min_input)); + ind = seg_end == sm.nCols() ? ind_end : sm.pos_(row, seg_end); } - // On x86_64, there is a bug in glibc that makes expf very slow - // (more than it should be), so we continue using exp on that - // platform as a workaround. - // https://bugzilla.redhat.com/show_bug.cgi?id=521190 - // To force the compiler to use exp instead of expf, the return - // type (and not the argument type!) needs to be double. - *y = exp_f(sum / k); - ++y; } - if (positive_less_than(y_begin, y_end, min_input)) - std::fill(y_begin, y_end, (value_type) 0); + // On x86_64, there is a bug in glibc that makes expf very slow + // (more than it should be), so we continue using exp on that + // platform as a workaround. + // https://bugzilla.redhat.com/show_bug.cgi?id=521190 + // To force the compiler to use exp instead of expf, the return + // type (and not the argument type!) needs to be double. + *y = exp_f(sum / k); + ++y; } + if (positive_less_than(y_begin, y_end, min_input)) + std::fill(y_begin, y_end, (value_type)0); + } - //-------------------------------------------------------------------------------- - /** - * Computes the product of a sparse matrix and a sparse vector on the right. - * [x_begin, x_end) is the range of x that contains the non-zeros for x. - * This function skips multiplying by zeros out of [x_begin, x_end). - * This is used only in nupic/math_research/shmm.cpp and direct unit testing - * is missing. - * - * TODO: check if we can't remove that and replace by incrementOuterWithNZ - */ - template - static void sparseRightVecProd(const SM& sm, - typename SM::size_type x_begin, - typename SM::size_type x_end, - InputIterator1 x, - InputIterator2 y) - { - { // Pre-conditions - sm.assert_valid_col_range_(x_begin, x_end, "sparseRightVecProd: Invalid range"); - } // End pre-conditions - - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; - - for (size_type row = 0; row != sm.nRows(); ++row, ++y) { - size_type nnzr = sm.nNonZerosOnRow(row); - if (nnzr == 0) { - *y = 0; - continue; - } - size_type *ind = sm.ind_begin_(row), *ind_end = sm.ind_end_(row); - size_type *p1 = std::lower_bound(ind, ind_end, x_begin); - if (p1 == ind_end) { - *y = 0; - continue; - } - size_type *p2 = std::lower_bound(p1, ind_end, x_end); - value_type *nz = sm.nz_begin_(row) + (p1 - ind); - value_type val = 0; - for (ind = p1; ind != p2; ++ind, ++nz) - val += *nz * *(x + *ind); - *y = val; + //-------------------------------------------------------------------------------- + /** + * Computes the product of a sparse matrix and a sparse vector on the right. + * [x_begin, x_end) is the range of x that contains the non-zeros for x. + * This function skips multiplying by zeros out of [x_begin, x_end). + * This is used only in nupic/math_research/shmm.cpp and direct unit testing + * is missing. + * + * TODO: check if we can't remove that and replace by incrementOuterWithNZ + */ + template + static void sparseRightVecProd(const SM &sm, typename SM::size_type x_begin, + typename SM::size_type x_end, InputIterator1 x, + InputIterator2 y) { + { // Pre-conditions + sm.assert_valid_col_range_(x_begin, x_end, + "sparseRightVecProd: Invalid range"); + } // End pre-conditions + + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; + + for (size_type row = 0; row != sm.nRows(); ++row, ++y) { + size_type nnzr = sm.nNonZerosOnRow(row); + if (nnzr == 0) { + *y = 0; + continue; } + size_type *ind = sm.ind_begin_(row), *ind_end = sm.ind_end_(row); + size_type *p1 = std::lower_bound(ind, ind_end, x_begin); + if (p1 == ind_end) { + *y = 0; + continue; + } + size_type *p2 = std::lower_bound(p1, ind_end, x_end); + value_type *nz = sm.nz_begin_(row) + (p1 - ind); + value_type val = 0; + for (ind = p1; ind != p2; ++ind, ++nz) + val += *nz * *(x + *ind); + *y = val; } + } - //-------------------------------------------------------------------------------- - /** - * Wrapper for iterator-based sparseRightVecProd that takes std::vectors. - * - * TODO: can we remove? - */ - /* - template - static void sparseRightVecProd(const SM& sm, - typename SM::size_type x_begin, - typename SM::size_type x_end, - const std::vector& x, - std::vector& y) - { - { // Pre-conditions - NTA_ASSERT(x.size() == sm.nCols()) - << "sparseRightVecProd: Wrong size for x: " << x.size() - << " when sparse matrix has: " << sm.nCols() << " columns"; - NTA_ASSERT(y.size() == sm.nRows()) - << "sparseRightVecProd: Wrong size for y: " << y.size() - << " when sparse matrix has: " << sm.nRows() << " rows"; - sm.assert_valid_col_range_(x_begin, x_end, "sparseRightVecProd: Invalid range"); - } // End pre-conditions - - SparseMatrixAlgorithms::sparseRightVecProd(sm, x_begin, x_end, x.begin(), y.begin()); - } - */ - - //-------------------------------------------------------------------------------- - /** - * Computes a smoothed version of all rows vec max prod, that is: - * - * for row in [0,nrows): - * y[row] = max((this[row,col] + k) * x[col], for col in [0,ncols)) - * - */ - template - static void smoothVecMaxProd(const SM& sm, - typename SM::value_type k, - InputIterator x, InputIterator x_end, - OutputIterator y, OutputIterator y_end) - { - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; - - { // Pre-conditions - NTA_ASSERT((size_type)(x_end - x) == sm.nCols()); - NTA_ASSERT((size_type)(y_end - y) == sm.nRows()); - } // End pre-conditions - - // Compute k * x only once, and cache result in sm.nzb_ - for (size_type j = 0; j != sm.nCols(); ++j) - sm.nzb_[j] = k * x[j]; - - for (size_type row = 0; row != sm.nRows(); ++row) { - - value_type max_v = - std::numeric_limits::max(); - size_type *ind = sm.ind_[row], *ind_end = ind + sm.nnzr_[row]; - value_type *nz = sm.nz_[row]; - - for (size_type col = 0; col != sm.nCols(); ++col) { + //-------------------------------------------------------------------------------- + /** + * Wrapper for iterator-based sparseRightVecProd that takes std::vectors. + * + * TODO: can we remove? + */ + /* + template + static void sparseRightVecProd(const SM& sm, + typename SM::size_type x_begin, + typename SM::size_type x_end, + const std::vector& x, + std::vector& y) + { + { // Pre-conditions + NTA_ASSERT(x.size() == sm.nCols()) + << "sparseRightVecProd: Wrong size for x: " << x.size() + << " when sparse matrix has: " << sm.nCols() << " columns"; + NTA_ASSERT(y.size() == sm.nRows()) + << "sparseRightVecProd: Wrong size for y: " << y.size() + << " when sparse matrix has: " << sm.nRows() << " rows"; + sm.assert_valid_col_range_(x_begin, x_end, "sparseRightVecProd: Invalid + range"); } // End pre-conditions + + SparseMatrixAlgorithms::sparseRightVecProd(sm, x_begin, x_end, x.begin(), + y.begin()); + } + */ - value_type p = sm.nzb_[col]; - if (ind != ind_end && col == *ind) - p += *nz++ * x[*ind++]; - if (p > max_v) - max_v = p; - } - - *y++ = max_v; + //-------------------------------------------------------------------------------- + /** + * Computes a smoothed version of all rows vec max prod, that is: + * + * for row in [0,nrows): + * y[row] = max((this[row,col] + k) * x[col], for col in [0,ncols)) + * + */ + template + static void smoothVecMaxProd(const SM &sm, typename SM::value_type k, + InputIterator x, InputIterator x_end, + OutputIterator y, OutputIterator y_end) { + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; + + { // Pre-conditions + NTA_ASSERT((size_type)(x_end - x) == sm.nCols()); + NTA_ASSERT((size_type)(y_end - y) == sm.nRows()); + } // End pre-conditions + + // Compute k * x only once, and cache result in sm.nzb_ + for (size_type j = 0; j != sm.nCols(); ++j) + sm.nzb_[j] = k * x[j]; + + for (size_type row = 0; row != sm.nRows(); ++row) { + + value_type max_v = -std::numeric_limits::max(); + size_type *ind = sm.ind_[row], *ind_end = ind + sm.nnzr_[row]; + value_type *nz = sm.nz_[row]; + + for (size_type col = 0; col != sm.nCols(); ++col) { + + value_type p = sm.nzb_[col]; + if (ind != ind_end && col == *ind) + p += *nz++ * x[*ind++]; + if (p > max_v) + max_v = p; } + + *y++ = max_v; } + } - //-------------------------------------------------------------------------------- - /** - * Computes a smoothed version of all rows vec arg max prod, that is: - * - * for row in [0,nrows): - * y[row] = argmax((this[row,col] + k) * x[col], for col in [0,ncols)) - * - */ - template - static void smoothVecArgMaxProd(const SM& sm, - typename SM::value_type k, - InputIterator x, InputIterator x_end, - OutputIterator y, OutputIterator y_end) - { - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; - - { // Pre-conditions - NTA_ASSERT((size_type)(x_end - x) == sm.nCols()); - NTA_ASSERT((size_type)(y_end - y) == sm.nRows()); - } // End pre-conditions - - // Compute k * x only once, and cache result in sm.nzb_ - for (size_type j = 0; j != sm.nCols(); ++j) - sm.nzb_[j] = k * x[j]; - - for (size_type row = 0; row != sm.nRows(); ++row) { - - size_type arg_max = 0; - value_type max_v = - std::numeric_limits::max(); - size_type *ind = sm.ind_[row], *ind_end = ind + sm.nnzr_[row]; - value_type *nz = sm.nz_[row]; - - for (size_type col = 0; col != sm.nCols(); ++col) { - - value_type p = sm.nzb_[col]; - if (ind != ind_end && col == *ind) - p += *nz++ * x[*ind++]; - if (p > max_v) { - max_v = p; - arg_max = col; - } + //-------------------------------------------------------------------------------- + /** + * Computes a smoothed version of all rows vec arg max prod, that is: + * + * for row in [0,nrows): + * y[row] = argmax((this[row,col] + k) * x[col], for col in [0,ncols)) + * + */ + template + static void smoothVecArgMaxProd(const SM &sm, typename SM::value_type k, + InputIterator x, InputIterator x_end, + OutputIterator y, OutputIterator y_end) { + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; + + { // Pre-conditions + NTA_ASSERT((size_type)(x_end - x) == sm.nCols()); + NTA_ASSERT((size_type)(y_end - y) == sm.nRows()); + } // End pre-conditions + + // Compute k * x only once, and cache result in sm.nzb_ + for (size_type j = 0; j != sm.nCols(); ++j) + sm.nzb_[j] = k * x[j]; + + for (size_type row = 0; row != sm.nRows(); ++row) { + + size_type arg_max = 0; + value_type max_v = -std::numeric_limits::max(); + size_type *ind = sm.ind_[row], *ind_end = ind + sm.nnzr_[row]; + value_type *nz = sm.nz_[row]; + + for (size_type col = 0; col != sm.nCols(); ++col) { + + value_type p = sm.nzb_[col]; + if (ind != ind_end && col == *ind) + p += *nz++ * x[*ind++]; + if (p > max_v) { + max_v = p; + arg_max = col; } - - *y++ = arg_max; } - } - //-------------------------------------------------------------------------------- - //-------------------------------------------------------------------------------- - // LBP (Loopy Belief Propagation) - // - // This section contains algorithms that were written to speed-up LBP operations. - // - //-------------------------------------------------------------------------------- - - /** - * Adds a value to the non-zeros of a SparseMatrix. - * If minFloor > 0, values < minFloor are replaced by minFloor. - */ - template - static void addToNZOnly(SM& A, double val, typename SM::value_type minFloor =0) - { - { - NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); - } + *y++ = arg_max; + } + } - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; + //-------------------------------------------------------------------------------- + //-------------------------------------------------------------------------------- + // LBP (Loopy Belief Propagation) + // + // This section contains algorithms that were written to speed-up LBP + // operations. + // + //-------------------------------------------------------------------------------- - size_type M = A.nRows(); + /** + * Adds a value to the non-zeros of a SparseMatrix. + * If minFloor > 0, values < minFloor are replaced by minFloor. + */ + template + static void addToNZOnly(SM &A, double val, + typename SM::value_type minFloor = 0) { + { NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); } - if (minFloor == 0) { + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; - // Can introduce new zeros - for (size_type row = 0; row != M; ++row) { + size_type M = A.nRows(); - value_type *nz = A.nz_begin_(row); - value_type *nz_end = A.nz_end_(row); - value_type *nz_dst = nz; + if (minFloor == 0) { - for (; nz != nz_end; ++nz) { - value_type v = *nz + val; - if (!A.isZero_(v)) - *nz_dst++ = v; - } + // Can introduce new zeros + for (size_type row = 0; row != M; ++row) { - A.nnzr_[row] = nz_dst - A.nz_begin_(row); - } + value_type *nz = A.nz_begin_(row); + value_type *nz_end = A.nz_end_(row); + value_type *nz_dst = nz; - } else { // if minFloor != 0 - - nupic::Abs abs_f; - - // Doesn't change the number of non-zeros - for (size_type row = 0; row != M; ++row) { - - size_type *ind = A.ind_begin_(row); - size_type *ind_end = ind + A.nnzr_[row]; - value_type *nz = A.nz_begin_(row); - - for (; ind != ind_end; ++ind, ++nz) { - *nz += val; - if (abs_f(*nz) < minFloor) - *nz = minFloor; - } + for (; nz != nz_end; ++nz) { + value_type v = *nz + val; + if (!A.isZero_(v)) + *nz_dst++ = v; } - } - } - - //-------------------------------------------------------------------------------- - /** - * Adds vector to non-zeros only, down the columns. If minFloor is > 0, any - * element that drop below minFloor are replaced by minFloor. - */ - template - static void addToNZDownCols(SM& A, U begin, U end, - typename SM::value_type minFloor =0) - { - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; - { - NTA_ASSERT((size_type)(end - begin) == A.nCols()); - NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); + A.nnzr_[row] = nz_dst - A.nz_begin_(row); } - if (minFloor == 0) { + } else { // if minFloor != 0 - for (size_type row = 0; row != A.nRows(); ++row) { - size_type *ind = A.ind_begin_(row); - value_type *nz = A.nz_begin_(row); - value_type *nz_end = A.nz_end_(row); - for (; nz != nz_end; ++ind) { - *nz += *(begin + *ind); - if (!A.isZero_(*nz)) - ++nz; - } - A.nnzr_[row] = (size_type) (nz - A.nz_begin_(row)); - } + nupic::Abs abs_f; - } else { // if minFloor != 0 + // Doesn't change the number of non-zeros + for (size_type row = 0; row != M; ++row) { - nupic::Abs abs_f; + size_type *ind = A.ind_begin_(row); + size_type *ind_end = ind + A.nnzr_[row]; + value_type *nz = A.nz_begin_(row); - for (size_type row = 0; row != A.nRows(); ++row) { - size_type *ind = A.ind_begin_(row); - value_type *nz = A.nz_begin_(row); - value_type *nz_end = A.nz_end_(row); - for (; nz != nz_end; ++ind, ++nz) { - *nz += *(begin + *ind); - if (abs_f(*nz) < minFloor) - *nz = minFloor; - } + for (; ind != ind_end; ++ind, ++nz) { + *nz += val; + if (abs_f(*nz) < minFloor) + *nz = minFloor; } } } + } + + //-------------------------------------------------------------------------------- + /** + * Adds vector to non-zeros only, down the columns. If minFloor is > 0, any + * element that drop below minFloor are replaced by minFloor. + */ + template + static void addToNZDownCols(SM &A, U begin, U end, + typename SM::value_type minFloor = 0) { + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; - //-------------------------------------------------------------------------------- - /** - * Adds vector to non-zeros only, across the rows. If minFloor is > 0, any - * element that drop below minFloor are replaced by minFloor. - */ - template - static void addToNZAcrossRows(SM& A, U begin, U end, - typename SM::value_type minFloor =0) { - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; + NTA_ASSERT((size_type)(end - begin) == A.nCols()); + NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); + } - { - NTA_ASSERT((size_type)(end - begin) == A.nRows()); - NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); - } + if (minFloor == 0) { - if (minFloor == 0) { - - // Can introduce new zeros - for (size_type row = 0; row != A.nRows(); ++row) { - size_type *ind = A.ind_begin_(row); - value_type *nz = A.nz_begin_(row); - value_type *nz_end = A.nz_end_(row); - for (; nz != nz_end; ++ind) { - *nz += *begin; - if (!A.isZero_(*nz)) - ++nz; - } - A.nnzr_[row] = (size_type) (nz - A.nz_begin_(row)); - ++begin; + for (size_type row = 0; row != A.nRows(); ++row) { + size_type *ind = A.ind_begin_(row); + value_type *nz = A.nz_begin_(row); + value_type *nz_end = A.nz_end_(row); + for (; nz != nz_end; ++ind) { + *nz += *(begin + *ind); + if (!A.isZero_(*nz)) + ++nz; } + A.nnzr_[row] = (size_type)(nz - A.nz_begin_(row)); + } - } else { // if minFloor != 0 + } else { // if minFloor != 0 - nupic::Abs abs_f; + nupic::Abs abs_f; - // Doesn't change the number of non-zeros - for (size_type row = 0; row != A.nRows(); ++row) { - size_type *ind = A.ind_begin_(row); - value_type *nz = A.nz_begin_(row); - value_type *nz_end = A.nz_end_(row); - for (; nz != nz_end; ++ind, ++nz) { - *nz += *begin; - if (abs_f(*nz) < minFloor) - *nz = minFloor; - } - ++begin; + for (size_type row = 0; row != A.nRows(); ++row) { + size_type *ind = A.ind_begin_(row); + value_type *nz = A.nz_begin_(row); + value_type *nz_end = A.nz_end_(row); + for (; nz != nz_end; ++ind, ++nz) { + *nz += *(begin + *ind); + if (abs_f(*nz) < minFloor) + *nz = minFloor; } } } + } + + //-------------------------------------------------------------------------------- + /** + * Adds vector to non-zeros only, across the rows. If minFloor is > 0, any + * element that drop below minFloor are replaced by minFloor. + */ + template + static void addToNZAcrossRows(SM &A, U begin, U end, + typename SM::value_type minFloor = 0) { + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; - //-------------------------------------------------------------------------------- - /** - * Replaces non-zeros by 1 - non-zero value. This can introduce new zeros, but - * not any new zero. - * - * TODO: clarify. - */ - template - static void NZOneMinus(SM& A) { - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; + NTA_ASSERT((size_type)(end - begin) == A.nRows()); + NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); + } + + if (minFloor == 0) { + // Can introduce new zeros for (size_type row = 0; row != A.nRows(); ++row) { size_type *ind = A.ind_begin_(row); value_type *nz = A.nz_begin_(row); value_type *nz_end = A.nz_end_(row); for (; nz != nz_end; ++ind) { - *nz = (value_type) 1.0 - *nz; + *nz += *begin; if (!A.isZero_(*nz)) ++nz; } - A.nnzr_[row] = (size_type) (nz - A.nz_begin_(row)); - } - } - - //-------------------------------------------------------------------------------- - /** - * Adds the non-zeros of B to the non-zeros of A. Assumes that everywhere B has - * a non-zeros, A has a non-zero. Non-zeros of A that don't match up with a - * non-zero of B are unaffected. - * - * [[1 0 2] + [[3 0 0] = [[4 0 2] - * [0 3 0]] [0 1 0]] [0 4 0]] - */ - template - static void addNoAlloc(SM& A, const SM& B, typename SM::value_type minFloor =0) - { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - NTA_ASSERT(B.nonZeroIndicesIncluded(A)); - NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); + A.nnzr_[row] = (size_type)(nz - A.nz_begin_(row)); + ++begin; } - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; + } else { // if minFloor != 0 nupic::Abs abs_f; - size_type M = A.nRows(); - - for (size_type row = 0; row != M; ++row) { - - size_type *ind_a = A.ind_begin_(row); - size_type *ind_b = B.ind_begin_(row); - size_type *ind_b_end = B.ind_end_(row); - value_type *nz_a = A.nz_begin_(row); - value_type *nz_b = B.nz_begin_(row); - - while (ind_b != ind_b_end) { - if (*ind_a == *ind_b) { - value_type a = *nz_a; - value_type b = *nz_b; - a += b; - if (minFloor > 0 && abs_f(a) < minFloor) - a = minFloor; - *nz_a = a; - NTA_ASSERT(!A.isZero_(*nz_a)); - ++ind_a; ++nz_a; - ++ind_b; ++nz_b; - } else if (*ind_a < *ind_b) { - ++ind_a; ++nz_a; - } + // Doesn't change the number of non-zeros + for (size_type row = 0; row != A.nRows(); ++row) { + size_type *ind = A.ind_begin_(row); + value_type *nz = A.nz_begin_(row); + value_type *nz_end = A.nz_end_(row); + for (; nz != nz_end; ++ind, ++nz) { + *nz += *begin; + if (abs_f(*nz) < minFloor) + *nz = minFloor; } + ++begin; } } + } - //-------------------------------------------------------------------------------- - /** - * Subtracts the non-zeros of B from the non-zeros of A. Assumes that everywhere B has - * a non-zeros, A has a non-zero. Non-zeros of A that don't match up with a - * non-zero of B are unaffected. - * - * [[1 0 2] - [[3 0 0] = [[-2 0 2] - * [0 3 0]] [0 1 0]] [0 2 0]] - */ - template - static void subtractNoAlloc(SM& A, const SM& B, typename SM::value_type minFloor =0) - { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - NTA_ASSERT(B.nonZeroIndicesIncluded(A)); - NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); + //-------------------------------------------------------------------------------- + /** + * Replaces non-zeros by 1 - non-zero value. This can introduce new zeros, but + * not any new zero. + * + * TODO: clarify. + */ + template static void NZOneMinus(SM &A) { + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; + + for (size_type row = 0; row != A.nRows(); ++row) { + size_type *ind = A.ind_begin_(row); + value_type *nz = A.nz_begin_(row); + value_type *nz_end = A.nz_end_(row); + for (; nz != nz_end; ++ind) { + *nz = (value_type)1.0 - *nz; + if (!A.isZero_(*nz)) + ++nz; } + A.nnzr_[row] = (size_type)(nz - A.nz_begin_(row)); + } + } - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; + //-------------------------------------------------------------------------------- + /** + * Adds the non-zeros of B to the non-zeros of A. Assumes that everywhere B + * has a non-zeros, A has a non-zero. Non-zeros of A that don't match up with + * a non-zero of B are unaffected. + * + * [[1 0 2] + [[3 0 0] = [[4 0 2] + * [0 3 0]] [0 1 0]] [0 4 0]] + */ + template + static void addNoAlloc(SM &A, const SM &B, + typename SM::value_type minFloor = 0) { + { + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); + NTA_ASSERT(B.nonZeroIndicesIncluded(A)); + NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); + } - nupic::Abs abs_f; + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; - size_type M = A.nRows(); + nupic::Abs abs_f; - for (size_type row = 0; row != M; ++row) { + size_type M = A.nRows(); - size_type *ind_a = A.ind_begin_(row); - size_type *ind_b = B.ind_begin_(row); - size_type *ind_b_end = B.ind_end_(row); - value_type *nz_a = A.nz_begin_(row); - value_type *nz_b = B.nz_begin_(row); + for (size_type row = 0; row != M; ++row) { - while (ind_b != ind_b_end) { - if (*ind_a == *ind_b) { - value_type a = *nz_a; - value_type b = *nz_b; - a -= b; - if (minFloor > 0 && abs_f(a) < minFloor) - a = minFloor; - *nz_a = a; - NTA_ASSERT(!A.isZero_(*nz_a)); - ++ind_a; ++nz_a; - ++ind_b; ++nz_b; - } else if (*ind_a < *ind_b) { - ++ind_a; ++nz_a; - } + size_type *ind_a = A.ind_begin_(row); + size_type *ind_b = B.ind_begin_(row); + size_type *ind_b_end = B.ind_end_(row); + value_type *nz_a = A.nz_begin_(row); + value_type *nz_b = B.nz_begin_(row); + + while (ind_b != ind_b_end) { + if (*ind_a == *ind_b) { + value_type a = *nz_a; + value_type b = *nz_b; + a += b; + if (minFloor > 0 && abs_f(a) < minFloor) + a = minFloor; + *nz_a = a; + NTA_ASSERT(!A.isZero_(*nz_a)); + ++ind_a; + ++nz_a; + ++ind_b; + ++nz_b; + } else if (*ind_a < *ind_b) { + ++ind_a; + ++nz_a; } } } + } - //-------------------------------------------------------------------------------- - /** - * Copies the values of the non-zeros of B into A, where A and B have a non-zero - * in the same location. Leaves the other non-zeros of A unchanged. - * - * TODO: move to SM copy, with parameter to re-use memory - */ - template - static void assignNoAlloc(SM& A, const SM&B) + //-------------------------------------------------------------------------------- + /** + * Subtracts the non-zeros of B from the non-zeros of A. Assumes that + * everywhere B has a non-zeros, A has a non-zero. Non-zeros of A that don't + * match up with a non-zero of B are unaffected. + * + * [[1 0 2] - [[3 0 0] = [[-2 0 2] + * [0 3 0]] [0 1 0]] [0 2 0]] + */ + template + static void subtractNoAlloc(SM &A, const SM &B, + typename SM::value_type minFloor = 0) { { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - } + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); + NTA_ASSERT(B.nonZeroIndicesIncluded(A)); + NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); + } - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; - const size_type M = A.nRows(); + nupic::Abs abs_f; - for (size_type row = 0; row != M; ++row) { + size_type M = A.nRows(); - size_type *ind_a = A.ind_begin_(row); - size_type *ind_b = B.ind_begin_(row); - value_type *nz_a = A.nz_begin_(row); - value_type *nz_a_end = A.nz_end_(row); - value_type *nz_b = B.nz_begin_(row); - value_type *nz_b_end = B.nz_end_(row); - - while (nz_a != nz_a_end && nz_b != nz_b_end) - if (*ind_a == *ind_b) { - *nz_a = *nz_b; - ++ind_a; ++ind_b; - ++nz_a; ++nz_b; - } else if (*ind_a < *ind_b) { - ++ind_a; ++nz_a; - } else if (*ind_b < *ind_a) { - ++ind_b; ++nz_b; - } + for (size_type row = 0; row != M; ++row) { + + size_type *ind_a = A.ind_begin_(row); + size_type *ind_b = B.ind_begin_(row); + size_type *ind_b_end = B.ind_end_(row); + value_type *nz_a = A.nz_begin_(row); + value_type *nz_b = B.nz_begin_(row); + + while (ind_b != ind_b_end) { + if (*ind_a == *ind_b) { + value_type a = *nz_a; + value_type b = *nz_b; + a -= b; + if (minFloor > 0 && abs_f(a) < minFloor) + a = minFloor; + *nz_a = a; + NTA_ASSERT(!A.isZero_(*nz_a)); + ++ind_a; + ++nz_a; + ++ind_b; + ++nz_b; + } else if (*ind_a < *ind_b) { + ++ind_a; + ++nz_a; + } } } + } - //-------------------------------------------------------------------------------- - /** - * Copies the values of the non-zeros of B into A, only where A and B have a - * non-zero in the same location. The other non-zeros of A are left unchanged. - * SM2 is assumed to be a binary matrix. - * - * TODO: maybe a constructor of SM, or a copy method with an argument to re-use - * the memory. - */ - template - static void assignNoAllocFromBinary(SM& A, const SM01&B) + //-------------------------------------------------------------------------------- + /** + * Copies the values of the non-zeros of B into A, where A and B have a + * non-zero in the same location. Leaves the other non-zeros of A unchanged. + * + * TODO: move to SM copy, with parameter to re-use memory + */ + template static void assignNoAlloc(SM &A, const SM &B) { { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - } - - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); + } - const size_type M = A.nRows(); + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; + + const size_type M = A.nRows(); + + for (size_type row = 0; row != M; ++row) { + + size_type *ind_a = A.ind_begin_(row); + size_type *ind_b = B.ind_begin_(row); + value_type *nz_a = A.nz_begin_(row); + value_type *nz_a_end = A.nz_end_(row); + value_type *nz_b = B.nz_begin_(row); + value_type *nz_b_end = B.nz_end_(row); + + while (nz_a != nz_a_end && nz_b != nz_b_end) + if (*ind_a == *ind_b) { + *nz_a = *nz_b; + ++ind_a; + ++ind_b; + ++nz_a; + ++nz_b; + } else if (*ind_a < *ind_b) { + ++ind_a; + ++nz_a; + } else if (*ind_b < *ind_a) { + ++ind_b; + ++nz_b; + } + } + } - for (size_type row = 0; row != M; ++row) { + //-------------------------------------------------------------------------------- + /** + * Copies the values of the non-zeros of B into A, only where A and B have a + * non-zero in the same location. The other non-zeros of A are left unchanged. + * SM2 is assumed to be a binary matrix. + * + * TODO: maybe a constructor of SM, or a copy method with an argument to + * re-use the memory. + */ + template + static void assignNoAllocFromBinary(SM &A, const SM01 &B) { + { + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); + } - size_type *ind_a = A.ind_begin_(row); - typename SM01::Row::const_iterator ind_b = B.ind_begin_(row); - typename SM01::Row::const_iterator ind_b_end = B.ind_end_(row); - value_type *nz_a = A.nz_begin_(row); - value_type *nz_a_end = A.nz_end_(row); - - while (nz_a != nz_a_end && ind_b != ind_b_end) - if (*ind_a == *ind_b) { - *nz_a = (value_type) 1; - ++ind_a; ++ind_b; - ++nz_a; - } else if (*ind_a < *ind_b) { - ++ind_a; ++nz_a; - } else if (*ind_b < *ind_a) { - ++ind_b; + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; + + const size_type M = A.nRows(); + + for (size_type row = 0; row != M; ++row) { + + size_type *ind_a = A.ind_begin_(row); + typename SM01::Row::const_iterator ind_b = B.ind_begin_(row); + typename SM01::Row::const_iterator ind_b_end = B.ind_end_(row); + value_type *nz_a = A.nz_begin_(row); + value_type *nz_a_end = A.nz_end_(row); + + while (nz_a != nz_a_end && ind_b != ind_b_end) + if (*ind_a == *ind_b) { + *nz_a = (value_type)1; + ++ind_a; + ++ind_b; + ++nz_a; + } else if (*ind_a < *ind_b) { + ++ind_a; + ++nz_a; + } else if (*ind_b < *ind_a) { + ++ind_b; + } + } + } + //-------------------------------------------------------------------------------- + /** + * Adds a constant value on nonzeros of one SparseMatrix(B) to another (A). + */ + template + static void addConstantOnNonZeros(SM &A, const SM01 &B, + typename SM::value_type cval) { + { // Pre-conditions + NTA_ASSERT(A.nRows() == B.nRows()) + << "add: Wrong number of rows: " << A.nRows() << " and " << B.nRows(); + NTA_ASSERT(A.nCols() == B.nCols()) + << "add: Wrong number of columns: " << A.nCols() << " and " + << B.nCols(); + } // End pre-conditions + + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; + + const size_type nrows = A.nRows(); + for (size_type row = 0; row != nrows; ++row) { + + size_type *ind = A.ind_begin_(row); + size_type *ind_end = A.ind_end_(row); + value_type *nz = A.nz_begin_(row); + + typename SM01::Row::const_iterator ind_b = B.ind_begin_(row); + typename SM01::Row::const_iterator ind_b_end = B.ind_end_(row); + + std::vector indb_; + std::vector nzb_; + + while (ind != ind_end && ind_b != ind_b_end) { + if (*ind == *ind_b) { + value_type val = *nz++ + cval; + if (!A.isZero_(val)) { + indb_.push_back(*ind); + nzb_.push_back(val); + } + ++ind; + ++ind_b; + } else if (*ind < *ind_b) { + indb_.push_back(*ind++); + nzb_.push_back(*nz++); + } else if (*ind_b < *ind) { + if (!A.isZero_(cval)) { + indb_.push_back(*ind_b++); + nzb_.push_back(cval); } + } + } + + while (ind != ind_end) { + indb_.push_back(*ind++); + nzb_.push_back(*nz++); + } + + while (ind_b != ind_b_end) { + if (!A.isZero_(cval)) { + indb_.push_back(*ind_b++); + nzb_.push_back(cval); + } } + + A.setRowFromSparse(row, indb_.begin(), indb_.end(), nzb_.begin()); } - //-------------------------------------------------------------------------------- - /** - * Adds a constant value on nonzeros of one SparseMatrix(B) to another (A). - */ - template - static void addConstantOnNonZeros(SM& A, const SM01& B, - typename SM::value_type cval) - { - { // Pre-conditions - NTA_ASSERT(A.nRows() == B.nRows()) - << "add: Wrong number of rows: " << A.nRows() - << " and " << B.nRows(); - NTA_ASSERT(A.nCols() == B.nCols()) - << "add: Wrong number of columns: " << A.nCols() - << " and " << B.nCols(); - } // End pre-conditions - - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; - - const size_type nrows = A.nRows(); - for (size_type row = 0; row != nrows; ++row) { - - size_type *ind = A.ind_begin_(row); - size_type *ind_end = A.ind_end_(row); - value_type *nz = A.nz_begin_(row); - - typename SM01::Row::const_iterator ind_b = B.ind_begin_(row); - typename SM01::Row::const_iterator ind_b_end = B.ind_end_(row); - - std::vector indb_; - std::vector nzb_; - - while (ind != ind_end && ind_b != ind_b_end) { - if (*ind == *ind_b) { - value_type val = *nz++ + cval; - if (!A.isZero_(val)) { - indb_.push_back(*ind); - nzb_.push_back(val); - } - ++ind; ++ind_b; - } else if (*ind < *ind_b) { - indb_.push_back(*ind++); - nzb_.push_back(*nz++); - } else if (*ind_b < *ind) { - if (!A.isZero_(cval)) { - indb_.push_back(*ind_b++); - nzb_.push_back(cval); - } - } - } - - while (ind != ind_end) { - indb_.push_back(*ind++); - nzb_.push_back(*nz++); - } - - while (ind_b != ind_b_end) { - if (!A.isZero_(cval)) { - indb_.push_back(*ind_b++); - nzb_.push_back(cval); - } - } - - A.setRowFromSparse(row, indb_.begin(), indb_.end(), nzb_.begin()); - } - } - - //-------------------------------------------------------------------------------- - /** - * Computes the sum of two SMs that are in log space. - * A = log(exp(A) + exp(B)), but only for the non-zeros of B. - * A and B are already in log space. - * A has non-zeros everywhere that B does. - * This assumes that the operation does not introduce new zeros. - * Note: we follow the non-zeros of B, which can be less than the non-zeros - * of A. - * If minFloor > 0, any value that drops below minFloor becomes minFloor. - */ - template - static void logSumNoAlloc(SM& A, const SM& B, typename SM::value_type minFloor =0) + } + + //-------------------------------------------------------------------------------- + /** + * Computes the sum of two SMs that are in log space. + * A = log(exp(A) + exp(B)), but only for the non-zeros of B. + * A and B are already in log space. + * A has non-zeros everywhere that B does. + * This assumes that the operation does not introduce new zeros. + * Note: we follow the non-zeros of B, which can be less than the non-zeros + * of A. + * If minFloor > 0, any value that drops below minFloor becomes minFloor. + */ + template + static void logSumNoAlloc(SM &A, const SM &B, + typename SM::value_type minFloor = 0) { { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - NTA_ASSERT(B.nonZeroIndicesIncluded(A)); - NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); - } + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); + NTA_ASSERT(B.nonZeroIndicesIncluded(A)); + NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); + } - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; - nupic::Exp exp_f; - nupic::Log log_f; - nupic::Log1p log1p_f; - nupic::Abs abs_f; + nupic::Exp exp_f; + nupic::Log log_f; + nupic::Log1p log1p_f; + nupic::Abs abs_f; - size_type M = A.nRows(); - value_type minExp = log_f(std::numeric_limits::epsilon()); + size_type M = A.nRows(); + value_type minExp = log_f(std::numeric_limits::epsilon()); - for (size_type row = 0; row != M; ++row) { + for (size_type row = 0; row != M; ++row) { - size_type *ind_a = A.ind_begin_(row); - size_type *ind_b = B.ind_begin_(row); - size_type *ind_b_end = B.ind_end_(row); - value_type *nz_a = A.nz_begin_(row); - value_type *nz_b = B.nz_begin_(row); + size_type *ind_a = A.ind_begin_(row); + size_type *ind_b = B.ind_begin_(row); + size_type *ind_b_end = B.ind_end_(row); + value_type *nz_a = A.nz_begin_(row); + value_type *nz_b = B.nz_begin_(row); - while (ind_b != ind_b_end) { - if (*ind_a == *ind_b) { - value_type a = *nz_a; - value_type b = *nz_b; - if (a < b) - std::swap(a,b); - value_type d = b - a; - if (d >= minExp) { - a += log1p_f(exp_f(d)); - if (minFloor > 0 && abs_f(a) < minFloor) - a = minFloor; - *nz_a = a; - } else { - *nz_a = a; - } - NTA_ASSERT(!A.isZero_(*nz_a)); - ++ind_a; ++nz_a; - ++ind_b; ++nz_b; - } else if (*ind_a < *ind_b) { - ++ind_a; ++nz_a; - } - } - } - } - - //-------------------------------------------------------------------------------- - /** - * Adds a constant to the non-zeros of A in log space. - * Assumes that no new zeros are introduced. - */ - template - static void - logAddValNoAlloc(SM& A, - typename SM::value_type val, typename SM::value_type minFloor =0) - { - { - NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); - } - - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; - - nupic::Exp exp_f; - nupic::Log log_f; - nupic::Log1p log1p_f; - nupic::Abs abs_f; - - size_type M = A.nRows(); - value_type minExp = log_f(std::numeric_limits::epsilon()); - value_type b; - - for (size_type row = 0; row != M; ++row) { - - size_type *ind_a = A.ind_begin_(row); - size_type *ind_a_end = A.ind_end_(row); - value_type *nz_a = A.nz_begin_(row); - - while (ind_a != ind_a_end) { + while (ind_b != ind_b_end) { + if (*ind_a == *ind_b) { value_type a = *nz_a; - - // Put smaller value in b, larger in a - if (a < val) { - b = a; - a = val; - } else { - b = val; - } + value_type b = *nz_b; + if (a < b) + std::swap(a, b); value_type d = b - a; if (d >= minExp) { a += log1p_f(exp_f(d)); - if (minFloor > 0 && abs_f(a) < minFloor) + if (minFloor > 0 && abs_f(a) < minFloor) a = minFloor; *nz_a = a; - } else + } else { *nz_a = a; + } NTA_ASSERT(!A.isZero_(*nz_a)); - ++ind_a; ++nz_a; - } + ++ind_a; + ++nz_a; + ++ind_b; + ++nz_b; + } else if (*ind_a < *ind_b) { + ++ind_a; + ++nz_a; + } } } + } - //-------------------------------------------------------------------------------- - /** - * Computes the diff of two SMs that are in log space. - * A = log(exp(A) - exp(B)), but only for the non-zeros of B. - * A and B are already in log space. - * A has non-zeros everywhere that B does. - * A > B in all non-zeros. - * This assumes that the operation does not introduce new zeros. - * Note: we follow the non-zeros of B, which can be less than the non-zeros - * of A. - * If minFloor > 0, any value that drops below minFloor becomes minFloor. - */ - template - static void logDiffNoAlloc(SM& A, const SM& B, typename SM::value_type minFloor =0) - { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - NTA_ASSERT(B.nonZeroIndicesIncluded(A)); - NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); + //-------------------------------------------------------------------------------- + /** + * Adds a constant to the non-zeros of A in log space. + * Assumes that no new zeros are introduced. + */ + template + static void logAddValNoAlloc(SM &A, typename SM::value_type val, + typename SM::value_type minFloor = 0) { + { NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); } + + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; + + nupic::Exp exp_f; + nupic::Log log_f; + nupic::Log1p log1p_f; + nupic::Abs abs_f; + + size_type M = A.nRows(); + value_type minExp = log_f(std::numeric_limits::epsilon()); + value_type b; + + for (size_type row = 0; row != M; ++row) { + + size_type *ind_a = A.ind_begin_(row); + size_type *ind_a_end = A.ind_end_(row); + value_type *nz_a = A.nz_begin_(row); + + while (ind_a != ind_a_end) { + value_type a = *nz_a; + + // Put smaller value in b, larger in a + if (a < val) { + b = a; + a = val; + } else { + b = val; + } + value_type d = b - a; + if (d >= minExp) { + a += log1p_f(exp_f(d)); + if (minFloor > 0 && abs_f(a) < minFloor) + a = minFloor; + *nz_a = a; + } else + *nz_a = a; + NTA_ASSERT(!A.isZero_(*nz_a)); + ++ind_a; + ++nz_a; } + } + } - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; - - // Important to use double here, because in float, there can be - // cancelation in log(1 - exp(b-a)), when a is very close to b. - nupic::Exp exp_f; - nupic::Log log_f; - nupic::Log1p log1p_f; - nupic::Abs abs_f; - - size_type M = A.nRows(); - value_type minExp = log_f(std::numeric_limits::epsilon()); - - // Two log values that are this close to each other should generate a difference - // of 0, which is -inf in log space, which we want to avoid - double minDiff = -std::numeric_limits::epsilon(); - value_type logOfZero = -1.0/std::numeric_limits::epsilon(); - - for (size_type row = 0; row != M; ++row) { + //-------------------------------------------------------------------------------- + /** + * Computes the diff of two SMs that are in log space. + * A = log(exp(A) - exp(B)), but only for the non-zeros of B. + * A and B are already in log space. + * A has non-zeros everywhere that B does. + * A > B in all non-zeros. + * This assumes that the operation does not introduce new zeros. + * Note: we follow the non-zeros of B, which can be less than the non-zeros + * of A. + * If minFloor > 0, any value that drops below minFloor becomes minFloor. + */ + template + static void logDiffNoAlloc(SM &A, const SM &B, + typename SM::value_type minFloor = 0) { + { + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); + NTA_ASSERT(B.nonZeroIndicesIncluded(A)); + NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); + } - size_type *ind_a = A.ind_begin_(row); - size_type *ind_b = B.ind_begin_(row); - size_type *ind_b_end = B.ind_end_(row); - value_type *nz_a = A.nz_begin_(row); - value_type *nz_b = B.nz_begin_(row); + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; - while (ind_b != ind_b_end) { - if (*ind_a == *ind_b) { - double a = *nz_a; - double b = *nz_b; - NTA_ASSERT(a >= b); - double d = b - a; - // If the values are too close to each other, generate log of 0 manually - // We know d <= 0 at this point. - if (d >= minDiff) - *nz_a = logOfZero; - else if (d >= minExp) { - a += log1p_f(-exp_f(d)); - if (minFloor > 0 && abs_f(a) < minFloor) - a = minFloor; - *nz_a = (value_type) a; - } else { - *nz_a = (value_type) a; - } - NTA_ASSERT(!A.isZero_(*nz_a)); - ++ind_a; ++nz_a; - ++ind_b; ++nz_b; - } else if (*ind_a < *ind_b) { - ++ind_a; ++nz_a; + // Important to use double here, because in float, there can be + // cancelation in log(1 - exp(b-a)), when a is very close to b. + nupic::Exp exp_f; + nupic::Log log_f; + nupic::Log1p log1p_f; + nupic::Abs abs_f; + + size_type M = A.nRows(); + value_type minExp = log_f(std::numeric_limits::epsilon()); + + // Two log values that are this close to each other should generate a + // difference + // of 0, which is -inf in log space, which we want to avoid + double minDiff = -std::numeric_limits::epsilon(); + value_type logOfZero = -1.0 / std::numeric_limits::epsilon(); + + for (size_type row = 0; row != M; ++row) { + + size_type *ind_a = A.ind_begin_(row); + size_type *ind_b = B.ind_begin_(row); + size_type *ind_b_end = B.ind_end_(row); + value_type *nz_a = A.nz_begin_(row); + value_type *nz_b = B.nz_begin_(row); + + while (ind_b != ind_b_end) { + if (*ind_a == *ind_b) { + double a = *nz_a; + double b = *nz_b; + NTA_ASSERT(a >= b); + double d = b - a; + // If the values are too close to each other, generate log of 0 + // manually We know d <= 0 at this point. + if (d >= minDiff) + *nz_a = logOfZero; + else if (d >= minExp) { + a += log1p_f(-exp_f(d)); + if (minFloor > 0 && abs_f(a) < minFloor) + a = minFloor; + *nz_a = (value_type)a; + } else { + *nz_a = (value_type)a; } + NTA_ASSERT(!A.isZero_(*nz_a)); + ++ind_a; + ++nz_a; + ++ind_b; + ++nz_b; + } else if (*ind_a < *ind_b) { + ++ind_a; + ++nz_a; } } } + } - //-------------------------------------------------------------------------------- - /** - * Algorithm to compute piPrime in loopy belief propagation. - * - * The net operation performed is prod(col)/element, but it is performed in log mode - * and the mat argument is assumed to have already been converted to log mode. - * All values within mat are between 0 and 1 in normal space (-inf and -epsilon in log - * space). - * - * This does a sum of each column, then places colSum-element into each - * location, insuring that no new zeros are introduced. Any result that - * would have computed to 0 (within max_floor) will be replaced with max_floor - */ - template - static void LBP_piPrime(SM& mat, typename SM::value_type max_floor) - { - { - NTA_ASSERT(max_floor < 0); - } + //-------------------------------------------------------------------------------- + /** + * Algorithm to compute piPrime in loopy belief propagation. + * + * The net operation performed is prod(col)/element, but it is performed in + * log mode and the mat argument is assumed to have already been converted to + * log mode. All values within mat are between 0 and 1 in normal space (-inf + * and -epsilon in log space). + * + * This does a sum of each column, then places colSum-element into each + * location, insuring that no new zeros are introduced. Any result that + * would have computed to 0 (within max_floor) will be replaced with + * max_floor + */ + template + static void LBP_piPrime(SM &mat, typename SM::value_type max_floor) { + { NTA_ASSERT(max_floor < 0); } - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; - size_type M = mat.nRows(); - size_type N = mat.nCols(); + size_type M = mat.nRows(); + size_type N = mat.nCols(); - nupic::Abs abs_f; + nupic::Abs abs_f; - std::fill(mat.nzb_, mat.nzb_ + N, (value_type) 0); + std::fill(mat.nzb_, mat.nzb_ + N, (value_type)0); - // Compute the column sums, place them into mat.nzb_ - for (size_type row = 0; row != M; ++row) { + // Compute the column sums, place them into mat.nzb_ + for (size_type row = 0; row != M; ++row) { - if (mat.nnzr_[row] == 0) - continue; + if (mat.nnzr_[row] == 0) + continue; - size_type *ind = mat.ind_begin_(row); - size_type *ind_end = mat.ind_end_(row); - value_type *nz = mat.nz_begin_(row); + size_type *ind = mat.ind_begin_(row); + size_type *ind_end = mat.ind_end_(row); + value_type *nz = mat.nz_begin_(row); - for (; ind != ind_end; ++ind, ++nz) - mat.nzb_[*ind] += *nz; - } + for (; ind != ind_end; ++ind, ++nz) + mat.nzb_[*ind] += *nz; + } - // Replace each element with colSum - element - for (size_type row = 0; row != M; ++row) { - - if (mat.nnzr_[row] == 0) - continue; - - size_type *ind = mat.ind_begin_(row); - size_type *ind_end = mat.ind_end_(row); - value_type *nz = mat.nz_begin_(row); - - value_type absFloor = abs_f(max_floor); - - for (; ind != ind_end; ++ind, ++nz) { + // Replace each element with colSum - element + for (size_type row = 0; row != M; ++row) { - value_type v = mat.nzb_[*ind] - *nz; - - if (abs_f(v) < absFloor) - v = max_floor; - - *nz = v; - } - } - } + if (mat.nnzr_[row] == 0) + continue; - //-------------------------------------------------------------------------------- - /** - * Copies the values of the non-zeros of B into A, only where A and B have a - * non-zero in the same location. The other non-zeros of A are left unchanged. - */ - template - static void assignNoAlloc(SM& A, const STR3F& B, typename SM::size_type s) - { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - } + size_type *ind = mat.ind_begin_(row); + size_type *ind_end = mat.ind_end_(row); + value_type *nz = mat.nz_begin_(row); - typedef typename SM::size_type size_type; - typedef typename STR3F::col_index_type col_index_type; - typedef typename SM::value_type value_type; + value_type absFloor = abs_f(max_floor); - const size_type M = A.nRows(); + for (; ind != ind_end; ++ind, ++nz) { - for (size_type row = 0; row != M; ++row) { + value_type v = mat.nzb_[*ind] - *nz; - size_type *ind_a = A.ind_begin_(row); - col_index_type *ind_b = B.ind_begin_(row); - value_type *nz_a = A.nz_begin_(row); - value_type *nz_a_end = A.nz_end_(row); - value_type *nz_b = B.nz_begin_(s, row); - value_type *nz_b_end = B.nz_end_(s, row); - - while (nz_a != nz_a_end && nz_b != nz_b_end) - if (*ind_a == (size_type) *ind_b) { - *nz_a = *nz_b; - ++ind_a; ++ind_b; - ++nz_a; ++nz_b; - } else if (*ind_a < (size_type) *ind_b) { - ++ind_a; ++nz_a; - } else if ((size_type) *ind_b < *ind_a) { - ++ind_b; ++nz_b; - } + if (abs_f(v) < absFloor) + v = max_floor; + + *nz = v; } } + } - //-------------------------------------------------------------------------------- - /** - * Computes the sum in log space of A and B, where A is a slice of a STR3F - * and B is a SM. The operation is: - * a = log(exp(a) + exp(b)), - * where a is a non-zero of slice s of A, and b is the corresponding non-zero of B - * (in the same location). - * - * The number of non-zeros of in A is unchanged, and if the absolute value - * of a non-zero would fall below minFloor, it is replaced by minFloor. - * A and B need to have the same dimensions. - */ - template - static void logSumNoAlloc(STR3F& A, typename SM::size_type s, - const SM& B, typename SM::value_type minFloor =0) + //-------------------------------------------------------------------------------- + /** + * Copies the values of the non-zeros of B into A, only where A and B have a + * non-zero in the same location. The other non-zeros of A are left unchanged. + */ + template + static void assignNoAlloc(SM &A, const STR3F &B, typename SM::size_type s) { { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); - } - - typedef typename STR3F::col_index_type col_index_type; - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; - - nupic::Exp exp_f; - nupic::Log log_f; - nupic::Log1p log1p_f; - nupic::Abs abs_f; - - size_type M = (size_type) A.nRows(); - value_type minExp = log_f(std::numeric_limits::epsilon()); - - if (nupic::Epsilon < minFloor) { - - for (size_type row = 0; row != M; ++row) { - - col_index_type *ind_a = A.ind_begin_(row); - size_type *ind_b = B.ind_begin_(row); - size_type *ind_b_end = B.ind_end_(row); - value_type *nz_a = A.nz_begin_(s, row); - value_type *nz_b = B.nz_begin_(row); - - while (ind_b != ind_b_end) { - if ((size_type) *ind_a == *ind_b) { - value_type a = *nz_a; - value_type b = *nz_b; - if (a < b) - std::swap(a,b); - value_type d = b - a; - if (d >= minExp) { - a += log1p_f(exp_f(d)); - if (abs_f(a) < minFloor) - a = minFloor; - *nz_a = a; - } else { - *nz_a = a; - } - ++ind_a; ++nz_a; - ++ind_b; ++nz_b; - } else if ((size_type) *ind_a < *ind_b) { - ++ind_a; ++nz_a; - } - } - } + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); + } - } else { // minFloor <= nupic::Epsilon, i.e. essentially minFloor == 0 - - for (size_type row = 0; row != M; ++row) { - - col_index_type *ind_a = A.ind_begin_(row); - size_type *ind_b = B.ind_begin_(row); - size_type *ind_b_end = B.ind_end_(row); - value_type *nz_a = A.nz_begin_(s, row); - value_type *nz_b = B.nz_begin_(row); - - while (ind_b != ind_b_end) { - if ((size_type) *ind_a == *ind_b) { - value_type a = *nz_a; - value_type b = *nz_b; - if (a < b) - std::swap(a,b); - value_type d = b - a; - if (d >= minExp) { - a += log1p_f(exp_f(d)); - *nz_a = a; - } else { - *nz_a = a; - } - ++ind_a; ++nz_a; - ++ind_b; ++nz_b; - } else if ((size_type) *ind_a < *ind_b) { - ++ind_a; ++nz_a; - } - } + typedef typename SM::size_type size_type; + typedef typename STR3F::col_index_type col_index_type; + typedef typename SM::value_type value_type; + + const size_type M = A.nRows(); + + for (size_type row = 0; row != M; ++row) { + + size_type *ind_a = A.ind_begin_(row); + col_index_type *ind_b = B.ind_begin_(row); + value_type *nz_a = A.nz_begin_(row); + value_type *nz_a_end = A.nz_end_(row); + value_type *nz_b = B.nz_begin_(s, row); + value_type *nz_b_end = B.nz_end_(s, row); + + while (nz_a != nz_a_end && nz_b != nz_b_end) + if (*ind_a == (size_type)*ind_b) { + *nz_a = *nz_b; + ++ind_a; + ++ind_b; + ++nz_a; + ++nz_b; + } else if (*ind_a < (size_type)*ind_b) { + ++ind_a; + ++nz_a; + } else if ((size_type)*ind_b < *ind_a) { + ++ind_b; + ++nz_b; } - } } + } - //-------------------------------------------------------------------------------- - /** - * Computes the diff in log space of A and B, where A is a slice of a STR3F - * and B is a SM. The operation is: - * a = log(exp(a) - exp(b)), - * where a is a non-zero of slice s of A, and b is the corresponding non-zero of B - * (in the same location). - * - * The number of non-zeros of in A is unchanged, and if the absolute value - * of a non-zero would fall below minFloor, it is replaced by minFloor. - * A and B need to have the same dimensions. - */ - template - static void logDiffNoAlloc(STR3F& A, typename SM::size_type s, - const SM& B, typename SM::value_type minFloor =0) + //-------------------------------------------------------------------------------- + /** + * Computes the sum in log space of A and B, where A is a slice of a STR3F + * and B is a SM. The operation is: + * a = log(exp(a) + exp(b)), + * where a is a non-zero of slice s of A, and b is the corresponding non-zero + * of B (in the same location). + * + * The number of non-zeros of in A is unchanged, and if the absolute value + * of a non-zero would fall below minFloor, it is replaced by minFloor. + * A and B need to have the same dimensions. + */ + template + static void logSumNoAlloc(STR3F &A, typename SM::size_type s, const SM &B, + typename SM::value_type minFloor = 0) { { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); - } + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); + NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); + } - typedef typename STR3F::col_index_type col_index_type; - typedef typename SM::size_type size_type; - typedef typename SM::value_type value_type; - - // Important to use double here, because in float, there can be - // cancelation in log(1 - exp(b-a)), when a is very close to b. - nupic::Exp exp_f; - nupic::Log log_f; - nupic::Log1p log1p_f; - nupic::Abs abs_f; - - size_type M = (size_type) A.nRows(); - value_type minExp = log_f(std::numeric_limits::epsilon()); - - // Two log values that are this close to each other should generate a difference - // of 0, which is -inf in log space, which we want to avoid - double minDiff = -std::numeric_limits::epsilon(); - value_type logOfZero = ((value_type)-1.0)/std::numeric_limits::epsilon(); - - if (nupic::Epsilon < minFloor) { - - for (size_type row = 0; row != M; ++row) { - - col_index_type *ind_a = A.ind_begin_(row); - size_type *ind_b = B.ind_begin_(row); - size_type *ind_b_end = B.ind_end_(row); - value_type *nz_a = A.nz_begin_(s, row); - value_type *nz_b = B.nz_begin_(row); - - while (ind_b != ind_b_end) { - if ((size_type) *ind_a == *ind_b) { - double a = *nz_a; - double b = *nz_b; - NTA_ASSERT(a >= b); - double d = b - a; - // If the values are too close to each other, generate log of 0 manually - // We know d <= 0 at this point. - if (d >= minDiff) - *nz_a = logOfZero; - else if (d >= minExp) { - a += log1p_f(-exp_f(d)); - if (abs_f(a) < minFloor) - a = minFloor; - *nz_a = (value_type) a; - } else { - *nz_a = (value_type) a; - } - ++ind_a; ++nz_a; - ++ind_b; ++nz_b; - } else if ((size_type) *ind_a < *ind_b) { - ++ind_a; ++nz_a; - } - } - } + typedef typename STR3F::col_index_type col_index_type; + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; - } else { // minFloor <= nupic::Epsilon, i.e. essentially minFloor == 0 - - for (size_type row = 0; row != M; ++row) { - - col_index_type *ind_a = A.ind_begin_(row); - size_type *ind_b = B.ind_begin_(row); - size_type *ind_b_end = B.ind_end_(row); - value_type *nz_a = A.nz_begin_(s, row); - value_type *nz_b = B.nz_begin_(row); - - while (ind_b != ind_b_end) { - if ((size_type) *ind_a == *ind_b) { - double a = *nz_a; - double b = *nz_b; - NTA_ASSERT(a >= b); - double d = b - a; - // If the values are too close to each other, generate log of 0 manually - // We know d <= 0 at this point. - if (d >= minDiff) - *nz_a = logOfZero; - else if (d >= minExp) { - a += log1p_f(-exp_f(d)); - *nz_a = (value_type) a; - } else { - *nz_a = (value_type) a; - } - ++ind_a; ++nz_a; - ++ind_b; ++nz_b; - } else if ((size_type) *ind_a < *ind_b) { - ++ind_a; ++nz_a; - } - } - } - } - } + nupic::Exp exp_f; + nupic::Log log_f; + nupic::Log1p log1p_f; + nupic::Abs abs_f; - //-------------------------------------------------------------------------------- - /** - * Updates A only where A and B have a non-zero in the same location, by copying - * the corresponding non-zero of B. The other non-zeros of A are left unchanged. - */ - template - static void assignNoAlloc(STR3F& A, typename STR3F::size_type slice_a, - const STR3F& B, typename STR3F::size_type slice_b) - { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - } + size_type M = (size_type)A.nRows(); + value_type minExp = log_f(std::numeric_limits::epsilon()); - typedef typename STR3F::row_index_type row_index_type; - typedef typename STR3F::col_index_type col_index_type; - typedef typename STR3F::value_type value_type; + if (nupic::Epsilon < minFloor) { - for (row_index_type row = 0; row != A.nRows(); ++row) { + for (size_type row = 0; row != M; ++row) { col_index_type *ind_a = A.ind_begin_(row); - col_index_type *ind_b = B.ind_begin_(row); - value_type *nz_a = A.nz_begin_(slice_a, row); - value_type *nz_a_end = A.nz_end_(slice_a, row); - value_type *nz_b = B.nz_begin_(slice_b, row); - value_type *nz_b_end = B.nz_end_(slice_b, row); - - while (nz_a != nz_a_end && nz_b != nz_b_end) - if (*ind_a == *ind_b) { - *nz_a = *nz_b; - ++ind_a; ++ind_b; - ++nz_a; ++nz_b; - } else if (*ind_a < *ind_b) { - ++ind_a; ++nz_a; - } else { - ++ind_b; ++nz_b; - } - } - } + size_type *ind_b = B.ind_begin_(row); + size_type *ind_b_end = B.ind_end_(row); + value_type *nz_a = A.nz_begin_(s, row); + value_type *nz_b = B.nz_begin_(row); - //-------------------------------------------------------------------------------- - /** - * Computes the sum in log space of A and B, where A is a slice of a STR3F - * and B is another slice of another STR3F. The operation is: - * a = log(exp(a) + exp(b)), - * where a is a non-zero of slice s of A, and b is the corresponding non-zero of B - * (in the same location). - * - * The number of non-zeros of in A is unchanged, and if the absolute value - * of a non-zero would fall below minFloor, it is replaced by minFloor. - * A and B need to have the same dimensions. - */ - template - static void logSumNoAlloc(STR3F& A, typename STR3F::size_type slice_a, - const STR3F& B, typename STR3F::size_type slice_b, - typename STR3F::value_type minFloor =0) - { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); + while (ind_b != ind_b_end) { + if ((size_type)*ind_a == *ind_b) { + value_type a = *nz_a; + value_type b = *nz_b; + if (a < b) + std::swap(a, b); + value_type d = b - a; + if (d >= minExp) { + a += log1p_f(exp_f(d)); + if (abs_f(a) < minFloor) + a = minFloor; + *nz_a = a; + } else { + *nz_a = a; + } + ++ind_a; + ++nz_a; + ++ind_b; + ++nz_b; + } else if ((size_type)*ind_a < *ind_b) { + ++ind_a; + ++nz_a; + } + } } - typedef typename STR3F::col_index_type col_index_type; - typedef typename STR3F::size_type size_type; - typedef typename STR3F::value_type value_type; - - nupic::Exp exp_f; - nupic::Log log_f; - nupic::Log1p log1p_f; - nupic::Abs abs_f; - - size_type M = (size_type) A.nRows(); - value_type minExp = log_f(std::numeric_limits::epsilon()); + } else { // minFloor <= nupic::Epsilon, i.e. essentially minFloor == 0 for (size_type row = 0; row != M; ++row) { col_index_type *ind_a = A.ind_begin_(row); - col_index_type *ind_b = B.ind_begin_(row); - col_index_type *ind_b_end = B.ind_end_(row); - value_type *nz_a = A.nz_begin_(slice_a, row); - value_type *nz_b = B.nz_begin_(slice_b, row); + size_type *ind_b = B.ind_begin_(row); + size_type *ind_b_end = B.ind_end_(row); + value_type *nz_a = A.nz_begin_(s, row); + value_type *nz_b = B.nz_begin_(row); while (ind_b != ind_b_end) { - if (*ind_a == *ind_b) { + if ((size_type)*ind_a == *ind_b) { value_type a = *nz_a; value_type b = *nz_b; if (a < b) - std::swap(a,b); + std::swap(a, b); value_type d = b - a; if (d >= minExp) { a += log1p_f(exp_f(d)); - if (minFloor > 0 && abs_f(a) < minFloor) - a = minFloor; *nz_a = a; } else { *nz_a = a; } - NTA_ASSERT(!A.isZero_(*nz_a)); - ++ind_a; ++nz_a; - ++ind_b; ++nz_b; - } else if ( *ind_a < *ind_b) { - ++ind_a; ++nz_a; + ++ind_a; + ++nz_a; + ++ind_b; + ++nz_b; + } else if ((size_type)*ind_a < *ind_b) { + ++ind_a; + ++nz_a; } } } } + } - //-------------------------------------------------------------------------------- - /** - * Computes the diff in log space of A and B, where A is a slice of a STR3F - * and B is another slice of another STR3F. The operation is: - * a = log(exp(a) - exp(b)), - * where a is a non-zero of slice s of A, and b is the corresponding non-zero of B - * (in the same location). - * - * The number of non-zeros of in A is unchanged, and if the absolute value - * of a non-zero would fall below minFloor, it is replaced by minFloor. - * A and B need to have the same dimensions. - */ - template - static void logDiffNoAlloc(STR3F& A, typename STR3F::size_type slice_a, - const STR3F& B, typename STR3F::size_type slice_b, - typename STR3F::value_type minFloor =0) + //-------------------------------------------------------------------------------- + /** + * Computes the diff in log space of A and B, where A is a slice of a STR3F + * and B is a SM. The operation is: + * a = log(exp(a) - exp(b)), + * where a is a non-zero of slice s of A, and b is the corresponding non-zero + * of B (in the same location). + * + * The number of non-zeros of in A is unchanged, and if the absolute value + * of a non-zero would fall below minFloor, it is replaced by minFloor. + * A and B need to have the same dimensions. + */ + template + static void logDiffNoAlloc(STR3F &A, typename SM::size_type s, const SM &B, + typename SM::value_type minFloor = 0) { { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); - } + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); + NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); + } + + typedef typename STR3F::col_index_type col_index_type; + typedef typename SM::size_type size_type; + typedef typename SM::value_type value_type; + + // Important to use double here, because in float, there can be + // cancelation in log(1 - exp(b-a)), when a is very close to b. + nupic::Exp exp_f; + nupic::Log log_f; + nupic::Log1p log1p_f; + nupic::Abs abs_f; + + size_type M = (size_type)A.nRows(); + value_type minExp = log_f(std::numeric_limits::epsilon()); - typedef typename STR3F::col_index_type col_index_type; - typedef typename STR3F::size_type size_type; - typedef typename STR3F::value_type value_type; - - // Important to use double here, because in float, there can be - // cancelation in log(1 - exp(b-a)), when a is very close to b. - nupic::Exp exp_f; - nupic::Log log_f; - nupic::Log1p log1p_f; - nupic::Abs abs_f; - - size_type M = (size_type) A.nRows(); - value_type minExp = log_f(std::numeric_limits::epsilon()); - - // Two log values that are this close to each other should generate a difference - // of 0, which is -inf in log space, which we want to avoid - double minDiff = -std::numeric_limits::epsilon(); - value_type logOfZero = ((value_type)-1.0)/std::numeric_limits::epsilon(); + // Two log values that are this close to each other should generate a + // difference + // of 0, which is -inf in log space, which we want to avoid + double minDiff = -std::numeric_limits::epsilon(); + value_type logOfZero = + ((value_type)-1.0) / std::numeric_limits::epsilon(); + + if (nupic::Epsilon < minFloor) { for (size_type row = 0; row != M; ++row) { col_index_type *ind_a = A.ind_begin_(row); - col_index_type *ind_b = B.ind_begin_(row); - col_index_type *ind_b_end = B.ind_end_(row); - value_type *nz_a = A.nz_begin_(slice_a, row); - value_type *nz_b = B.nz_begin_(slice_b, row); + size_type *ind_b = B.ind_begin_(row); + size_type *ind_b_end = B.ind_end_(row); + value_type *nz_a = A.nz_begin_(s, row); + value_type *nz_b = B.nz_begin_(row); while (ind_b != ind_b_end) { - if (*ind_a == *ind_b) { + if ((size_type)*ind_a == *ind_b) { double a = *nz_a; double b = *nz_b; NTA_ASSERT(a >= b); double d = b - a; - // If the values are too close to each other, generate log of 0 manually - // We know d <= 0 at this point. + // If the values are too close to each other, generate log of 0 + // manually We know d <= 0 at this point. if (d >= minDiff) - *nz_a = logOfZero; + *nz_a = logOfZero; else if (d >= minExp) { a += log1p_f(-exp_f(d)); - if (minFloor > 0 && abs_f(a) < minFloor) + if (abs_f(a) < minFloor) a = minFloor; - *nz_a = (value_type) a; + *nz_a = (value_type)a; } else { - *nz_a = (value_type) a; + *nz_a = (value_type)a; } - NTA_ASSERT(!A.isZero_(*nz_a)); - ++ind_a; ++nz_a; - ++ind_b; ++nz_b; - } else if (*ind_a < *ind_b) { - ++ind_a; ++nz_a; + ++ind_a; + ++nz_a; + ++ind_b; + ++nz_b; + } else if ((size_type)*ind_a < *ind_b) { + ++ind_a; + ++nz_a; } } } - } - //-------------------------------------------------------------------------------- - // END LBP - //-------------------------------------------------------------------------------- + } else { // minFloor <= nupic::Epsilon, i.e. essentially minFloor == 0 - //-------------------------------------------------------------------------------- - }; // End class SparseMatrixAlgorithms + for (size_type row = 0; row != M; ++row) { + + col_index_type *ind_a = A.ind_begin_(row); + size_type *ind_b = B.ind_begin_(row); + size_type *ind_b_end = B.ind_end_(row); + value_type *nz_a = A.nz_begin_(s, row); + value_type *nz_b = B.nz_begin_(row); + + while (ind_b != ind_b_end) { + if ((size_type)*ind_a == *ind_b) { + double a = *nz_a; + double b = *nz_b; + NTA_ASSERT(a >= b); + double d = b - a; + // If the values are too close to each other, generate log of 0 + // manually We know d <= 0 at this point. + if (d >= minDiff) + *nz_a = logOfZero; + else if (d >= minExp) { + a += log1p_f(-exp_f(d)); + *nz_a = (value_type)a; + } else { + *nz_a = (value_type)a; + } + ++ind_a; + ++nz_a; + ++ind_b; + ++nz_b; + } else if ((size_type)*ind_a < *ind_b) { + ++ind_a; + ++nz_a; + } + } + } + } + } - //-------------------------------------------------------------------------------- - //-------------------------------------------------------------------------------- - // SUM OF LOGS AND DIFF OF LOGS APPROXIMATIONS - // - // This section contains two classes that allow to approximate addition of numbers - // that are logarithms, in an efficient manner. The operations approximated are: - // z = log(exp(x) + exp(y) and z = log(exp(x) - exp(y). - // There are many pitfalls and tricks that make the implementation not trivial if - // it is to be efficient and accurate. - // - //-------------------------------------------------------------------------------- - // IMPLEMENTATION NOTES: - // ==================== - // - // How to add/subtract in log domain: - // ================================= - // - // We want to compute: log(exp(a) + exp(b)) or log(exp(a) - exp(b)). The first step is: - // - // log(exp(a) + exp(b)) = log(exp(a) * (1 + exp(b-a))) - // = a + log(1 + exp(b-a)) - // - // double logSum(double a, double b) - // { - // if (a < b) - // swap(a,b); - // if (!(a >= b)) - // fprintf(stderr, "ERROR: logSum: %f %f\n", a, b); - // assert(a >= b); - // - // return a + log1p(exp(b-a)); - // } - // - // If your numerical library doesn't have the log1p(x) function, replace it with log(1+x). - // - // This step saves a call to exp() and relies on log1p which is hopefully implemented - // efficiently, maybe even in hardware. However, we are going to speed-up this operation - // further by approximation. - // - // Why a > b: - // ========= - // - // a needs to be > b, and that's because otherwise, in float, the exponential of a large - // number can overflow the range of float. In double, it doesn't matter till the number - // becomes even larger. So, mathematically, it doesn't matter that a > b, but in terms of - // floating point overflow, it matters. - // - // Making logSum/logDiff fast (10X): - // ================================ - // - // In LBP, we approximate the following two functions using a table. These functions, - // both of the form z = f(x,y) are slow: they do a lot of work, and they have multiple - // ifs that can cause the processor pipeline to stall. - // - // value_type logSum(value_type a, value_type b) - // { - // if (a < b) - // swap(a,b); - // value_type d = b - a; - // if (d >= minExp) { - // a += log1p(exp(d)); - // if (fabs(a) < minFloor) - // a = minFloor; - // } - // return a; - // } - // - // value_type logDiff(value_type a, value_type b) - // { - // assert(b < a); - // double d = b - a; - // if (d >= minDiff) - // a = logOfZero; - // else if (d >= minExp) { - // a += log1p(-exp(d)); - // if (fabs(a) < minFloor) - // a = minFloor; - // } - // return a; - // } - // - // The domain of f is known to be: [-14,14] x [-14,14]. - // The first idea was to create a 2D table to store a step function - // as an approximation of f. However, this ended up taking too much memory and not - // being precise enough. Looking at the graphs of z, it became apparent that for both - // logSum and logDiff: 1) logSum is symmetric w.r.t. the line y = x, but logDiff is - // defined only for y < x, 2) any two slices for fixed y0 and y1 are related by a - // simple translation. Taking advantage of this last observation, an identity can be - // derived relating two slices. This can be used to store a step approximation for only - // one slice, discretized very finely, while computing the value z = f(x,y) for any other - // slice using the identity. Here is the derivation of the identity: - // - // f(x,y) = log(exp(x) + exp(y)) - // f(x+u,y+u) = log(exp(x+u) + exp(y+u)) - // = log(exp(x) + exp(y)) + u - // = f(x,y) + u - // f(x,y) = f(x+u,y+u) - u - // - // Choosing u = -y: - // - // f(x,y) = f(x-y,0) + y - // - // The same holds for log(exp(x) - exp(y)). - // - // So, we approximate f(a,0) using a step function on [lb,ub], each value in a table - // storing the value of f(k*delta + lb), k integer positive and delta small. Note that - // f(a,0) is log(exp(a) + 1), continuous, with derivative the sigmoid 1/(1+exp(-a)). - // This derivative is strictly positive, with range ]0,1[ (f is monotonically increasing). - // For logDiff(a,0) = log(exp(x) - 1), there is a discontinuity at zero which is harder - // to approximate. - // - // Errors: - // ====== - // The absolute error due to the step-wise approximation is directly bounded above by - // the size of the step delta, because the derivative is in ]0,1[ on the whole domain of f. - // In practice, this is true in double, but in float, the experimental absolute error is - // slightly above delta. - // - // The relative error can be reduced by increasing the size of the table, and experimentally, - // it reduces to arbitrary level as the number of entries in the table is increased. Again, - // logDiff has higher relative error because of its singularity, but it can still be - // approximated reasonably within less than 50 MB. - // - // Implementation tricks: - // ===================== - // - // - when computing the table, use a double precision step, otherwise the steps - // can get out of step, resulting in a wrong approximation - // - when computing the index in the table for a given (x,y), use float for the step, - // otherwise the speed is halved. This float precision is enough when evaluating the index. - // //-------------------------------------------------------------------------------- /** - * Approximates a sum of logs operation we have in LBP with a table for speed. - * A table of values (step function) is computed once, then accessed on subsequent - * calls. The values are used directly, rather than using interpolation. - * - * TODO: use asymptotes to reduce the size of the table + * Updates A only where A and B have a non-zero in the same location, by + * copying the corresponding non-zero of B. The other non-zeros of A are left + * unchanged. */ - class LogSumApprox { - - public: - // the values are stored in float to save space (4 bytes per value) - typedef float value_type; - - private: - value_type min_a, max_a; // bounds of domain - value_type step_a; // step along side of domain - static std::vector table; // table of approximant step function - - // Various constants used in the function - value_type minFloor, minExp, logOfZero; - double minDiff; - bool trace; - - LogSumApprox(const LogSumApprox&); - LogSumApprox& operator=(const LogSumApprox&); - - // Portable exp and log functions - nupic::Exp exp_f; - nupic::Log1p log1p_f; - - public: - //-------------------------------------------------------------------------------- - /** - * n is the size of the square domain on which we approximate, and the other - * parameters are the bounds of the domain. n*n values are computed once - * and stored in a table. - * - * Errors: - * ====== - * On darwin86: - * Sum of logs table: 20000000 -28 28 5.6e-06 76MB - * abs=4.03906228374e-06 rel=6.57921289158e-05 - * - * On Windows: - * Sum of logs table: 20000000 -28 28 2.8e-006 76MB - * abs=3.41339533527e-006 rel=0.00028832192106 - */ - inline LogSumApprox(int n_ = 5000000, - value_type min_a_ =-28, value_type max_a_ =28, - bool trace_ =false) - : min_a(min_a_), max_a(max_a_), - step_a((value_type)((max_a - min_a)/n_)), - minFloor((value_type)(1.1 * 1e-6)), - minExp(logf(std::numeric_limits::epsilon())), - logOfZero(((value_type)-1.0)/std::numeric_limits::epsilon()), - minDiff(-std::numeric_limits::epsilon()), - trace(trace_) + template + static void assignNoAlloc(STR3F &A, typename STR3F::size_type slice_a, + const STR3F &B, typename STR3F::size_type slice_b) { { - { // Pre-conditions - NTA_ASSERT(min_a < max_a); - NTA_ASSERT(0 < step_a); - } // End pre-conditions - - if (table.empty()) { - table.resize(n_); - compute_table(); - } - - if (trace) - std::cout << "Sum of logs table: " << table.size() << " " - << min_a << " " << max_a << " " << step_a << " " - << (4*table.size()/(1024*1024)) << "MB" << std::endl; + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); } - //-------------------------------------------------------------------------------- - /** - * This function computes the slice of sum_of_logs(a,b) for b = 0, but when - * the function will be called later, it will always be called with a and b - * far from minFloor. The net result of that is that all the values in the table - * should be greater than minFloor, which we enforce here. We replicate the code - * of sum_of_logs_f (modified for minFloor) because sum_of_logs_f has asserts that - * wouldn't pass for b = 0. The step needs to be a double, or it gets out of sync - * to compute the intervals of the table. - */ - inline void compute_table() - { - double a = min_a, step = step_a; - - for (size_t ia = 0; ia < table.size(); ++ia, a += step) - table[ia] = sum_of_logs_f((value_type)a,(value_type)0); + typedef typename STR3F::row_index_type row_index_type; + typedef typename STR3F::col_index_type col_index_type; + typedef typename STR3F::value_type value_type; + + for (row_index_type row = 0; row != A.nRows(); ++row) { + + col_index_type *ind_a = A.ind_begin_(row); + col_index_type *ind_b = B.ind_begin_(row); + value_type *nz_a = A.nz_begin_(slice_a, row); + value_type *nz_a_end = A.nz_end_(slice_a, row); + value_type *nz_b = B.nz_begin_(slice_b, row); + value_type *nz_b_end = B.nz_end_(slice_b, row); + + while (nz_a != nz_a_end && nz_b != nz_b_end) + if (*ind_a == *ind_b) { + *nz_a = *nz_b; + ++ind_a; + ++ind_b; + ++nz_a; + ++nz_b; + } else if (*ind_a < *ind_b) { + ++ind_a; + ++nz_a; + } else { + ++ind_b; + ++nz_b; + } } + } - //-------------------------------------------------------------------------------- - /** - * Computes the index corresponding to a,b in the table. - */ - inline int index(value_type a, value_type b) const + //-------------------------------------------------------------------------------- + /** + * Computes the sum in log space of A and B, where A is a slice of a STR3F + * and B is another slice of another STR3F. The operation is: + * a = log(exp(a) + exp(b)), + * where a is a non-zero of slice s of A, and b is the corresponding non-zero + * of B (in the same location). + * + * The number of non-zeros of in A is unchanged, and if the absolute value + * of a non-zero would fall below minFloor, it is replaced by minFloor. + * A and B need to have the same dimensions. + */ + template + static void logSumNoAlloc(STR3F &A, typename STR3F::size_type slice_a, + const STR3F &B, typename STR3F::size_type slice_b, + typename STR3F::value_type minFloor = 0) { { - return (int)((a - (b + min_a)) / step_a); + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); + NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); } - //-------------------------------------------------------------------------------- - private: - /** - * This is the exact function we approximate. It's of the form z = f(a,b). - * Note that in real applications, a and b will be far from minFloor. - * Doesn't check pre-conditions, so it can be called for b = 0 when - * building the table of values. - */ - inline value_type sum_of_logs_f(value_type a, value_type b) const - { - if (a < b) - std::swap(a,b); - value_type d = b - a; - if (d >= minExp) { - a += (value_type) log1p_f(exp_f(d)); - if (fabs(a) < minFloor) - a = minFloor; - } + typedef typename STR3F::col_index_type col_index_type; + typedef typename STR3F::size_type size_type; + typedef typename STR3F::value_type value_type; - return a; - } - - public: - //-------------------------------------------------------------------------------- - /** - * Will crash if (a,b) outside of domain (faster). - * Note that in real applications, a and b will be far from minFloor (see - * pre-conditions). - */ - inline value_type fast_sum_of_logs(value_type a, value_type b) const - { - { // Pre-conditions - NTA_ASSERT(minFloor <= fabs(a)) << a; - NTA_ASSERT(minFloor <= fabs(b)) << b; - } // End pre-conditions - - value_type val = table[index(a,b)] + b; - - if (fabs(val) < minFloor) - val = minFloor; - - return val; - } - - //-------------------------------------------------------------------------------- - /** - * Works with illimited range, but slower. - * Note that in real applications, a and b will be far from minFloor (see - * pre-conditions). - */ - inline value_type sum_of_logs(value_type a, value_type b) const - { - { // Pre-conditions - NTA_ASSERT(minFloor <= fabs(a)) << a; - NTA_ASSERT(minFloor <= fabs(b)) << b; - } // End pre-conditions - - value_type val; - - if (-14 <= a && a < 14 && -14 <= b && b < 14) - val = fast_sum_of_logs(a, b); - else - val = sum_of_logs_f(a, b); - - return val; - } - - //-------------------------------------------------------------------------------- - /** - * Computes sum of logs between a STR3F and a SM. This is a piece of the LBP - * algorithm. Values closer to zero than minFloor are replaced by minFloor. - */ - template - inline void logSum(STR3F& A, typename SM::size_type s, const SM& B) - { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - NTA_ASSERT(nupic::Epsilon < minFloor); - } + nupic::Exp exp_f; + nupic::Log log_f; + nupic::Log1p log1p_f; + nupic::Abs abs_f; - typedef typename STR3F::col_index_type col_index_type; - typedef typename SM::size_type size_type; - - size_type M = (size_type) A.nRows(); + size_type M = (size_type)A.nRows(); + value_type minExp = log_f(std::numeric_limits::epsilon()); - for (size_type row = 0; row != M; ++row) { + for (size_type row = 0; row != M; ++row) { - col_index_type *ind_a = A.ind_begin_(row); - const size_type *ind_b = B.row_nz_index_begin(row); - const size_type *ind_b_end = B.row_nz_index_end(row); - value_type *nz_a = A.nz_begin_(s, row); - const value_type *nz_b = B.row_nz_value_begin(row); + col_index_type *ind_a = A.ind_begin_(row); + col_index_type *ind_b = B.ind_begin_(row); + col_index_type *ind_b_end = B.ind_end_(row); + value_type *nz_a = A.nz_begin_(slice_a, row); + value_type *nz_b = B.nz_begin_(slice_b, row); - while (ind_b != ind_b_end) { - if ((size_type) *ind_a == *ind_b) { - *nz_a = sum_of_logs(*nz_a, *nz_b); - ++ind_a; ++nz_a; - ++ind_b; ++nz_b; - } else if ((size_type) *ind_a < *ind_b) { - ++ind_a; ++nz_a; + while (ind_b != ind_b_end) { + if (*ind_a == *ind_b) { + value_type a = *nz_a; + value_type b = *nz_b; + if (a < b) + std::swap(a, b); + value_type d = b - a; + if (d >= minExp) { + a += log1p_f(exp_f(d)); + if (minFloor > 0 && abs_f(a) < minFloor) + a = minFloor; + *nz_a = a; + } else { + *nz_a = a; } + NTA_ASSERT(!A.isZero_(*nz_a)); + ++ind_a; + ++nz_a; + ++ind_b; + ++nz_b; + } else if (*ind_a < *ind_b) { + ++ind_a; + ++nz_a; } } } + } - //-------------------------------------------------------------------------------- - /** - * Computes sum of logs between a SM and a STR3F. Also a piece of LBP. - * Values closer to zero than minFloor are replaced by minFloor. - */ - template - inline void fastLogSum(STR3F& A, typename SM::size_type s, const SM& B) + //-------------------------------------------------------------------------------- + /** + * Computes the diff in log space of A and B, where A is a slice of a STR3F + * and B is another slice of another STR3F. The operation is: + * a = log(exp(a) - exp(b)), + * where a is a non-zero of slice s of A, and b is the corresponding non-zero + * of B (in the same location). + * + * The number of non-zeros of in A is unchanged, and if the absolute value + * of a non-zero would fall below minFloor, it is replaced by minFloor. + * A and B need to have the same dimensions. + */ + template + static void logDiffNoAlloc(STR3F &A, typename STR3F::size_type slice_a, + const STR3F &B, typename STR3F::size_type slice_b, + typename STR3F::value_type minFloor = 0) { { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - NTA_ASSERT(nupic::Epsilon < minFloor); - } + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); + NTA_ASSERT(minFloor == 0 || nupic::Epsilon < minFloor); + } - typedef typename STR3F::col_index_type col_index_type; - typedef typename SM::size_type size_type; - - size_type M = (size_type) A.nRows(); + typedef typename STR3F::col_index_type col_index_type; + typedef typename STR3F::size_type size_type; + typedef typename STR3F::value_type value_type; - for (size_type row = 0; row != M; ++row) { - - col_index_type *ind_a = A.ind_begin_(row); - const size_type *ind_b = B.row_nz_index_begin(row); - const size_type *ind_b_end = B.row_nz_index_end(row); - value_type *nz_a = A.nz_begin_(s, row); - const value_type *nz_b = B.row_nz_value_begin(row); - - while (ind_b != ind_b_end) { - if ((size_type) *ind_a == *ind_b) { - *nz_a = fast_sum_of_logs(*nz_a, *nz_b); - ++ind_a; ++nz_a; - ++ind_b; ++nz_b; - } else if ((size_type) *ind_a < *ind_b) { - ++ind_a; ++nz_a; + // Important to use double here, because in float, there can be + // cancelation in log(1 - exp(b-a)), when a is very close to b. + nupic::Exp exp_f; + nupic::Log log_f; + nupic::Log1p log1p_f; + nupic::Abs abs_f; + + size_type M = (size_type)A.nRows(); + value_type minExp = log_f(std::numeric_limits::epsilon()); + + // Two log values that are this close to each other should generate a + // difference + // of 0, which is -inf in log space, which we want to avoid + double minDiff = -std::numeric_limits::epsilon(); + value_type logOfZero = + ((value_type)-1.0) / std::numeric_limits::epsilon(); + + for (size_type row = 0; row != M; ++row) { + + col_index_type *ind_a = A.ind_begin_(row); + col_index_type *ind_b = B.ind_begin_(row); + col_index_type *ind_b_end = B.ind_end_(row); + value_type *nz_a = A.nz_begin_(slice_a, row); + value_type *nz_b = B.nz_begin_(slice_b, row); + + while (ind_b != ind_b_end) { + if (*ind_a == *ind_b) { + double a = *nz_a; + double b = *nz_b; + NTA_ASSERT(a >= b); + double d = b - a; + // If the values are too close to each other, generate log of 0 + // manually We know d <= 0 at this point. + if (d >= minDiff) + *nz_a = logOfZero; + else if (d >= minExp) { + a += log1p_f(-exp_f(d)); + if (minFloor > 0 && abs_f(a) < minFloor) + a = minFloor; + *nz_a = (value_type)a; + } else { + *nz_a = (value_type)a; } + NTA_ASSERT(!A.isZero_(*nz_a)); + ++ind_a; + ++nz_a; + ++ind_b; + ++nz_b; + } else if (*ind_a < *ind_b) { + ++ind_a; + ++nz_a; } } } - }; // End LogSumApprox + } + + //-------------------------------------------------------------------------------- + // END LBP + //-------------------------------------------------------------------------------- //-------------------------------------------------------------------------------- +}; // End class SparseMatrixAlgorithms + +//-------------------------------------------------------------------------------- +//-------------------------------------------------------------------------------- +// SUM OF LOGS AND DIFF OF LOGS APPROXIMATIONS +// +// This section contains two classes that allow to approximate addition of +// numbers that are logarithms, in an efficient manner. The operations +// approximated are: z = log(exp(x) + exp(y) and z = log(exp(x) - exp(y). There +// are many pitfalls and tricks that make the implementation not trivial if it +// is to be efficient and accurate. +// +//-------------------------------------------------------------------------------- +// IMPLEMENTATION NOTES: +// ==================== +// +// How to add/subtract in log domain: +// ================================= +// +// We want to compute: log(exp(a) + exp(b)) or log(exp(a) - exp(b)). The first +// step is: +// +// log(exp(a) + exp(b)) = log(exp(a) * (1 + exp(b-a))) +// = a + log(1 + exp(b-a)) +// +// double logSum(double a, double b) +// { +// if (a < b) +// swap(a,b); +// if (!(a >= b)) +// fprintf(stderr, "ERROR: logSum: %f %f\n", a, b); +// assert(a >= b); +// +// return a + log1p(exp(b-a)); +// } +// +// If your numerical library doesn't have the log1p(x) function, replace it with +// log(1+x). +// +// This step saves a call to exp() and relies on log1p which is hopefully +// implemented efficiently, maybe even in hardware. However, we are going to +// speed-up this operation further by approximation. +// +// Why a > b: +// ========= +// +// a needs to be > b, and that's because otherwise, in float, the exponential of +// a large number can overflow the range of float. In double, it doesn't matter +// till the number becomes even larger. So, mathematically, it doesn't matter +// that a > b, but in terms of floating point overflow, it matters. +// +// Making logSum/logDiff fast (10X): +// ================================ +// +// In LBP, we approximate the following two functions using a table. These +// functions, both of the form z = f(x,y) are slow: they do a lot of work, and +// they have multiple ifs that can cause the processor pipeline to stall. +// +// value_type logSum(value_type a, value_type b) +// { +// if (a < b) +// swap(a,b); +// value_type d = b - a; +// if (d >= minExp) { +// a += log1p(exp(d)); +// if (fabs(a) < minFloor) +// a = minFloor; +// } +// return a; +// } +// +// value_type logDiff(value_type a, value_type b) +// { +// assert(b < a); +// double d = b - a; +// if (d >= minDiff) +// a = logOfZero; +// else if (d >= minExp) { +// a += log1p(-exp(d)); +// if (fabs(a) < minFloor) +// a = minFloor; +// } +// return a; +// } +// +// The domain of f is known to be: [-14,14] x [-14,14]. +// The first idea was to create a 2D table to store a step function +// as an approximation of f. However, this ended up taking too much memory and +// not being precise enough. Looking at the graphs of z, it became apparent that +// for both logSum and logDiff: 1) logSum is symmetric w.r.t. the line y = x, +// but logDiff is defined only for y < x, 2) any two slices for fixed y0 and y1 +// are related by a simple translation. Taking advantage of this last +// observation, an identity can be derived relating two slices. This can be used +// to store a step approximation for only one slice, discretized very finely, +// while computing the value z = f(x,y) for any other slice using the identity. +// Here is the derivation of the identity: +// +// f(x,y) = log(exp(x) + exp(y)) +// f(x+u,y+u) = log(exp(x+u) + exp(y+u)) +// = log(exp(x) + exp(y)) + u +// = f(x,y) + u +// f(x,y) = f(x+u,y+u) - u +// +// Choosing u = -y: +// +// f(x,y) = f(x-y,0) + y +// +// The same holds for log(exp(x) - exp(y)). +// +// So, we approximate f(a,0) using a step function on [lb,ub], each value in a +// table storing the value of f(k*delta + lb), k integer positive and delta +// small. Note that f(a,0) is log(exp(a) + 1), continuous, with derivative the +// sigmoid 1/(1+exp(-a)). This derivative is strictly positive, with range ]0,1[ +// (f is monotonically increasing). For logDiff(a,0) = log(exp(x) - 1), there is +// a discontinuity at zero which is harder to approximate. +// +// Errors: +// ====== +// The absolute error due to the step-wise approximation is directly bounded +// above by the size of the step delta, because the derivative is in ]0,1[ on +// the whole domain of f. In practice, this is true in double, but in float, the +// experimental absolute error is slightly above delta. +// +// The relative error can be reduced by increasing the size of the table, and +// experimentally, it reduces to arbitrary level as the number of entries in the +// table is increased. Again, logDiff has higher relative error because of its +// singularity, but it can still be approximated reasonably within less than 50 +// MB. +// +// Implementation tricks: +// ===================== +// +// - when computing the table, use a double precision step, otherwise the steps +// can get out of step, resulting in a wrong approximation +// - when computing the index in the table for a given (x,y), use float for the +// step, otherwise the speed is halved. This float precision is enough when +// evaluating the index. +// +//-------------------------------------------------------------------------------- +/** + * Approximates a sum of logs operation we have in LBP with a table for speed. + * A table of values (step function) is computed once, then accessed on + * subsequent calls. The values are used directly, rather than using + * interpolation. + * + * TODO: use asymptotes to reduce the size of the table + */ +class LogSumApprox { + +public: + // the values are stored in float to save space (4 bytes per value) + typedef float value_type; + +private: + value_type min_a, max_a; // bounds of domain + value_type step_a; // step along side of domain + static std::vector table; // table of approximant step function + + // Various constants used in the function + value_type minFloor, minExp, logOfZero; + double minDiff; + bool trace; + + LogSumApprox(const LogSumApprox &); + LogSumApprox &operator=(const LogSumApprox &); + + // Portable exp and log functions + nupic::Exp exp_f; + nupic::Log1p log1p_f; + +public: + //-------------------------------------------------------------------------------- /** - * See comments in LogSumApprox. This is the same idea, except that the function - * approximated is a diff of logs, rather than a sum of logs. + * n is the size of the square domain on which we approximate, and the other + * parameters are the bounds of the domain. n*n values are computed once + * and stored in a table. * * Errors: * ====== * On darwin86: - * Diff of logs table: 20000000 1e-06 28 1.4e-06 76MB - * abs=2.56909073832e-05 rel=0.000589275477819 + * Sum of logs table: 20000000 -28 28 5.6e-06 76MB + * abs=4.03906228374e-06 rel=6.57921289158e-05 * * On Windows: - * Diff of logs table: 20000000 1e-006 28 1.4e-006 76MB - * abs=2.56909073832e-005 rel=0.000589275477819 - */ - class LogDiffApprox { - - public: - // The values are stored in float to save space (4 bytes per value) - typedef float value_type; - - private: - value_type min_a, max_a; // bounds of domain - value_type step_a; // step along side of domain - static std::vector table; // the approximating values themselves + * Sum of logs table: 20000000 -28 28 2.8e-006 76MB + * abs=3.41339533527e-006 rel=0.00028832192106 + */ + inline LogSumApprox(int n_ = 5000000, value_type min_a_ = -28, + value_type max_a_ = 28, bool trace_ = false) + : min_a(min_a_), max_a(max_a_), + step_a((value_type)((max_a - min_a) / n_)), + minFloor((value_type)(1.1 * 1e-6)), + minExp(logf(std::numeric_limits::epsilon())), + logOfZero(((value_type)-1.0) / + std::numeric_limits::epsilon()), + minDiff(-std::numeric_limits::epsilon()), trace(trace_) { + { // Pre-conditions + NTA_ASSERT(min_a < max_a); + NTA_ASSERT(0 < step_a); + } // End pre-conditions + + if (table.empty()) { + table.resize(n_); + compute_table(); + } - // Various constants used in the function - value_type minFloor, minExp, logOfZero; - double minDiff; - bool trace; + if (trace) + std::cout << "Sum of logs table: " << table.size() << " " << min_a << " " + << max_a << " " << step_a << " " + << (4 * table.size() / (1024 * 1024)) << "MB" << std::endl; + } - LogDiffApprox(const LogDiffApprox&); - LogDiffApprox& operator=(const LogDiffApprox&); + //-------------------------------------------------------------------------------- + /** + * This function computes the slice of sum_of_logs(a,b) for b = 0, but when + * the function will be called later, it will always be called with a and b + * far from minFloor. The net result of that is that all the values in the + * table should be greater than minFloor, which we enforce here. We replicate + * the code of sum_of_logs_f (modified for minFloor) because sum_of_logs_f has + * asserts that wouldn't pass for b = 0. The step needs to be a double, or it + * gets out of sync to compute the intervals of the table. + */ + inline void compute_table() { + double a = min_a, step = step_a; - nupic::Exp exp_f; - nupic::Log1p log1p_f; + for (size_t ia = 0; ia < table.size(); ++ia, a += step) + table[ia] = sum_of_logs_f((value_type)a, (value_type)0); + } - public: - //-------------------------------------------------------------------------------- - /** - * n is the size of the square domain on which we approximate, and the other - * parameters are the bounds of the domain. n*n values are computed once - * and stored in a table. - * - * TODO: use asymptotes to reduce the size of the table - */ - inline LogDiffApprox(int n_ = 5000000, - value_type min_a_ =1e-10, value_type max_a_ =28, - bool trace_ =false) - : min_a(min_a_), max_a(max_a_), - step_a((value_type)((max_a - min_a)/n_)), - minFloor((value_type)(1.1 * 1e-6)), - minExp(logf(std::numeric_limits::epsilon())), - logOfZero(((value_type)-1.0)/std::numeric_limits::epsilon()), - minDiff(-std::numeric_limits::epsilon()), - trace(trace_) - { - { // Pre-conditions - NTA_ASSERT(min_a < max_a); - NTA_ASSERT(0 < step_a); - } // End pre-conditions - - if (table.empty()) { - table.resize(n_); - compute_table(); - } + //-------------------------------------------------------------------------------- + /** + * Computes the index corresponding to a,b in the table. + */ + inline int index(value_type a, value_type b) const { + return (int)((a - (b + min_a)) / step_a); + } - if (trace) - std::cout << "Diff of logs table: " << table.size() << " " - << min_a << " " << max_a << " " << step_a << " " - << (4*table.size()/(1024*1024)) << "MB" << std::endl; + //-------------------------------------------------------------------------------- +private: + /** + * This is the exact function we approximate. It's of the form z = f(a,b). + * Note that in real applications, a and b will be far from minFloor. + * Doesn't check pre-conditions, so it can be called for b = 0 when + * building the table of values. + */ + inline value_type sum_of_logs_f(value_type a, value_type b) const { + if (a < b) + std::swap(a, b); + value_type d = b - a; + if (d >= minExp) { + a += (value_type)log1p_f(exp_f(d)); + if (fabs(a) < minFloor) + a = minFloor; } - //-------------------------------------------------------------------------------- - /** - * See comments for LogSumApprox::compute_table(). - */ - inline void compute_table() - { - double a = min_a, step = step_a; - - for (size_t ia = 0; ia < table.size(); ++ia, a += step) - table[ia] = diff_of_logs_f((value_type)a,(value_type)0); - } + return a; + } + +public: + //-------------------------------------------------------------------------------- + /** + * Will crash if (a,b) outside of domain (faster). + * Note that in real applications, a and b will be far from minFloor (see + * pre-conditions). + */ + inline value_type fast_sum_of_logs(value_type a, value_type b) const { + { // Pre-conditions + NTA_ASSERT(minFloor <= fabs(a)) << a; + NTA_ASSERT(minFloor <= fabs(b)) << b; + } // End pre-conditions + + value_type val = table[index(a, b)] + b; + + if (fabs(val) < minFloor) + val = minFloor; + + return val; + } + + //-------------------------------------------------------------------------------- + /** + * Works with illimited range, but slower. + * Note that in real applications, a and b will be far from minFloor (see + * pre-conditions). + */ + inline value_type sum_of_logs(value_type a, value_type b) const { + { // Pre-conditions + NTA_ASSERT(minFloor <= fabs(a)) << a; + NTA_ASSERT(minFloor <= fabs(b)) << b; + } // End pre-conditions + + value_type val; - //-------------------------------------------------------------------------------- - /** - * Computes the index corresponding to a,b in the table. - */ - inline int index(value_type a, value_type b) const + if (-14 <= a && a < 14 && -14 <= b && b < 14) + val = fast_sum_of_logs(a, b); + else + val = sum_of_logs_f(a, b); + + return val; + } + + //-------------------------------------------------------------------------------- + /** + * Computes sum of logs between a STR3F and a SM. This is a piece of the LBP + * algorithm. Values closer to zero than minFloor are replaced by minFloor. + */ + template + inline void logSum(STR3F &A, typename SM::size_type s, const SM &B) { { - return (int)((a - (b + min_a)) / step_a); + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); + NTA_ASSERT(nupic::Epsilon < minFloor); } - private: - //-------------------------------------------------------------------------------- - /** - * This is the exact function we approximate. It's of the form z = f(a,b). - * Note that in real applications, a and b will be far from minFloor (see - * pre-conditions). Doesn't check pre-conditions so we can call it with b = 0 - * when building the table of values. - */ - inline value_type diff_of_logs_f(value_type a, value_type b) const - { - double d = b - a; - if (d >= minDiff) - a = logOfZero; - else if (d >= minExp) { - a += (value_type) log1p_f(-exp_f(d)); - if (fabs(a) < minFloor) - a = minFloor; + typedef typename STR3F::col_index_type col_index_type; + typedef typename SM::size_type size_type; + + size_type M = (size_type)A.nRows(); + + for (size_type row = 0; row != M; ++row) { + + col_index_type *ind_a = A.ind_begin_(row); + const size_type *ind_b = B.row_nz_index_begin(row); + const size_type *ind_b_end = B.row_nz_index_end(row); + value_type *nz_a = A.nz_begin_(s, row); + const value_type *nz_b = B.row_nz_value_begin(row); + + while (ind_b != ind_b_end) { + if ((size_type)*ind_a == *ind_b) { + *nz_a = sum_of_logs(*nz_a, *nz_b); + ++ind_a; + ++nz_a; + ++ind_b; + ++nz_b; + } else if ((size_type)*ind_a < *ind_b) { + ++ind_a; + ++nz_a; + } } + } + } - return a; - } - - public: - //-------------------------------------------------------------------------------- - /** - * Will crash if (a,b) outside of domain (faster). - * Note that in real applications, a and b will be far from minFloor (see - * pre-conditions). - */ - inline value_type fast_diff_of_logs(value_type a, value_type b) const - { - { // Pre-conditions - NTA_ASSERT(b < a); - NTA_ASSERT(minFloor <= fabs(a)) << a; - NTA_ASSERT(minFloor <= fabs(b)) << b; - } // End pre-conditions - - value_type val = table[index(a,b)] + b; - - if (fabs(val) < minFloor) - val = minFloor; - - return val; - } - - //-------------------------------------------------------------------------------- - /** - * Will fall back on calling function if (a,b) outside domain (slower). - * Note that in real applications, a and b will be far from minFloor (see - * pre-conditions). - */ - inline value_type diff_of_logs(value_type a, value_type b) const - { - { // Pre-conditions - NTA_ASSERT(b < a); - NTA_ASSERT(minFloor <= fabs(a)) << a; - NTA_ASSERT(minFloor <= fabs(b)) << b; - } // End pre-conditions - - value_type val; - - if (-14 <= a && a < 14 && -14 <= b && b < 14) - val = fast_diff_of_logs(a, b); - else - val = diff_of_logs_f(a, b); - - return val; - } - - //-------------------------------------------------------------------------------- - /** - * Computes diff of logs between a STR3F and a SM. This is a piece of the LBP - * algorithm. Values closer to zero than minFloor are replaced by minFloor. - */ - template - inline void logDiff(STR3F& A, typename SM::size_type s, const SM& B) + //-------------------------------------------------------------------------------- + /** + * Computes sum of logs between a SM and a STR3F. Also a piece of LBP. + * Values closer to zero than minFloor are replaced by minFloor. + */ + template + inline void fastLogSum(STR3F &A, typename SM::size_type s, const SM &B) { { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - NTA_ASSERT(nupic::Epsilon < minFloor); + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); + NTA_ASSERT(nupic::Epsilon < minFloor); + } + + typedef typename STR3F::col_index_type col_index_type; + typedef typename SM::size_type size_type; + + size_type M = (size_type)A.nRows(); + + for (size_type row = 0; row != M; ++row) { + + col_index_type *ind_a = A.ind_begin_(row); + const size_type *ind_b = B.row_nz_index_begin(row); + const size_type *ind_b_end = B.row_nz_index_end(row); + value_type *nz_a = A.nz_begin_(s, row); + const value_type *nz_b = B.row_nz_value_begin(row); + + while (ind_b != ind_b_end) { + if ((size_type)*ind_a == *ind_b) { + *nz_a = fast_sum_of_logs(*nz_a, *nz_b); + ++ind_a; + ++nz_a; + ++ind_b; + ++nz_b; + } else if ((size_type)*ind_a < *ind_b) { + ++ind_a; + ++nz_a; + } } + } + } +}; // End LogSumApprox - typedef typename STR3F::col_index_type col_index_type; - typedef typename SM::size_type size_type; - - size_type M = (size_type) A.nRows(); +//-------------------------------------------------------------------------------- +/** + * See comments in LogSumApprox. This is the same idea, except that the function + * approximated is a diff of logs, rather than a sum of logs. + * + * Errors: + * ====== + * On darwin86: + * Diff of logs table: 20000000 1e-06 28 1.4e-06 76MB + * abs=2.56909073832e-05 rel=0.000589275477819 + * + * On Windows: + * Diff of logs table: 20000000 1e-006 28 1.4e-006 76MB + * abs=2.56909073832e-005 rel=0.000589275477819 + */ +class LogDiffApprox { - for (size_type row = 0; row != M; ++row) { +public: + // The values are stored in float to save space (4 bytes per value) + typedef float value_type; - col_index_type *ind_a = A.ind_begin_(row); - const size_type *ind_b = B.row_nz_index_begin(row); - const size_type *ind_b_end = B.row_nz_index_end(row); - value_type *nz_a = A.nz_begin_(s, row); - const value_type *nz_b = B.row_nz_value_begin(row); +private: + value_type min_a, max_a; // bounds of domain + value_type step_a; // step along side of domain + static std::vector table; // the approximating values themselves - while (ind_b != ind_b_end) { - if ((size_type) *ind_a == *ind_b) { - *nz_a = diff_of_logs(*nz_a, *nz_b); - ++ind_a; ++nz_a; - ++ind_b; ++nz_b; - } else if ((size_type) *ind_a < *ind_b) { - ++ind_a; ++nz_a; - } - } - } + // Various constants used in the function + value_type minFloor, minExp, logOfZero; + double minDiff; + bool trace; + + LogDiffApprox(const LogDiffApprox &); + LogDiffApprox &operator=(const LogDiffApprox &); + + nupic::Exp exp_f; + nupic::Log1p log1p_f; + +public: + //-------------------------------------------------------------------------------- + /** + * n is the size of the square domain on which we approximate, and the other + * parameters are the bounds of the domain. n*n values are computed once + * and stored in a table. + * + * TODO: use asymptotes to reduce the size of the table + */ + inline LogDiffApprox(int n_ = 5000000, value_type min_a_ = 1e-10, + value_type max_a_ = 28, bool trace_ = false) + : min_a(min_a_), max_a(max_a_), + step_a((value_type)((max_a - min_a) / n_)), + minFloor((value_type)(1.1 * 1e-6)), + minExp(logf(std::numeric_limits::epsilon())), + logOfZero(((value_type)-1.0) / + std::numeric_limits::epsilon()), + minDiff(-std::numeric_limits::epsilon()), trace(trace_) { + { // Pre-conditions + NTA_ASSERT(min_a < max_a); + NTA_ASSERT(0 < step_a); + } // End pre-conditions + + if (table.empty()) { + table.resize(n_); + compute_table(); } - //-------------------------------------------------------------------------------- - /** - * Computes diff of logs between a SM and a STR3F. Also a part of LBP. - * Values closer to zero than minFloor are replaced by minFloor. - */ - template - inline void fastLogDiff(STR3F& A, typename SM::size_type s, const SM& B) - { - { - NTA_ASSERT(A.nRows() == B.nRows()); - NTA_ASSERT(A.nCols() == B.nCols()); - NTA_ASSERT(nupic::Epsilon < minFloor); - } + if (trace) + std::cout << "Diff of logs table: " << table.size() << " " << min_a << " " + << max_a << " " << step_a << " " + << (4 * table.size() / (1024 * 1024)) << "MB" << std::endl; + } - typedef typename STR3F::col_index_type col_index_type; - typedef typename SM::size_type size_type; - - size_type M = (size_type) A.nRows(); + //-------------------------------------------------------------------------------- + /** + * See comments for LogSumApprox::compute_table(). + */ + inline void compute_table() { + double a = min_a, step = step_a; - for (size_type row = 0; row != M; ++row) { + for (size_t ia = 0; ia < table.size(); ++ia, a += step) + table[ia] = diff_of_logs_f((value_type)a, (value_type)0); + } - col_index_type *ind_a = A.ind_begin_(row); - const size_type *ind_b = B.row_nz_index_begin(row); - const size_type *ind_b_end = B.row_nz_index_end(row); - value_type *nz_a = A.nz_begin_(s, row); - const value_type *nz_b = B.row_nz_value_begin(row); + //-------------------------------------------------------------------------------- + /** + * Computes the index corresponding to a,b in the table. + */ + inline int index(value_type a, value_type b) const { + return (int)((a - (b + min_a)) / step_a); + } - while (ind_b != ind_b_end) { - if ((size_type) *ind_a == *ind_b) { - *nz_a = fast_diff_of_logs(*nz_a, *nz_b); - ++ind_a; ++nz_a; - ++ind_b; ++nz_b; - } else if ((size_type) *ind_a < *ind_b) { - ++ind_a; ++nz_a; - } +private: + //-------------------------------------------------------------------------------- + /** + * This is the exact function we approximate. It's of the form z = f(a,b). + * Note that in real applications, a and b will be far from minFloor (see + * pre-conditions). Doesn't check pre-conditions so we can call it with b = 0 + * when building the table of values. + */ + inline value_type diff_of_logs_f(value_type a, value_type b) const { + double d = b - a; + if (d >= minDiff) + a = logOfZero; + else if (d >= minExp) { + a += (value_type)log1p_f(-exp_f(d)); + if (fabs(a) < minFloor) + a = minFloor; + } + + return a; + } + +public: + //-------------------------------------------------------------------------------- + /** + * Will crash if (a,b) outside of domain (faster). + * Note that in real applications, a and b will be far from minFloor (see + * pre-conditions). + */ + inline value_type fast_diff_of_logs(value_type a, value_type b) const { + { // Pre-conditions + NTA_ASSERT(b < a); + NTA_ASSERT(minFloor <= fabs(a)) << a; + NTA_ASSERT(minFloor <= fabs(b)) << b; + } // End pre-conditions + + value_type val = table[index(a, b)] + b; + + if (fabs(val) < minFloor) + val = minFloor; + + return val; + } + + //-------------------------------------------------------------------------------- + /** + * Will fall back on calling function if (a,b) outside domain (slower). + * Note that in real applications, a and b will be far from minFloor (see + * pre-conditions). + */ + inline value_type diff_of_logs(value_type a, value_type b) const { + { // Pre-conditions + NTA_ASSERT(b < a); + NTA_ASSERT(minFloor <= fabs(a)) << a; + NTA_ASSERT(minFloor <= fabs(b)) << b; + } // End pre-conditions + + value_type val; + + if (-14 <= a && a < 14 && -14 <= b && b < 14) + val = fast_diff_of_logs(a, b); + else + val = diff_of_logs_f(a, b); + + return val; + } + + //-------------------------------------------------------------------------------- + /** + * Computes diff of logs between a STR3F and a SM. This is a piece of the LBP + * algorithm. Values closer to zero than minFloor are replaced by minFloor. + */ + template + inline void logDiff(STR3F &A, typename SM::size_type s, const SM &B) { + { + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); + NTA_ASSERT(nupic::Epsilon < minFloor); + } + + typedef typename STR3F::col_index_type col_index_type; + typedef typename SM::size_type size_type; + + size_type M = (size_type)A.nRows(); + + for (size_type row = 0; row != M; ++row) { + + col_index_type *ind_a = A.ind_begin_(row); + const size_type *ind_b = B.row_nz_index_begin(row); + const size_type *ind_b_end = B.row_nz_index_end(row); + value_type *nz_a = A.nz_begin_(s, row); + const value_type *nz_b = B.row_nz_value_begin(row); + + while (ind_b != ind_b_end) { + if ((size_type)*ind_a == *ind_b) { + *nz_a = diff_of_logs(*nz_a, *nz_b); + ++ind_a; + ++nz_a; + ++ind_b; + ++nz_b; + } else if ((size_type)*ind_a < *ind_b) { + ++ind_a; + ++nz_a; } } } - }; // End LogDiffApprox - + } + //-------------------------------------------------------------------------------- + /** + * Computes diff of logs between a SM and a STR3F. Also a part of LBP. + * Values closer to zero than minFloor are replaced by minFloor. + */ + template + inline void fastLogDiff(STR3F &A, typename SM::size_type s, const SM &B) { + { + NTA_ASSERT(A.nRows() == B.nRows()); + NTA_ASSERT(A.nCols() == B.nCols()); + NTA_ASSERT(nupic::Epsilon < minFloor); + } - + typedef typename STR3F::col_index_type col_index_type; + typedef typename SM::size_type size_type; + + size_type M = (size_type)A.nRows(); + + for (size_type row = 0; row != M; ++row) { + + col_index_type *ind_a = A.ind_begin_(row); + const size_type *ind_b = B.row_nz_index_begin(row); + const size_type *ind_b_end = B.row_nz_index_end(row); + value_type *nz_a = A.nz_begin_(s, row); + const value_type *nz_b = B.row_nz_value_begin(row); + + while (ind_b != ind_b_end) { + if ((size_type)*ind_a == *ind_b) { + *nz_a = fast_diff_of_logs(*nz_a, *nz_b); + ++ind_a; + ++nz_a; + ++ind_b; + ++nz_b; + } else if ((size_type)*ind_a < *ind_b) { + ++ind_a; + ++nz_a; + } + } + } + } +}; // End LogDiffApprox }; // End namespace nupic diff --git a/src/nupic/math/SparseMatrixConnections.cpp b/src/nupic/math/SparseMatrixConnections.cpp index 17a11d858c..abc3114133 100644 --- a/src/nupic/math/SparseMatrixConnections.cpp +++ b/src/nupic/math/SparseMatrixConnections.cpp @@ -28,126 +28,106 @@ using namespace nupic; -SparseMatrixConnections::SparseMatrixConnections( - UInt32 numCells, UInt32 numInputs) - : SegmentMatrixAdapter>(numCells, - numInputs) -{} - - -void SparseMatrixConnections::computeActivity( - const UInt32* activeInputs_begin, const UInt32* activeInputs_end, - Int32* overlaps_begin) const -{ - matrix.rightVecSumAtNZSparse( - activeInputs_begin, activeInputs_end, - overlaps_begin); +SparseMatrixConnections::SparseMatrixConnections(UInt32 numCells, + UInt32 numInputs) + : SegmentMatrixAdapter>( + numCells, numInputs) {} + +void SparseMatrixConnections::computeActivity(const UInt32 *activeInputs_begin, + const UInt32 *activeInputs_end, + Int32 *overlaps_begin) const { + matrix.rightVecSumAtNZSparse(activeInputs_begin, activeInputs_end, + overlaps_begin); } -void SparseMatrixConnections::computeActivity( - const UInt32* activeInputs_begin, const UInt32* activeInputs_end, - Real32 permanenceThreshold, Int32* overlaps_begin) const -{ - matrix.rightVecSumAtNZGteThresholdSparse( - activeInputs_begin, activeInputs_end, - overlaps_begin, permanenceThreshold); +void SparseMatrixConnections::computeActivity(const UInt32 *activeInputs_begin, + const UInt32 *activeInputs_end, + Real32 permanenceThreshold, + Int32 *overlaps_begin) const { + matrix.rightVecSumAtNZGteThresholdSparse(activeInputs_begin, activeInputs_end, + overlaps_begin, permanenceThreshold); } -void SparseMatrixConnections::adjustSynapses( - const UInt32* segments_begin, const UInt32* segments_end, - const UInt32* activeInputs_begin, const UInt32* activeInputs_end, - Real32 activePermanenceDelta, Real32 inactivePermanenceDelta) -{ - matrix.incrementNonZerosOnOuter( - segments_begin, segments_end, - activeInputs_begin, activeInputs_end, - activePermanenceDelta); +void SparseMatrixConnections::adjustSynapses(const UInt32 *segments_begin, + const UInt32 *segments_end, + const UInt32 *activeInputs_begin, + const UInt32 *activeInputs_end, + Real32 activePermanenceDelta, + Real32 inactivePermanenceDelta) { + matrix.incrementNonZerosOnOuter(segments_begin, segments_end, + activeInputs_begin, activeInputs_end, + activePermanenceDelta); matrix.incrementNonZerosOnRowsExcludingCols( - segments_begin, segments_end, - activeInputs_begin, activeInputs_end, - inactivePermanenceDelta); + segments_begin, segments_end, activeInputs_begin, activeInputs_end, + inactivePermanenceDelta); clipPermanences(segments_begin, segments_end); } void SparseMatrixConnections::adjustActiveSynapses( - const UInt32* segments_begin, const UInt32* segments_end, - const UInt32* activeInputs_begin, const UInt32* activeInputs_end, - Real32 permanenceDelta) -{ - matrix.incrementNonZerosOnOuter( - segments_begin, segments_end, - activeInputs_begin, activeInputs_end, - permanenceDelta); + const UInt32 *segments_begin, const UInt32 *segments_end, + const UInt32 *activeInputs_begin, const UInt32 *activeInputs_end, + Real32 permanenceDelta) { + matrix.incrementNonZerosOnOuter(segments_begin, segments_end, + activeInputs_begin, activeInputs_end, + permanenceDelta); clipPermanences(segments_begin, segments_end); } void SparseMatrixConnections::adjustInactiveSynapses( - const UInt32* segments_begin, const UInt32* segments_end, - const UInt32* activeInputs_begin, const UInt32* activeInputs_end, - Real32 permanenceDelta) -{ + const UInt32 *segments_begin, const UInt32 *segments_end, + const UInt32 *activeInputs_begin, const UInt32 *activeInputs_end, + Real32 permanenceDelta) { matrix.incrementNonZerosOnRowsExcludingCols( - segments_begin, segments_end, - activeInputs_begin, activeInputs_end, - permanenceDelta); + segments_begin, segments_end, activeInputs_begin, activeInputs_end, + permanenceDelta); clipPermanences(segments_begin, segments_end); } -void SparseMatrixConnections::growSynapses( - const UInt32* segments_begin, const UInt32* segments_end, - const UInt32* inputs_begin, const UInt32* inputs_end, - Real32 initialPermanence) -{ - matrix.setZerosOnOuter( - segments_begin, segments_end, - inputs_begin, inputs_end, - initialPermanence); +void SparseMatrixConnections::growSynapses(const UInt32 *segments_begin, + const UInt32 *segments_end, + const UInt32 *inputs_begin, + const UInt32 *inputs_end, + Real32 initialPermanence) { + matrix.setZerosOnOuter(segments_begin, segments_end, inputs_begin, inputs_end, + initialPermanence); } void SparseMatrixConnections::growSynapsesToSample( - const UInt32* segments_begin, const UInt32* segments_end, - const UInt32* inputs_begin, const UInt32* inputs_end, - Int32 sampleSize, Real32 initialPermanence, nupic::Random& rng) -{ - matrix.setRandomZerosOnOuter( - segments_begin, segments_end, - inputs_begin, inputs_end, - sampleSize, initialPermanence, rng); + const UInt32 *segments_begin, const UInt32 *segments_end, + const UInt32 *inputs_begin, const UInt32 *inputs_end, Int32 sampleSize, + Real32 initialPermanence, nupic::Random &rng) { + matrix.setRandomZerosOnOuter(segments_begin, segments_end, inputs_begin, + inputs_end, sampleSize, initialPermanence, rng); clipPermanences(segments_begin, segments_end); } void SparseMatrixConnections::growSynapsesToSample( - const UInt32* segments_begin, const UInt32* segments_end, - const UInt32* inputs_begin, const UInt32* inputs_end, - const Int32* sampleSizes_begin, const Int32* sampleSizes_end, - Real32 initialPermanence, nupic::Random& rng) -{ + const UInt32 *segments_begin, const UInt32 *segments_end, + const UInt32 *inputs_begin, const UInt32 *inputs_end, + const Int32 *sampleSizes_begin, const Int32 *sampleSizes_end, + Real32 initialPermanence, nupic::Random &rng) { NTA_ASSERT(std::distance(sampleSizes_begin, sampleSizes_end) == std::distance(segments_begin, segments_end)); - matrix.setRandomZerosOnOuter( - segments_begin, segments_end, - inputs_begin, inputs_end, - sampleSizes_begin, sampleSizes_end, - initialPermanence, rng); + matrix.setRandomZerosOnOuter(segments_begin, segments_end, inputs_begin, + inputs_end, sampleSizes_begin, sampleSizes_end, + initialPermanence, rng); clipPermanences(segments_begin, segments_end); } -void SparseMatrixConnections::clipPermanences( - const UInt32* segments_begin, const UInt32* segments_end) -{ +void SparseMatrixConnections::clipPermanences(const UInt32 *segments_begin, + const UInt32 *segments_end) { matrix.clipRowsBelowAndAbove(segments_begin, segments_end, 0.0, 1.0); } void SparseMatrixConnections::mapSegmentsToSynapseCounts( - const UInt32* segments_begin, const UInt32* segments_end, - Int32* out_begin) const -{ + const UInt32 *segments_begin, const UInt32 *segments_end, + Int32 *out_begin) const { matrix.nNonZerosPerRow(segments_begin, segments_end, out_begin); } diff --git a/src/nupic/math/SparseMatrixConnections.hpp b/src/nupic/math/SparseMatrixConnections.hpp index eea2f7b2bc..db1e33c688 100644 --- a/src/nupic/math/SparseMatrixConnections.hpp +++ b/src/nupic/math/SparseMatrixConnections.hpp @@ -33,213 +33,218 @@ namespace nupic { +/** + * Wraps the SparseMatrix with an easy-to-read API that stores dendrite + * segments as rows in the matrix. + * + * The internal SparseMatrix is part of the public API. It is exposed via the + * "matrix" member variable. + */ +class SparseMatrixConnections + : public SegmentMatrixAdapter> { + +public: /** - * Wraps the SparseMatrix with an easy-to-read API that stores dendrite - * segments as rows in the matrix. + * SparseMatrixConnections constructor + * + * @param numCells + * The number of cells in this Connections + * + * @param numInputs + * The number of input bits, i.e. the number of columns in the internal + * SparseMatrix + */ + SparseMatrixConnections(UInt32 numCells, UInt32 numInputs); + + /** + * Compute the number of active synapses on each segment. + * + * @param activeInputs + * The active input bits * - * The internal SparseMatrix is part of the public API. It is exposed via the - * "matrix" member variable. + * @param overlaps + * Output buffer that will be filled with a number of active synapses for + * each segment. This number is the "overlap" between the input SDR and the + * SDR formed by each segment's synapses. */ - class SparseMatrixConnections : - public SegmentMatrixAdapter> { - - public: - /** - * SparseMatrixConnections constructor - * - * @param numCells - * The number of cells in this Connections - * - * @param numInputs - * The number of input bits, i.e. the number of columns in the internal - * SparseMatrix - */ - SparseMatrixConnections(UInt32 numCells, UInt32 numInputs); - - /** - * Compute the number of active synapses on each segment. - * - * @param activeInputs - * The active input bits - * - * @param overlaps - * Output buffer that will be filled with a number of active synapses for - * each segment. This number is the "overlap" between the input SDR and the - * SDR formed by each segment's synapses. - */ - void computeActivity( - const UInt32* activeInputs_begin, const UInt32* activeInputs_end, - Int32* overlaps_begin) const; - - /** - * Compute the number of active connected synapses on each segment. - * - * @param activeInputs - * The active input bits - * - * @param permanenceThreshold - * The minimum permanence required for a synapse to be "connected" - * - * @param overlaps - * Output buffer that will be filled with a number of active connected - * synapses for each segment. This number is the "overlap" between the input - * SDR and the SDR formed by each segment's connected synapses. - */ - void computeActivity( - const UInt32* activeInputs_begin, const UInt32* activeInputs_end, - Real32 permanenceThreshold, Int32* overlaps_begin) const; - - /** - * For each specified segment, update the permanence of each synapse - * according to whether the synapse would be active given the specified - * active inputs. - * - * @param segments - * The segments to modify - * - * @param activeInputs - * The active inputs. Used to compute the active synapses. - * - * @param activePermanenceDelta - * Additive constant for each active synapse's permanence - * - * @param inactivePermanenceDelta - * Additive constant for each inactive synapse's permanence - */ - void adjustSynapses( - const UInt32* segments_begin, const UInt32* segments_end, - const UInt32* activeInputs_begin, const UInt32* activeInputs_end, - Real32 activePermanenceDelta, Real32 inactivePermanenceDelta); - - /** - * For each specified segment, add a delta to the permanences of the - * synapses that would be active given the specified active inputs. - * - * @param segments - * The segments to modify - * - * @param activeInputs - * The active inputs. Used to compute the active synapses. - * - * @param permanenceDelta - * Additive constant for each active synapse's permanence - */ - void adjustActiveSynapses( - const UInt32* segments_begin, const UInt32* segments_end, - const UInt32* activeInputs_begin, const UInt32* activeInputs_end, - Real32 permanenceDelta); - - /** - * For each specified segment, add a delta to the permanences of the - * synapses that would be inactive given the specified active inputs. - * - * @param segments - * The segments to modify - * - * @param activeInputs - * The active inputs. Used to compute the active synapses. - * - * @param permanenceDelta - * Additive constant for each inactive synapse's permanence - */ - void adjustInactiveSynapses( - const UInt32* segments_begin, const UInt32* segments_end, - const UInt32* activeInputs_begin, const UInt32* activeInputs_end, - Real32 permanenceDelta); - - /** - * For each specified segments, grow synapses to all specified inputs that - * aren't already connected to the segment. - * - * @param segments - * The segments to modify - * - * @param inputs - * The inputs to connect to - * - * @param initialPermanence - * The permanence for each added synapse - */ - void growSynapses( - const UInt32* segments_begin, const UInt32* segments_end, - const UInt32* inputs_begin, const UInt32* inputs_end, - Real32 initialPermanence); - - /** - * For each specified segments, grow synapses to a random subset of the - * inputs that aren't already connected to the segment. - * - * @param segments - * The segments to modify - * - * @param inputs - * The inputs to sample - * - * @param sampleSize - * The number of synapses to attempt to grow per segment - * - * @param initialPermanence - * The permanence for each added synapse - * - * @param rng - * Random number generator - */ - void growSynapsesToSample( - const UInt32* segments_begin, const UInt32* segments_end, - const UInt32* inputs_begin, const UInt32* inputs_end, - Int32 sampleSize, Real32 initialPermanence, nupic::Random& rng); - - /** - * For each specified segments, grow synapses to a random subset of the - * inputs that aren't already connected to the segment. - * - * @param segments - * The segments to modify - * - * @param inputs - * The inputs to sample - * - * @param sampleSizes - * The number of synapses to attempt to grow for each segment. - * This list must be the same length as 'segments'. - * - * @param initialPermanence - * The permanence for each added synapse - * - * @param rng - * Random number generator - */ - void growSynapsesToSample( - const UInt32* segments_begin, const UInt32* segments_end, - const UInt32* inputs_begin, const UInt32* inputs_end, - const Int32* sampleSizes_begin, const Int32* sampleSizes_end, - Real32 initialPermanence, nupic::Random& rng); - - /** - * Clip all permanences to a minimum of 0.0 and a maximum of 1.0. - * For any synapse with <= 0.0 permanence, destroy the synapse. - * For any synapse with > 1.0 permanence, set the permanence to 1.0. - * - * @param segments - * The segments to modify - */ - void clipPermanences( - const UInt32* segments_begin, const UInt32* segments_end); - - /** - * Get the number of synapses for each specified segment. - * - * @param segments - * The segments to query - * - * @param out - * An output buffer that will be filled with the counts - */ - void mapSegmentsToSynapseCounts( - const UInt32* segments_begin, const UInt32* segments_end, - Int32* out_begin) const; - }; + void computeActivity(const UInt32 *activeInputs_begin, + const UInt32 *activeInputs_end, + Int32 *overlaps_begin) const; + /** + * Compute the number of active connected synapses on each segment. + * + * @param activeInputs + * The active input bits + * + * @param permanenceThreshold + * The minimum permanence required for a synapse to be "connected" + * + * @param overlaps + * Output buffer that will be filled with a number of active connected + * synapses for each segment. This number is the "overlap" between the input + * SDR and the SDR formed by each segment's connected synapses. + */ + void computeActivity(const UInt32 *activeInputs_begin, + const UInt32 *activeInputs_end, + Real32 permanenceThreshold, Int32 *overlaps_begin) const; + + /** + * For each specified segment, update the permanence of each synapse + * according to whether the synapse would be active given the specified + * active inputs. + * + * @param segments + * The segments to modify + * + * @param activeInputs + * The active inputs. Used to compute the active synapses. + * + * @param activePermanenceDelta + * Additive constant for each active synapse's permanence + * + * @param inactivePermanenceDelta + * Additive constant for each inactive synapse's permanence + */ + void adjustSynapses(const UInt32 *segments_begin, const UInt32 *segments_end, + const UInt32 *activeInputs_begin, + const UInt32 *activeInputs_end, + Real32 activePermanenceDelta, + Real32 inactivePermanenceDelta); + + /** + * For each specified segment, add a delta to the permanences of the + * synapses that would be active given the specified active inputs. + * + * @param segments + * The segments to modify + * + * @param activeInputs + * The active inputs. Used to compute the active synapses. + * + * @param permanenceDelta + * Additive constant for each active synapse's permanence + */ + void adjustActiveSynapses(const UInt32 *segments_begin, + const UInt32 *segments_end, + const UInt32 *activeInputs_begin, + const UInt32 *activeInputs_end, + Real32 permanenceDelta); + + /** + * For each specified segment, add a delta to the permanences of the + * synapses that would be inactive given the specified active inputs. + * + * @param segments + * The segments to modify + * + * @param activeInputs + * The active inputs. Used to compute the active synapses. + * + * @param permanenceDelta + * Additive constant for each inactive synapse's permanence + */ + void adjustInactiveSynapses(const UInt32 *segments_begin, + const UInt32 *segments_end, + const UInt32 *activeInputs_begin, + const UInt32 *activeInputs_end, + Real32 permanenceDelta); + + /** + * For each specified segments, grow synapses to all specified inputs that + * aren't already connected to the segment. + * + * @param segments + * The segments to modify + * + * @param inputs + * The inputs to connect to + * + * @param initialPermanence + * The permanence for each added synapse + */ + void growSynapses(const UInt32 *segments_begin, const UInt32 *segments_end, + const UInt32 *inputs_begin, const UInt32 *inputs_end, + Real32 initialPermanence); + + /** + * For each specified segments, grow synapses to a random subset of the + * inputs that aren't already connected to the segment. + * + * @param segments + * The segments to modify + * + * @param inputs + * The inputs to sample + * + * @param sampleSize + * The number of synapses to attempt to grow per segment + * + * @param initialPermanence + * The permanence for each added synapse + * + * @param rng + * Random number generator + */ + void growSynapsesToSample(const UInt32 *segments_begin, + const UInt32 *segments_end, + const UInt32 *inputs_begin, + const UInt32 *inputs_end, Int32 sampleSize, + Real32 initialPermanence, nupic::Random &rng); + + /** + * For each specified segments, grow synapses to a random subset of the + * inputs that aren't already connected to the segment. + * + * @param segments + * The segments to modify + * + * @param inputs + * The inputs to sample + * + * @param sampleSizes + * The number of synapses to attempt to grow for each segment. + * This list must be the same length as 'segments'. + * + * @param initialPermanence + * The permanence for each added synapse + * + * @param rng + * Random number generator + */ + void growSynapsesToSample(const UInt32 *segments_begin, + const UInt32 *segments_end, + const UInt32 *inputs_begin, + const UInt32 *inputs_end, + const Int32 *sampleSizes_begin, + const Int32 *sampleSizes_end, + Real32 initialPermanence, nupic::Random &rng); + + /** + * Clip all permanences to a minimum of 0.0 and a maximum of 1.0. + * For any synapse with <= 0.0 permanence, destroy the synapse. + * For any synapse with > 1.0 permanence, set the permanence to 1.0. + * + * @param segments + * The segments to modify + */ + void clipPermanences(const UInt32 *segments_begin, + const UInt32 *segments_end); + + /** + * Get the number of synapses for each specified segment. + * + * @param segments + * The segments to query + * + * @param out + * An output buffer that will be filled with the counts + */ + void mapSegmentsToSynapseCounts(const UInt32 *segments_begin, + const UInt32 *segments_end, + Int32 *out_begin) const; }; +}; // namespace nupic + #endif // NTA_SPARSE_MATRIX_CONNECTIONS_HPP diff --git a/src/nupic/math/SparseRLEMatrix.hpp b/src/nupic/math/SparseRLEMatrix.hpp index 4d9f9014d3..40034fa15f 100644 --- a/src/nupic/math/SparseRLEMatrix.hpp +++ b/src/nupic/math/SparseRLEMatrix.hpp @@ -20,649 +20,584 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Definition and implementation for SparseRLEMatrix */ #ifndef NTA_SPARSE_RLE_MATRIX_HPP #define NTA_SPARSE_RLE_MATRIX_HPP -#include #include +#include #include #include //-------------------------------------------------------------------------------- namespace nupic { - + +//-------------------------------------------------------------------------------- +/** + * A matrix where only the positions and values of runs of non-zeros are stored. + * Optionally compresses values using zlib (off by default). + * + * WATCH OUT! That the Index type doesn't become too small to store parameters + * of the matrix, such as total number of non-zeros. + * + * TODO: run length encode different values, which can be valuable when + * quantizing vector components. This could be another data structure. + */ +template class SparseRLEMatrix { +public: + typedef unsigned long ulong_size_type; + typedef Index size_type; + typedef Value value_type; + +private: + typedef typename std::vector IndexVector; + typedef typename std::vector ValueVector; + typedef typename std::pair Row; + + bool compress_; + std::vector data_; + IndexVector indb_; + ValueVector nzb_; + std::vector c_; + +public: //-------------------------------------------------------------------------------- - /** - * A matrix where only the positions and values of runs of non-zeros are stored. - * Optionally compresses values using zlib (off by default). - * - * WATCH OUT! That the Index type doesn't become too small to store parameters - * of the matrix, such as total number of non-zeros. - * - * TODO: run length encode different values, which can be valuable when - * quantizing vector components. This could be another data structure. - */ - template - class SparseRLEMatrix - { - public: - typedef unsigned long ulong_size_type; - typedef Index size_type; - typedef Value value_type; - - private: - typedef typename std::vector IndexVector; - typedef typename std::vector ValueVector; - typedef typename std::pair Row; - - bool compress_; - std::vector data_; - IndexVector indb_; - ValueVector nzb_; - std::vector c_; - - public: - //-------------------------------------------------------------------------------- - inline SparseRLEMatrix() - : compress_(false), - data_(), - indb_(), - nzb_(), - c_() - {} - - //-------------------------------------------------------------------------------- - inline SparseRLEMatrix(std::istream& inStream) - : compress_(false), - data_(), - indb_(), - nzb_(), - c_() - { - fromCSR(inStream); - } + inline SparseRLEMatrix() : compress_(false), data_(), indb_(), nzb_(), c_() {} - //-------------------------------------------------------------------------------- - template - inline SparseRLEMatrix(InputIterator begin, InputIterator end) - : compress_(false), - data_(), - indb_(), - nzb_(), - c_() - { - fromDense(begin, end); - } + //-------------------------------------------------------------------------------- + inline SparseRLEMatrix(std::istream &inStream) + : compress_(false), data_(), indb_(), nzb_(), c_() { + fromCSR(inStream); + } - //-------------------------------------------------------------------------------- - inline ~SparseRLEMatrix() - { - data_.clear(); - indb_.clear(); - nzb_.clear(); - c_.clear(); - } + //-------------------------------------------------------------------------------- + template + inline SparseRLEMatrix(InputIterator begin, InputIterator end) + : compress_(false), data_(), indb_(), nzb_(), c_() { + fromDense(begin, end); + } - //-------------------------------------------------------------------------------- - inline const std::string getVersion() const - { - return std::string("sm_rle_1.0"); - } + //-------------------------------------------------------------------------------- + inline ~SparseRLEMatrix() { + data_.clear(); + indb_.clear(); + nzb_.clear(); + c_.clear(); + } - //-------------------------------------------------------------------------------- - inline ulong_size_type capacity() const - { - ulong_size_type n = 0; - for (size_type i = 0; i != data_.size(); ++i) - n += data_[i].second.capacity(); - return n; - } + //-------------------------------------------------------------------------------- + inline const std::string getVersion() const { + return std::string("sm_rle_1.0"); + } - //-------------------------------------------------------------------------------- - inline ulong_size_type nBytes() const - { - ulong_size_type n = sizeof(SparseRLEMatrix); - n += data_.capacity() * sizeof(Row); - for (size_type i = 0; i != nRows(); ++i) - n += data_[i].first.capacity() * sizeof(size_type) - + data_[i].second.capacity() * sizeof(value_type); - n += indb_.capacity() * sizeof(size_type); - n += nzb_.capacity() * sizeof(value_type); - n += c_.capacity() * sizeof(uLong); - return n; - } + //-------------------------------------------------------------------------------- + inline ulong_size_type capacity() const { + ulong_size_type n = 0; + for (size_type i = 0; i != data_.size(); ++i) + n += data_[i].second.capacity(); + return n; + } - //-------------------------------------------------------------------------------- - inline bool isCompressed() const - { - return compress_; - } + //-------------------------------------------------------------------------------- + inline ulong_size_type nBytes() const { + ulong_size_type n = sizeof(SparseRLEMatrix); + n += data_.capacity() * sizeof(Row); + for (size_type i = 0; i != nRows(); ++i) + n += data_[i].first.capacity() * sizeof(size_type) + + data_[i].second.capacity() * sizeof(value_type); + n += indb_.capacity() * sizeof(size_type); + n += nzb_.capacity() * sizeof(value_type); + n += c_.capacity() * sizeof(uLong); + return n; + } - //-------------------------------------------------------------------------------- - inline ulong_size_type nRows() const - { - return data_.size(); - } + //-------------------------------------------------------------------------------- + inline bool isCompressed() const { return compress_; } - //-------------------------------------------------------------------------------- - inline size_type nCols() const - { - return indb_.size(); - } + //-------------------------------------------------------------------------------- + inline ulong_size_type nRows() const { return data_.size(); } - //-------------------------------------------------------------------------------- - inline size_type nNonZerosOnRow(ulong_size_type row) const - { - { // Pre-conditions - NTA_ASSERT(0 <= row && row < nRows()) - << "SparseRLEMatrix::nNonZerosOnRow: " - << "Invalid row index: " << row; - } // End pre-conditions - - size_type n = 0; - for (size_type j = 1; j < data_[row].first.size(); j += 2) - n += data_[row].first[j] - data_[row].first[j-1]; - return n; - } + //-------------------------------------------------------------------------------- + inline size_type nCols() const { return indb_.size(); } - //-------------------------------------------------------------------------------- - inline ulong_size_type nNonZeros() const - { - ulong_size_type n = 0; - for (ulong_size_type i = 0; i != data_.size(); ++i) - n += nNonZerosOnRow(i); - return n; - } + //-------------------------------------------------------------------------------- + inline size_type nNonZerosOnRow(ulong_size_type row) const { + { // Pre-conditions + NTA_ASSERT(0 <= row && row < nRows()) + << "SparseRLEMatrix::nNonZerosOnRow: " + << "Invalid row index: " << row; + } // End pre-conditions + + size_type n = 0; + for (size_type j = 1; j < data_[row].first.size(); j += 2) + n += data_[row].first[j] - data_[row].first[j - 1]; + return n; + } - //-------------------------------------------------------------------------------- - /** - * Adjusts the size of the internal vectors so that their capacity matches - * their size. - */ - inline void compact() - { - if (capacity() == nNonZeros() - && data_.capacity() == data_.size() - && indb_.capacity() == indb_.size()) - return; - - std::stringstream buffer; - toCSR(buffer); - clear(); - fromCSR(buffer); - - NTA_ASSERT(capacity() == nNonZeros()); - } + //-------------------------------------------------------------------------------- + inline ulong_size_type nNonZeros() const { + ulong_size_type n = 0; + for (ulong_size_type i = 0; i != data_.size(); ++i) + n += nNonZerosOnRow(i); + return n; + } - //-------------------------------------------------------------------------------- - /** - * Compress data using compression algorithm. - */ - inline void compressData() - { - if (compress_) - return; + //-------------------------------------------------------------------------------- + /** + * Adjusts the size of the internal vectors so that their capacity matches + * their size. + */ + inline void compact() { + if (capacity() == nNonZeros() && data_.capacity() == data_.size() && + indb_.capacity() == indb_.size()) + return; - if (c_.empty()) - c_.resize(nRows(), 0); + std::stringstream buffer; + toCSR(buffer); + clear(); + fromCSR(buffer); - for (ulong_size_type i = 0; i != nRows(); ++i) - compressRow_(i); + NTA_ASSERT(capacity() == nNonZeros()); + } - compress_ = true; - } + //-------------------------------------------------------------------------------- + /** + * Compress data using compression algorithm. + */ + inline void compressData() { + if (compress_) + return; - //-------------------------------------------------------------------------------- - inline void decompressData() - { - if (!compress_) - return; - - if (c_.empty()) - c_.resize(nRows(), 0); - - for (ulong_size_type i = 0; i != nRows(); ++i) { - uLongf dstLen = decompressRow_(i); - size_type n = dstLen / sizeof(value_type); - ValueVector new_vector(nzb_.begin(), nzb_.begin() + n); - data_[i].second.swap(new_vector); - c_[i] = 0; - } + if (c_.empty()) + c_.resize(nRows(), 0); + + for (ulong_size_type i = 0; i != nRows(); ++i) + compressRow_(i); - compress_ = false; + compress_ = true; + } + + //-------------------------------------------------------------------------------- + inline void decompressData() { + if (!compress_) + return; + + if (c_.empty()) + c_.resize(nRows(), 0); + + for (ulong_size_type i = 0; i != nRows(); ++i) { + uLongf dstLen = decompressRow_(i); + size_type n = dstLen / sizeof(value_type); + ValueVector new_vector(nzb_.begin(), nzb_.begin() + n); + data_[i].second.swap(new_vector); + c_[i] = 0; } - //-------------------------------------------------------------------------------- - /** - * Deallocates memory used by this instance. - */ - inline void clear() - { - std::vector empty; - data_.swap(empty); - IndexVector empty_indb; - indb_.swap(empty_indb); - ValueVector empty_nzb; - nzb_.swap(empty_nzb); - std::vector empty_c; - c_.swap(empty_c); - compress_ = false; - - NTA_ASSERT(nBytes() == sizeof(SparseRLEMatrix)); + compress_ = false; + } + + //-------------------------------------------------------------------------------- + /** + * Deallocates memory used by this instance. + */ + inline void clear() { + std::vector empty; + data_.swap(empty); + IndexVector empty_indb; + indb_.swap(empty_indb); + ValueVector empty_nzb; + nzb_.swap(empty_nzb); + std::vector empty_c; + c_.swap(empty_c); + compress_ = false; + + NTA_ASSERT(nBytes() == sizeof(SparseRLEMatrix)); + } + + //-------------------------------------------------------------------------------- + /** + * Appends a row to this matrix. + */ + template + inline void appendRow(InputIterator begin, InputIterator end) { + { // Pre-conditions + NTA_ASSERT(begin <= end) << "SparseRLEMatrix::appendRow: " + << "Invalid range"; + } // End pre-conditions + + // Resize matrix if needed + if (indb_.size() < (size_type)(end - begin)) { + size_type ncols = (size_type)(end - begin); + indb_.resize(ncols); + nzb_.resize(ncols); } - //-------------------------------------------------------------------------------- - /** - * Appends a row to this matrix. - */ - template - inline void - appendRow(InputIterator begin, InputIterator end) - { - { // Pre-conditions - NTA_ASSERT(begin <= end) - << "SparseRLEMatrix::appendRow: " - << "Invalid range"; - } // End pre-conditions - - // Resize matrix if needed - if (indb_.size() < (size_type)(end - begin)) { - size_type ncols = (size_type) (end - begin); - indb_.resize(ncols); - nzb_.resize(ncols); - } - - typename IndexVector::iterator indb = indb_.begin(); - typename ValueVector::iterator nzb = nzb_.begin(); - InputIterator it = begin; - - // Find positions and values of non-zeros - while (it != end) { - while (it != end && nupic::nearlyZero(*it)) - ++it; - if (it != end) { - *indb++ = (size_type)(it - begin); - while (it != end && !nupic::nearlyZero(*it)) - *nzb++ = (value_type) *it++; - *indb++ = (size_type)(it - begin); - } + typename IndexVector::iterator indb = indb_.begin(); + typename ValueVector::iterator nzb = nzb_.begin(); + InputIterator it = begin; + + // Find positions and values of non-zeros + while (it != end) { + while (it != end && nupic::nearlyZero(*it)) + ++it; + if (it != end) { + *indb++ = (size_type)(it - begin); + while (it != end && !nupic::nearlyZero(*it)) + *nzb++ = (value_type)*it++; + *indb++ = (size_type)(it - begin); } + } - data_.resize(data_.size() + 1); - Row& row = data_[data_.size() - 1]; - IndexVector& ind = row.first; - ValueVector& nz = row.second; - size_type ind_size = (size_type)(indb - indb_.begin()); - size_type nz_size = (size_type)(nzb - nzb_.begin()); - ind.reserve(ind_size); - nz.reserve(nz_size); - ind.insert(ind.end(), indb_.begin(), indb_.begin() + ind_size); - nz.insert(nz.end(), nzb_.begin(), nzb_.begin() + nz_size); - c_.push_back(0); + data_.resize(data_.size() + 1); + Row &row = data_[data_.size() - 1]; + IndexVector &ind = row.first; + ValueVector &nz = row.second; + size_type ind_size = (size_type)(indb - indb_.begin()); + size_type nz_size = (size_type)(nzb - nzb_.begin()); + ind.reserve(ind_size); + nz.reserve(nz_size); + ind.insert(ind.end(), indb_.begin(), indb_.begin() + ind_size); + nz.insert(nz.end(), nzb_.begin(), nzb_.begin() + nz_size); + c_.push_back(0); + + if (compress_) + compressRow_(data_.size() - 1); + } - if (compress_) - compressRow_(data_.size()-1); + //-------------------------------------------------------------------------------- + template + inline void getRowToDense(ulong_size_type r, OutputIterator begin, + OutputIterator end) const { + { // Pre-conditions + NTA_ASSERT(0 <= r && r < nRows()) << "SparseRLEMatrix::getRow: " + << "Invalid row index: " << r; + + NTA_ASSERT((size_type)(end - begin) == nCols()) + << "SparseRLEMatrix::getRow: " + << "Not enough memory"; + } // End pre-conditions + + const Row &row = data_[r]; + const IndexVector &ind = row.first; + size_type n = ind.size(); + typename ValueVector::const_iterator nz; + + if (compress_) { + const_cast(*this).decompressRow_(r); + nz = nzb_.begin(); + } else + nz = data_[r].second.begin(); + + size_type j = 0; + for (size_type i = 0; i + 1 < n; i += 2) { + for (; j != ind[i]; ++j) + *(begin + j) = (value_type)0; + for (; j != ind[i + 1]; ++j) + *(begin + j) = *nz++; } + for (; j < nCols(); ++j) + *(begin + j) = (value_type)0; + } + + //-------------------------------------------------------------------------------- + /** + * Returns index of first row within of argument, or nRows() if + * none. + */ + template + inline ulong_size_type firstRowCloserThan(InputIterator begin, + InputIterator end, + nupic::Real32 distance) const { + { NTA_ASSERT(begin <= end); } + + nupic::Real32 d2 = distance * distance; - //-------------------------------------------------------------------------------- - template - inline void - getRowToDense(ulong_size_type r, OutputIterator begin, OutputIterator end) const - { - { // Pre-conditions - NTA_ASSERT(0 <= r && r < nRows()) - << "SparseRLEMatrix::getRow: " - << "Invalid row index: " << r; - - NTA_ASSERT((size_type)(end - begin) == nCols()) - << "SparseRLEMatrix::getRow: " - << "Not enough memory"; - } // End pre-conditions - - const Row& row = data_[r]; - const IndexVector& ind = row.first; + for (ulong_size_type r = 0; r != nRows(); ++r) { + + const Row &row = data_[r]; + const IndexVector &ind = row.first; size_type n = ind.size(); typename ValueVector::const_iterator nz; if (compress_) { - const_cast(*this).decompressRow_(r); - nz = nzb_.begin(); - } else - nz = data_[r].second.begin(); - - size_type j = 0; - for (size_type i = 0; i+1 < n; i += 2) { - for (; j != ind[i]; ++j) - *(begin + j) = (value_type) 0; - for (; j != ind[i+1]; ++j) - *(begin + j) = *nz++; + const_cast(*this).decompressRow_(r); + nz = nzb_.begin(); + } else { + nz = data_[r].second.begin(); } - for (; j < nCols(); ++j) - *(begin + j) = (value_type) 0; - } - //-------------------------------------------------------------------------------- - /** - * Returns index of first row within of argument, or nRows() if none. - */ - template - inline ulong_size_type - firstRowCloserThan(InputIterator begin, InputIterator end, nupic::Real32 distance) const - { - { - NTA_ASSERT(begin <= end); - } - - nupic::Real32 d2 = distance * distance; - - for (ulong_size_type r = 0; r != nRows(); ++r) { - - const Row& row = data_[r]; - const IndexVector& ind = row.first; - size_type n = ind.size(); - typename ValueVector::const_iterator nz; - - if (compress_) { - const_cast(*this).decompressRow_(r); - nz = nzb_.begin(); - } else { - nz = data_[r].second.begin(); - } - - nupic::Real32 d = 0; - size_type j = 0; - for (size_type i = 0; i+1 < n && d < d2; i += 2) { - for (; j != ind[i]; ++j) - d += *(begin + j) * *(begin + j); - for (; j != ind[i+1] && d < d2; ++j) { - nupic::Real32 v = *(begin + j) - *nz++; - d += v * v; - } - } - for (; j < nCols() && d < d2; ++j) - d += *(begin + j) * *(begin + j); - - if (d < d2) - return r; + nupic::Real32 d = 0; + size_type j = 0; + for (size_type i = 0; i + 1 < n && d < d2; i += 2) { + for (; j != ind[i]; ++j) + d += *(begin + j) * *(begin + j); + for (; j != ind[i + 1] && d < d2; ++j) { + nupic::Real32 v = *(begin + j) - *nz++; + d += v * v; + } } + for (; j < nCols() && d < d2; ++j) + d += *(begin + j) * *(begin + j); - return nRows(); + if (d < d2) + return r; } - //-------------------------------------------------------------------------------- - inline ulong_size_type CSRSize() const - { - char buffer[32]; - - std::stringstream b; - b << getVersion() << " " - << nRows() << " " << nCols() << " " - << (compress_ ? "1" : "0") << " "; - - ulong_size_type n = b.str().size(); - - for (ulong_size_type row = 0; row != nRows(); ++row) { - - size_type n1 = data_[row].first.size(); - n += sprintf(buffer, "%d ", n1); - - for (size_type j = 0; j != n1; ++j) - n += sprintf(buffer, "%d ", data_[row].first[j]); - - if (compress_) { - const_cast(*this).decompressRow_(row); - ulong_size_type n2 = nNonZerosOnRow(row); - for (size_type j = 0; j != n2; ++j) - n += sprintf(buffer, "%.15g ", (double) nzb_[j]); - } else { - ulong_size_type n2 = nNonZerosOnRow(row); - for (size_type j = 0; j != n2; ++j) - n += sprintf(buffer, "%.15g ", (double) data_[row].second[j]); - } - } + return nRows(); + } - return n; - } + //-------------------------------------------------------------------------------- + inline ulong_size_type CSRSize() const { + char buffer[32]; - //-------------------------------------------------------------------------------- - inline void toCSR(std::ostream& outStream) const - { - { // Pre-conditions - NTA_ASSERT(outStream.good()) - << "SparseRLEMatrix::toCSR: Bad stream"; - } // End pre-conditions - - outStream << getVersion() << " " - << nRows() << " " << nCols() << " " - << (compress_ ? "1" : "0") << " " - << std::setprecision(15); - - for (ulong_size_type i = 0; i != nRows(); ++i) { - outStream << data_[i].first; - size_type nnzr = nNonZerosOnRow(i); - if (compress_) { - const_cast(*this).decompressRow_(i); - for (size_type k = 0; k != nnzr; ++k) - outStream << (double) nzb_[k] << " "; - } else { - for (size_type k = 0; k != nnzr; ++k) - outStream << (double) data_[i].second[k] << " "; - } - } - } + std::stringstream b; + b << getVersion() << " " << nRows() << " " << nCols() << " " + << (compress_ ? "1" : "0") << " "; - //-------------------------------------------------------------------------------- - inline void fromCSR(std::istream& inStream) - { - { // Pre-conditions - NTA_ASSERT(inStream.good()) - << "SparseRLEMatrix::fromCSR: Bad stream"; - } // End pre-conditions - - std::string version; - inStream >> version; - - NTA_CHECK(version == getVersion()) - << "SparseRLEMatrix::fromCSR: Unknown version: " - << version; - - ulong_size_type nrows = 0; - inStream >> nrows; - data_.resize(nrows); - - size_type ncols = 0; - inStream >> ncols; - indb_.resize(ncols); - nzb_.resize(ncols); + ulong_size_type n = b.str().size(); - int compressVal = 0; - inStream >> compressVal; - compress_ = compressVal == 1; + for (ulong_size_type row = 0; row != nRows(); ++row) { - if (compress_) - c_.resize(nRows(), 0); - - for (ulong_size_type i = 0; i != nrows; ++i) { - inStream >> data_[i].first; - data_[i].second.resize(nNonZerosOnRow(i)); - size_type k2 = 0; - for (size_type k = 0; k < data_[i].first.size(); k += 2) { - for (size_type j = data_[i].first[k]; j != data_[i].first[k+1]; ++j) { - double val; - inStream >> val; - data_[i].second[k2++] = (value_type) val; - } - } - /* - NTA_CHECK(data_[i].first.size() <= 2*nCols()) - << "SparseRLEMatrix::fromCSR: " - << "Too many indices"; - */ - NTA_CHECK(data_[i].second.size() <= nCols()) - << "SparseRLEMatrix::fromCSR: " - << "Too many values"; - for (size_type j = 0; j != data_[i].first.size(); ++j) { - size_type idx = data_[i].first[j]; - NTA_CHECK(idx <= nCols()) - << "SparseRLEMatrix::fromCSR: " - << "Invalid index: " << idx; - if (1 < j) - NTA_CHECK(data_[i].first[j-1] < idx) - << "SparseRLEMatrix::fromCSR: " - << "Invalid index: " << idx - << " - Indices need to be in strictly increasing order"; - } - NTA_CHECK(data_[i].second.size() == nNonZerosOnRow(i)) - << "SparseRLEMatrix::fromCSR: " - << "Mismatching number of indices and values"; - if (compress_) - compressRow_(i); + size_type n1 = data_[row].first.size(); + n += sprintf(buffer, "%d ", n1); + + for (size_type j = 0; j != n1; ++j) + n += sprintf(buffer, "%d ", data_[row].first[j]); + + if (compress_) { + const_cast(*this).decompressRow_(row); + ulong_size_type n2 = nNonZerosOnRow(row); + for (size_type j = 0; j != n2; ++j) + n += sprintf(buffer, "%.15g ", (double)nzb_[j]); + } else { + ulong_size_type n2 = nNonZerosOnRow(row); + for (size_type j = 0; j != n2; ++j) + n += sprintf(buffer, "%.15g ", (double)data_[row].second[j]); } } - //-------------------------------------------------------------------------------- - template - inline void toDense(OutputIterator begin, OutputIterator end) const - { - { // Pre-conditions - NTA_ASSERT((size_type)(end - begin) == nRows() * nCols()) - << "SparseRLEMatrix::toDense: " - << "Not enough memory"; - } // End pre-conditions - - for (ulong_size_type row = 0; row != data_.size(); ++row) - getRowToDense(row, begin + row*nCols(), begin + (row+1)*nCols()); - } + return n; + } - //-------------------------------------------------------------------------------- - /** - * Clears this instance and creates a new one from dense. - */ - template - inline void fromDense(ulong_size_type nrows, size_type ncols, - InputIterator begin, InputIterator end) - { - { // Pre-conditions - NTA_ASSERT((ulong_size_type)(end - begin) >= nrows * ncols) - << "SparseRLEMatrix::fromDense: " - << "Not enough memory"; - } // End pre-conditions - - clear(); - - for (ulong_size_type row = 0; row != nrows; ++row) - appendRow(begin + row * ncols, begin + (row+1) * ncols); - } + //-------------------------------------------------------------------------------- + inline void toCSR(std::ostream &outStream) const { + { // Pre-conditions + NTA_ASSERT(outStream.good()) << "SparseRLEMatrix::toCSR: Bad stream"; + } // End pre-conditions + + outStream << getVersion() << " " << nRows() << " " << nCols() << " " + << (compress_ ? "1" : "0") << " " << std::setprecision(15); - //-------------------------------------------------------------------------------- - inline void print(std::ostream& outStream) const - { - { // Pre-conditions - NTA_CHECK(outStream.good()) - << "SparseRLEMatrix::print: Bad stream"; - } // End pre-conditions - - std::vector buffer(nCols()); - - for (ulong_size_type row = 0; row != nRows(); ++row) { - getRowToDense(row, buffer.begin(), buffer.end()); - for (size_type col = 0; col != nCols(); ++col) - outStream << buffer[col] << " "; - outStream << std::endl; + for (ulong_size_type i = 0; i != nRows(); ++i) { + outStream << data_[i].first; + size_type nnzr = nNonZerosOnRow(i); + if (compress_) { + const_cast(*this).decompressRow_(i); + for (size_type k = 0; k != nnzr; ++k) + outStream << (double)nzb_[k] << " "; + } else { + for (size_type k = 0; k != nnzr; ++k) + outStream << (double)data_[i].second[k] << " "; } } + } - //-------------------------------------------------------------------------------- - inline void debugPrint() const - { - std::cout << "n rows= " << nRows() - << " n cols= " << nCols() - << " n nz= " << nNonZeros() - << " n bytes= " << nBytes() << std::endl; - std::cout << "this= " << sizeof(SparseRLEMatrix) - << " Row= " << sizeof(Row) - << " size= " << sizeof(size_type) - << " value= " << sizeof(value_type) - << " uLong= " << sizeof(uLong) << std::endl; - std::cout << "data= " << data_.capacity() << " " << data_.size() << std::endl; - std::cout << "indb= " << indb_.capacity() << " " << indb_.size() << std::endl; - std::cout << "nzb= " << nzb_.capacity() << " " << nzb_.size() << std::endl; - std::cout << "c= " << c_.capacity() << " " << c_.size() << std::endl; - for (ulong_size_type i = 0; i != nRows(); ++i) - std::cout << "row " << i << ": first: " - << data_[i].first.capacity() << " " - << data_[i].first.size() << " second: " - << data_[i].second.capacity() << " " - << data_[i].second.size() << std::endl; - for (size_type row = 0; row != nRows(); ++row) { - std::cout << data_[row].first << std::endl; - for (size_type i = 0; i != data_[row].second.size(); ++i) - std::cout << (float) data_[row].second[i] << " "; - std::cout << std::endl; + //-------------------------------------------------------------------------------- + inline void fromCSR(std::istream &inStream) { + { // Pre-conditions + NTA_ASSERT(inStream.good()) << "SparseRLEMatrix::fromCSR: Bad stream"; + } // End pre-conditions + + std::string version; + inStream >> version; + + NTA_CHECK(version == getVersion()) + << "SparseRLEMatrix::fromCSR: Unknown version: " << version; + + ulong_size_type nrows = 0; + inStream >> nrows; + data_.resize(nrows); + + size_type ncols = 0; + inStream >> ncols; + indb_.resize(ncols); + nzb_.resize(ncols); + + int compressVal = 0; + inStream >> compressVal; + compress_ = compressVal == 1; + + if (compress_) + c_.resize(nRows(), 0); + + for (ulong_size_type i = 0; i != nrows; ++i) { + inStream >> data_[i].first; + data_[i].second.resize(nNonZerosOnRow(i)); + size_type k2 = 0; + for (size_type k = 0; k < data_[i].first.size(); k += 2) { + for (size_type j = data_[i].first[k]; j != data_[i].first[k + 1]; ++j) { + double val; + inStream >> val; + data_[i].second[k2++] = (value_type)val; + } } + /* + NTA_CHECK(data_[i].first.size() <= 2*nCols()) + << "SparseRLEMatrix::fromCSR: " + << "Too many indices"; + */ + NTA_CHECK(data_[i].second.size() <= nCols()) + << "SparseRLEMatrix::fromCSR: " + << "Too many values"; + for (size_type j = 0; j != data_[i].first.size(); ++j) { + size_type idx = data_[i].first[j]; + NTA_CHECK(idx <= nCols()) << "SparseRLEMatrix::fromCSR: " + << "Invalid index: " << idx; + if (1 < j) + NTA_CHECK(data_[i].first[j - 1] < idx) + << "SparseRLEMatrix::fromCSR: " + << "Invalid index: " << idx + << " - Indices need to be in strictly increasing order"; + } + NTA_CHECK(data_[i].second.size() == nNonZerosOnRow(i)) + << "SparseRLEMatrix::fromCSR: " + << "Mismatching number of indices and values"; + if (compress_) + compressRow_(i); } - - private: - - //-------------------------------------------------------------------------------- - inline void compressRow_(ulong_size_type row) - { - { // Pre-conditions - NTA_ASSERT(0 <= row && row < nRows()) - << "SparseRLEMatrix::compressRow_: " - << "Invalid row index: " << row; - } // End pre-conditions - - nzb_.resize(nCols() + 10); - std::fill(nzb_.begin(), nzb_.end(), (value_type) 0); - - uLongf dstLen = nzb_.size() * sizeof(value_type); - uLong srcLen = data_[row].second.size() * sizeof(value_type); - - // This gives some iterator related failure in debug mode on Windows, - // but it works in release mode. - compress((Bytef*)(&*nzb_.begin()), &dstLen, - (Bytef*)(&*data_[row].second.begin()), srcLen); - - c_[row] = dstLen; - size_type n = dstLen / sizeof(value_type) + 1; - ValueVector new_vector(nzb_.begin(), nzb_.begin() + n); - data_[row].second.swap(new_vector); + } + + //-------------------------------------------------------------------------------- + template + inline void toDense(OutputIterator begin, OutputIterator end) const { + { // Pre-conditions + NTA_ASSERT((size_type)(end - begin) == nRows() * nCols()) + << "SparseRLEMatrix::toDense: " + << "Not enough memory"; + } // End pre-conditions + + for (ulong_size_type row = 0; row != data_.size(); ++row) + getRowToDense(row, begin + row * nCols(), begin + (row + 1) * nCols()); + } + + //-------------------------------------------------------------------------------- + /** + * Clears this instance and creates a new one from dense. + */ + template + inline void fromDense(ulong_size_type nrows, size_type ncols, + InputIterator begin, InputIterator end) { + { // Pre-conditions + NTA_ASSERT((ulong_size_type)(end - begin) >= nrows * ncols) + << "SparseRLEMatrix::fromDense: " + << "Not enough memory"; + } // End pre-conditions + + clear(); + + for (ulong_size_type row = 0; row != nrows; ++row) + appendRow(begin + row * ncols, begin + (row + 1) * ncols); + } + + //-------------------------------------------------------------------------------- + inline void print(std::ostream &outStream) const { + { // Pre-conditions + NTA_CHECK(outStream.good()) << "SparseRLEMatrix::print: Bad stream"; + } // End pre-conditions + + std::vector buffer(nCols()); + + for (ulong_size_type row = 0; row != nRows(); ++row) { + getRowToDense(row, buffer.begin(), buffer.end()); + for (size_type col = 0; col != nCols(); ++col) + outStream << buffer[col] << " "; + outStream << std::endl; } + } - //-------------------------------------------------------------------------------- - inline uLongf decompressRow_(ulong_size_type row) - { - { // Pre-conditions - NTA_ASSERT(0 <= row && row < nRows()) - << "SparseRLEMatrix::decompressRow_: " - << "Invalid row index: " << row; - } // End pre-conditions - - uLongf dstLen = nzb_.size() * sizeof(value_type); - - // This gives some iterator related failure in debug mode on Windows, - // but it works in release mode. - uncompress((Bytef*)(&*nzb_.begin()), &dstLen, - (Bytef*)(&*data_[row].second.begin()), (uLong) c_[row]); - - return dstLen; + //-------------------------------------------------------------------------------- + inline void debugPrint() const { + std::cout << "n rows= " << nRows() << " n cols= " << nCols() + << " n nz= " << nNonZeros() << " n bytes= " << nBytes() + << std::endl; + std::cout << "this= " << sizeof(SparseRLEMatrix) << " Row= " << sizeof(Row) + << " size= " << sizeof(size_type) + << " value= " << sizeof(value_type) << " uLong= " << sizeof(uLong) + << std::endl; + std::cout << "data= " << data_.capacity() << " " << data_.size() + << std::endl; + std::cout << "indb= " << indb_.capacity() << " " << indb_.size() + << std::endl; + std::cout << "nzb= " << nzb_.capacity() << " " << nzb_.size() << std::endl; + std::cout << "c= " << c_.capacity() << " " << c_.size() << std::endl; + for (ulong_size_type i = 0; i != nRows(); ++i) + std::cout << "row " << i << ": first: " << data_[i].first.capacity() + << " " << data_[i].first.size() + << " second: " << data_[i].second.capacity() << " " + << data_[i].second.size() << std::endl; + for (size_type row = 0; row != nRows(); ++row) { + std::cout << data_[row].first << std::endl; + for (size_type i = 0; i != data_[row].second.size(); ++i) + std::cout << (float)data_[row].second[i] << " "; + std::cout << std::endl; } + } + +private: + //-------------------------------------------------------------------------------- + inline void compressRow_(ulong_size_type row) { + { // Pre-conditions + NTA_ASSERT(0 <= row && row < nRows()) << "SparseRLEMatrix::compressRow_: " + << "Invalid row index: " << row; + } // End pre-conditions + + nzb_.resize(nCols() + 10); + std::fill(nzb_.begin(), nzb_.end(), (value_type)0); + + uLongf dstLen = nzb_.size() * sizeof(value_type); + uLong srcLen = data_[row].second.size() * sizeof(value_type); + + // This gives some iterator related failure in debug mode on Windows, + // but it works in release mode. + compress((Bytef *)(&*nzb_.begin()), &dstLen, + (Bytef *)(&*data_[row].second.begin()), srcLen); - //-------------------------------------------------------------------------------- - + c_[row] = dstLen; + size_type n = dstLen / sizeof(value_type) + 1; + ValueVector new_vector(nzb_.begin(), nzb_.begin() + n); + data_[row].second.swap(new_vector); + } - SparseRLEMatrix(const SparseRLEMatrix&); - SparseRLEMatrix& operator=(const SparseRLEMatrix&); + //-------------------------------------------------------------------------------- + inline uLongf decompressRow_(ulong_size_type row) { + { // Pre-conditions + NTA_ASSERT(0 <= row && row < nRows()) + << "SparseRLEMatrix::decompressRow_: " + << "Invalid row index: " << row; + } // End pre-conditions + + uLongf dstLen = nzb_.size() * sizeof(value_type); - }; // end class SparseRLEMatrix + // This gives some iterator related failure in debug mode on Windows, + // but it works in release mode. + uncompress((Bytef *)(&*nzb_.begin()), &dstLen, + (Bytef *)(&*data_[row].second.begin()), (uLong)c_[row]); + + return dstLen; + } //-------------------------------------------------------------------------------- + + SparseRLEMatrix(const SparseRLEMatrix &); + SparseRLEMatrix &operator=(const SparseRLEMatrix &); + +}; // end class SparseRLEMatrix + +//-------------------------------------------------------------------------------- } // end namespace nupic #endif // NTA_SPARSE_RLE_MATRIX_HPP diff --git a/src/nupic/math/SparseTensor.hpp b/src/nupic/math/SparseTensor.hpp index 74a5a970cd..6b61f42998 100644 --- a/src/nupic/math/SparseTensor.hpp +++ b/src/nupic/math/SparseTensor.hpp @@ -20,23 +20,21 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Definition and implementation for SparseTensors class */ #ifndef NTA_SPARSE_TENSOR_HPP #define NTA_SPARSE_TENSOR_HPP -#include -#include #include -#include #include - +#include +#include +#include //---------------------------------------------------------------------- - /* FAST TENSOR */ /* #include @@ -56,3055 +54,2854 @@ struct HashIndex namespace nupic { +/** + * @b Description + * SparseTensor models a multi-dimensional array, with an arbitrary + * number of dimensions, and arbitrary size for each dimension, + * where only certain elements are not zero. "Not zero" is defined as + * being outside the closed ball [-nupic::Epsilon..nupic::Epsilon]. + * Zero elements are not stored. Non-zero elements are stored in + * a data structure that provides logarithmic insertion and retrieval. + * A number of operations on tensors are implemented as efficiently as + * possible, oftentimes having complexity not worse than the number + * of non-zeros in the tensor. There is no limit to the number of + * dimensions that can be specified for a sparse tensor. + * + * SparseTensor is parameterized on the type of Index used to index + * the non-zeros, and on the type of the non-zeros themselves (Float). + * The numerical type used as the second template parameter needs + * to be functionally equivalent to float, but can be int or double. + * It doesn't work with complex numbers yet (have to modify nearlyZero_ + * to look at the modulus). + * + * The implementation relies on a Unique, Sorted Associative NZ, + * that is map (rather than hash_map, we need the Indices to be sorted). + * + * Examples: + * 1) SparseTensor, float>: + * defines a sparse tensor of dimension 2 (a matrix), storing floats. + * The type of Index is the efficient, compile-time sized Index. + * + * 2) SparseTensor, float>: + * defines the same sparse tensor as 1), but using std::vector + * for the index, which is not as fast. + * + * 3) SparseTensor double>: + * defines a sparse tensor of rank 4 (4 dimensions), storing doubles. + * + * @b Responsibility + * An efficient multi-dimensional sparse data structure + * + * @b Rationale + * Numenta algorithms require very large data structure that are + * sparse, and those data structures cannot be handled efficiently + * with contiguous storage in memory. + * + * @b Resource @Ownership + * SparseTensor owns the keys used to index the non-zeros, as + * well as the values of the non-zeros themselves. + * + * @b Notes + * Note 1: in preliminary testing, using Index was + * about 20 times faster than using std::vector. + * + * Note 2: some operations are very slow, depending on the properties + * of the functors used. Watch out that you are using the + * right one for your functor. + * + * Note 3: SparseTensor is limited to max columns, or rows + * or non-zeros. + * + */ +template class SparseTensor { +public: + typedef Index TensorIndex; + typedef typename Index::value_type UInt; + typedef std::map NZ; + // typedef __gnu_cxx::hash_map > NZ; + // typedef hash_map > NZ; + typedef typename NZ::iterator iterator; + typedef typename NZ::const_iterator const_iterator; + + /** + * SparseTensor constructor from list of bounds. + * The constructed instance is identically zero. + * Each of the integers passed in represents the size of + * this sparse tensor along a given dimension. There + * need to be as many integers passed in as this tensor + * has dimensions. All the integers need to be > 0. + * + * Note: + * This constructor will not work with Index = std::vector + * + * @param ub [UInt >= 0] the size of this tensor along one dimension + */ + explicit inline SparseTensor(UInt ub0, ...) : bounds_(), nz_() { + bounds_[0] = ub0; + va_list indices; + va_start(indices, ub0); + for (UInt k = 1; k < getRank(); ++k) + bounds_[k] = (UInt)va_arg(indices, unsigned int); + va_end(indices); + } + /** - * @b Description - * SparseTensor models a multi-dimensional array, with an arbitrary - * number of dimensions, and arbitrary size for each dimension, - * where only certain elements are not zero. "Not zero" is defined as - * being outside the closed ball [-nupic::Epsilon..nupic::Epsilon]. - * Zero elements are not stored. Non-zero elements are stored in - * a data structure that provides logarithmic insertion and retrieval. - * A number of operations on tensors are implemented as efficiently as - * possible, oftentimes having complexity not worse than the number - * of non-zeros in the tensor. There is no limit to the number of - * dimensions that can be specified for a sparse tensor. - * - * SparseTensor is parameterized on the type of Index used to index - * the non-zeros, and on the type of the non-zeros themselves (Float). - * The numerical type used as the second template parameter needs - * to be functionally equivalent to float, but can be int or double. - * It doesn't work with complex numbers yet (have to modify nearlyZero_ - * to look at the modulus). - * - * The implementation relies on a Unique, Sorted Associative NZ, - * that is map (rather than hash_map, we need the Indices to be sorted). + * SparseTensor constructor from Index that contains the bounds. + * The constructed instance is identically zero. + * The size of the Index becomes the rank of this sparse tensor, + * that is, its number of dimensions. + * The values of each element of the index need to be > 0. + * + * @param bounds [Index] the bounds of each dimension + */ + explicit inline SparseTensor(const Index &bounds) : bounds_(bounds), nz_() {} + + /** + * SparseTensor copy constructor + */ + inline SparseTensor(const SparseTensor &other) + : bounds_(other.getBounds()), nz_() { + this->operator=(other); + } + + /** + * Assignment operator + */ + inline SparseTensor &operator=(const SparseTensor &other) { + if (&other != this) { + bounds_ = other.bounds_; + nz_ = other.nz_; + } + return *this; + } + + /** + * Swaps the contents of two tensors. + * The two tensors need to have the same rank, but they don't + * need to have the same dimensions. + * + * @param B [SparseTensor] the tensor to swap with + */ + inline void swap(SparseTensor &B) { + { NTA_ASSERT(B.getRank() == getRank()); } + + std::swap(bounds_, B.bounds_); + nz_.swap(B.nz_); + } + + /** + * Returns the rank of this tensor. + * The rank is the number of dimensions of this sparse tensor, + * it is an integer >= 1. * * Examples: - * 1) SparseTensor, float>: - * defines a sparse tensor of dimension 2 (a matrix), storing floats. - * The type of Index is the efficient, compile-time sized Index. - * - * 2) SparseTensor, float>: - * defines the same sparse tensor as 1), but using std::vector - * for the index, which is not as fast. - * - * 3) SparseTensor double>: - * defines a sparse tensor of rank 4 (4 dimensions), storing doubles. - * - * @b Responsibility - * An efficient multi-dimensional sparse data structure - * - * @b Rationale - * Numenta algorithms require very large data structure that are - * sparse, and those data structures cannot be handled efficiently - * with contiguous storage in memory. - * - * @b Resource @Ownership - * SparseTensor owns the keys used to index the non-zeros, as - * well as the values of the non-zeros themselves. - * - * @b Notes - * Note 1: in preliminary testing, using Index was - * about 20 times faster than using std::vector. - * - * Note 2: some operations are very slow, depending on the properties - * of the functors used. Watch out that you are using the - * right one for your functor. - * - * Note 3: SparseTensor is limited to max columns, or rows - * or non-zeros. - * - */ - template - class SparseTensor - { - public: - typedef Index TensorIndex; - typedef typename Index::value_type UInt; - typedef std::map NZ; - //typedef __gnu_cxx::hash_map > NZ; - //typedef hash_map > NZ; - typedef typename NZ::iterator iterator; - typedef typename NZ::const_iterator const_iterator; - - /** - * SparseTensor constructor from list of bounds. - * The constructed instance is identically zero. - * Each of the integers passed in represents the size of - * this sparse tensor along a given dimension. There - * need to be as many integers passed in as this tensor - * has dimensions. All the integers need to be > 0. - * - * Note: - * This constructor will not work with Index = std::vector - * - * @param ub [UInt >= 0] the size of this tensor along one dimension - */ - explicit inline SparseTensor(UInt ub0, ...) - : bounds_(), nz_() + * A tensor of rank 0 is a scalar (not possible here). + * A tensor of rank 1 is a vector. + * A tensor of rank 2 is a matrix. + * + * @retval UInt [ > 0 ] the rank of this sparse tensor + */ + inline const UInt getRank() const { return (UInt)bounds_.size(); } + + /** + * Returns the bounds of this tensor, that is the size of this tensor + * along each of its dimensions. + * Tensor indices start at zero along all dimensions. + * The product of the bounds is the total number of elements that + * this sparse tensor can store. + * + * Examples: + * A 3 long vector has bounds Index(3). + * A 10x10 matrix has bounds: Index(10, 10). + * + * @retval Index the upper bound for this sparse tensor + */ + inline const Index getBounds() const { return bounds_; } + + /** + * Returns the upper bound of this sparse tensor along + * the given dimension. + * + * Example: + * A 3x4x5 tensor has: + * - getBound(0) == 3, getBound(1) == 4, getBound(2) == 5. + * + * @param dim [0 <= UInt < getRank()] the dimension + * @retval [UInt >= 0] the upper of this tensor along dim + */ + inline const UInt getBound(const UInt &dim) const { + NTA_ASSERT(0 <= dim && dim < getRank()); + return getBounds()[dim]; + } + + /** + * Returns the domain of this sparse tensor, where the lower bound + * is zero and the upper bound is the upper bound. + * + * Example: + * A 3x2x4 tensor has domain { [0..3), [0..2), [0..4) }. + * + * @retval [Domain] the domain for this tensor + */ + inline Domain getDomain() const { + return Domain(getNewZeroIndex(), getBounds()); + } + + /** + * Returns the total size of this sparse tensor, + * that is, the total number of non-zeros that can be stored. + * It is the product of the bounds. + * + * Example: + * A 3x3 matrix has a size of 9. + * + * @retval UInt [ > 0 ] the size of this sparse tensor + */ + inline const UInt getSizeElts() const { + NTA_ASSERT(!isNull()); + return product(getBounds()); + } + + /** + * Returns the size of a sub-space of this sparse tensor, + * designated by dims. + * + * Example: + * A 3x4 matrix has a size of 4 along the columns and 3 + * along the rows. + */ + template + inline const UInt getSizeElts(const Index2 &dims) const { + { NTA_ASSERT(dims.size() <= getRank()); } + + UInt n = 1; + for (UInt k = 0; k < dims.size(); ++k) + n *= getBound(dims[k]); + return n; + } + + /** + * Returns the number of non-zeros in this sparse tensor. + * + * Invariant: + * getNNonZeros() + getNZeros() == product(getBounds()) + * + * @retval UInt [ >= 0 ] the number of non-zeros in this sparse tensor + */ + inline const UInt getNNonZeros() const { return (UInt)nz_.size(); } + + inline const UInt nNonZeros() const { return (UInt)nz_.size(); } + + /** + * Returns the number of zeros in this sparse tensor. + * + * Invariant: + * getNZeros() + getNNonZeros() == product(getBounds()) + * + * @retval UInt [ >= 0 ] the number of zeros in this sparse tensor + */ + inline const UInt getNZeros() const { return getSizeElts() - getNNonZeros(); } + + /** + * Returns the number of non-zeros in a domain of this sparse tensor. + * Does not work with a domain that has closed dimensions. + * The domain needs to have the same rank as this sparse tensor. + * + * @param dom [Domain] the domain to scan for non-zeros + * @retval UInt [ >= 0 ] the number of non-zeros in dom + */ + inline const UInt getNNonZeros(const Domain &dom) const { { - bounds_[0] = ub0; - va_list indices; - va_start(indices, ub0); - for (UInt k = 1; k < getRank(); ++k) - bounds_[k] = (UInt) va_arg(indices, unsigned int); - va_end(indices); - } - - /** - * SparseTensor constructor from Index that contains the bounds. - * The constructed instance is identically zero. - * The size of the Index becomes the rank of this sparse tensor, - * that is, its number of dimensions. - * The values of each element of the index need to be > 0. - * - * @param bounds [Index] the bounds of each dimension - */ - explicit inline SparseTensor(const Index& bounds) - : bounds_(bounds), nz_() + NTA_ASSERT(!dom.hasClosedDims()); + // NTA_ASSERT(getDomain().includes(dom)); + } + + // I can reduce the domain ub by 1 to find the upper_bound + // but I still have to check for domain inclusion + + UInt nnz = 0; + Index lb = getNewIndex(), ub = getNewIndex(); + + if (dom == getDomain()) + return getNNonZeros(); + + dom.getLB(lb); + dom.getIterationLast(ub); + const_iterator it = begin(); + const_iterator e = end(); + + for (; it != e; ++it) { + if (dom.includes(it->first)) + ++nnz; + } + + return nnz; + } + + /** + * Returns the number of zeros in a domain of this sparse tensor. + * Doens't work if the domain has closed dimensions. + * The domain needs to have the same rank as this sparse tensor. + * + * @param dom [Domain] the domain to scan for zeros + * @retval UInt [ >= 0 ] the number of zeros in dom + */ + inline const UInt getNZeros(const Domain &dom) const { + return dom.size_elts() - getNNonZeros(dom); + } + + /** + * Returns the number of non-zeros in designated sub-spaces of + * this sparse tensor. The sub-spaces are designated by dims. + * The B tensor collects the results. + * + * Complexity: O(number of non-zeros) + * + * Example: + * If A is a 11x13 sparse tensor: + * - A.getNNonZeros(I1(1), B) returns the number of non-zeros + * per row in A, and B is a vector of size 11. + * - A.getNNonZeros(I1(0), B) returns the number of non-zeros + * per column of A, and B is a vector of size 13. + * + * @param dims [Index2] the dimensions along which to count the non-zeros + * @param B [SparseTensor] the sparse tensor of the number + * of non-zeros per sub-space + */ + template + inline void getNNonZeros(const Index2 &dims, + SparseTensor &B) const { + { NTA_ASSERT(dims.size() + B.getRank() == getRank()); } + + B.clear(); + + IndexB compDims = B.getNewIndex(), idxB = B.getNewIndex(); + complement(dims, compDims); + + const_iterator it, e; + for (it = begin(), e = end(); it != e; ++it) { + project(compDims, it->first, idxB); + B.update(idxB, (Float)1, std::plus()); + } + } + + /** + * Returns the number of zeros in designated sub-spaces of + * this sparse tensor. See getNNonZeros doc. + */ + template + inline void getNZeros(const Index2 &dims, + SparseTensor &B) const { + { NTA_ASSERT(dims.size() + B.getRank() == getRank()); } + + IndexB compDims = B.getNewIndex(), idxB = B.getNewIndex(); + complement(dims, compDims); + + B.setAll((Float)getSizeElts(dims)); + + const_iterator it, e; + for (it = begin(), e = end(); it != e; ++it) { + project(compDims, it->first, idxB); + B.update(idxB, (Float)1, std::minus()); + } + } + + /** + * Returns true if this SparseTensor is the "empty" tensor, + * that is, a SparseTensor with no value (like a matrix without + * rows). + */ + inline bool isNull() const { return product(getBounds()) == 0; } + + /** + * Returns true if there is no non-zero in this tensor, false otherwise. + * + * @retval bool whether this sparse tensor is identically zero or not + */ + inline bool isZero() const { return getNNonZeros() == 0; } + + /** + * Returns true if the domain inside this sparse tensor is identically + * zero. + * Doens't work if the domain has closed dimensions. + * The domain needs to have the same rank as this sparse tensor. + * + * @param dom [Domain] the domain to look at + * @retval bool whether this sparse tensor is zero inside dom + */ + inline bool isZero(const Domain &dom) const { + return getNNonZeros(dom) == 0; + } + + /** + * Returns true if there are no zeros in this tensor, false otherwise. + * The tensor is dense if it contains no zero. + * + * @retval bool whether this tensor is dense or not + */ + inline bool isDense() const { return getNNonZeros() == getSizeElts(); } + + /** + * Returns true if the domain inside this sparse tensor is dense. + * Doens't work if the domain has closed dimensions. + * The domain needs to have the same rank as this sparse tensor. + * + * @param dom [Domain] the domain to look at + * @retval bool whether this sparse tensor is dense inside dom + */ + inline bool isDense(const Domain &dom) const { + return getNNonZeros(dom) == dom.size_elts(); + } + + /** + * Returns true if there are zeros in this tensor, false otherwise. + * The tensor is sparse if it contains at least one zero. + * + * @retval bool whether this tensor is sparse or not + */ + inline bool isSparse() const { return getNNonZeros() != getSizeElts(); } + + /** + * Returns true if the domain inside this sparse tensor is sparse. + * Doens't work if the domain has closed dimensions. + * The domain needs to have the same rank as this sparse tensor. + * + * @param dom [Domain] the domain to look at + * @retval bool whether this sparse tensor is sparse inside dom + */ + inline bool isSparse(const Domain &dom) const { + return getNNonZeros(dom) != dom.size_elts(); + } + + /** + * Returns the fill rate for this tensor, that is, the ratio of the + * number of non-zeros to the total number of elements in this tensor. + * + * @retval Float the fill rate + */ + inline const Float getFillRate() const { + return Float(getNNonZeros()) / Float(getSizeElts()); + } + + /** + * Returns the fill rate for this tensor inside the given domain, that is, + * the ratio of the number of non-zeros in the given domain to the + * size of the domain. + * + * @retval Float the fill rate inside the given domain + */ + inline const Float getFillRate(const Domain &dom) const { + return Float(getNNonZeros(dom)) / Float(dom.size_elts()); + } + + /** + * Returns the fill rate for sub-spaces of this sparse tensor. + */ + template + inline void getFillRate(const Index2 &dims, + SparseTensor &B) const { + getNNonZeros(dims, B); + B.element_apply_fast( + bind2nd(std::divides(), (Float)getSizeElts(dims))); + } + + /** + * Returns whether this sparse tensor is positive or not, that is, + * whether all its coefficients are > nupic::Epsilon (there are no + * zeros in this tensor, and all the elements have positive values). + * + * Complexity: O(number of non-zeros) + */ + inline bool isPositive() const { + if (getNZeros() > 0) + return false; + + const_iterator it, e; + for (it = begin(), e = end(); it != e; ++it) + if (strictlyNegative(it->second)) + return false; + return true; + } + + /** + * Returns whether this sparse tensor is non-negative or not, + * that is, whether all its coefficients are >= -nupic::Epsilon + * (there can be zeros in this tensor, but all the non-zeros + * have positive values). + * + * Complexity: O(number of non-zeros) + */ + inline bool isNonNegative() const { + if (nz_.empty()) + return true; + + const_iterator it, e; + for (it = begin(), e = end(); it != e; ++it) + if (strictlyNegative(it->second)) + return false; + return true; + } + + /** + * Returns the set of values in this SparseTensor and how + * many times each of them appears. + * + * Complexity: O(number of non-zeros) with some log for + * the insertion in the result map... + */ + inline std::map values() const { + std::map vals; + + if (!isDense()) + vals[0] = getNZeros(); + + const_iterator it, e; + typename std::map::iterator found; + for (it = begin(), e = end(); it != e; ++it) { + found = vals.find(it->second); + if (found == vals.end()) + vals[it->second] = 1; + else + ++vals[it->second]; + } + + return vals; + } + + /** + * Makes this tensor the tensor zero, that is, all the non-zeros + * are removed. + */ + inline void clear() { nz_.clear(); } + + /** + * Creates a new Index that has the rank of this sparse tensor. + * The initial value of this Index is the bounds of this tensor. + * + * Note: + * To accomodate both Index and std::vector as + * indices, we can't allocate memory ourselves, so when we + * need an index, we create a copy of the bounds, and either + * do nothing, or set it to zero, or set to some specified + * set of values. + * + * @retval Index a new Index, that contains the values of the bounds + * for this sparse tensor + */ + inline Index getNewIndex() const { return getBounds(); } + + /** + * Creates a new Index that has the rank of this sparse tensor + * and sets it to zero (see note in getNewIndex()). + * + * @retval Index a new Index, initialized to zero + */ + inline Index getNewZeroIndex() const { + Index idx = getBounds(); + setToZero(idx); + return idx; + } + + /** + * Creates a new Index that has the rank of this sparse tensor + * and sets it to the specified values (see note in getNewIndex()). + * + * @retval Index a new Index, initialized to the values passed + */ + inline Index getNewIndex(UInt i0, ...) const { + Index idx = getBounds(); + idx[0] = i0; + va_list indices; + va_start(indices, i0); + for (UInt k = 1; k < getRank(); ++k) + idx[k] = (UInt)va_arg(indices, unsigned int); + va_end(indices); + return idx; + } + + /** + * Computes whether this tensor is symmetric or not. + * A tensor is symmetric w.r.t. a permutation of the dimensions iff: + * A[ijkl...] = A[permutation(ijkl...)]. + * This implies that the bounds of the permuted dimensions need to + * be the same. If they are not, the tensor is not symmetric. + * The Index passed in needs to have the same size as the rank + * of this sparse tensor. + * + * Complexity: O(number of non-zeros) + * + * @param perm [Index] the permutation to use to evaluate whether + * this sparse tensor is symmetric or not + * @retval bool whether this sparse tensor is symmetric w.r.t. the + * given permutation + */ + inline bool isSymmetric(const Index &perm) const { { + NTA_ASSERT(perm.size() == getRank()); + NTA_ASSERT(isSet(perm)); + } + + Index idx2 = getNewZeroIndex(); + + nupic::permute(perm, bounds_, idx2); + if (bounds_ != idx2) + return false; + + const_iterator it, e; + + for (it = begin(), e = end(); it != e; ++it) { + nupic::permute(perm, it->first, idx2); + if (!nearlyZero_(it->second - get(idx2))) + return false; } - /** - * SparseTensor copy constructor - */ - inline SparseTensor(const SparseTensor& other) - : bounds_(other.getBounds()), nz_() + return true; + } + + /** + * Computes whether this tensor is anti-symmetric or not. + * A tensor is anti-symmetric w.r.t. to a permutation of the + * dimensions iff: + * A[ijkl...] = -A[permutation(ijkl...)] + * This implies that the upper bounds of the permuted dimensions + * need to be the same, or the tensor is not anti-symmetric. + * The Index passed in needs to have the same size as the rank + * of this sparse tensor. + * + * Complexity: O(number of non-zeros) + * + * @param perm [Index] the permutation to use to evaluate anti-symmetry + * @retval bool whether this sparse tensor is anty-symmetric w.r.t. + * the given permutation or not + */ + inline bool isAntiSymmetric(const Index &perm) const { { - this->operator=(other); + NTA_ASSERT(perm.size() == getRank()); + NTA_ASSERT(isSet(perm)); + } + + Index idx2 = getNewZeroIndex(); + + nupic::permute(perm, bounds_, idx2); + if (bounds_ != idx2) + return false; + + const_iterator it, e; + + for (it = begin(), e = end(); it != e; ++it) { + nupic::permute(perm, it->first, idx2); + if (!nearlyZero_(it->second + get(idx2))) + return false; } - - /** - * Assignment operator - */ - inline SparseTensor& operator=(const SparseTensor& other) + + return true; + } + + /** + * Sets the element at idx to val. Handles zeros by not storing + * them, or by erasing non-zeros that become zeros when val = 0. + * The Index idx needs to be >= 0 and < getBounds(). + * + * Complexity: O(log(number of non-zeros)) + * + * @param idx [Index] the index of the element to set + * @param val [Float] the value to set for the element at index + */ + inline void set(const Index &idx, const Float &val) { { - if (&other != this) { - bounds_ = other.bounds_; - nz_ = other.nz_; + NTA_ASSERT(positiveInBounds(idx, getBounds())) + << "SparseTensor::set(idx, val): " + << "Invalid index: " << idx + << " - Should be >= 0 and strictly less than: " << bounds_; + } + + if (nearlyZero_(val)) { + iterator it = nz_.find(idx); + if (it != end()) + nz_.erase(it); + } else + nz_[idx] = val; + } + + /** + * Sets the element at idx to val. Calls set(Index, Float). + */ + inline void set(UInt i0, ...) { + Index idx = getNewIndex(); + idx[0] = i0; + va_list indices; + va_start(indices, i0); + for (UInt k = 1; k < getRank(); ++k) + idx[k] = (UInt)va_arg(indices, unsigned int); + const Float val = (Float)va_arg(indices, double); + va_end(indices); + set(idx, val); + } + + /** + * Sets all the elements inside the dom to val. + * Handles zeros correctly (i.e. does not store them). + * + * @param dom [Domain] the domain inside which to set values + * @param val [Float] the value to set inside dom + */ + inline void set(const Domain &dom, const Float &val) { + if (nearlyZero_(val)) { + setZero(dom); + } else { + setNonZero(dom, val); + } + } + + /** + * Sets the element at idx to zero, that is, removes it + * from the internal storage. + * + * Complexity: O(log(number of non-zeros)) + * + * @param idx [Index] the index of the element to set to zero + */ + inline void setZero(const Index &idx) { + { + NTA_ASSERT(positiveInBounds(idx, getBounds())) + << "SparseTensor::setZero(idx): " + << "Invalid index: " << idx + << " - Should be >= 0 and strictly less than: " << bounds_; + } + + iterator it = nz_.find(idx); + if (it != end()) + nz_.erase(it); + } + + /** + * Sets the element at idx to zero. Calls setZero(Index). + */ + inline void setZero(UInt i0, ...) { + Index idx = getNewIndex(); + idx[0] = i0; + va_list indices; + va_start(indices, i0); + for (UInt k = 1; k < getRank(); ++k) + idx[k] = (UInt)va_arg(indices, unsigned int); + va_end(indices); + setZero(idx); + } + + /** + * Sets to zero all the elements in Domain dom. + * + * @param dom [Domain] the domain to set to zero + */ + inline void setZero(const Domain &dom) { + { + NTA_ASSERT(getDomain().includes(dom)) + << "SparseTensor::setZero(Domain): " + << "Domain argument: " << dom << " is invalid" + << " - Should be included in: " << getDomain(); + } + + iterator it = begin(), d, e = end(); + while (it != e) { + if (dom.includes(it->first)) { + d = it; + ++it; // increment before erase + nz_.erase(d); + } else { + ++it; } - return *this; + } + } + + /** + * Sets element at idx to val, where |val| > nupic::Epsilon. + * + * Use if you know what you do: even f(non-zero, non-zero) + * can be "zero", if it falls below nupic::Epsilon. + * + * Complexity: O(log(number of non-zeros)) + * + * @param idx [Index] the index of the element to set to val + * @param val [Float] the value to set for the element at idx + */ + inline void setNonZero(const Index &idx, const Float &val) { + { + NTA_ASSERT(positiveInBounds(idx, getBounds())) + << "SparseTensor::setNonZero(idx, val): " + << "Invalid index: " << idx + << " - Should be >= 0 and strictly less than: " << bounds_; + + NTA_ASSERT(!nearlyZero_(val)) + << "SparseTensor::setNonZero(idx, val): " + << "Invalid zero value: " << val << " at index: " << idx + << " - Should be non-zero (> " << nupic::Epsilon << ")"; } - /** - * Swaps the contents of two tensors. - * The two tensors need to have the same rank, but they don't - * need to have the same dimensions. - * - * @param B [SparseTensor] the tensor to swap with - */ - inline void swap(SparseTensor& B) + nz_[idx] = val; + } + + /** + * Sets all the values inside dom to val. + * Works only if |val| > nupic::Epsilon. + * + * @param dom [Domain] the domain inside which to set values + * @param val [Float] the value to set inside dom + */ + inline void setNonZero(const Domain &dom, const Float &val) { + { NTA_ASSERT(!nearlyZero_(val)); } + + Index lb = getNewIndex(), ub = getNewIndex(), idx = getNewIndex(); + dom.getLB(lb); + dom.getUB(ub); + + idx = lb; + do { + setNonZero(idx, val); + } while (increment(lb, ub, idx)); + } + + /** + * Updates the value of this tensor at idx in place, using f and val: + * A[idx] = f(A[idx], val) (val as second argument). + * + * Handles zeros properly. + * + * Complexity: O(log(number of non-zeros)) + */ + template + inline Float update(const Index &idx, const Float &val, binary_functor f) { { - { - NTA_ASSERT(B.getRank() == getRank()); + NTA_ASSERT(positiveInBounds(idx, getBounds())) + << "SparseTensor::update(idx, val, f(x, y)): " + << "Invalid index: " << idx + << " - Should be >= 0 and strictly less than: " << bounds_; + } + + Float res = 0; + + iterator it = nz_.find(idx); + if (it != end()) { + res = f(it->second, val); + if (nearlyZero_(res)) + nz_.erase(it); + else + it->second = res; + } else { + res = f(0, val); + if (!nearlyZero_(res)) + nz_[idx] = res; + } + + return res; + } + + /** + * TODO: unit test + */ + inline void add(const Index &idx, const Float &val) { + std::pair r = nz_.insert(std::make_pair(idx, val)); + + if (!r.second) + r.first->second += val; + } + + /** + * Sets all the values in this tensor to val. + * Makes this sparse tensor dense if |val| > nupic::Epsilon. + * Otherwise, removes all the values in this sparse tensor + * + * Complexity: O(product of bounds) (worst case, if |val| > nupic::Epsilon) + */ + inline void setAll(const Float &val) { + if (nearlyZero_(val)) { + nz_.clear(); + return; + } + + Index idx = getNewZeroIndex(); + do { + nz_[idx] = val; + } while (increment(bounds_, idx)); + } + + /** + * Returns the value of the element at idx. + * + * Complexity: O(log(number of non-zeros)) + */ + inline Float get(const Index &idx) const { + { + NTA_ASSERT(positiveInBounds(idx, getBounds())) + << "SparseTensor::get(idx): " + << "Invalid index: " << idx + << " - Should be >= 0 and strictly less than: " << bounds_; + } + + const_iterator it = nz_.find(idx); + if (it != nz_.end()) + return it->second; + else + return Float(0); + } + + /** + * Returns the value of the element at idx. + * Calls get(Index). + */ + inline const Float get(UInt i0, ...) { + Index idx = getNewIndex(); + idx[0] = i0; + va_list indices; + va_start(indices, i0); + for (UInt k = 1; k < getRank(); ++k) + idx[k] = (UInt)va_arg(indices, unsigned int); + va_end(indices); + return get(idx); + } + + /** + * Returns the element at idx. + * Calls get(Index). + * not providing the other one, because we need to control for zero, + * we can't just blindly pass a reference + */ + inline const Float operator()(UInt i0, ...) const { + Index idx = getNewIndex(); + idx[0] = i0; + va_list indices; + va_start(indices, i0); + for (UInt k = 1; k < getRank(); ++k) + idx[k] = (UInt)va_arg(indices, unsigned int); + va_end(indices); + return get(idx); + } + + /** + * Extract sub-spaces along dimension dim, and put result + * in B. Only the non-zeros who dim-th coordinate is in ind + * are kept and stored in B. + * + * This operation reduces the size of this SparseTensor + * along dimension dim to the number of elements in ind. + * This operation reduces the size of this SparseTensor + * along dimension dim to the number of elements in ind. + * If ind is full, returns this SparseTensor unmodified. + * If ind is empty, throws an exception, because I can't + * reduce a SparseTensor to have a size along any dimension + * of zero. You need to detect that ind is empty before + * calling reduce. + * + * Returns the null tensor (not a tensor) if ind is empty, + * i.e. + * + * dim = 0 indicates that some rows will be removed... + */ + inline void extract(UInt dim, const std::set &ind, + SparseTensor &B) const { + { + NTA_ASSERT(&B != this) << "SparseTensor::extract(): " + << "Cannot extract to self"; + + NTA_ASSERT(0 <= dim && dim < getRank()) + << "SparseTensor::extract(): " + << "Invalid dimension: " << dim + << " - Should be between 0 and rank = " << getRank(); + + typename std::set::const_iterator i, e; + for (i = ind.begin(), e = ind.end(); i != e; ++i) + NTA_ASSERT(0 <= *i && *i < getBound(dim)) + << "SparseTensor::extract(): " + << "Invalid set member: " << *i + << " - Should be between 0 and bound (" << getBound(dim) + << ") for dim: " << dim; + } + + if (ind.empty()) { + B.bounds_[dim] = 0; + return; + } + + if (ind.size() == getBound(dim)) { + B = *this; + return; + } + + B.clear(); + Index bounds = getNewIndex(); + bounds[dim] = (UInt)ind.size(); + B.resize(bounds); + + std::vector ind_v(getBound(dim)); + { + UInt j = 0; + typename std::set::const_iterator i, e; + for (j = 0, i = ind.begin(), e = ind.end(); i != e; ++i, ++j) + ind_v[*i] = j; + } + + const_iterator i, e; + for (i = begin(), e = end(); i != e; ++i) + if (ind.find(i->first[dim]) != ind.end()) { + Index idx = i->first; + idx[dim] = ind_v[idx[dim]]; + B.setNonZero(idx, i->second); } - - std::swap(bounds_, B.bounds_); - nz_.swap(B.nz_); - } - - /** - * Returns the rank of this tensor. - * The rank is the number of dimensions of this sparse tensor, - * it is an integer >= 1. - * - * Examples: - * A tensor of rank 0 is a scalar (not possible here). - * A tensor of rank 1 is a vector. - * A tensor of rank 2 is a matrix. - * - * @retval UInt [ > 0 ] the rank of this sparse tensor - */ - inline const UInt getRank() const { return (UInt)bounds_.size(); } - - /** - * Returns the bounds of this tensor, that is the size of this tensor - * along each of its dimensions. - * Tensor indices start at zero along all dimensions. - * The product of the bounds is the total number of elements that - * this sparse tensor can store. - * - * Examples: - * A 3 long vector has bounds Index(3). - * A 10x10 matrix has bounds: Index(10, 10). - * - * @retval Index the upper bound for this sparse tensor - */ - inline const Index getBounds() const { return bounds_; } - - /** - * Returns the upper bound of this sparse tensor along - * the given dimension. - * - * Example: - * A 3x4x5 tensor has: - * - getBound(0) == 3, getBound(1) == 4, getBound(2) == 5. - * - * @param dim [0 <= UInt < getRank()] the dimension - * @retval [UInt >= 0] the upper of this tensor along dim - */ - inline const UInt getBound(const UInt& dim) const - { - NTA_ASSERT(0 <= dim && dim < getRank()); - return getBounds()[dim]; - } - - /** - * Returns the domain of this sparse tensor, where the lower bound - * is zero and the upper bound is the upper bound. - * - * Example: - * A 3x2x4 tensor has domain { [0..3), [0..2), [0..4) }. - * - * @retval [Domain] the domain for this tensor - */ - inline Domain getDomain() const + } + + /** + * In place (mutating) reduce. + * + * Keeps only the sub-spaces (rows or columns for a matrix) + * whose coordinate is a member of ind. Reduces the size + * of the tensor along dimension dim to the size of ind. + * Yields the null tensor if ind is empty. + * Does not change the tensor if ind is full. + * + * Examples: + * S2.reduce(0, set(1, 3)) keeps rows 1 and 3 of the matrix, + * eliminates the other rows. + * S2.reduce(1, set(1, 3)) keeps columns 1 and 3 of the matrix, + * eliminates the other columns. + */ + inline void reduce(UInt dim, const std::set &ind) { { - return Domain(getNewZeroIndex(), getBounds()); - } - - /** - * Returns the total size of this sparse tensor, - * that is, the total number of non-zeros that can be stored. - * It is the product of the bounds. - * - * Example: - * A 3x3 matrix has a size of 9. - * - * @retval UInt [ > 0 ] the size of this sparse tensor - */ - inline const UInt getSizeElts() const - { - NTA_ASSERT(!isNull()); - return product(getBounds()); - } - - /** - * Returns the size of a sub-space of this sparse tensor, - * designated by dims. - * - * Example: - * A 3x4 matrix has a size of 4 along the columns and 3 - * along the rows. - */ - template - inline const UInt getSizeElts(const Index2& dims) const + NTA_ASSERT(0 <= dim && dim < getRank()) + << "SparseTensor::reduce(): " + << "Invalid dimension: " << dim + << " - Should be between 0 and rank = " << getRank(); + + typename std::set::const_iterator i, e; + for (i = ind.begin(), e = ind.end(); i != e; ++i) + NTA_ASSERT(0 <= *i && *i < getBound(dim)) + << "SparseTensor::reduce(): " + << "Invalid set member: " << *i + << " - Should be between 0 and bound (" << getBound(dim) + << ") for dim: " << dim; + } + + if (ind.empty()) { + bounds_[dim] = 0; + return; + } + + if (ind.size() == getBound(dim)) + return; + + std::vector ind_v(getBound(dim)); { - { - NTA_ASSERT(dims.size() <= getRank()); + UInt j = 0; + typename std::set::const_iterator i, e; + for (j = 0, i = ind.begin(), e = ind.end(); i != e; ++i, ++j) + ind_v[*i] = j; + } + + NZ keep; + + iterator i, e; + for (i = begin(), e = end(); i != e; ++i) + if (ind.find(i->first[dim]) != ind.end()) { + Index idx = i->first; + idx[dim] = ind_v[idx[dim]]; + keep[idx] = i->second; + } + + nz_ = keep; + + Index bounds = getNewIndex(); + bounds[dim] = (UInt)ind.size(); + resize(bounds); + } + + /** + * Extract slices or sub-arrays from this tensor, of + * any dimensions, and within any bounds. Slices can be + * of any dimension <= getRank(). Slices are not allocated + * by this function, so an optional clearYesNo parameter + * is provided to remove all the existing non-zeros from + * the slice (B). When the range is not as big as the + * cartesian product of the bounds of this tensor, + * a sub-array tensor is extracted. + * + * Examples: + * If A has dimensions [(0, 10), (0, 20), (0, 30)] + * - slice with Domain = ((0, 0, 0), (10, 20, 30)) gives A + * - slice with Domain = ((0, 0, 10), (10, 20, 30)) gives + * the subarray of A reduced to indices 10 to 29 along + * the third dimension + * - slice with Domain = ((0, 0, 10), (10, 20, 10)) gives + * the matrix (0,10) by (0, 20) obtained when the third index + * is blocked to 10. + * + * Complexity: O(number of non-zeros in slice) + * + * @param range [Domain] the range to extract from this + * tensor. + * @param B [SparseTensor] the resulting + * slice + * @param clearYesNo [bool] whether to clear B before + * slicing or not + */ + template + inline void getSlice(const Domain &range, + SparseTensor &B) const { + { + NTA_ASSERT(range.rank() == getRank()); + + NTA_ASSERT(B.getRank() == range.getNOpenDims()) + << "SparseTensor::slice(): " + << "Invalid range: " << range + << " - Range should have a number of open dims" + << " equal to the rank of the slice (" << B.getRank() << ")"; + } + + // Always clear, so we extract a zero slice + // if we don't hit any non-zero + B.clear(); + + IndexB sliceIndex = B.getNewIndex(), openDims = B.getNewIndex(); + range.getOpenDims(openDims); + const_iterator it = begin(); + const_iterator e = end(); + + for (; it != e; ++it) + if (range.includes(it->first)) { + project(openDims, it->first, sliceIndex); + for (UInt k = 0; k < B.getRank(); ++k) + sliceIndex[k] -= range[openDims[k]].getLB(); + B.set(sliceIndex, it->second); } + } - UInt n = 1; - for (UInt k = 0; k < dims.size(); ++k) - n *= getBound(dims[k]); - return n; - } - - /** - * Returns the number of non-zeros in this sparse tensor. - * - * Invariant: - * getNNonZeros() + getNZeros() == product(getBounds()) - * - * @retval UInt [ >= 0 ] the number of non-zeros in this sparse tensor - */ - inline const UInt getNNonZeros() const - { - return (UInt)nz_.size(); - } - - inline const UInt nNonZeros() const - { - return (UInt)nz_.size(); - } - - /** - * Returns the number of zeros in this sparse tensor. - * - * Invariant: - * getNZeros() + getNNonZeros() == product(getBounds()) - * - * @retval UInt [ >= 0 ] the number of zeros in this sparse tensor - */ - inline const UInt getNZeros() const + template + inline void setSlice(const Domain &range, + const SparseTensor &B) { { - return getSizeElts() - getNNonZeros(); - } - - /** - * Returns the number of non-zeros in a domain of this sparse tensor. - * Does not work with a domain that has closed dimensions. - * The domain needs to have the same rank as this sparse tensor. - * - * @param dom [Domain] the domain to scan for non-zeros - * @retval UInt [ >= 0 ] the number of non-zeros in dom - */ - inline const UInt getNNonZeros(const Domain& dom) const - { - { - NTA_ASSERT(!dom.hasClosedDims()); - //NTA_ASSERT(getDomain().includes(dom)); - } - - // I can reduce the domain ub by 1 to find the upper_bound - // but I still have to check for domain inclusion + NTA_ASSERT(range.rank() == getRank()); - UInt nnz = 0; - Index lb = getNewIndex(), ub = getNewIndex(); + NTA_ASSERT(B.getRank() == range.getNOpenDims()) + << "SparseTensor::setSlice(): " + << "Invalid range: " << range + << " - Range should have a number of open dims" + << " equal to the rank of the slice (" << B.getRank() << ")"; + } - if (dom == getDomain()) - return getNNonZeros(); + // If the slice is empty, call setZero on range + // (processing below is based on non-zeros exclusively) + if (B.isZero()) { + setZero(range); + return; + } - dom.getLB(lb); dom.getIterationLast(ub); - const_iterator it = begin(); - const_iterator e = end(); + Index idx = getNewIndex(); + IndexB openDims = B.getNewIndex(); + for (UInt i = 0; i < range.rank(); ++i) + if (range[i].empty()) + idx[range[i].getDim()] = range[i].getLB(); + range.getOpenDims(openDims); + typename SparseTensor::const_iterator it, e; + it = B.begin(); + e = B.end(); - for (; it != e; ++it) { - if (dom.includes(it->first)) - ++nnz; - } + for (; it != e; ++it) { + embed(openDims, it->first, idx); + for (UInt k = 0; k < B.getRank(); ++k) + idx[k] += range[openDims[k]].getLB(); + set(idx, it->second); + } + } - return nnz; + template + inline void toList(OutputIterator1 indices, OutputIterator2 values) const { + for (const_iterator it = begin(); it != end(); ++it) { + *indices = it->first; + *values = it->second; + ++indices; + ++values; } + } - /** - * Returns the number of zeros in a domain of this sparse tensor. - * Doens't work if the domain has closed dimensions. - * The domain needs to have the same rank as this sparse tensor. - * - * @param dom [Domain] the domain to scan for zeros - * @retval UInt [ >= 0 ] the number of zeros in dom - */ - inline const UInt getNZeros(const Domain& dom) const - { - return dom.size_elts() - getNNonZeros(dom); - } - - /** - * Returns the number of non-zeros in designated sub-spaces of - * this sparse tensor. The sub-spaces are designated by dims. - * The B tensor collects the results. - * - * Complexity: O(number of non-zeros) - * - * Example: - * If A is a 11x13 sparse tensor: - * - A.getNNonZeros(I1(1), B) returns the number of non-zeros - * per row in A, and B is a vector of size 11. - * - A.getNNonZeros(I1(0), B) returns the number of non-zeros - * per column of A, and B is a vector of size 13. - * - * @param dims [Index2] the dimensions along which to count the non-zeros - * @param B [SparseTensor] the sparse tensor of the number - * of non-zeros per sub-space - */ - template - inline void getNNonZeros(const Index2& dims, SparseTensor& B) const + /** + * Returns whether the element at idx is zero or not. + * + * Complexity: O(log(number of non-zeros)) + */ + inline bool isZero(const Index &idx) const { { - { - NTA_ASSERT(dims.size() + B.getRank() == getRank()); - } + NTA_ASSERT(positiveInBounds(idx, getBounds())) + << "SparseTensor::isZero(idx): " + << "Invalid index: " << idx + << " - Should be >= 0 and strictly less than: " << bounds_; + } - B.clear(); + return nz_.find(idx) == nz_.end(); + } - IndexB compDims = B.getNewIndex(), idxB = B.getNewIndex(); - complement(dims, compDims); - - const_iterator it, e; - for (it = begin(), e = end(); it != e; ++it) { - project(compDims, it->first, idxB); - B.update(idxB, (Float)1, std::plus()); - } - } + /** + * Returns whether the element at idx is zero or not. + * Calls isZero(Index). + */ + inline bool isZero(UInt i0, ...) const { + Index idx = getNewIndex(); + idx[0] = i0; + va_list indices; + va_start(indices, i0); + for (UInt k = 1; k < getRank(); ++k) + idx[k] = (UInt)va_arg(indices, unsigned int); + va_end(indices); + return isZero(idx); + } - /** - * Returns the number of zeros in designated sub-spaces of - * this sparse tensor. See getNNonZeros doc. - */ - template - inline void getNZeros(const Index2& dims, SparseTensor& B) const - { - { - NTA_ASSERT(dims.size() + B.getRank() == getRank()); - } - - IndexB compDims = B.getNewIndex(), idxB = B.getNewIndex(); - complement(dims, compDims); - - B.setAll((Float)getSizeElts(dims)); + /** + * Copies this sparse tensor to the given array of Floats. + * Sets the array to zero first, then copies only the non-zeros. + * + * Complexity: O(number of non-zeros) + */ + template inline void toDense(OutIter array) const { + { NTA_ASSERT(!isNull()); } - const_iterator it, e; - for (it = begin(), e = end(); it != e; ++it) { - project(compDims, it->first, idxB); - B.update(idxB, (Float)1, std::minus()); - } - } + memset(array, 0, product(getBounds()) * sizeof(Float)); + const_iterator it, e; + for (it = begin(), e = end(); it != e; ++it) + *(array + ordinal(getBounds(), it->first)) = it->second; + } - /** - * Returns true if this SparseTensor is the "empty" tensor, - * that is, a SparseTensor with no value (like a matrix without - * rows). - */ - inline bool isNull() const - { - return product(getBounds()) == 0; - } - - /** - * Returns true if there is no non-zero in this tensor, false otherwise. - * - * @retval bool whether this sparse tensor is identically zero or not - */ - inline bool isZero() const - { - return getNNonZeros() == 0; - } - - /** - * Returns true if the domain inside this sparse tensor is identically - * zero. - * Doens't work if the domain has closed dimensions. - * The domain needs to have the same rank as this sparse tensor. - * - * @param dom [Domain] the domain to look at - * @retval bool whether this sparse tensor is zero inside dom - */ - inline bool isZero(const Domain& dom) const - { - return getNNonZeros(dom) == 0; - } - - /** - * Returns true if there are no zeros in this tensor, false otherwise. - * The tensor is dense if it contains no zero. - * - * @retval bool whether this tensor is dense or not - */ - inline bool isDense() const - { - return getNNonZeros() == getSizeElts(); - } - - /** - * Returns true if the domain inside this sparse tensor is dense. - * Doens't work if the domain has closed dimensions. - * The domain needs to have the same rank as this sparse tensor. - * - * @param dom [Domain] the domain to look at - * @retval bool whether this sparse tensor is dense inside dom - */ - inline bool isDense(const Domain& dom) const - { - return getNNonZeros(dom) == dom.size_elts(); - } - - /** - * Returns true if there are zeros in this tensor, false otherwise. - * The tensor is sparse if it contains at least one zero. - * - * @retval bool whether this tensor is sparse or not - */ - inline bool isSparse() const - { - return getNNonZeros() != getSizeElts(); - } - - /** - * Returns true if the domain inside this sparse tensor is sparse. - * Doens't work if the domain has closed dimensions. - * The domain needs to have the same rank as this sparse tensor. - * - * @param dom [Domain] the domain to look at - * @retval bool whether this sparse tensor is sparse inside dom - */ - inline bool isSparse(const Domain& dom) const - { - return getNNonZeros(dom) != dom.size_elts(); - } - - /** - * Returns the fill rate for this tensor, that is, the ratio of the - * number of non-zeros to the total number of elements in this tensor. - * - * @retval Float the fill rate - */ - inline const Float getFillRate() const - { - return Float(getNNonZeros()) / Float(getSizeElts()); - } - - /** - * Returns the fill rate for this tensor inside the given domain, that is, - * the ratio of the number of non-zeros in the given domain to the - * size of the domain. - * - * @retval Float the fill rate inside the given domain - */ - inline const Float getFillRate(const Domain& dom) const - { - return Float(getNNonZeros(dom)) / Float(dom.size_elts()); - } + /** + * Copies the non-zeros from array into this sparse tensor. + * Clears this tensor first if the flag is true. + * + * Complexity: O(size * log(size)) ?? + */ + template + inline void fromDense(InIter array, bool clearYesNo = true) { + { NTA_ASSERT(!isNull()); } - /** - * Returns the fill rate for sub-spaces of this sparse tensor. - */ - template - inline void getFillRate(const Index2& dims, SparseTensor& B) const - { - getNNonZeros(dims, B); - B.element_apply_fast(bind2nd(std::divides(), (Float)getSizeElts(dims))); - } - - /** - * Returns whether this sparse tensor is positive or not, that is, - * whether all its coefficients are > nupic::Epsilon (there are no - * zeros in this tensor, and all the elements have positive values). - * - * Complexity: O(number of non-zeros) - */ - inline bool isPositive() const - { - if (getNZeros() > 0) - return false; + if (clearYesNo) + clear(); - const_iterator it, e; - for (it = begin(), e = end(); it != e; ++it) - if (strictlyNegative(it->second)) - return false; - return true; + Index idx = getNewIndex(); + const UInt M = product(getBounds()); + for (UInt i = 0; i < M; ++i) { + setFromOrdinal(bounds_, i, idx); + set(idx, *array++); } + } - /** - * Returns whether this sparse tensor is non-negative or not, - * that is, whether all its coefficients are >= -nupic::Epsilon - * (there can be zeros in this tensor, but all the non-zeros - * have positive values). - * - * Complexity: O(number of non-zeros) - */ - inline bool isNonNegative() const - { - if (nz_.empty()) - return true; + /** + * Copies the non-zeros from this tensor to the given output iterator. + * + * Complexity: O(number of non-zeros) + */ + template inline void toIdxVal(OutIter iv) const { + const_iterator it, e; + for (it = begin(), e = end(); it != e; ++it, ++iv) + *iv = std::make_pair(it->first, it->second); + } - const_iterator it, e; - for (it = begin(), e = end(); it != e; ++it) - if (strictlyNegative(it->second)) - return false; - return true; - } + /** + * Copies the values from the input iterator into this sparse tensor. + * Clear this tensor first, optionally. + */ + template + inline void fromIdxVal(const UInt &nz, InIter iv, bool clearYesNo = true) { + if (clearYesNo) + clear(); - /** - * Returns the set of values in this SparseTensor and how - * many times each of them appears. - * - * Complexity: O(number of non-zeros) with some log for - * the insertion in the result map... - */ - inline std::map values() const - { - std::map vals; + for (UInt i = 0; i < nz; ++i, ++iv) + set(iv->first, iv->second); + } - if (!isDense()) - vals[0] = getNZeros(); + /** + * Copies the values from the input iterator into this sparse tensor, + * assuming that only non-zeros are passed. + * Clear this tensor first, optionally. + */ + template + inline void fromIdxVal_nz(const UInt &nz, InIter iv, bool clearYesNo = true) { + if (clearYesNo) + clear(); - const_iterator it, e; - typename std::map::iterator found; - for (it = begin(), e = end(); it != e; ++it) { - found = vals.find(it->second); - if (found == vals.end()) - vals[it->second] = 1; - else - ++ vals[it->second]; - } + for (UInt i = 0; i < nz; ++i, ++iv) + setNonZero(iv->first, iv->second); + } - return vals; - } - - /** - * Makes this tensor the tensor zero, that is, all the non-zeros - * are removed. - */ - inline void clear() { nz_.clear(); } - - /** - * Creates a new Index that has the rank of this sparse tensor. - * The initial value of this Index is the bounds of this tensor. - * - * Note: - * To accomodate both Index and std::vector as - * indices, we can't allocate memory ourselves, so when we - * need an index, we create a copy of the bounds, and either - * do nothing, or set it to zero, or set to some specified - * set of values. - * - * @retval Index a new Index, that contains the values of the bounds - * for this sparse tensor - */ - inline Index getNewIndex() const - { - return getBounds(); - } + /** + * Updates some of the values in this sparse tensor, the indices and + * values to use for the update being passed in the input iterator. + * Uses binary functor f to carry out the update. + */ + template + inline void updateFromIdxVal(const UInt &nz, InIter iv, binary_functor f) { + for (UInt i = 0; i < nz; ++i, ++iv) + update(iv->first, iv->second, f); + } - /** - * Creates a new Index that has the rank of this sparse tensor - * and sets it to zero (see note in getNewIndex()). - * - * @retval Index a new Index, initialized to zero - */ - inline Index getNewZeroIndex() const - { - Index idx = getBounds(); - setToZero(idx); - return idx; - } - - /** - * Creates a new Index that has the rank of this sparse tensor - * and sets it to the specified values (see note in getNewIndex()). - * - * @retval Index a new Index, initialized to the values passed - */ - inline Index getNewIndex(UInt i0, ...) const - { - Index idx = getBounds(); - idx[0] = i0; - va_list indices; - va_start(indices, i0); - for (UInt k = 1; k < getRank(); ++k) - idx[k] = (UInt) va_arg(indices, unsigned int); - va_end(indices); - return idx; - } - - /** - * Computes whether this tensor is symmetric or not. - * A tensor is symmetric w.r.t. a permutation of the dimensions iff: - * A[ijkl...] = A[permutation(ijkl...)]. - * This implies that the bounds of the permuted dimensions need to - * be the same. If they are not, the tensor is not symmetric. - * The Index passed in needs to have the same size as the rank - * of this sparse tensor. - * - * Complexity: O(number of non-zeros) - * - * @param perm [Index] the permutation to use to evaluate whether - * this sparse tensor is symmetric or not - * @retval bool whether this sparse tensor is symmetric w.r.t. the - * given permutation - */ - inline bool isSymmetric(const Index& perm) const - { - { - NTA_ASSERT(perm.size() == getRank()); - NTA_ASSERT(isSet(perm)); - } - - Index idx2 = getNewZeroIndex(); - - nupic::permute(perm, bounds_, idx2); - if (bounds_ != idx2) - return false; + /** + * Outputs the non-zeros of this sparse tensor to a stream. + * Only non-zeros are put to the stream. + */ + inline void toStream(std::ostream &outStream) const { + { NTA_ASSERT(outStream.good()); } - const_iterator it, e; + outStream << getRank() << " "; - for (it = begin(), e = end(); it != e; ++it) { - nupic::permute(perm, it->first, idx2); - if (!nearlyZero_(it->second - get(idx2))) - return false; - } + for (UInt i = 0; i < getRank(); ++i) + outStream << getBounds()[i] << " "; - return true; + outStream << getNNonZeros() << " "; + + const_iterator it, e; + for (it = begin(), e = end(); it != e; ++it) { + for (UInt i = 0; i < getRank(); ++i) + outStream << (it->first)[i] << " "; + outStream << it->second << " "; } + } - /** - * Computes whether this tensor is anti-symmetric or not. - * A tensor is anti-symmetric w.r.t. to a permutation of the - * dimensions iff: - * A[ijkl...] = -A[permutation(ijkl...)] - * This implies that the upper bounds of the permuted dimensions - * need to be the same, or the tensor is not anti-symmetric. - * The Index passed in needs to have the same size as the rank - * of this sparse tensor. - * - * Complexity: O(number of non-zeros) - * - * @param perm [Index] the permutation to use to evaluate anti-symmetry - * @retval bool whether this sparse tensor is anty-symmetric w.r.t. - * the given permutation or not - */ - inline bool isAntiSymmetric(const Index& perm) const - { - { - NTA_ASSERT(perm.size() == getRank()); - NTA_ASSERT(isSet(perm)); - } + /** + * Reads values for this sparse tensor from a stream. + * Works even if the stream contains zeros (calls set). + */ + inline void fromStream(std::istream &inStream) { + { NTA_ASSERT(inStream.good()); } - Index idx2 = getNewZeroIndex(); - - nupic::permute(perm, bounds_, idx2); - if (bounds_ != idx2) - return false; + clear(); - const_iterator it, e; + UInt rank, nnz; + Index idx = getNewIndex(); + Float val; - for (it = begin(), e = end(); it != e; ++it) { - nupic::permute(perm, it->first, idx2); - if (!nearlyZero_(it->second + get(idx2))) - return false; - } + inStream >> rank; + NTA_ASSERT(rank > 0); + NTA_ASSERT(rank == bounds_.size()); - return true; - } - - /** - * Sets the element at idx to val. Handles zeros by not storing - * them, or by erasing non-zeros that become zeros when val = 0. - * The Index idx needs to be >= 0 and < getBounds(). - * - * Complexity: O(log(number of non-zeros)) - * - * @param idx [Index] the index of the element to set - * @param val [Float] the value to set for the element at index - */ - inline void set(const Index& idx, const Float& val) - { - { - NTA_ASSERT(positiveInBounds(idx, getBounds())) - << "SparseTensor::set(idx, val): " - << "Invalid index: " << idx - << " - Should be >= 0 and strictly less than: " << bounds_; - } - - if (nearlyZero_(val)) { - iterator it = nz_.find(idx); - if (it != end()) - nz_.erase(it); - } else - nz_[idx] = val; - } - - /** - * Sets the element at idx to val. Calls set(Index, Float). - */ - inline void set(UInt i0, ...) - { - Index idx = getNewIndex(); - idx[0] = i0; - va_list indices; - va_start(indices, i0); - for (UInt k = 1; k < getRank(); ++k) - idx[k] = (UInt) va_arg(indices, unsigned int); - const Float val = (Float) va_arg(indices, double); - va_end(indices); - set(idx, val); + for (UInt i = 0; i < rank; ++i) { + inStream >> bounds_[i]; + NTA_ASSERT(bounds_[i] > 0); } - /** - * Sets all the elements inside the dom to val. - * Handles zeros correctly (i.e. does not store them). - * - * @param dom [Domain] the domain inside which to set values - * @param val [Float] the value to set inside dom - */ - inline void set(const Domain& dom, const Float& val) - { - if (nearlyZero_(val)) { - setZero(dom); - } else { - setNonZero(dom, val); - } - } + inStream >> nnz; + NTA_ASSERT(nnz >= 0); - /** - * Sets the element at idx to zero, that is, removes it - * from the internal storage. - * - * Complexity: O(log(number of non-zeros)) - * - * @param idx [Index] the index of the element to set to zero - */ - inline void setZero(const Index& idx) - { - { - NTA_ASSERT(positiveInBounds(idx, getBounds())) - << "SparseTensor::setZero(idx): " - << "Invalid index: " << idx - << " - Should be >= 0 and strictly less than: " << bounds_; + for (UInt i = 0; i < nnz; ++i) { + for (UInt j = 0; j < rank; ++j) { + inStream >> idx[j]; + NTA_ASSERT(idx[j] >= 0 && idx[j] < bounds_[j]); } - - iterator it = nz_.find(idx); - if (it != end()) - nz_.erase(it); + inStream >> val; + set(idx, val); } + } - /** - * Sets the element at idx to zero. Calls setZero(Index). - */ - inline void setZero(UInt i0, ...) - { - Index idx = getNewIndex(); - idx[0] = i0; - va_list indices; - va_start(indices, i0); - for (UInt k = 1; k < getRank(); ++k) - idx[k] = (UInt) va_arg(indices, unsigned int); - va_end(indices); - setZero(idx); - } - - /** - * Sets to zero all the elements in Domain dom. - * - * @param dom [Domain] the domain to set to zero - */ - inline void setZero(const Domain& dom) - { - { - NTA_ASSERT(getDomain().includes(dom)) - << "SparseTensor::setZero(Domain): " - << "Domain argument: " << dom - << " is invalid" - << " - Should be included in: " << getDomain(); - } + /** + * Returns an iterator to the beginning of the non-zeros + * in this tensor. Iterator iterate only over the non-zeros. + */ + inline iterator begin() { return nz_.begin(); } - iterator it = begin(), d, e = end(); - while (it != e) { - if (dom.includes(it->first)) { - d = it; - ++it; // increment before erase - nz_.erase(d); - } else { - ++it; - } - } - } + /** + * Returns an iterator to one past the end of the non-zeros + * in this tensor. Iterator iterate only over the non-zeros. + */ + inline iterator end() { return nz_.end(); } - /** - * Sets element at idx to val, where |val| > nupic::Epsilon. - * - * Use if you know what you do: even f(non-zero, non-zero) - * can be "zero", if it falls below nupic::Epsilon. - * - * Complexity: O(log(number of non-zeros)) - * - * @param idx [Index] the index of the element to set to val - * @param val [Float] the value to set for the element at idx - */ - inline void setNonZero(const Index& idx, const Float& val) - { - { - NTA_ASSERT(positiveInBounds(idx, getBounds())) - << "SparseTensor::setNonZero(idx, val): " - << "Invalid index: " << idx - << " - Should be >= 0 and strictly less than: " << bounds_; + /** + * Returns a const iterator to the beginning of the non-zeros + * in this tensor. Iterator iterate only over the non-zeros. + */ + inline const_iterator begin() const { return nz_.begin(); } - NTA_ASSERT(!nearlyZero_(val)) - << "SparseTensor::setNonZero(idx, val): " - << "Invalid zero value: " << val - << " at index: " << idx - << " - Should be non-zero (> " << nupic::Epsilon << ")"; - } - - nz_[idx] = val; - } + /** + * Returns a const iterator to one past the end of the non-zeros + * in this tensor. Iterator iterate only over the non-zeros. + */ + inline const_iterator end() const { return nz_.end(); } - /** - * Sets all the values inside dom to val. - * Works only if |val| > nupic::Epsilon. - * - * @param dom [Domain] the domain inside which to set values - * @param val [Float] the value to set inside dom - */ - inline void setNonZero(const Domain& dom, const Float& val) - { - { - NTA_ASSERT(!nearlyZero_(val)); - } + inline std::pair equal_range(const Index &idx) { + return nz_.equal_range(idx); + } - Index lb = getNewIndex(), ub = getNewIndex(), idx = getNewIndex(); - dom.getLB(lb); dom.getUB(ub); - - idx = lb; - do { - setNonZero(idx, val); - } while (increment(lb, ub, idx)); - } - - /** - * Updates the value of this tensor at idx in place, using f and val: - * A[idx] = f(A[idx], val) (val as second argument). - * - * Handles zeros properly. - * - * Complexity: O(log(number of non-zeros)) - */ - template - inline Float update(const Index& idx, const Float& val, binary_functor f) - { - { - NTA_ASSERT(positiveInBounds(idx, getBounds())) - << "SparseTensor::update(idx, val, f(x, y)): " - << "Invalid index: " << idx - << " - Should be >= 0 and strictly less than: " << bounds_; - } - - Float res = 0; + inline std::pair + equal_range(const Index &idx) const { + return nz_.equal_range(idx); + } - iterator it = nz_.find(idx); - if (it != end()) { - res = f(it->second, val); - if (nearlyZero_(res)) - nz_.erase(it); - else - it->second = res; - } - else { - res = f(0, val); - if (!nearlyZero_(res)) - nz_[idx] = res; - } - - return res; - } + /** + * Permute the dimensions of each element of this tensor. + * + * Complexity: O(number of non-zeros) + */ + inline void permute(const Index &ind) { + { NTA_ASSERT(isSet(ind)); } - /** - * TODO: unit test - */ - inline void add(const Index& idx, const Float& val) - { - std::pair r = nz_.insert(std::make_pair(idx, val)); + Index idx = getNewIndex(), newBounds = getNewIndex(); + nupic::permute(ind, bounds_, newBounds); + + NZ newMap; - if (!r.second) - r.first->second += val; + const_iterator it, e; + for (it = begin(), e = end(); it != e; ++it) { + nupic::permute(ind, it->first, idx); + newMap[idx] = it->second; } - /** - * Sets all the values in this tensor to val. - * Makes this sparse tensor dense if |val| > nupic::Epsilon. - * Otherwise, removes all the values in this sparse tensor - * - * Complexity: O(product of bounds) (worst case, if |val| > nupic::Epsilon) - */ - inline void setAll(const Float& val) - { - if (nearlyZero_(val)) { - nz_.clear(); - return; - } + nz_ = newMap; + bounds_ = newBounds; + } - Index idx = getNewZeroIndex(); - do { - nz_[idx] = val; - } while (increment(bounds_, idx)); - } + /** + * Change the bounds of this tensor, while keeping the dimensionality. + */ + inline void resize(const Index &newBounds) { + if (newBounds == bounds_) + return; - /** - * Returns the value of the element at idx. - * - * Complexity: O(log(number of non-zeros)) - */ - inline Float get(const Index& idx) const - { - { - NTA_ASSERT(positiveInBounds(idx, getBounds())) - << "SparseTensor::get(idx): " - << "Invalid index: " << idx - << " - Should be >= 0 and strictly less than: " << bounds_; - } + bool shrink = false; + ITER_1(bounds_.size()) + if (newBounds[i] < bounds_[i]) + shrink = true; - const_iterator it = nz_.find(idx); - if (it != nz_.end()) - return it->second; - else - return Float(0); + if (shrink) { + const_iterator it, e; + for (it = begin(), e = end(); it != e; ++it) + if (!positiveInBounds(it->first, newBounds)) { + const_iterator d = it; + ++it; + nz_.erase(d->first); + } } - /** - * Returns the value of the element at idx. - * Calls get(Index). - */ - inline const Float get(UInt i0, ...) - { - Index idx = getNewIndex(); - idx[0] = i0; - va_list indices; - va_start(indices, i0); - for (UInt k = 1; k < getRank(); ++k) - idx[k] = (UInt) va_arg(indices, unsigned int); - va_end(indices); - return get(idx); - } - - /** - * Returns the element at idx. - * Calls get(Index). - * not providing the other one, because we need to control for zero, - * we can't just blindly pass a reference - */ - inline const Float operator()(UInt i0, ...) const - { - Index idx = getNewIndex(); - idx[0] = i0; - va_list indices; - va_start(indices, i0); - for (UInt k = 1; k < getRank(); ++k) - idx[k] = (UInt) va_arg(indices, unsigned int); - va_end(indices); - return get(idx); - } - - /** - * Extract sub-spaces along dimension dim, and put result - * in B. Only the non-zeros who dim-th coordinate is in ind - * are kept and stored in B. - * - * This operation reduces the size of this SparseTensor - * along dimension dim to the number of elements in ind. - * This operation reduces the size of this SparseTensor - * along dimension dim to the number of elements in ind. - * If ind is full, returns this SparseTensor unmodified. - * If ind is empty, throws an exception, because I can't - * reduce a SparseTensor to have a size along any dimension - * of zero. You need to detect that ind is empty before - * calling reduce. - * - * Returns the null tensor (not a tensor) if ind is empty, - * i.e. - * - * dim = 0 indicates that some rows will be removed... - */ - inline void extract(UInt dim, const std::set& ind, SparseTensor& B) const - { - { - NTA_ASSERT(&B != this) - << "SparseTensor::extract(): " - << "Cannot extract to self"; + bounds_ = newBounds; + } - NTA_ASSERT(0 <= dim && dim < getRank()) - << "SparseTensor::extract(): " - << "Invalid dimension: " << dim - << " - Should be between 0 and rank = " << getRank(); + /** + */ + inline void resize(const UInt &dim, const UInt &newSize) { + { NTA_ASSERT(0 <= dim && dim < getRank()); } - typename std::set::const_iterator i, e; - for (i = ind.begin(), e = ind.end(); i != e; ++i) - NTA_ASSERT(0 <= *i && *i < getBound(dim)) - << "SparseTensor::extract(): " - << "Invalid set member: " << *i - << " - Should be between 0 and bound (" << getBound(dim) - << ") for dim: " << dim; - } - - if (ind.empty()) { - B.bounds_[dim] = 0; - return; - } + Index newBounds(getNewIndex()); + newBounds[dim] = newSize; + resize(newBounds); + } - if (ind.size() == getBound(dim)) { - B = *this; - return; - } + /** + * Produces a tensor B that has the same non-zeros as this one + * but the given dimensions (the dimensions of B). Tensor B + * needs to provide as much storage as this sparse tensor. + * + * Complexity: O(number of non-zeros) + * + * @parameter B [SparseTensor] the target + * sparse tensor + */ + template + inline void reshape(SparseTensor &B) const { + { + NTA_ASSERT(indexGtZero(B.getBounds())); + NTA_ASSERT(!isNull()); + NTA_ASSERT(product(B.getBounds()) == product(getBounds())); + NTA_ASSERT((void *)&B != (void *)this); + } - B.clear(); - Index bounds = getNewIndex(); - bounds[dim] = (UInt)ind.size(); - B.resize(bounds); - - std::vector ind_v(getBound(dim)); - { - UInt j = 0; - typename std::set::const_iterator i, e; - for (j = 0, i = ind.begin(), e = ind.end(); i != e; ++i, ++j) - ind_v[*i] = j; - } + B.clear(); - const_iterator i, e; - for (i = begin(), e = end(); i != e; ++i) - if (ind.find(i->first[dim]) != ind.end()) { - Index idx = i->first; - idx[dim] = ind_v[idx[dim]]; - B.setNonZero(idx, i->second); - } + IndexB newBounds = B.getBounds(), idx2 = B.getNewIndex(); + + const_iterator it, e; + for (it = begin(), e = end(); it != e; ++it) { + setFromOrdinal(newBounds, ordinal(getBounds(), it->first), idx2); + B.setNonZero(idx2, it->second); } + } - /** - * In place (mutating) reduce. - * - * Keeps only the sub-spaces (rows or columns for a matrix) - * whose coordinate is a member of ind. Reduces the size - * of the tensor along dimension dim to the size of ind. - * Yields the null tensor if ind is empty. - * Does not change the tensor if ind is full. - * - * Examples: - * S2.reduce(0, set(1, 3)) keeps rows 1 and 3 of the matrix, - * eliminates the other rows. - * S2.reduce(1, set(1, 3)) keeps columns 1 and 3 of the matrix, - * eliminates the other columns. - */ - inline void reduce(UInt dim, const std::set& ind) - { - { - NTA_ASSERT(0 <= dim && dim < getRank()) - << "SparseTensor::reduce(): " - << "Invalid dimension: " << dim - << " - Should be between 0 and rank = " << getRank(); + /** + * A small class to carry information about two non-zeros + * in an intersection or union or sparse tensors of arbitraty + * (possibly different) ranks. If the ranks are different, + * IndexA and IndexB will have different sizes. + */ + template class Elt { + public: + inline Elt(const IndexA &ia, const Float a, const IndexB &ib, const Float b) + : index_a_(ia), index_b_(ib), a_(a), b_(b) {} - typename std::set::const_iterator i, e; - for (i = ind.begin(), e = ind.end(); i != e; ++i) - NTA_ASSERT(0 <= *i && *i < getBound(dim)) - << "SparseTensor::reduce(): " - << "Invalid set member: " << *i - << " - Should be between 0 and bound (" << getBound(dim) - << ") for dim: " << dim; - } + inline Elt(const Elt &o) + : index_a_(o.index_a_), index_b_(o.index_b_), a_(o.a_), b_(o.b_) {} - if (ind.empty()) { - bounds_[dim] = 0; - return; - } - - if (ind.size() == getBound(dim)) - return; - - std::vector ind_v(getBound(dim)); - { - UInt j = 0; - typename std::set::const_iterator i, e; - for (j = 0, i = ind.begin(), e = ind.end(); i != e; ++i, ++j) - ind_v[*i] = j; - } - - NZ keep; - - iterator i, e; - for (i = begin(), e = end(); i != e; ++i) - if (ind.find(i->first[dim]) != ind.end()) { - Index idx = i->first; - idx[dim] = ind_v[idx[dim]]; - keep[idx] = i->second; - } + inline Elt &operator=(const Elt &o) { + index_a_ = o.index_a_; + index_b_ = o.index_b_; + a_ = o.a_; + b_ = o.b_; + return *this; + } - nz_ = keep; - - Index bounds = getNewIndex(); - bounds[dim] = (UInt)ind.size(); - resize(bounds); - } - - /** - * Extract slices or sub-arrays from this tensor, of - * any dimensions, and within any bounds. Slices can be - * of any dimension <= getRank(). Slices are not allocated - * by this function, so an optional clearYesNo parameter - * is provided to remove all the existing non-zeros from - * the slice (B). When the range is not as big as the - * cartesian product of the bounds of this tensor, - * a sub-array tensor is extracted. - * - * Examples: - * If A has dimensions [(0, 10), (0, 20), (0, 30)] - * - slice with Domain = ((0, 0, 0), (10, 20, 30)) gives A - * - slice with Domain = ((0, 0, 10), (10, 20, 30)) gives - * the subarray of A reduced to indices 10 to 29 along - * the third dimension - * - slice with Domain = ((0, 0, 10), (10, 20, 10)) gives - * the matrix (0,10) by (0, 20) obtained when the third index - * is blocked to 10. - * - * Complexity: O(number of non-zeros in slice) - * - * @param range [Domain] the range to extract from this - * tensor. - * @param B [SparseTensor] the resulting - * slice - * @param clearYesNo [bool] whether to clear B before - * slicing or not - */ - template - inline void getSlice(const Domain& range, - SparseTensor& B) const - { - { - NTA_ASSERT(range.rank() == getRank()); - - NTA_ASSERT(B.getRank() == range.getNOpenDims()) - << "SparseTensor::slice(): " - << "Invalid range: " << range - << " - Range should have a number of open dims" - << " equal to the rank of the slice (" - << B.getRank() << ")"; - } + inline const IndexA getIndexA() const { return index_a_; } + inline const IndexB getIndexB() const { return index_b_; } + inline const Float getValA() const { return a_; } + inline const Float getValB() const { return b_; } - // Always clear, so we extract a zero slice - // if we don't hit any non-zero - B.clear(); - - IndexB sliceIndex = B.getNewIndex(), openDims = B.getNewIndex(); - range.getOpenDims(openDims); - const_iterator it = begin(); - const_iterator e = end(); - - for (; it != e; ++it) - if (range.includes(it->first)) { - project(openDims, it->first, sliceIndex); - for (UInt k = 0; k < B.getRank(); ++k) - sliceIndex[k] -= range[openDims[k]].getLB(); - B.set(sliceIndex, it->second); - } + inline std::ostream &print(std::ostream &outStream) const { + return outStream << getIndexA() << " " << getValA() << " " << getIndexB() + << " " << getValB(); } - template - inline void setSlice(const Domain& range, - const SparseTensor& B) - { - { - NTA_ASSERT(range.rank() == getRank()); - - NTA_ASSERT(B.getRank() == range.getNOpenDims()) - << "SparseTensor::setSlice(): " - << "Invalid range: " << range - << " - Range should have a number of open dims" - << " equal to the rank of the slice (" - << B.getRank() << ")"; - } + private: + IndexA index_a_; + IndexB index_b_; + Float a_, b_; + }; - // If the slice is empty, call setZero on range - // (processing below is based on non-zeros exclusively) - if (B.isZero()) { - setZero(range); - return; - } + /** + * A data structure to hold the non-zero intersection of two tensors + * of different dimensionalities. + */ + template + struct NonZeros : public std::vector> {}; - Index idx = getNewIndex(); - IndexB openDims = B.getNewIndex(); - for (UInt i = 0; i < range.rank(); ++i) - if (range[i].empty()) - idx[range[i].getDim()] = range[i].getLB(); - range.getOpenDims(openDims); - typename SparseTensor::const_iterator it, e; - it = B.begin(); - e = B.end(); - - for (; it != e; ++it) { - embed(openDims, it->first, idx); - for (UInt k = 0; k < B.getRank(); ++k) - idx[k] += range[openDims[k]].getLB(); - set(idx, it->second); + /** + * Computes the set of indices where this tensor and B have + * common non-zeros, when A and B have the same rank. + * + * Complexity: O(smaller number of non-zeros between this and B) + */ + inline void nz_intersection(const SparseTensor &B, + std::vector &inter) const { + inter.clear(); + + const_iterator it1 = begin(), end1 = end(); + const_iterator it2 = B.begin(), end2 = B.end(); + + while (it1 != end1 && it2 != end2) { + if (it1->first == it2->first) { + inter.push_back(it1->first); + ++it1; + ++it2; + } else if (it2->first < it1->first) { + ++it2; + } else { + ++it1; } } + } - template - inline void toList(OutputIterator1 indices, OutputIterator2 values) const + /** + * Computes the set of indices where the projection of this tensor + * on dims and B have common non-zeros. A and B have different Ranks. + * + * Complexity: O(number of non-zeros) + */ + template + inline void nz_intersection(const IndexB &dims, + const SparseTensor &B, + NonZeros &inter) const { { - for (const_iterator it = begin(); it != end(); ++it) { - *indices = it->first; - *values = it->second; - ++indices; ++values; - } + NTA_ASSERT(B.getRank() <= getRank()) + << "SparseTensor::nz_intersection(): " + << "Invalid tensor ranks: " << getRank() << " " << B.getRank() + << " - Tensor B's rank needs to be <= this rank: " << getRank(); } - /** - * Returns whether the element at idx is zero or not. - * - * Complexity: O(log(number of non-zeros)) - */ - inline bool isZero(const Index& idx) const - { - { - NTA_ASSERT(positiveInBounds(idx, getBounds())) - << "SparseTensor::isZero(idx): " - << "Invalid index: " << idx - << " - Should be >= 0 and strictly less than: " << bounds_; - } + inter.clear(); - return nz_.find(idx) == nz_.end(); + const_iterator it1, end1; + IndexB idxB = B.getNewIndex(); + + for (it1 = begin(), end1 = end(); it1 != end1; ++it1) { + project(dims, it1->first, idxB); + Float b = B.get(idxB); + if (!nearlyZero_(b)) + inter.push_back(Elt(it1->first, it1->second, idxB, b)); } + } - /** - * Returns whether the element at idx is zero or not. - * Calls isZero(Index). - */ - inline bool isZero(UInt i0, ...) const - { - Index idx = getNewIndex(); - idx[0] = i0; - va_list indices; - va_start(indices, i0); - for (UInt k = 1; k < getRank(); ++k) - idx[k] = (UInt) va_arg(indices, unsigned int); - va_end(indices); - return isZero(idx); - } - - /** - * Copies this sparse tensor to the given array of Floats. - * Sets the array to zero first, then copies only the non-zeros. - * - * Complexity: O(number of non-zeros) - */ - template - inline void toDense(OutIter array) const - { - { - NTA_ASSERT(!isNull()); - } - - memset(array, 0, product(getBounds()) * sizeof(Float)); - const_iterator it, e; - for (it = begin(), e = end(); it != e; ++it) - *(array + ordinal(getBounds(), it->first)) = it->second; - } - - /** - * Copies the non-zeros from array into this sparse tensor. - * Clears this tensor first if the flag is true. - * - * Complexity: O(size * log(size)) ?? - */ - template - inline void fromDense(InIter array, bool clearYesNo =true) - { - { - NTA_ASSERT(!isNull()); - } + /** + * Computes the set of indices where this tensor or B have + * a non-zero. + * + * Complexity: O(sum of number of non-zeros in this and B) + */ + inline void nz_union(const SparseTensor &B, std::vector &u) const { + u.clear(); - if (clearYesNo) - clear(); + const_iterator it1 = begin(), end1 = end(); + const_iterator it2 = B.begin(), end2 = B.end(); - Index idx = getNewIndex(); - const UInt M = product(getBounds()); - for (UInt i = 0; i < M; ++i) { - setFromOrdinal(bounds_, i, idx); - set(idx, *array++); + while (it1 != end1 && it2 != end2) { + if (it1->first == it2->first) { + u.push_back(it1->first); + ++it1; + ++it2; + } else if (it2->first < it1->first) { + u.push_back(it2->first); + ++it2; + } else { + u.push_back(it1->first); + ++it1; } } - /** - * Copies the non-zeros from this tensor to the given output iterator. - * - * Complexity: O(number of non-zeros) - */ - template - inline void toIdxVal(OutIter iv) const - { - const_iterator it, e; - for (it = begin(), e = end(); it != e; ++it, ++iv) - *iv = std::make_pair(it->first, it->second); - } + for (; it1 != end1; ++it1) + u.push_back(it1->first); - /** - * Copies the values from the input iterator into this sparse tensor. - * Clear this tensor first, optionally. - */ - template - inline void fromIdxVal(const UInt& nz, InIter iv, bool clearYesNo =true) - { - if (clearYesNo) - clear(); - - for (UInt i = 0; i < nz; ++i, ++iv) - set(iv->first, iv->second); - } - - /** - * Copies the values from the input iterator into this sparse tensor, - * assuming that only non-zeros are passed. - * Clear this tensor first, optionally. - */ - template - inline void fromIdxVal_nz(const UInt& nz, InIter iv, bool clearYesNo =true) - { - if (clearYesNo) - clear(); - - for (UInt i = 0; i < nz; ++i, ++iv) - setNonZero(iv->first, iv->second); - } - - /** - * Updates some of the values in this sparse tensor, the indices and - * values to use for the update being passed in the input iterator. - * Uses binary functor f to carry out the update. - */ - template - inline void updateFromIdxVal(const UInt& nz, InIter iv, binary_functor f) - { - for (UInt i = 0; i < nz; ++i, ++iv) - update(iv->first, iv->second, f); - } + for (; it2 != end2; ++it2) + u.push_back(it2->first); + } - /** - * Outputs the non-zeros of this sparse tensor to a stream. - * Only non-zeros are put to the stream. - */ - inline void toStream(std::ostream& outStream) const + /** + * Computes the set of indices where the projection of this tensor + * on dims and B have at least one non-zero. + * + * Complexity: O(product of bounds) + * + * Note: wish I could find a faster way to compute that union + */ + template + inline void nz_union(const IndexB &dims, const SparseTensor &B, + NonZeros &u) const { { - { - NTA_ASSERT(outStream.good()); - } + NTA_ASSERT(B.getRank() <= getRank()) + << "SparseTensor::nz_union(): " + << "Invalid tensor ranks: " << getRank() << " " << B.getRank() + << " - Tensor B's rank needs to be <= this rank: " << getRank(); + } - outStream << getRank() << " "; + u.clear(); - for (UInt i = 0; i < getRank(); ++i) - outStream << getBounds()[i] << " "; + Index idxa = getNewZeroIndex(); + IndexB idxb = B.getNewIndex(); - outStream << getNNonZeros() << " "; + do { + project(dims, idxa, idxb); + Float a = get(idxa), b = B.get(idxb); + if (!nearlyZero_(a) || !nearlyZero_(b)) + u.push_back(Elt(idxa, a, idxb, b)); + } while (increment(bounds_, idxa)); + } - const_iterator it, e; - for (it = begin(), e = end(); it != e; ++it) { - for (UInt i = 0; i < getRank(); ++i) - outStream << (it->first)[i] << " "; - outStream << it->second << " "; - } + /** + * Applies the unary functor f to each non-zero element + * in this sparse tensor, assuming that no new non-zero + * is introduced. This works for scaling, for example, if + * the scaling value is not zero. + * + * @WARNING: this is pretty dangerous, since it doesn't + * check that f introduces new zeros!!! + */ + template + inline void element_apply_nz(unary_function f) { + { + NTA_ASSERT(f(0) == 0); + NTA_ASSERT(f(1) != 0); + NTA_ASSERT(f(2) != 0); } - /** - * Reads values for this sparse tensor from a stream. - * Works even if the stream contains zeros (calls set). - */ - inline void fromStream(std::istream& inStream) - { - { - NTA_ASSERT(inStream.good()); - } - - clear(); + iterator i, e; + for (i = begin(), e = end(); i != e; ++i) + i->second = f(i->second); + } - UInt rank, nnz; - Index idx = getNewIndex(); - Float val; + /** + * Applies the unary functor f to each non-zero element + * in this tensor: + * + * A[i] = f(A[i]), with f(0) == 0. + * + * Assumes (and checks) that f(0) == 0. The non-zeros + * can change. New non-zeros can be introduced, but this + * function iterates on the non-zeros only. + */ + template + inline void element_apply_fast(unary_function f) { + { + NTA_ASSERT(f(0) == 0) << "SparseTensor::element_apply(unary_functor): " + << "Binary functor should do: f(0) == 0"; + } - inStream >> rank; - NTA_ASSERT(rank > 0); - NTA_ASSERT(rank == bounds_.size()); + // Can introduce new zeros! if we know nothing about + // the functor - for (UInt i = 0 ; i < rank; ++i) { - inStream >> bounds_[i]; - NTA_ASSERT(bounds_[i] > 0); + iterator it, d, e; + it = begin(); + e = end(); + while (it != e) { + Float val = f(it->second); + if (nearlyZero_(val)) { // check zero _after_ applying functor + d = it; + ++it; // increment before erasing! + nz_.erase(d); + } else { + it->second = val; + ++it; } + } + } - inStream >> nnz; - NTA_ASSERT(nnz >= 0); + /** + * Applies the unary functor f to all elements in this tensor, + * as if it were a dense tensor. This is useful when f(0) != 0, but it is + * slow because it doesn't take advantage of the sparsity. + * + * A[i] = f(A[i]) + * + * Complexity: O(product of the bounds) + */ + template inline void element_apply(unary_functor f) { + Index idx = getNewZeroIndex(); + do { + // Writing it that way allows to use boost::lambda expressions + // for f, which is really, really convenient + Float val = get(idx); + val = f(val); + set(idx, val); // doesn't invalidate iterator + } while (increment(bounds_, idx)); + } - for (UInt i = 0; i < nnz; ++i) { - for (UInt j = 0; j < rank; ++j) { - inStream >> idx[j]; - NTA_ASSERT(idx[j] >= 0 && idx[j] < bounds_[j]); - } - inStream >> val; - set(idx, val); - } - } + /** + * Applies the binary functor f to couple of elements from this tensor + * and tensor B, at the same element, where no element of the couple is zero. + * The result is stored in tensor C. + * + * C[i] = f(A[i], B[i]), where A[i] != 0 AND B[i] != 0. + * + * This works for f = multiplication, where f(x, 0) == f(0, x) == 0, for all + * x. It doesn't work for addition. + * + * Complexity: O(smaller number of non-zeros between this and B) + */ + template + inline void element_apply_fast(const SparseTensor &B, SparseTensor &C, + binary_functor f, + bool clearYesNo = true) const { + { + NTA_ASSERT(getBounds() == B.getBounds()) + << "SparseTensor::element_apply_fast(): " + << "A and B have different bounds: " << getBounds() << " and " + << B.getBounds() << " - Bounds need to be the same"; - /** - * Returns an iterator to the beginning of the non-zeros - * in this tensor. Iterator iterate only over the non-zeros. - */ - inline iterator begin() { return nz_.begin(); } + NTA_ASSERT(getBounds() == C.getBounds()) + << "SparseTensor::element_apply_fast(): " + << "A and C have different bounds: " << getBounds() << " and " + << C.getBounds() << " - Bounds need to be the same"; - /** - * Returns an iterator to one past the end of the non-zeros - * in this tensor. Iterator iterate only over the non-zeros. - */ - inline iterator end() { return nz_.end(); } + NTA_ASSERT(f(0, 1) == 0 && f(1, 0) == 0 && f(0, 0) == 0) + << "SparseTensor::element_apply_fast(): " + << "Binary functor should do: f(x, 0) == f(0, x) == 0 for all x"; + } - /** - * Returns a const iterator to the beginning of the non-zeros - * in this tensor. Iterator iterate only over the non-zeros. - */ - inline const_iterator begin() const { return nz_.begin(); } + if (clearYesNo) + C.clear(); - /** - * Returns a const iterator to one past the end of the non-zeros - * in this tensor. Iterator iterate only over the non-zeros. - */ - inline const_iterator end() const { return nz_.end(); } + const_iterator it1, end1, it2, end2; + it1 = begin(); + end1 = end(); + it2 = B.begin(); + end2 = B.end(); + + while (it1 != end1 && it2 != end2) + if (it1->first == it2->first) { + C.set(it1->first, f(it1->second, it2->second)); + ++it1; + ++it2; + } else if (it2->first < it1->first) { + ++it2; + } else { + ++it1; + } + } - inline std::pair equal_range(const Index& idx) + /** + * Applies the binary functor f to couple of elements from this tensor + * and tensor B, at the same index, assuming that f(0, 0) == 0. + * The result is stored in tensor C. + * + * C[i] = f(A[i], B[i]), where A[i] != 0 OR B[i] != 0 + * + * This works for f = multiplication, and f = addition. + * It does not work if f(0, 0) != 0. + * + * Complexity: O(sum of number of non-zeros between this and B) + */ + template + inline void element_apply_nz(const SparseTensor &B, SparseTensor &C, + binary_functor f, bool clearYesNo = true) const { { - return nz_.equal_range(idx); - } + NTA_ASSERT(getBounds() == B.getBounds()) + << "SparseTensor::element_apply_nz(): " + << "A and B have different bounds: " << getBounds() << " and " + << B.getBounds() << " - Bounds need to be the same"; - inline std::pair - equal_range(const Index& idx) const - { - return nz_.equal_range(idx); + NTA_ASSERT(getBounds() == C.getBounds()) + << "SparseTensor::element_apply_nz(): " + << "A and C have different bounds: " << getBounds() << " and " + << C.getBounds() << " - Bounds need to be the same"; + + NTA_ASSERT(f(0, 0) == 0) << "SparseTensor::element_apply_nz(): " + << "Binary functor should do: f(0, 0) == 0"; } - /** - * Permute the dimensions of each element of this tensor. - * - * Complexity: O(number of non-zeros) - */ - inline void permute(const Index& ind) - { - { - NTA_ASSERT(isSet(ind)); - } + if (clearYesNo) + C.clear(); - Index idx = getNewIndex(), newBounds = getNewIndex(); - nupic::permute(ind, bounds_, newBounds); + const_iterator it1 = begin(), it2 = B.begin(), end1 = end(), end2 = B.end(); - NZ newMap; - - const_iterator it, e; - for (it = begin(), e = end(); it != e; ++it) { - nupic::permute(ind, it->first, idx); - newMap[idx] = it->second; + // Any of the set() can introduce new zeros! + while (it1 != end1 && it2 != end2) + if (it1->first == it2->first) { + C.set(it1->first, f(it1->second, it2->second)); + ++it1; + ++it2; + } else if (it2->first < it1->first) { + C.set(it2->first, f(0, it2->second)); + ++it2; + } else { + C.set(it1->first, f(it1->second, 0)); + ++it1; } - nz_ = newMap; - bounds_ = newBounds; - } + for (; it1 != end1; ++it1) + C.set(it1->first, f(it1->second, 0)); + + for (; it2 != end2; ++it2) + C.set(it2->first, f(0, it2->second)); + } - /** - * Change the bounds of this tensor, while keeping the dimensionality. - */ - inline void resize(const Index& newBounds) + /** + * Applies the binary functor f to couple of elements from this tensor + * and tensor B, at the same index, without assuming anything on f. + * The result is stored in tensor C. + * + * C[i] = f(A[i], B[i]) + * + * This works in all cases, even if f(0, 0) != 0. It does not take + * advantage of the sparsity. + * + * Complexity: O(product of the bounds) + */ + template + inline void element_apply(const SparseTensor &B, SparseTensor &C, + binary_functor f) const { { - if (newBounds == bounds_) - return; - - bool shrink = false; - ITER_1(bounds_.size()) - if (newBounds[i] < bounds_[i]) - shrink = true; - - if (shrink) { - const_iterator it, e; - for (it = begin(), e = end(); it != e; ++it) - if (!positiveInBounds(it->first, newBounds)) { - const_iterator d = it; - ++it; - nz_.erase(d->first); - } - } + NTA_ASSERT(getBounds() == B.getBounds()) + << "SparseTensor::element_apply(): " + << "A and B have different bounds: " << getBounds() << " and " + << B.getBounds() << " - Bounds need to be the same"; - bounds_ = newBounds; + NTA_ASSERT(getBounds() == C.getBounds()) + << "SparseTensor::element_apply(): " + << "A and C have different bounds: " << getBounds() << " and " + << C.getBounds() << " - Bounds need to be the same"; } - /** - */ - inline void resize(const UInt& dim, const UInt& newSize) - { - { - NTA_ASSERT(0 <= dim && dim < getRank()); - } + Index idx = getNewZeroIndex(); + do { + C.set(idx, f(get(idx), B.get(idx))); + } while (increment(bounds_, idx)); + } - Index newBounds(getNewIndex()); - newBounds[dim] = newSize; - resize(newBounds); - } - - /** - * Produces a tensor B that has the same non-zeros as this one - * but the given dimensions (the dimensions of B). Tensor B - * needs to provide as much storage as this sparse tensor. - * - * Complexity: O(number of non-zeros) - * - * @parameter B [SparseTensor] the target - * sparse tensor - */ - template - inline void reshape(SparseTensor& B) const + /** + * In place factor apply (mutating) + * + * A[i] = f(A[i], B[j]), + * where j = projection of i on dims, + * and A[i] != 0 AND B[j] != 0. + * + * This works for multiplication, but not for addition, and not if f(0, 0) == + * 0. + * + * Complexity: O(smaller number of non-zeros between this and B) + */ + template + inline void factor_apply_fast(const IndexB &dims, + const SparseTensor &B, + binary_functor f) { { - { - NTA_ASSERT(indexGtZero(B.getBounds())); - NTA_ASSERT(!isNull()); - NTA_ASSERT(product(B.getBounds()) == product(getBounds())); - NTA_ASSERT((void*)&B != (void*)this); - } + NTA_ASSERT(getRank() > 1) + << "SparseTensor::factor_apply_fast() (in place): " + << "A rank is: " << getRank() << " - Should be > 1"; - B.clear(); + NTA_ASSERT(B.getRank() >= 1) + << "SparseTensor::factor_apply_fast() (in place): " + << "B rank is: " << B.getRank() << " - Should be >= 1"; - IndexB newBounds = B.getBounds(), idx2 = B.getNewIndex(); + NTA_ASSERT(B.getRank() <= getRank()) + << "SparseTensor::factor_apply_fast() (in place): " + << "A rank is: " << getRank() << " B rank is: " << B.getRank() + << " - B rank should <= A rank"; + } - const_iterator it, e; - for (it = begin(), e = end(); it != e; ++it) { - setFromOrdinal(newBounds, ordinal(getBounds(), it->first), idx2); - B.setNonZero(idx2, it->second); - } - } - - /** - * A small class to carry information about two non-zeros - * in an intersection or union or sparse tensors of arbitraty - * (possibly different) ranks. If the ranks are different, - * IndexA and IndexB will have different sizes. - */ - template - class Elt - { - public: - inline Elt(const IndexA& ia, const Float a, const IndexB& ib, const Float b) - : index_a_(ia), index_b_(ib), a_(a), b_(b) - {} - - inline Elt(const Elt& o) - : index_a_(o.index_a_), index_b_(o.index_b_), a_(o.a_), b_(o.b_) - {} - - inline Elt& operator=(const Elt& o) - { - index_a_ = o.index_a_; index_b_ = o.index_b_; - a_ = o.a_; b_ = o.b_; - return *this; - } + NonZeros inter; + nz_intersection(dims, B, inter); - inline const IndexA getIndexA() const { return index_a_; } - inline const IndexB getIndexB() const { return index_b_; } - inline const Float getValA() const { return a_; } - inline const Float getValB() const { return b_; } + // Need to clear everything so that zeros are handled properly, + // but we have all the values of this tensor in the intersection... + clear(); - inline std::ostream& print(std::ostream& outStream) const - { - return outStream << getIndexA() << " " << getValA() << " " - << getIndexB() << " " << getValB(); - } + // We have to call set instead of setNonZero, because + // even the multiplication of two non-zeros can result + // in a "zero" from the point of view of SparseMatrix + // that is, a value such that |value| <= nupic::Epsilon + typename NonZeros::const_iterator it, e; + for (it = inter.begin(), e = inter.end(); it != e; ++it) + set(it->getIndexA(), f(it->getValA(), it->getValB())); + } - private: - IndexA index_a_; - IndexB index_b_; - Float a_, b_; - }; - - /** - * A data structure to hold the non-zero intersection of two tensors - * of different dimensionalities. - */ - template - struct NonZeros : public std::vector > - {}; - - /** - * Computes the set of indices where this tensor and B have - * common non-zeros, when A and B have the same rank. - * - * Complexity: O(smaller number of non-zeros between this and B) - */ - inline void nz_intersection(const SparseTensor& B, std::vector& inter) const - { - inter.clear(); - - const_iterator it1 = begin(), end1 = end(); - const_iterator it2 = B.begin(), end2 = B.end(); - - while (it1 != end1 && it2 != end2) { - if (it1->first == it2->first) { - inter.push_back(it1->first); - ++it1; ++it2; - } else if (it2->first < it1->first) { - ++it2; - } else { - ++it1; - } - } - } - - /** - * Computes the set of indices where the projection of this tensor - * on dims and B have common non-zeros. A and B have different Ranks. - * - * Complexity: O(number of non-zeros) - */ - template - inline void nz_intersection(const IndexB& dims, - const SparseTensor& B, - NonZeros& inter) const + /** + * In place factor apply on non-zeros (mutating). + * Works for addition and multiplication. + */ + template + inline void factor_apply_nz(const IndexB &dims, + const SparseTensor &B, + binary_functor f) { { - { - NTA_ASSERT(B.getRank() <= getRank()) - << "SparseTensor::nz_intersection(): " - << "Invalid tensor ranks: " << getRank() - << " " << B.getRank() - << " - Tensor B's rank needs to be <= this rank: " << getRank(); - } + NTA_ASSERT(getRank() > 1) + << "SparseTensor::factor_apply_nz(): " + << "A rank is: " << getRank() << " - Should be > 1"; - inter.clear(); + NTA_ASSERT(B.getRank() >= 1) + << "SparseTensor::factor_apply_nz(): " + << "B rank is: " << B.getRank() << " - Should be >= 1"; - const_iterator it1, end1; - IndexB idxB = B.getNewIndex(); - - for (it1 = begin(), end1 = end(); it1 != end1; ++it1) { - project(dims, it1->first, idxB); - Float b = B.get(idxB); - if (!nearlyZero_(b)) - inter.push_back(Elt(it1->first, it1->second, idxB, b)); - } - } + NTA_ASSERT(B.getRank() <= getRank()) + << "SparseTensor::factor_apply_nz(): " + << "A rank is: " << getRank() << " B rank is: " << B.getRank() + << " - B rank should <= A rank"; - /** - * Computes the set of indices where this tensor or B have - * a non-zero. - * - * Complexity: O(sum of number of non-zeros in this and B) - */ - inline void nz_union(const SparseTensor& B, std::vector& u) const - { - u.clear(); - - const_iterator it1 = begin(), end1 = end(); - const_iterator it2 = B.begin(), end2 = B.end(); - - while (it1 != end1 && it2 != end2) { - if (it1->first == it2->first) { - u.push_back(it1->first); - ++it1; ++it2; - } else if (it2->first < it1->first) { - u.push_back(it2->first); - ++it2; - } else { - u.push_back(it1->first); - ++it1; - } - } + NTA_ASSERT(f(0, 0) == 0) << "SparseTensor::factor_apply_nz(): " + << "Binary functor should do: f(0, 0) == 0"; + } - for (; it1 != end1; ++it1) - u.push_back(it1->first); - - for (; it2 != end2; ++it2) - u.push_back(it2->first); - } - - /** - * Computes the set of indices where the projection of this tensor - * on dims and B have at least one non-zero. - * - * Complexity: O(product of bounds) - * - * Note: wish I could find a faster way to compute that union - */ - template - inline void nz_union(const IndexB& dims, - const SparseTensor& B, - NonZeros& u) const - { - { - NTA_ASSERT(B.getRank() <= getRank()) - << "SparseTensor::nz_union(): " - << "Invalid tensor ranks: " << getRank() - << " " << B.getRank() - << " - Tensor B's rank needs to be <= this rank: " << getRank(); - } + // This is unfortunately quite slow, because projection + // is a surjection... + NonZeros u; + nz_union(dims, B, u); - u.clear(); + // Calling set because f(a, b) can fall below nupic::Epsilon + typename NonZeros::const_iterator it, e; + for (it = u.begin(), e = u.end(); it != e; ++it) + set(it->getIndexA(), f(it->getValA(), it->getValB())); + } - Index idxa = getNewZeroIndex(); - IndexB idxb = B.getNewIndex(); - - do { - project(dims, idxa, idxb); - Float a = get(idxa), b = B.get(idxb); - if (! nearlyZero_(a) || ! nearlyZero_(b)) - u.push_back(Elt(idxa, a, idxb, b)); - } while (increment(bounds_, idxa)); - } - - /** - * Applies the unary functor f to each non-zero element - * in this sparse tensor, assuming that no new non-zero - * is introduced. This works for scaling, for example, if - * the scaling value is not zero. - * - * @WARNING: this is pretty dangerous, since it doesn't - * check that f introduces new zeros!!! - */ - template - inline void element_apply_nz(unary_function f) + /** + * Binary factor apply (non-mutating) + * + * C[i] = f(A[i], B[j]), + * where j = projection of i on dims, + * and A[i] != 0 AND B[j] != 0. + * + * This works for multiplication, but not for addition, and not if f(0, 0) == + * 0. + * + * Complexity: O(smaller number of non-zeros between this and B) + */ + template + inline void factor_apply_fast(const IndexB &dims, + const SparseTensor &B, + SparseTensor &C, binary_functor f, + bool clearYesNo = true) const { { - { - NTA_ASSERT(f(0) == 0); - NTA_ASSERT(f(1) != 0); - NTA_ASSERT(f(2) != 0); - } + NTA_ASSERT(getRank() > 1) + << "SparseTensor::factor_apply_fast(): " + << "A rank is: " << getRank() << " - Should be > 1"; - iterator i, e; - for (i = begin(), e = end(); i != e; ++i) - i->second = f(i->second); - } - - /** - * Applies the unary functor f to each non-zero element - * in this tensor: - * - * A[i] = f(A[i]), with f(0) == 0. - * - * Assumes (and checks) that f(0) == 0. The non-zeros - * can change. New non-zeros can be introduced, but this - * function iterates on the non-zeros only. - */ - template - inline void element_apply_fast(unary_function f) - { - { - NTA_ASSERT(f(0) == 0) - << "SparseTensor::element_apply(unary_functor): " - << "Binary functor should do: f(0) == 0"; - } + NTA_ASSERT(B.getRank() >= 1) + << "SparseTensor::factor_apply_fast(): " + << "B rank is: " << B.getRank() << " - Should be >= 1"; - // Can introduce new zeros! if we know nothing about - // the functor - - iterator it, d, e; - it = begin(); e = end(); - while (it != e) { - Float val = f(it->second); - if (nearlyZero_(val)) { // check zero _after_ applying functor - d = it; - ++it; // increment before erasing! - nz_.erase(d); - } - else { - it->second = val; - ++it; - } - } - } - - /** - * Applies the unary functor f to all elements in this tensor, - * as if it were a dense tensor. This is useful when f(0) != 0, but it is - * slow because it doesn't take advantage of the sparsity. - * - * A[i] = f(A[i]) - * - * Complexity: O(product of the bounds) - */ - template - inline void element_apply(unary_functor f) - { - Index idx = getNewZeroIndex(); - do { - // Writing it that way allows to use boost::lambda expressions - // for f, which is really, really convenient - Float val = get(idx); - val = f(val); - set(idx, val); // doesn't invalidate iterator - } while (increment(bounds_, idx)); - } - - /** - * Applies the binary functor f to couple of elements from this tensor - * and tensor B, at the same element, where no element of the couple is zero. - * The result is stored in tensor C. - * - * C[i] = f(A[i], B[i]), where A[i] != 0 AND B[i] != 0. - * - * This works for f = multiplication, where f(x, 0) == f(0, x) == 0, for all x. - * It doesn't work for addition. - * - * Complexity: O(smaller number of non-zeros between this and B) - */ - template - inline void element_apply_fast(const SparseTensor& B, SparseTensor& C, - binary_functor f, bool clearYesNo =true) const - { - { - NTA_ASSERT(getBounds() == B.getBounds()) - << "SparseTensor::element_apply_fast(): " - << "A and B have different bounds: " - << getBounds() << " and " << B.getBounds() - << " - Bounds need to be the same"; + NTA_ASSERT(B.getRank() <= getRank()) + << "SparseTensor::factor_apply_fast(): " + << "A rank is: " << getRank() << " B rank is: " << B.getRank() + << " - B rank should <= A rank"; - NTA_ASSERT(getBounds() == C.getBounds()) - << "SparseTensor::element_apply_fast(): " - << "A and C have different bounds: " - << getBounds() << " and " << C.getBounds() - << " - Bounds need to be the same"; + NTA_ASSERT(f(0, 1) == 0 && f(1, 0) == 0 && f(0, 0) == 0) + << "SparseTensor::factor_apply_fast(): " + << "Binary functor should do: f(0, x) == f(x, 0) == 0 for all x"; + } - NTA_ASSERT(f(0, 1) == 0 && f(1, 0) == 0 && f(0, 0) == 0) - << "SparseTensor::element_apply_fast(): " - << "Binary functor should do: f(x, 0) == f(0, x) == 0 for all x"; - } + if (clearYesNo) + C.clear(); - if (clearYesNo) - C.clear(); - - const_iterator it1, end1, it2, end2; - it1 = begin(); end1 = end(); - it2 = B.begin(); end2 = B.end(); - - while (it1 != end1 && it2 != end2) - if (it1->first == it2->first) { - C.set(it1->first, f(it1->second, it2->second)); - ++it1; ++it2; - } else if (it2->first < it1->first) { - ++it2; - } else { - ++it1; - } - } + NonZeros inter; + nz_intersection(dims, B, inter); + + // Calling set because f(a, b) can fall below nupic::Epsilon + typename NonZeros::const_iterator it, e; + for (it = inter.begin(), e = inter.end(); it != e; ++it) + C.set(it->getIndexA(), f(it->getValA(), it->getValB())); + } - /** - * Applies the binary functor f to couple of elements from this tensor - * and tensor B, at the same index, assuming that f(0, 0) == 0. - * The result is stored in tensor C. - * - * C[i] = f(A[i], B[i]), where A[i] != 0 OR B[i] != 0 - * - * This works for f = multiplication, and f = addition. - * It does not work if f(0, 0) != 0. - * - * Complexity: O(sum of number of non-zeros between this and B) - */ - template - inline void element_apply_nz(const SparseTensor& B, SparseTensor& C, - binary_functor f, bool clearYesNo =true) const + /** + * C[i] = f(A[i], B[j]), + * where j = projection of i on dims, + * and A[i] != 0 OR B[j] != 0. + * + * This works for addition, but not if f(0, 0) != 0. + * + * Complexity: O(sum of number of non-zeros between this and B) + */ + template + inline void factor_apply_nz(const IndexB &dims, + const SparseTensor &B, + SparseTensor &C, binary_functor f, + bool clearYesNo = true) const { { - { - NTA_ASSERT(getBounds() == B.getBounds()) - << "SparseTensor::element_apply_nz(): " - << "A and B have different bounds: " - << getBounds() << " and " << B.getBounds() - << " - Bounds need to be the same"; + NTA_ASSERT(getRank() > 1) + << "SparseTensor::factor_apply_nz(): " + << "A rank is: " << getRank() << " - Should be > 1"; - NTA_ASSERT(getBounds() == C.getBounds()) - << "SparseTensor::element_apply_nz(): " - << "A and C have different bounds: " - << getBounds() << " and " << C.getBounds() - << " - Bounds need to be the same"; + NTA_ASSERT(B.getRank() >= 1) + << "SparseTensor::factor_apply_nz(): " + << "B rank is: " << B.getRank() << " - Should be >= 1"; - NTA_ASSERT(f(0, 0) == 0) - << "SparseTensor::element_apply_nz(): " - << "Binary functor should do: f(0, 0) == 0"; - } + NTA_ASSERT(B.getRank() <= getRank()) + << "SparseTensor::factor_apply_nz(): " + << "A rank is: " << getRank() << " B rank is: " << B.getRank() + << " - B rank should <= A rank"; - if (clearYesNo) - C.clear(); - - const_iterator - it1 = begin(), it2 = B.begin(), - end1 = end(), end2 = B.end(); - - // Any of the set() can introduce new zeros! - while (it1 != end1 && it2 != end2) - if (it1->first == it2->first) { - C.set(it1->first, f(it1->second, it2->second)); - ++it1; ++it2; - } else if (it2->first < it1->first) { - C.set(it2->first, f(0, it2->second)); - ++it2; - } else { - C.set(it1->first, f(it1->second, 0)); - ++it1; - } - - for (; it1 != end1; ++it1) - C.set(it1->first, f(it1->second, 0)); + NTA_ASSERT(&C != this) << "SparseTensor::factor_apply_nz(): " + << "Can't store result in A"; - for (; it2 != end2; ++it2) - C.set(it2->first, f(0, it2->second)); + NTA_ASSERT(f(0, 0) == 0) << "SparseTensor::factor_apply_nz(): " + << "Binary functor should do: f(0, 0) == 0"; } - /** - * Applies the binary functor f to couple of elements from this tensor - * and tensor B, at the same index, without assuming anything on f. - * The result is stored in tensor C. - * - * C[i] = f(A[i], B[i]) - * - * This works in all cases, even if f(0, 0) != 0. It does not take - * advantage of the sparsity. - * - * Complexity: O(product of the bounds) - */ - template - inline void element_apply(const SparseTensor& B, SparseTensor& C, - binary_functor f) const - { - { - NTA_ASSERT(getBounds() == B.getBounds()) - << "SparseTensor::element_apply(): " - << "A and B have different bounds: " - << getBounds() << " and " << B.getBounds() - << " - Bounds need to be the same"; + if (clearYesNo) + C.clear(); - NTA_ASSERT(getBounds() == C.getBounds()) - << "SparseTensor::element_apply(): " - << "A and C have different bounds: " - << getBounds() << " and " << C.getBounds() - << " - Bounds need to be the same"; - } + // This is unfortunately quite slow, because projection + // is a surjection... + NonZeros u; + nz_union(dims, B, u); - Index idx = getNewZeroIndex(); - do { - C.set(idx, f(get(idx), B.get(idx))); - } while (increment(bounds_, idx)); - } - - /** - * In place factor apply (mutating) - * - * A[i] = f(A[i], B[j]), - * where j = projection of i on dims, - * and A[i] != 0 AND B[j] != 0. - * - * This works for multiplication, but not for addition, and not if f(0, 0) == 0. - * - * Complexity: O(smaller number of non-zeros between this and B) - */ - template - inline void factor_apply_fast(const IndexB& dims, - const SparseTensor& B, - binary_functor f) + // Calling set because f(a, b) can fall below nupic::Epsilon + typename NonZeros::const_iterator it, e; + for (it = u.begin(), e = u.end(); it != e; ++it) + C.set(it->getIndexA(), f(it->getValA(), it->getValB())); + } + + /** + * C[i] = f(A[i], B[j]), + * where j = projection of i on dims. + * + * There is no restriction on f, it works even if f(0, 0) != 0. + * Doesn't take advantage of the sparsity. + * + * Complexity: O(product of bounds) + */ + template + inline void + factor_apply(const IndexB &dims, const SparseTensor &B, + SparseTensor &C, binary_functor f) const { { - { - NTA_ASSERT(getRank() > 1) - << "SparseTensor::factor_apply_fast() (in place): " - << "A rank is: " << getRank() - << " - Should be > 1"; + NTA_ASSERT(getRank() > 1) + << "SparseTensor::factor_apply(): " + << "A rank is: " << getRank() << " - Should be > 1"; - NTA_ASSERT(B.getRank() >= 1) - << "SparseTensor::factor_apply_fast() (in place): " - << "B rank is: " << B.getRank() - << " - Should be >= 1"; + NTA_ASSERT(B.getRank() >= 1) + << "SparseTensor::factor_apply(): " + << "B rank is: " << B.getRank() << " - Should be >= 1"; - NTA_ASSERT(B.getRank() <= getRank()) - << "SparseTensor::factor_apply_fast() (in place): " - << "A rank is: " << getRank() - << " B rank is: " << B.getRank() + NTA_ASSERT(B.getRank() <= getRank()) + << "SparseTensor::factor_apply(): " + << "A rank is: " << getRank() << " B rank is: " << B.getRank() << " - B rank should <= A rank"; - } + } - NonZeros inter; - nz_intersection(dims, B, inter); + Index idx = getNewZeroIndex(); + IndexB idxB = B.getNewIndex(); - // Need to clear everything so that zeros are handled properly, - // but we have all the values of this tensor in the intersection... - clear(); - - // We have to call set instead of setNonZero, because - // even the multiplication of two non-zeros can result - // in a "zero" from the point of view of SparseMatrix - // that is, a value such that |value| <= nupic::Epsilon - typename NonZeros::const_iterator it, e; - for (it = inter.begin(), e = inter.end(); it != e; ++it) - set(it->getIndexA(), f(it->getValA(), it->getValB())); - } - - /** - * In place factor apply on non-zeros (mutating). - * Works for addition and multiplication. - */ - template - inline void factor_apply_nz(const IndexB& dims, - const SparseTensor& B, - binary_functor f) - { - { - NTA_ASSERT(getRank() > 1) - << "SparseTensor::factor_apply_nz(): " - << "A rank is: " << getRank() - << " - Should be > 1"; + // Calling set because f(a, b) can fall below nupic::Epsilon + do { + project(dims, idx, idxB); + C.set(idx, f(get(idx), B.get(idxB))); + } while (increment(bounds_, idx)); + } - NTA_ASSERT(B.getRank() >= 1) - << "SparseTensor::factor_apply_nz(): " - << "B rank is: " << B.getRank() - << " - Should be >= 1"; + /** + * C[j] = f(C[j], A[i]), + * where j = projection of i on L dims. + * + * Works only on the non-zeros, assumes f(0, 0) = 0 ?? + * Use this version AND init = 1 for multiplication. + * + * Complexity: O(number of non-zeros) + * + * Examples: + * If s2 is a 2D sparse tensor with dimensions (4, 5), + * and s1 a 1D sparse tensor (vector), then: + * - accumulate_nz(I1(0), s1, plus(), 0) + * accumulates vertically, and s1 has size 5. + * - accumulate_nz(I1(1), s1, plus(), 0) + * accumulates horizontally, and s1 has size 4. + */ + template + inline void accumulate_nz(const Index2 &dims, SparseTensor &B, + binary_functor f, const Float &init = 0) const { + { + NTA_ASSERT(dims.size() == getRank() - B.getRank()); - NTA_ASSERT(B.getRank() <= getRank()) - << "SparseTensor::factor_apply_nz(): " - << "A rank is: " << getRank() - << " B rank is: " << B.getRank() + NTA_ASSERT(B.getRank() < getRank()) + << "SparseTensor::accumulate_nz(): " + << "A rank is: " << getRank() << " B rank is: " << B.getRank() << " - B rank should <= A rank"; + } - NTA_ASSERT(f(0, 0) == 0) - << "SparseTensor::factor_apply_nz(): " - << "Binary functor should do: f(0, 0) == 0"; - } + B.setAll(init); + + IndexB compDims = B.getNewIndex(), idxB = B.getNewIndex(); + complement(dims, compDims); - // This is unfortunately quite slow, because projection - // is a surjection... - NonZeros u; - nz_union(dims, B, u); - - // Calling set because f(a, b) can fall below nupic::Epsilon - typename NonZeros::const_iterator it, e; - for (it = u.begin(), e = u.end(); it != e; ++it) - set(it->getIndexA(), f(it->getValA(), it->getValB())); - } - - /** - * Binary factor apply (non-mutating) - * - * C[i] = f(A[i], B[j]), - * where j = projection of i on dims, - * and A[i] != 0 AND B[j] != 0. - * - * This works for multiplication, but not for addition, and not if f(0, 0) == 0. - * - * Complexity: O(smaller number of non-zeros between this and B) - */ - template - inline void factor_apply_fast(const IndexB& dims, - const SparseTensor& B, - SparseTensor& C, - binary_functor f, - bool clearYesNo =true) const + const_iterator it, e; + for (it = begin(), e = end(); it != e; ++it) { + project(compDims, it->first, idxB); + B.update(idxB, it->second, f); + } + } + + /** + * B[j] = f(B[j], A[i]), + * where j = projection of i on L dims. + * + * Works on all the values, including the zeros, so it is + * inappropriate for multiplication, since the zeros will + * produce zeros in the output, even if init != 0. + * + * No restriction on f, doesn't take advantage of the sparsity. + * + * Complexity: O(product of bounds) + */ + template + inline void accumulate(const Index2 &dims, SparseTensor &B, + binary_functor f, const Float &init = 0) const { { - { - NTA_ASSERT(getRank() > 1) - << "SparseTensor::factor_apply_fast(): " - << "A rank is: " << getRank() - << " - Should be > 1"; + NTA_ASSERT(dims.size() == getRank() - B.getRank()); + + NTA_ASSERT(B.getRank() < getRank()) + << "SparseTensor::accumulate(): " + << "A rank is: " << getRank() << " B rank is: " << B.getRank() + << " - B rank should < A rank"; + } - NTA_ASSERT(B.getRank() >= 1) - << "SparseTensor::factor_apply_fast(): " - << "B rank is: " << B.getRank() - << " - Should be >= 1"; + B.setAll(init); - NTA_ASSERT(B.getRank() <= getRank()) - << "SparseTensor::factor_apply_fast(): " - << "A rank is: " << getRank() - << " B rank is: " << B.getRank() - << " - B rank should <= A rank"; + Index idx = getNewZeroIndex(); + IndexB compDims = B.getNewIndex(), idxB = B.getNewIndex(); + complement(dims, compDims); - NTA_ASSERT(f(0, 1) == 0 && f(1, 0) == 0 && f(0, 0) == 0) - << "SparseTensor::factor_apply_fast(): " - << "Binary functor should do: f(0, x) == f(x, 0) == 0 for all x"; - } + do { + project(compDims, idx, idxB); + B.update(idxB, get(idx), f); + } while (increment(bounds_, idx)); + } - if (clearYesNo) - C.clear(); - - NonZeros inter; - nz_intersection(dims, B, inter); - - // Calling set because f(a, b) can fall below nupic::Epsilon - typename NonZeros::const_iterator it, e; - for (it = inter.begin(), e = inter.end(); it != e; ++it) - C.set(it->getIndexA(), f(it->getValA(), it->getValB())); - } - - /** - * C[i] = f(A[i], B[j]), - * where j = projection of i on dims, - * and A[i] != 0 OR B[j] != 0. - * - * This works for addition, but not if f(0, 0) != 0. - * - * Complexity: O(sum of number of non-zeros between this and B) - */ - template - inline void factor_apply_nz(const IndexB& dims, - const SparseTensor& B, - SparseTensor& C, - binary_functor f, - bool clearYesNo =true) const + /** + * In place (mutating) normalize. + * + * Examples: + * S2.normalize(I1(UInt(0))): normalize vertically + * S2.normalize(I1(UInt(1))): normalize horizontally + */ + template inline void normalize(const Index2 &dims) { { - { - NTA_ASSERT(getRank() > 1) - << "SparseTensor::factor_apply_nz(): " - << "A rank is: " << getRank() - << " - Should be > 1"; + NTA_ASSERT(dims.size() < getRank()) << "SparseTensor::normalize(Index): " + << " - Wrong ranks"; + } - NTA_ASSERT(B.getRank() >= 1) - << "SparseTensor::factor_apply_nz(): " - << "B rank is: " << B.getRank() - << " - Should be >= 1"; + std::vector compDims(getRank() - dims.size()); + complement(dims, compDims); - NTA_ASSERT(B.getRank() <= getRank()) - << "SparseTensor::factor_apply_nz(): " - << "A rank is: " << getRank() - << " B rank is: " << B.getRank() - << " - B rank should <= A rank"; + std::vector compBounds(getRank() - dims.size()); + project(compDims, getBounds(), compBounds); - NTA_ASSERT(&C != this) - << "SparseTensor::factor_apply_nz(): " - << "Can't store result in A"; + SparseTensor, Float> C(compBounds); - NTA_ASSERT(f(0, 0) == 0) - << "SparseTensor::factor_apply_nz(): " - << "Binary functor should do: f(0, 0) == 0"; - } + accumulate_nz(dims, C, std::plus(), 0); + // factor_apply_fast works only on the non-zeros, so it won't attempt + // to divide by zero! + factor_apply_fast(compDims, C, std::divides()); + } - if (clearYesNo) - C.clear(); - - // This is unfortunately quite slow, because projection - // is a surjection... - NonZeros u; - nz_union(dims, B, u); - - // Calling set because f(a, b) can fall below nupic::Epsilon - typename NonZeros::const_iterator it, e; - for (it = u.begin(), e = u.end(); it != e; ++it) - C.set(it->getIndexA(), f(it->getValA(), it->getValB())); - } - - /** - * C[i] = f(A[i], B[j]), - * where j = projection of i on dims. - * - * There is no restriction on f, it works even if f(0, 0) != 0. - * Doesn't take advantage of the sparsity. - * - * Complexity: O(product of bounds) - */ - template - inline void factor_apply(const IndexB& dims, - const SparseTensor& B, - SparseTensor& C, - binary_functor f) const + /** + * Computes the outer product of this sparse tensor and B, puts the result + * in C: + * + * C[i.j] = f(A[i], B[j]). + * Cijkpq = f(Aijk, Bpq) + * + * Works only the non-zeros, assumes f(0, 0) = f(x, 0) = f(0, x) = 0. + * Works for multiplication, but not for addition. + * + * Complexity: O(square of total number of non-zeros) + */ + template + inline void outer_product_nz(const SparseTensor &B, + SparseTensor &C, + binary_functor f) const { { - { - NTA_ASSERT(getRank() > 1) - << "SparseTensor::factor_apply(): " - << "A rank is: " << getRank() - << " - Should be > 1"; - - NTA_ASSERT(B.getRank() >= 1) - << "SparseTensor::factor_apply(): " - << "B rank is: " << B.getRank() - << " - Should be >= 1"; + NTA_ASSERT(C.getRank() == B.getRank() + getRank()); - NTA_ASSERT(B.getRank() <= getRank()) - << "SparseTensor::factor_apply(): " - << "A rank is: " << getRank() - << " B rank is: " << B.getRank() - << " - B rank should <= A rank"; - } + NTA_ASSERT(f(0, 0) == 0) << "SparseTensor::outer_product_nz(): " + << "Binary functor should do: f(0, 0) = 0"; + } - Index idx = getNewZeroIndex(); - IndexB idxB = B.getNewIndex(); + C.clear(); - // Calling set because f(a, b) can fall below nupic::Epsilon - do { - project(dims, idx, idxB); - C.set(idx, f(get(idx), B.get(idxB))); - } while (increment(bounds_, idx)); - } - - /** - * C[j] = f(C[j], A[i]), - * where j = projection of i on L dims. - * - * Works only on the non-zeros, assumes f(0, 0) = 0 ?? - * Use this version AND init = 1 for multiplication. - * - * Complexity: O(number of non-zeros) - * - * Examples: - * If s2 is a 2D sparse tensor with dimensions (4, 5), - * and s1 a 1D sparse tensor (vector), then: - * - accumulate_nz(I1(0), s1, plus(), 0) - * accumulates vertically, and s1 has size 5. - * - accumulate_nz(I1(1), s1, plus(), 0) - * accumulates horizontally, and s1 has size 4. - */ - template - inline void accumulate_nz(const Index2& dims, - SparseTensor& B, - binary_functor f, const Float& init =0) const - { - { - NTA_ASSERT(dims.size() == getRank() - B.getRank()); + const_iterator it1, end1; + typename SparseTensor::const_iterator it2, end2; - NTA_ASSERT(B.getRank() < getRank()) - << "SparseTensor::accumulate_nz(): " - << "A rank is: " << getRank() - << " B rank is: " << B.getRank() - << " - B rank should <= A rank"; - } + end1 = end(); + end2 = B.end(); - B.setAll(init); + for (it1 = begin(); it1 != end1; ++it1) + for (it2 = B.begin(); it2 != end2; ++it2) + C.set(concatenate(it1->first, it2->first), f(it1->second, it2->second)); + } - IndexB compDims = B.getNewIndex(), idxB = B.getNewIndex(); - complement(dims, compDims); - - const_iterator it, e; - for (it = begin(), e = end(); it != e; ++it) { - project(compDims, it->first, idxB); - B.update(idxB, it->second, f); - } - } + /** + * Computes the outer product of this sparse tensor and B, puts the result + * in C: + * + * C[i.j] = f(A[i], B[j]). + * + * Doesn't assume anything on f, works in all cases, but remarkably slow. + * + * Complexity: O(square of product of bounds) + */ + template + inline void outer_product(const SparseTensor &B, + SparseTensor &C, + binary_functor f) const { + { NTA_ASSERT(getRank() + B.getRank() == C.getRank()); } - /** - * B[j] = f(B[j], A[i]), - * where j = projection of i on L dims. - * - * Works on all the values, including the zeros, so it is - * inappropriate for multiplication, since the zeros will - * produce zeros in the output, even if init != 0. - * - * No restriction on f, doesn't take advantage of the sparsity. - * - * Complexity: O(product of bounds) - */ - template - inline void accumulate(const Index2& dims, - SparseTensor& B, - binary_functor f, const Float& init =0) const - { - { - NTA_ASSERT(dims.size() == getRank() - B.getRank()); + C.clear(); - NTA_ASSERT(B.getRank() < getRank()) - << "SparseTensor::accumulate(): " - << "A rank is: " << getRank() - << " B rank is: " << B.getRank() - << " - B rank should < A rank"; - } + Index idxA = getNewZeroIndex(), ubA = getBounds(); + IndexB idxB = B.getNewZeroIndex(), ubB = B.getBounds(); - B.setAll(init); - - Index idx = getNewZeroIndex(); - IndexB compDims = B.getNewIndex(), idxB = B.getNewIndex(); - complement(dims, compDims); - + do { + setToZero(idxB); do { - project(compDims, idx, idxB); - B.update(idxB, get(idx), f); - } while (increment(bounds_, idx)); - } - - /** - * In place (mutating) normalize. - * - * Examples: - * S2.normalize(I1(UInt(0))): normalize vertically - * S2.normalize(I1(UInt(1))): normalize horizontally - */ - template - inline void normalize(const Index2& dims) - { - { - NTA_ASSERT(dims.size() < getRank()) - << "SparseTensor::normalize(Index): " - << " - Wrong ranks"; - } + C.set(concatenate(idxA, idxB), f(get(idxA), B.get(idxB))); + } while (increment(ubB, idxB)); + } while (increment(ubA, idxA)); + } - std::vector compDims(getRank() - dims.size()); - complement(dims, compDims); - - std::vector compBounds(getRank() - dims.size()); - project(compDims, getBounds(), compBounds); - - SparseTensor, Float> C(compBounds); - - accumulate_nz(dims, C, std::plus(), 0); - // factor_apply_fast works only on the non-zeros, so it won't attempt - // to divide by zero! - factor_apply_fast(compDims, C, std::divides()); - } - - /** - * Computes the outer product of this sparse tensor and B, puts the result - * in C: - * - * C[i.j] = f(A[i], B[j]). - * Cijkpq = f(Aijk, Bpq) - * - * Works only the non-zeros, assumes f(0, 0) = f(x, 0) = f(0, x) = 0. - * Works for multiplication, but not for addition. - * - * Complexity: O(square of total number of non-zeros) - */ - template - inline void outer_product_nz(const SparseTensor& B, - SparseTensor& C, - binary_functor f) const - { - { - NTA_ASSERT(C.getRank() == B.getRank() + getRank()); - - NTA_ASSERT(f(0, 0) == 0) - << "SparseTensor::outer_product_nz(): " - << "Binary functor should do: f(0, 0) = 0"; - } - - C.clear(); - - const_iterator it1, end1; - typename SparseTensor::const_iterator it2, end2; - - end1 = end(); end2 = B.end(); - - for (it1 = begin(); it1 != end1; ++it1) - for (it2 = B.begin(); it2 != end2; ++it2) - C.set(concatenate(it1->first, it2->first), f(it1->second, it2->second)); - } - - /** - * Computes the outer product of this sparse tensor and B, puts the result - * in C: - * - * C[i.j] = f(A[i], B[j]). - * - * Doesn't assume anything on f, works in all cases, but remarkably slow. - * - * Complexity: O(square of product of bounds) - */ - template - inline void outer_product(const SparseTensor& B, - SparseTensor& C, - binary_functor f) const - { - { - NTA_ASSERT(getRank() + B.getRank() == C.getRank()); - } - - C.clear(); - - Index idxA = getNewZeroIndex(), ubA = getBounds(); - IndexB idxB = B.getNewZeroIndex(), ubB = B.getBounds(); - - do { - setToZero(idxB); - do { - C.set(concatenate(idxA, idxB), f(get(idxA), B.get(idxB))); - } while (increment(ubB, idxB)); - } while (increment(ubA, idxA)); - } - - /** - * Computes the contraction of this sparse tensor along the two - * given dimensions: - * - * B[ikl...] = accumulate using f(j, A[ijkl...j...]), - * where j shows at positions dim1 and dim2 of A. - * Cikq = f(Aiuk, Buq) - * - * Works only on the non-zeros, assumes f(0, 0) = 0 ?? - * - * Complexity: O(number of non-zeros) - */ - template - inline void contract_nz(const UInt dim1, const UInt dim2, - SparseTensor& B, - binary_functor f, const Float& init =0) const - { - { // Pre-conditions - NTA_ASSERT(B.getRank() == getRank() - 2) + /** + * Computes the contraction of this sparse tensor along the two + * given dimensions: + * + * B[ikl...] = accumulate using f(j, A[ijkl...j...]), + * where j shows at positions dim1 and dim2 of A. + * Cikq = f(Aiuk, Buq) + * + * Works only on the non-zeros, assumes f(0, 0) = 0 ?? + * + * Complexity: O(number of non-zeros) + */ + template + inline void contract_nz(const UInt dim1, const UInt dim2, + SparseTensor &B, binary_functor f, + const Float &init = 0) const { + { // Pre-conditions + NTA_ASSERT(B.getRank() == getRank() - 2) << "SparseTensor::contract_nz(): " << "Tensor B has rank: " << B.getRank() << " - B needs to have rank: " << getRank() - 2; - - NTA_ASSERT(getRank() > 2) + + NTA_ASSERT(getRank() > 2) << "SparseTensor::contract_nz(): " << "Trying to contract tensor of rank: " << getRank() << " - Can contract only tensors or rank > 2"; - - NTA_ASSERT(dim1 < getRank() && dim2 < getRank() && dim1 != dim2) + + NTA_ASSERT(dim1 < getRank() && dim2 < getRank() && dim1 != dim2) << "SparseTensor::contract_nz(): " << "Trying to contract along dimensions: " << dim1 << " and " << dim2 - << " - Dimensions must be different and less than tensor rank= " + << " - Dimensions must be different and less than tensor rank= " << getRank(); - - NTA_ASSERT(bounds_[dim1] == bounds_[dim2]) + + NTA_ASSERT(bounds_[dim1] == bounds_[dim2]) << "SparseTensor::contract_nz(): " - << "Using dim: " << dim1 - << " and dim: " << dim2 - << " but they have different sizes: " << bounds_[dim1] - << " and " << bounds_[dim2] + << "Using dim: " << dim1 << " and dim: " << dim2 + << " but they have different sizes: " << bounds_[dim1] << " and " + << bounds_[dim2] << " - Can contract only dimensions that have the same size"; - - NTA_ASSERT(f(0, 1) == 0 && f(1, 0) == 0 && f(0, 0) == 0) + + NTA_ASSERT(f(0, 1) == 0 && f(1, 0) == 0 && f(0, 0) == 0) << "SparseTensor::contract_nz(): " << "Binary functor should do: f(0, x) == f(x, 0) == 0 for all x"; + } + + IndexB compDims = B.getNewIndex(), idxB = B.getNewIndex(); + std::vector dims(2); + dims[0] = dim1; + dims[1] = dim2; + complement(dims, compDims); + + B.clear(); + + // Can't use setAll, because of if + const_iterator it, e; + for (it = begin(), e = end(); it != e; ++it) { + if (it->first[dim1] == it->first[dim2]) { + project(compDims, it->first, idxB); + B.set(idxB, init); } - - IndexB compDims = B.getNewIndex(), idxB = B.getNewIndex(); - std::vector dims(2); dims[0] = dim1; dims[1] = dim2; - complement(dims, compDims); - - B.clear(); - - // Can't use setAll, because of if - const_iterator it, e; - for (it = begin(), e = end(); it != e; ++it) { - if (it->first[dim1] == it->first[dim2]) { - project(compDims, it->first, idxB); - B.set(idxB, init); - } - } - - for (it = begin(), e = end(); it != e; ++it) { - if (it->first[dim1] == it->first[dim2]) { - project(compDims, it->first, idxB); - B.update(idxB, it->second, f); - } + } + + for (it = begin(), e = end(); it != e; ++it) { + if (it->first[dim1] == it->first[dim2]) { + project(compDims, it->first, idxB); + B.update(idxB, it->second, f); } } - - /** - * Computes the contraction of this sparse tensor along the two - * given dimensions: - * - * B[ikl...] = accumulate using f(j, A[ijkl...j...]), - * where j shows at positions dim1 and dim2 of A. - * - * No assumption on f. - * - * Complexity: O(product of bounds) - */ - template - inline void contract(const UInt dim1, const UInt dim2, - SparseTensor& B, - binary_functor f, const Float& init =0) const + } + + /** + * Computes the contraction of this sparse tensor along the two + * given dimensions: + * + * B[ikl...] = accumulate using f(j, A[ijkl...j...]), + * where j shows at positions dim1 and dim2 of A. + * + * No assumption on f. + * + * Complexity: O(product of bounds) + */ + template + inline void contract(const UInt dim1, const UInt dim2, + SparseTensor &B, binary_functor f, + const Float &init = 0) const { { - { - NTA_ASSERT(B.getRank() == getRank() - 2); + NTA_ASSERT(B.getRank() == getRank() - 2); - NTA_ASSERT(getRank() > 2) + NTA_ASSERT(getRank() > 2) << "SparseTensor::contract(): " << "Trying to contract tensor of rank: " << getRank() << " - Can contract only tensors or rank > 2"; - NTA_ASSERT(dim1 < getRank() && dim2 < getRank() && dim1 != dim2) + NTA_ASSERT(dim1 < getRank() && dim2 < getRank() && dim1 != dim2) << "SparseTensor::contract_nz(): " << "Trying to contract along dimensions: " << dim1 << " and " << dim2 - << " - Dimensions must be different and less than tensor rank: " + << " - Dimensions must be different and less than tensor rank: " << getRank(); - NTA_ASSERT(bounds_[dim1] == bounds_[dim2]) + NTA_ASSERT(bounds_[dim1] == bounds_[dim2]) << "SparseTensor::contract(): " - << "Using dim: " << dim1 - << " and dim: " << dim2 - << " but they have different size: " << bounds_[dim1] - << " and " << bounds_[dim2] + << "Using dim: " << dim1 << " and dim: " << dim2 + << " but they have different size: " << bounds_[dim1] << " and " + << bounds_[dim2] << " - Can contract only dimensions that have the same size"; + } + + Index idx = getNewZeroIndex(); + IndexB compDims = B.getNewIndex(), it2 = B.getNewIndex(); + std::vector dims(2); + dims[0] = dim1; + dims[1] = dim2; + complement(dims, compDims); + + B.setAll(init); + + do { + if (idx[dim1] == idx[dim2]) { + project(compDims, idx, it2); + B.update(it2, get(idx), f); } - - Index idx = getNewZeroIndex(); - IndexB compDims = B.getNewIndex(), it2 = B.getNewIndex(); - std::vector dims(2); dims[0] = dim1; dims[1] = dim2; - complement(dims, compDims); - - B.setAll(init); + } while (increment(bounds_, idx)); + } - do { - if (idx[dim1] == idx[dim2]) { - project(compDims, idx, it2); - B.update(it2, get(idx), f); - } - } while (increment(bounds_, idx)); - } - - /** - * Computes the inner product of this sparse tensor and B, put the result - * in C: - * - * C[k] = accumulate using g(product using f of B[i], C[j]) - * - * Works only on the non-zeros. - * - * Complexity: O(square of number of non-zeros in one dim) - */ - template - inline void inner_product_nz(const UInt dim1, const UInt dim2, - const SparseTensor& B, - SparseTensor& C, - binary_functor1 f, binary_functor2 g, - const Float& init =0) const + /** + * Computes the inner product of this sparse tensor and B, put the result + * in C: + * + * C[k] = accumulate using g(product using f of B[i], C[j]) + * + * Works only on the non-zeros. + * + * Complexity: O(square of number of non-zeros in one dim) + */ + template + inline void inner_product_nz(const UInt dim1, const UInt dim2, + const SparseTensor &B, + SparseTensor &C, + binary_functor1 f, binary_functor2 g, + const Float &init = 0) const { { - { - NTA_ASSERT(B.getRank() + getRank() - 2 == C.getRank()); + NTA_ASSERT(B.getRank() + getRank() - 2 == C.getRank()); - NTA_ASSERT(getRank() + B.getRank() > 2) + NTA_ASSERT(getRank() + B.getRank() > 2) << "SparseTensor::inner_product_nz(): " - << "Trying to take inner product of two tensors of rank : " + << "Trying to take inner product of two tensors of rank : " << getRank() << " and: " << B.getRank() << " - But need sum of ranks > 2"; - NTA_ASSERT(dim1 < getRank()) + NTA_ASSERT(dim1 < getRank()) << "SparseTensor::inner_product_nz(): " - << " - Dimension 1 must be less than tensor A's rank: " - << getRank(); + << " - Dimension 1 must be less than tensor A's rank: " << getRank(); - NTA_ASSERT(dim2 < B.getRank()) + NTA_ASSERT(dim2 < B.getRank()) << "SparseTensor::inner_product_nz(): " - << " - Dimension 2 must be less than tensor B's rank: " + << " - Dimension 2 must be less than tensor B's rank: " << B.getRank(); - NTA_ASSERT(bounds_[dim1] == B.getBounds()[dim2]) + NTA_ASSERT(bounds_[dim1] == B.getBounds()[dim2]) << "SparseTensor::inner_product_nz(): " - << "Using dim: " << dim1 - << " and dim: " << dim2 - << " but they have different size: " << bounds_[dim1] - << " and " << B.getBounds()[dim2] - << " - Can take inner product only along dimensions that have the same size"; - } + << "Using dim: " << dim1 << " and dim: " << dim2 + << " but they have different size: " << bounds_[dim1] << " and " + << B.getBounds()[dim2] + << " - Can take inner product only along dimensions that have the " + "same size"; + } - std::vector - pit1(getRank()-1, 0), pit2(B.getRank()-1, 0), - d1(1, dim1), d2(1, dim2), - compDims1(getRank()-1), compDims2(B.getRank()-1); - - complement(d1, compDims1); - complement(d2, compDims2); + std::vector pit1(getRank() - 1, 0), pit2(B.getRank() - 1, 0), + d1(1, dim1), d2(1, dim2), compDims1(getRank() - 1), + compDims2(B.getRank() - 1); - C.clear(); - - const_iterator it1, e1; - typename SparseTensor::const_iterator it2, e2; - - for (it1 = begin(), e1 = end(); it1 != e1; ++it1) - for (it2 = B.begin(), e2 = B.end(); it2 != e2; ++it2) { - if (it1->first[dim1] == it2->first[dim2]) { - C.set(concatenate(pit1, pit2), init); - } + complement(d1, compDims1); + complement(d2, compDims2); + + C.clear(); + + const_iterator it1, e1; + typename SparseTensor::const_iterator it2, e2; + + for (it1 = begin(), e1 = end(); it1 != e1; ++it1) + for (it2 = B.begin(), e2 = B.end(); it2 != e2; ++it2) { + if (it1->first[dim1] == it2->first[dim2]) { + C.set(concatenate(pit1, pit2), init); } + } - for (it1 = begin(), e1 = end(); it1 != e1; ++it1) - for (it2 = B.begin(), e2 = B.end(); it2 != e2; ++it2) { - if (it1->first[dim1] == it2->first[dim2]) { - project(compDims1, it1->first, pit1); - project(compDims2, it2->first, pit2); - C.update(concatenate(pit1, pit2), f(it1->second, it2->second), g); - } + for (it1 = begin(), e1 = end(); it1 != e1; ++it1) + for (it2 = B.begin(), e2 = B.end(); it2 != e2; ++it2) { + if (it1->first[dim1] == it2->first[dim2]) { + project(compDims1, it1->first, pit1); + project(compDims2, it2->first, pit2); + C.update(concatenate(pit1, pit2), f(it1->second, it2->second), g); } - } + } + } - /** - * Computes the inner product of this sparse tensor and B, put the result - * in C: - * - * C[k] = accumulate using g(product using f of B[i], C[j]) - * Aijk, Bpq, i, p - * Tijkpq = f(Aijk, Bpq) - * Cikq = g(Tiukuq) - * - * Complexity: O( ?? ) - */ - template - inline void inner_product(const UInt dim1, const UInt dim2, - const SparseTensor& B, - SparseTensor& C, - binary_functor1 f, binary_functor2 g, - const Float& init =0) const + /** + * Computes the inner product of this sparse tensor and B, put the result + * in C: + * + * C[k] = accumulate using g(product using f of B[i], C[j]) + * Aijk, Bpq, i, p + * Tijkpq = f(Aijk, Bpq) + * Cikq = g(Tiukuq) + * + * Complexity: O( ?? ) + */ + template + inline void inner_product(const UInt dim1, const UInt dim2, + const SparseTensor &B, + SparseTensor &C, binary_functor1 f, + binary_functor2 g, const Float &init = 0) const { { - { - NTA_ASSERT(getRank() + B.getRank() - 2 == C.getRank()); + NTA_ASSERT(getRank() + B.getRank() - 2 == C.getRank()); - NTA_ASSERT(getRank() + B.getRank() > 2) + NTA_ASSERT(getRank() + B.getRank() > 2) << "SparseTensor::inner_product(): " - << "Trying to take inner product of two tensors of rank : " + << "Trying to take inner product of two tensors of rank : " << getRank() << " and: " << B.getRank() << " - But need sum of ranks > 2"; - NTA_ASSERT(dim1 < getRank()) + NTA_ASSERT(dim1 < getRank()) << "SparseTensor::inner_product(): " - << " - Dimension 1 must be less than tensor A's rank: " - << getRank(); + << " - Dimension 1 must be less than tensor A's rank: " << getRank(); - NTA_ASSERT(dim2 < B.getRank()) + NTA_ASSERT(dim2 < B.getRank()) << "SparseTensor::inner_product(): " - << " - Dimension 2 must be less than tensor B's rank: " + << " - Dimension 2 must be less than tensor B's rank: " << B.getRank(); - NTA_ASSERT(bounds_[dim1] == B.getBounds()[dim2]) + NTA_ASSERT(bounds_[dim1] == B.getBounds()[dim2]) << "SparseTensor::inner_product(): " - << "Using dim: " << dim1 - << " and dim: " << dim2 - << " but they have different size: " << bounds_[dim1] - << " and " << B.getBounds()[dim2] - << " - Can take inner product only along dimensions that have the same size"; - } + << "Using dim: " << dim1 << " and dim: " << dim2 + << " but they have different size: " << bounds_[dim1] << " and " + << B.getBounds()[dim2] + << " - Can take inner product only along dimensions that have the " + "same size"; + } + + Index idx1 = getNewZeroIndex(); + IndexB idx2 = B.getNewZeroIndex(); - Index idx1 = getNewZeroIndex(); - IndexB idx2 = B.getNewZeroIndex(); + std::vector pit1(getRank() - 1, 0), pit2(B.getRank() - 1, 0), + d1(1, dim1), d2(1, dim2), compDims1(getRank() - 1), + compDims2(B.getRank() - 1); - std::vector - pit1(getRank()-1, 0), pit2(B.getRank()-1, 0), - d1(1, dim1), d2(1, dim2), - compDims1(getRank()-1), compDims2(B.getRank()-1); - - complement(d1, compDims1); - complement(d2, compDims2); + complement(d1, compDims1); + complement(d2, compDims2); - C.setAll(init); + C.setAll(init); + do { + setToZero(idx2); do { - setToZero(idx2); - do { - if (idx1[dim1] == idx2[dim2]) { - project(compDims1, idx1, pit1); - project(compDims2, idx2, pit2); - C.update((concatenate(pit1, pit2)), f(get(idx1), B.get(idx2)), g); - } - } while (increment(B.getBounds(), idx2)); - } while (increment(bounds_, idx1)); - } - - /** - * Another type of product. - */ - template - inline void product3(const Index1A& dimsA, const Index1A& dimsB, - const SparseTensor& B, - SparseTensor& C, - binary_functor1 f) const + if (idx1[dim1] == idx2[dim2]) { + project(compDims1, idx1, pit1); + project(compDims2, idx2, pit2); + C.update((concatenate(pit1, pit2)), f(get(idx1), B.get(idx2)), g); + } + } while (increment(B.getBounds(), idx2)); + } while (increment(bounds_, idx1)); + } + + /** + * Another type of product. + */ + template + inline void product3(const Index1A &dimsA, const Index1A &dimsB, + const SparseTensor &B, + SparseTensor &C, + binary_functor1 f) const { { - { - NTA_ASSERT(dimsA.size() == dimsB.size()); - NTA_ASSERT(dimsA.size() <= getRank()); - NTA_ASSERT(dimsB.size() <= B.getRank()); - } + NTA_ASSERT(dimsA.size() == dimsB.size()); + NTA_ASSERT(dimsA.size() <= getRank()); + NTA_ASSERT(dimsB.size() <= B.getRank()); + } - std::vector - idx1a(dimsA.size()), - bounds1A(dimsA.size()), - lbA(getRank(), 0), - ubA(getBounds().begin(), getBounds().end()), - lbB(B.getRank(), 0), - ubB(B.getBounds().begin(), B.getBounds().end()), + std::vector idx1a(dimsA.size()), bounds1A(dimsA.size()), + lbA(getRank(), 0), ubA(getBounds().begin(), getBounds().end()), + lbB(B.getRank(), 0), ubB(B.getBounds().begin(), B.getBounds().end()), dimsSliceA(getRank() - dimsA.size()), dimsSliceB(B.getRank() - dimsB.size()), boundsSliceA(getRank() - dimsA.size()), boundsSliceB(B.getRank() - dimsB.size()), boundsRes(boundsSliceA.size() + boundsSliceB.size()); - complement(dimsA, dimsSliceA); - complement(dimsB, dimsSliceB); - project(dimsA, getBounds(), bounds1A); - project(dimsSliceA, getBounds(), boundsSliceA); - project(dimsSliceB, B.getBounds(), boundsSliceB); - boundsRes = concatenate(boundsSliceA, boundsSliceB); - - SparseTensor, Float> - sliceA(boundsSliceA), - sliceB(boundsSliceB), - res(boundsRes); + complement(dimsA, dimsSliceA); + complement(dimsB, dimsSliceB); + project(dimsA, getBounds(), bounds1A); + project(dimsSliceA, getBounds(), boundsSliceA); + project(dimsSliceB, B.getBounds(), boundsSliceB); + boundsRes = concatenate(boundsSliceA, boundsSliceB); + + SparseTensor, Float> sliceA(boundsSliceA), + sliceB(boundsSliceB), res(boundsRes); + + setToZero(idx1a); + + do { + + for (UInt k = 0; k < dimsA.size(); ++k) { // embed + lbA[dimsA[k]] = ubA[dimsA[k]] = idx1a[k]; + lbB[dimsB[k]] = ubB[dimsB[k]] = idx1a[k]; + } + + slice(Domain(lbA, ubA), sliceA); + B.slice(Domain(lbB, ubB), sliceB); + outer_product_nz(sliceA, sliceB, res, f); - setToZero(idx1a); + std::vector idxRes(dimsSliceA.size() + dimsSliceB.size(), 0); do { - - for (UInt k = 0; k < dimsA.size(); ++k) { // embed - lbA[dimsA[k]] = ubA[dimsA[k]] = idx1a[k]; - lbB[dimsB[k]] = ubB[dimsB[k]] = idx1a[k]; - } + IndexC idxC = C.getNewZeroIndex(); - slice(Domain(lbA, ubA), sliceA); - B.slice(Domain(lbB, ubB), sliceB); - outer_product_nz(sliceA, sliceB, res, f); - - std::vector - idxRes(dimsSliceA.size() + dimsSliceB.size(), 0); - - do { - IndexC idxC = C.getNewZeroIndex(); - - embed(dimsSliceA, idxRes, idxC); - embed(dimsSliceB, idxRes, idxC); - embed(dimsA, idx1a, idxC); - C.set(idxC, res.get(idxC)); - - } while (increment(boundsRes, idxRes)); - - } while (increment(bounds1A, idx1a)); - } - - /** - * Streaming operator. - * See print(). - */ - template - NTA_HIDDEN friend std::ostream& operator<<(std::ostream&, const SparseTensor&); - - /** - * Whether two sparse tensors are equal or not. - * To be equal, they need to have the same number of dimensions, - * the same size along each dimensions, and the same non-zeros. - * Equality of floating point numbers is controlled by nupic::Epsilon. - */ - template - NTA_HIDDEN friend bool operator==(const SparseTensor&, const SparseTensor&); - - /** - * Whether two sparse tensors are different or not. - * See operator==. - */ - template - NTA_HIDDEN friend bool operator!=(const SparseTensor&, const SparseTensor&); - - /** - * Prints out this tensor to a stream. - * There are special formats for dim 1, 2 and 3, and beyond that - * only the non-zeros are printed out, with their indices. - */ - inline void print(std::ostream& outStream) const - { - if (getRank() == 1) { - for (UInt i = 0; i < bounds_[0]; ++i) { - const_iterator it = nz_.find(getNewIndex(i)); + embed(dimsSliceA, idxRes, idxC); + embed(dimsSliceB, idxRes, idxC); + embed(dimsA, idx1a, idxC); + C.set(idxC, res.get(idxC)); + + } while (increment(boundsRes, idxRes)); + + } while (increment(bounds1A, idx1a)); + } + + /** + * Streaming operator. + * See print(). + */ + template + NTA_HIDDEN friend std::ostream &operator<<(std::ostream &, + const SparseTensor &); + + /** + * Whether two sparse tensors are equal or not. + * To be equal, they need to have the same number of dimensions, + * the same size along each dimensions, and the same non-zeros. + * Equality of floating point numbers is controlled by nupic::Epsilon. + */ + template + NTA_HIDDEN friend bool operator==(const SparseTensor &, + const SparseTensor &); + + /** + * Whether two sparse tensors are different or not. + * See operator==. + */ + template + NTA_HIDDEN friend bool operator!=(const SparseTensor &, + const SparseTensor &); + + /** + * Prints out this tensor to a stream. + * There are special formats for dim 1, 2 and 3, and beyond that + * only the non-zeros are printed out, with their indices. + */ + inline void print(std::ostream &outStream) const { + if (getRank() == 1) { + for (UInt i = 0; i < bounds_[0]; ++i) { + const_iterator it = nz_.find(getNewIndex(i)); + outStream << (it != end() ? it->second : 0) << " "; + } + outStream << std::endl; + return; + } + + if (getRank() == 2) { + for (UInt i = 0; i < bounds_[0]; ++i) { + for (UInt j = 0; j < bounds_[1]; ++j) { + const_iterator it = nz_.find(getNewIndex(i, j)); outStream << (it != end() ? it->second : 0) << " "; } outStream << std::endl; - return; } + return; + } - if (getRank() == 2) { - for (UInt i = 0; i < bounds_[0]; ++i) { - for (UInt j = 0; j < bounds_[1]; ++j) { - const_iterator it = nz_.find(getNewIndex(i, j)); + if (getRank() == 3) { + for (UInt i = 0; i < bounds_[0]; ++i) { + for (UInt j = 0; j < bounds_[1]; ++j) { + for (UInt k = 0; k < bounds_[2]; ++k) { + const_iterator it = nz_.find(getNewIndex(i, j, k)); outStream << (it != end() ? it->second : 0) << " "; } outStream << std::endl; } - return; - } - - if (getRank() == 3) { - for (UInt i = 0; i < bounds_[0]; ++i) { - for (UInt j = 0; j < bounds_[1]; ++j) { - for (UInt k = 0; k < bounds_[2]; ++k) { - const_iterator it = nz_.find(getNewIndex(i, j, k)); - outStream << (it != end() ? it->second : 0) << " "; - } - outStream << std::endl; - } - outStream << std::endl; - } - return; + outStream << std::endl; } - - for (const_iterator it = begin(); it != end(); ++it) - outStream << it->first << ": " << it->second << std::endl; - } - - //-------------------------------------------------------------------------------- - /** - * Find the max of some sub-space of this sparse tensor. - * - * Complexity: O(number of non-zeros) - * - * Examples: - * If s2 is a 2D sparse tensor of size (4, 5), and s1 a 1D, then: - * - s2.max(I1(0), s1) finds the max of each column of s2 and puts - * it in the corresponding element of s1. s1 has size 5. - * - s2.max(I1(1), s1) finds the max of each row of s2 and puts it - * in the correspondin element of s1. s1 has size 4. - */ - template - inline void max(const Index2& dims, SparseTensor& B) const - { - accumulate_nz(dims, B, nupic::Max(), 0); + return; } - /** - * Finds the max of this sparse tensor, and the index - * of this min. - * This funcion needed because SparseTensor doesn't - * specialize to a scalar properly. - */ - inline const std::pair max() const - { - if (isZero()) - return std::make_pair(getNewZeroIndex(), Float(0)); - - const_iterator min_it = - std::max_element(begin(), end(), - predicate_compose, - nupic::select2nd > >()); - - return std::make_pair(min_it->first, min_it->second); - } - - /** - * Returns the sum of all the non-zeros in this sparse tensor. - * This funcion needed because SparseTensor doesn't - * specialize to a scalar properly. - */ - inline const Float sum() const - { - Float sum = 0; - const_iterator i, e; - for (i = begin(), e = end(); i != e; ++i) - sum += i->second; - return sum; - } - - /** - * Wrapper for accumulate with plus. - */ - template - inline void sum(const Index2& dims, SparseTensor& B) const - { - accumulate_nz(dims, B, std::plus()); - } + for (const_iterator it = begin(); it != end(); ++it) + outStream << it->first << ": " << it->second << std::endl; + } - /** - * Adds a slice to another. - */ - inline void addSlice(UInt which, UInt src, UInt dst) - { - TensorIndex lb(getBounds()), ub(getBounds()); - lb[which] = ub[which] = src; - TensorIndex srcIndex = getNewZeroIndex(); - srcIndex[which] = src; + //-------------------------------------------------------------------------------- + /** + * Find the max of some sub-space of this sparse tensor. + * + * Complexity: O(number of non-zeros) + * + * Examples: + * If s2 is a 2D sparse tensor of size (4, 5), and s1 a 1D, then: + * - s2.max(I1(0), s1) finds the max of each column of s2 and puts + * it in the corresponding element of s1. s1 has size 5. + * - s2.max(I1(1), s1) finds the max of each row of s2 and puts it + * in the correspondin element of s1. s1 has size 4. + */ + template + inline void max(const Index2 &dims, SparseTensor &B) const { + accumulate_nz(dims, B, nupic::Max(), 0); + } - do { - TensorIndex dstIndex(srcIndex); - dstIndex[which] = dst; - set(dstIndex, get(dstIndex) + get(srcIndex)); - } while (increment(lb, ub, srcIndex)); - } - - /** - * Adds two sparse tensors of the same rank and dimensions. - * This is an element-wise addition. - */ - inline void axby(const Float& a, const SparseTensor& B, const Float& b, - SparseTensor& C) const - { - C.clear(); + /** + * Finds the max of this sparse tensor, and the index + * of this min. + * This funcion needed because SparseTensor doesn't + * specialize to a scalar properly. + */ + inline const std::pair max() const { + if (isZero()) + return std::make_pair(getNewZeroIndex(), Float(0)); - const_iterator - it1 = begin(), it2 = B.begin(), - end1 = end(), end2 = B.end(); - - while (it1 != end1 && it2 != end2) - if (it1->first == it2->first) { - C.set(it1->first, a*it1->second + b*it2->second); - ++it1; ++it2; - } else if (it2->first < it1->first) { - C.set(it2->first, b*it2->second); - ++it2; - } else { - C.set(it1->first, a*it1->second); - ++it1; - } - - for (; it1 != end1; ++it1) - C.set(it1->first, a*it1->second); - - for (; it2 != end2; ++it2) - C.set(it2->first, b*it2->second); - } + const_iterator min_it = std::max_element( + begin(), end(), + predicate_compose, + nupic::select2nd>>()); - inline void add(const SparseTensor& B) - { - if (B.isZero()) - return; + return std::make_pair(min_it->first, min_it->second); + } - element_apply_nz(B, *this, std::plus(), false); - } + /** + * Returns the sum of all the non-zeros in this sparse tensor. + * This funcion needed because SparseTensor doesn't + * specialize to a scalar properly. + */ + inline const Float sum() const { + Float sum = 0; + const_iterator i, e; + for (i = begin(), e = end(); i != e; ++i) + sum += i->second; + return sum; + } - /** - * Scales this sparse tensor by an arbitrary scalar a. - */ - inline void multiply(const Float& a) - { - if (a == 1.0) - return; + /** + * Wrapper for accumulate with plus. + */ + template + inline void sum(const Index2 &dims, SparseTensor &B) const { + accumulate_nz(dims, B, std::plus()); + } - element_apply_fast(std::bind2nd(std::multiplies(), a)); - } + /** + * Adds a slice to another. + */ + inline void addSlice(UInt which, UInt src, UInt dst) { + TensorIndex lb(getBounds()), ub(getBounds()); + lb[which] = ub[which] = src; + TensorIndex srcIndex = getNewZeroIndex(); + srcIndex[which] = src; + + do { + TensorIndex dstIndex(srcIndex); + dstIndex[which] = dst; + set(dstIndex, get(dstIndex) + get(srcIndex)); + } while (increment(lb, ub, srcIndex)); + } - /** - * Scales this sparse tensor and put the result in B, leaving - * this spare tensor unchanged. - */ - inline void multiply(const Float& a, SparseTensor& B) const - { - if (a == 1.0) { - B = *this; - return; + /** + * Adds two sparse tensors of the same rank and dimensions. + * This is an element-wise addition. + */ + inline void axby(const Float &a, const SparseTensor &B, const Float &b, + SparseTensor &C) const { + C.clear(); + + const_iterator it1 = begin(), it2 = B.begin(), end1 = end(), end2 = B.end(); + + while (it1 != end1 && it2 != end2) + if (it1->first == it2->first) { + C.set(it1->first, a * it1->second + b * it2->second); + ++it1; + ++it2; + } else if (it2->first < it1->first) { + C.set(it2->first, b * it2->second); + ++it2; + } else { + C.set(it1->first, a * it1->second); + ++it1; } - - B.clear(); - const_iterator i = begin(), e = end(); - while (i != e) { - B.set(i->first, a * i->second); - ++i; - } - } + for (; it1 != end1; ++it1) + C.set(it1->first, a * it1->second); - template - inline void factor_multiply(const IndexB& dims, - const SparseTensor& B, - SparseTensor& C) const - { - factor_apply_fast(dims, B, C, std::multiplies(), true); - } + for (; it2 != end2; ++it2) + C.set(it2->first, b * it2->second); + } - template - inline void outer_multiply(const SparseTensor B, - SparseTensor& C) const - { - outer_product_nz(B, C, std::multiplies()); - } + inline void add(const SparseTensor &B) { + if (B.isZero()) + return; - template - inline void marginalize(const Index2& dims, SparseTensor& B) const - { - accumulate_nz(dims, B, std::plus(), Float(0)); - } + element_apply_nz(B, *this, std::plus(), false); + } - /** - * Normalize by adding up the non-zeros across the whole tensor. - * Doesn't do anything if the sum of the tensor non-zeros adds up - * to 0. - */ - inline void normalize(const Float& tolerance =nupic::Epsilon) - { - Float s = sum(); - - if (s > tolerance) - multiply(Real(1./s)); - else - setAll(0.0); - } + /** + * Scales this sparse tensor by an arbitrary scalar a. + */ + inline void multiply(const Float &a) { + if (a == 1.0) + return; - //-------------------------------------------------------------------------------- - private: - Index bounds_; - NZ nz_; + element_apply_fast(std::bind2nd(std::multiplies(), a)); + } - inline bool nearlyZero_(const Float& val) const - { - return nearlyZero(val); + /** + * Scales this sparse tensor and put the result in B, leaving + * this spare tensor unchanged. + */ + inline void multiply(const Float &a, SparseTensor &B) const { + if (a == 1.0) { + B = *this; + return; } - // I need at least the bounds at construction time - SparseTensor(); - - friend class SparseTensorUnitTest; - }; + B.clear(); - //-------------------------------------------------------------------------------- - template - inline std::ostream& operator<<(std::ostream& outStream, const SparseTensor& s) - { - s.print(outStream); - return outStream; + const_iterator i = begin(), e = end(); + while (i != e) { + B.set(i->first, a * i->second); + ++i; + } } - template - inline bool operator==(const SparseTensor& A, const SparseTensor& B) - { - if (A.getBounds() != B.getBounds()) - return false; + template + inline void factor_multiply(const IndexB &dims, + const SparseTensor &B, + SparseTensor &C) const { + factor_apply_fast(dims, B, C, std::multiplies(), true); + } - if (A.getNNonZeros() != B.getNNonZeros()) - return false; - - typename SparseTensor::const_iterator it1, it2; - for (it1 = A.begin(), it2 = B.begin(); it1 != A.end(); ++it1, ++it2) - if (!nearlyEqual(it1->second, it2->second)) - return false; + template + inline void outer_multiply(const SparseTensor B, + SparseTensor &C) const { + outer_product_nz(B, C, std::multiplies()); + } - return true; + template + inline void marginalize(const Index2 &dims, + SparseTensor &B) const { + accumulate_nz(dims, B, std::plus(), Float(0)); } - template - inline bool operator!=(const SparseTensor& A, const SparseTensor& B) - { - return ! (A == B); + /** + * Normalize by adding up the non-zeros across the whole tensor. + * Doesn't do anything if the sum of the tensor non-zeros adds up + * to 0. + */ + inline void normalize(const Float &tolerance = nupic::Epsilon) { + Float s = sum(); + + if (s > tolerance) + multiply(Real(1. / s)); + else + setAll(0.0); } //-------------------------------------------------------------------------------- +private: + Index bounds_; + NZ nz_; -} // end namespace nupic + inline bool nearlyZero_(const Float &val) const { return nearlyZero(val); } -#endif // NTA_SPARSE_TENSOR_HPP + // I need at least the bounds at construction time + SparseTensor(); + + friend class SparseTensorUnitTest; +}; + +//-------------------------------------------------------------------------------- +template +inline std::ostream &operator<<(std::ostream &outStream, + const SparseTensor &s) { + s.print(outStream); + return outStream; +} + +template +inline bool operator==(const SparseTensor &A, + const SparseTensor &B) { + if (A.getBounds() != B.getBounds()) + return false; + + if (A.getNNonZeros() != B.getNNonZeros()) + return false; + + typename SparseTensor::const_iterator it1, it2; + for (it1 = A.begin(), it2 = B.begin(); it1 != A.end(); ++it1, ++it2) + if (!nearlyEqual(it1->second, it2->second)) + return false; + + return true; +} +template +inline bool operator!=(const SparseTensor &A, + const SparseTensor &B) { + return !(A == B); +} +//-------------------------------------------------------------------------------- + +} // end namespace nupic + +#endif // NTA_SPARSE_TENSOR_HPP diff --git a/src/nupic/math/StlIo.cpp b/src/nupic/math/StlIo.cpp index 64b75aa988..ad67cb1da7 100644 --- a/src/nupic/math/StlIo.cpp +++ b/src/nupic/math/StlIo.cpp @@ -20,14 +20,15 @@ * --------------------------------------------------------------------- */ -/** @file STL IO - * This file contains functions to print out and save/load various STL data structures. +/** @file STL IO + * This file contains functions to print out and save/load various STL data + * structures. */ #include namespace nupic { - IOControl io_control; +IOControl io_control; } // end namespace nupic diff --git a/src/nupic/math/StlIo.hpp b/src/nupic/math/StlIo.hpp index 5d13c8ea6d..a585694c91 100644 --- a/src/nupic/math/StlIo.hpp +++ b/src/nupic/math/StlIo.hpp @@ -20,21 +20,22 @@ * --------------------------------------------------------------------- */ -/** @file STL IO - * This file contains functions to print out and save/load various STL data structures. +/** @file STL IO + * This file contains functions to print out and save/load various STL data + * structures. */ #ifndef NTA_STL_IO_HPP #define NTA_STL_IO_HPP -#include #include #include #include +#include -#include -#include #include +#include +#include #include @@ -42,618 +43,561 @@ namespace nupic { - //-------------------------------------------------------------------------------- - // IO CONTROL AND MANIPULATORS - //-------------------------------------------------------------------------------- - typedef enum { CSR =0, CSR_01, BINARY, AS_DENSE } SPARSE_IO_TYPE; - - struct IOControl - { - int abbr; // shorten long vectors output - bool output_n_elts; // output vector size at beginning - - bool pair_paren; // put parens around pairs in vector of pairs - const char* pair_sep; // put separator between pair.first and pair.second - - int convert_to_sparse; // convert dense vector to pos. of non-zeros - int convert_from_sparse; // convert from pos. of non-zero to dense 0/1 vector - - SPARSE_IO_TYPE sparse_io; // do sparse io according to SPARSE_IO_TYPE - - bool bit_vector; // output 0/1 vector compactly - - inline IOControl(int a =-1, bool s =true, bool pp =false, const char* psep =" ", - SPARSE_IO_TYPE smio =CSR, - bool cts =false, - bool cfs =false, - bool bv =false) - : abbr(a), - output_n_elts(s), - pair_paren(pp), - pair_sep(psep), - convert_to_sparse(cts), - convert_from_sparse(cfs), - sparse_io(smio), - bit_vector(bv) - {} - - inline void reset() - { - abbr = -1; - output_n_elts = true; - pair_paren = false; - pair_sep = " "; - convert_to_sparse = false; - convert_from_sparse = false; - sparse_io = CSR; - bit_vector = false; - } - }; - - extern IOControl io_control; - - template - inline std::basic_ostream& - operator,(std::basic_ostream& out_stream, const T& a) - { - return out_stream << ' ' << a; - } - - template - inline std::basic_istream& - operator,(std::basic_istream& in_stream, T& a) - { - return in_stream >> a; - } - - template - inline std::basic_ostream& - operator,(std::basic_ostream& out_stream, - std::basic_ostream& (*pf)(std::basic_ostream&)) - { - pf(out_stream); - return out_stream; - } - - template - inline std::basic_ostream& - p_paren(std::basic_ostream& out_stream) - { - io_control.pair_paren = true; - return out_stream; - } - - template - inline std::basic_ostream& - psep_comma(std::basic_ostream& out_stream) - { - io_control.pair_sep = ","; - return out_stream; - } - - template - inline std::basic_ostream& - psep_dot(std::basic_ostream& out_stream) - { - io_control.pair_sep = "."; - return out_stream; - } - - struct abbr - { - int n; - inline abbr(int _n) : n(_n) {} - }; - - template - inline std::basic_ostream& - operator<<(std::basic_ostream& out_stream, abbr s) - { - io_control.abbr = s.n; - return out_stream; - } - - struct debug - { - int n; - inline debug(int _n =-1) : n(_n) {} - }; - - template - inline std::basic_ostream& - operator<<(std::basic_ostream& out_stream, debug d) - { - io_control.abbr = d.n; - io_control.output_n_elts = false; - io_control.pair_sep = ","; - io_control.pair_paren = true; - return out_stream; - } - - template - inline std::basic_istream& - from_csr_01(std::basic_istream& in_stream) - { - io_control.convert_from_sparse = CSR_01; - return in_stream; - } - - template - inline std::basic_ostream& - to_csr_01(std::basic_ostream& out_stream) - { - io_control.convert_to_sparse = CSR_01; - return out_stream; - } +//-------------------------------------------------------------------------------- +// IO CONTROL AND MANIPULATORS +//-------------------------------------------------------------------------------- +typedef enum { CSR = 0, CSR_01, BINARY, AS_DENSE } SPARSE_IO_TYPE; + +struct IOControl { + int abbr; // shorten long vectors output + bool output_n_elts; // output vector size at beginning + + bool pair_paren; // put parens around pairs in vector of pairs + const char *pair_sep; // put separator between pair.first and pair.second + + int convert_to_sparse; // convert dense vector to pos. of non-zeros + int convert_from_sparse; // convert from pos. of non-zero to dense 0/1 vector + + SPARSE_IO_TYPE sparse_io; // do sparse io according to SPARSE_IO_TYPE + + bool bit_vector; // output 0/1 vector compactly + + inline IOControl(int a = -1, bool s = true, bool pp = false, + const char *psep = " ", SPARSE_IO_TYPE smio = CSR, + bool cts = false, bool cfs = false, bool bv = false) + : abbr(a), output_n_elts(s), pair_paren(pp), pair_sep(psep), + convert_to_sparse(cts), convert_from_sparse(cfs), sparse_io(smio), + bit_vector(bv) {} + + inline void reset() { + abbr = -1; + output_n_elts = true; + pair_paren = false; + pair_sep = " "; + convert_to_sparse = false; + convert_from_sparse = false; + sparse_io = CSR; + bit_vector = false; + } +}; + +extern IOControl io_control; + +template +inline std::basic_ostream &operator,( + std::basic_ostream &out_stream, const T &a) { + return out_stream << ' ' << a; +} + +template +inline std::basic_istream &operator,( + std::basic_istream &in_stream, T &a) { + return in_stream >> a; +} + +template +inline std::basic_ostream &operator,( + std::basic_ostream &out_stream, + std::basic_ostream &(*pf)( + std::basic_ostream &)) { + pf(out_stream); + return out_stream; +} + +template +inline std::basic_ostream & +p_paren(std::basic_ostream &out_stream) { + io_control.pair_paren = true; + return out_stream; +} + +template +inline std::basic_ostream & +psep_comma(std::basic_ostream &out_stream) { + io_control.pair_sep = ","; + return out_stream; +} + +template +inline std::basic_ostream & +psep_dot(std::basic_ostream &out_stream) { + io_control.pair_sep = "."; + return out_stream; +} + +struct abbr { + int n; + inline abbr(int _n) : n(_n) {} +}; + +template +inline std::basic_ostream & +operator<<(std::basic_ostream &out_stream, abbr s) { + io_control.abbr = s.n; + return out_stream; +} + +struct debug { + int n; + inline debug(int _n = -1) : n(_n) {} +}; + +template +inline std::basic_ostream & +operator<<(std::basic_ostream &out_stream, debug d) { + io_control.abbr = d.n; + io_control.output_n_elts = false; + io_control.pair_sep = ","; + io_control.pair_paren = true; + return out_stream; +} + +template +inline std::basic_istream & +from_csr_01(std::basic_istream &in_stream) { + io_control.convert_from_sparse = CSR_01; + return in_stream; +} + +template +inline std::basic_ostream & +to_csr_01(std::basic_ostream &out_stream) { + io_control.convert_to_sparse = CSR_01; + return out_stream; +} + +template +inline std::basic_istream & +bit_vector(std::basic_istream &in_stream) { + io_control.bit_vector = true; + return in_stream; +} + +template +inline std::basic_ostream & +bit_vector(std::basic_ostream &out_stream) { + io_control.bit_vector = true; + return out_stream; +} + +template +inline std::basic_istream & +general_vector(std::basic_istream &in_stream) { + io_control.bit_vector = false; + return in_stream; +} + +template +inline std::basic_ostream & +general_vector(std::basic_ostream &out_stream) { + io_control.bit_vector = false; + return out_stream; +} + +//-------------------------------------------------------------------------------- +// SM IO CONTROL +//-------------------------------------------------------------------------------- +struct sparse_format_class { + SPARSE_IO_TYPE format; + + inline sparse_format_class(SPARSE_IO_TYPE f) : format(f) {} +}; + +inline sparse_format_class sparse_format(SPARSE_IO_TYPE f) { + return sparse_format_class(f); +} + +template +inline std::basic_ostream & +operator<<(std::basic_ostream &out_stream, + sparse_format_class s) { + io_control.sparse_io = s.format; + return out_stream; +} + +template +inline std::basic_istream & +operator>>(std::basic_istream &in_stream, + sparse_format_class s) { + io_control.sparse_io = s.format; + return in_stream; +} + +template +inline std::basic_ostream & +as_dense(std::basic_ostream &out_stream) { + io_control.sparse_io = AS_DENSE; + return out_stream; +} + +template +inline std::basic_ostream & +as_binary(std::basic_ostream &out_stream) { + io_control.sparse_io = BINARY; + return out_stream; +} + +//-------------------------------------------------------------------------------- +// CHECKERS +//-------------------------------------------------------------------------------- +template struct is_positive_checker { + T1 &var; + + inline is_positive_checker(T1 &v) : var(v) {} template - inline std::basic_istream& - bit_vector(std::basic_istream& in_stream) - { - io_control.bit_vector = true; - return in_stream; - } - - template - inline std::basic_ostream& - bit_vector(std::basic_ostream& out_stream) - { - io_control.bit_vector = true; - return out_stream; - } - - template - inline std::basic_istream& - general_vector(std::basic_istream& in_stream) - { - io_control.bit_vector = false; - return in_stream; - } - - template - inline std::basic_ostream& - general_vector(std::basic_ostream& out_stream) - { - io_control.bit_vector = false; - return out_stream; - } - - //-------------------------------------------------------------------------------- - // SM IO CONTROL - //-------------------------------------------------------------------------------- - struct sparse_format_class - { - SPARSE_IO_TYPE format; - - inline sparse_format_class(SPARSE_IO_TYPE f) : format(f) {} - }; - - inline sparse_format_class - sparse_format(SPARSE_IO_TYPE f) { return sparse_format_class(f); } + inline void do_check(std::basic_istream &in_stream) { + double value = 0; + in_stream >> value; + if (value < 0) { + std::cout << "Value out of range: " << value + << " - Expected positive or zero value" << std::endl; + exit(-1); + } + var = (T1)value; + } +}; + +template +inline std::basic_istream & +operator>>(std::basic_istream &in_stream, + is_positive_checker cp) { + cp.do_check(in_stream); + return in_stream; +} + +template inline is_positive_checker assert_positive(T1 &var) { + return is_positive_checker(var); +} + +//-------------------------------------------------------------------------------- +// BINARY PERSISTENCE +//-------------------------------------------------------------------------------- +template +inline void binary_save(std::ostream &out_stream, It begin, It end) { + typedef typename std::iterator_traits::value_type value_type; + size_t size = (size_t)(end - begin); + if (size > 0) { + char *ptr = (char *)&*begin; + out_stream.write(ptr, (std::streamsize)size * sizeof(value_type)); + } +} + +//-------------------------------------------------------------------------------- +template +inline void binary_load(std::istream &in_stream, It begin, It end) { + typedef typename std::iterator_traits::value_type value_type; + size_t size = (size_t)(end - begin); + if (size > 0) { + char *ptr = (char *)&*begin; + in_stream.read(ptr, (std::streamsize)size * sizeof(value_type)); + } +} + +//-------------------------------------------------------------------------------- +template +inline void binary_save(std::ostream &out_stream, const std::vector &v) { + nupic::binary_save(out_stream, v.begin(), v.end()); +} + +//-------------------------------------------------------------------------------- +template +inline void binary_load(std::istream &in_stream, std::vector &v) { + nupic::binary_load(in_stream, v.begin(), v.end()); +} + +//-------------------------------------------------------------------------------- +// STL STREAMING OPERATORS +//-------------------------------------------------------------------------------- + +//-------------------------------------------------------------------------------- +// std::pair +//-------------------------------------------------------------------------------- +template +inline std::ostream &operator<<(std::ostream &out_stream, + const std::pair &p) { + if (io_control.pair_paren) + out_stream << "("; + out_stream << p.first; + out_stream << io_control.pair_sep; + out_stream << p.second; + if (io_control.pair_paren) + out_stream << ")"; + return out_stream; +} + +//-------------------------------------------------------------------------------- +template +inline std::istream &operator>>(std::istream &in_stream, std::pair &p) { + in_stream >> p.first >> p.second; + return in_stream; +} + +//-------------------------------------------------------------------------------- +// std::vector +//-------------------------------------------------------------------------------- +template struct vector_loader { + inline void load(size_t, std::istream &, std::vector &); +}; + +//-------------------------------------------------------------------------------- +/** + * Partial specialization of above functor for primitive types. + */ +template struct vector_loader { + inline void load(size_t n, std::istream &in_stream, std::vector &v) { + if (io_control.convert_from_sparse == CSR_01) { - template - inline std::basic_ostream& - operator<<(std::basic_ostream& out_stream, sparse_format_class s) - { - io_control.sparse_io = s.format; - return out_stream; - } + std::fill(v.begin(), v.end(), (T)0); - template - inline std::basic_istream& - operator>>(std::basic_istream& in_stream, sparse_format_class s) - { - io_control.sparse_io = s.format; - return in_stream; - } - - template - inline std::basic_ostream& - as_dense(std::basic_ostream& out_stream) - { - io_control.sparse_io = AS_DENSE; - return out_stream; - } + for (size_t i = 0; i != n; ++i) { + int index = 0; + in_stream >> index; + v[index] = (T)1; + } - template - inline std::basic_ostream& - as_binary(std::basic_ostream& out_stream) - { - io_control.sparse_io = BINARY; - return out_stream; - } + } else if (io_control.bit_vector) { - //-------------------------------------------------------------------------------- - // CHECKERS - //-------------------------------------------------------------------------------- - template - struct is_positive_checker - { - T1& var; - - inline is_positive_checker(T1& v) : var(v) {} - - template - inline void do_check(std::basic_istream& in_stream) - { - double value = 0; - in_stream >> value; - if (value < 0) { - std::cout << "Value out of range: " << value - << " - Expected positive or zero value" - << std::endl; - exit(-1); + for (size_t i = 0; i != n; ++i) { + float x = 0; + in_stream >> x; + if (x) + v[i] = 1; + else + v[i] = 0; } - var = (T1) value; - } - }; - - template - inline std::basic_istream& - operator>>(std::basic_istream& in_stream, is_positive_checker cp) - { - cp.do_check(in_stream); - return in_stream; - } - template - inline is_positive_checker assert_positive(T1& var) - { - return is_positive_checker(var); - } - - //-------------------------------------------------------------------------------- - // BINARY PERSISTENCE - //-------------------------------------------------------------------------------- - template - inline void binary_save(std::ostream& out_stream, It begin, It end) - { - typedef typename std::iterator_traits::value_type value_type; - size_t size = (size_t) (end - begin); - if (size > 0) { - char* ptr = (char*) & *begin; - out_stream.write(ptr, (std::streamsize) size*sizeof(value_type)); - } - } - - //-------------------------------------------------------------------------------- - template - inline void binary_load(std::istream& in_stream, It begin, It end) - { - typedef typename std::iterator_traits::value_type value_type; - size_t size = (size_t) (end - begin); - if (size > 0) { - char* ptr = (char*) & *begin; - in_stream.read(ptr, (std::streamsize) size*sizeof(value_type)); + } else { + for (size_t i = 0; i != n; ++i) + in_stream >> v[i]; } } +}; - //-------------------------------------------------------------------------------- - template - inline void binary_save(std::ostream& out_stream, const std::vector& v) - { - nupic::binary_save(out_stream, v.begin(), v.end()); +// declartion of >> which is used in the following function. Avoid lookup error +template +inline std::istream &operator>>(std::istream &in_stream, std::vector &v); +//-------------------------------------------------------------------------------- +/** + * Partial specialization for non-primitive types. + */ +template struct vector_loader { + inline void load(size_t n, std::istream &in_stream, std::vector &v) { + for (size_t i = 0; i != n; ++i) + in_stream >> v[i]; } +}; - //-------------------------------------------------------------------------------- - template - inline void binary_load(std::istream& in_stream, std::vector& v) - { - nupic::binary_load(in_stream, v.begin(), v.end()); - } +//-------------------------------------------------------------------------------- +/** + * Factory that will instantiate the right functor to call depending on whether + * T is a primitive type or not. + */ +template +inline void vector_load(size_t n, std::istream &in_stream, std::vector &v) { + vector_loader::value> loader; + loader.load(n, in_stream, v); +} + +//-------------------------------------------------------------------------------- +template struct vector_saver { + inline void save(size_t n, std::ostream &out_stream, const std::vector &v); +}; + +//-------------------------------------------------------------------------------- +/** + * Partial specialization for primitive types. + */ +template struct vector_saver { + inline void save(size_t n, std::ostream &out_stream, + const std::vector &v) { + if (io_control.output_n_elts) + out_stream << n << ' '; - //-------------------------------------------------------------------------------- - // STL STREAMING OPERATORS - //-------------------------------------------------------------------------------- - - //-------------------------------------------------------------------------------- - // std::pair - //-------------------------------------------------------------------------------- - template - inline std::ostream& operator<<(std::ostream& out_stream, const std::pair& p) - { - if (io_control.pair_paren) - out_stream << "("; - out_stream << p.first; - out_stream << io_control.pair_sep; - out_stream << p.second; - if (io_control.pair_paren) - out_stream << ")"; - return out_stream; - } + if (io_control.abbr > 0) + n = std::min((size_t)io_control.abbr, n); - //-------------------------------------------------------------------------------- - template - inline std::istream& operator>>(std::istream& in_stream, std::pair& p) - { - in_stream >> p.first >> p.second; - return in_stream; - } + if (io_control.convert_to_sparse == CSR_01) { - //-------------------------------------------------------------------------------- - // std::vector - //-------------------------------------------------------------------------------- - template - struct vector_loader - { - inline void load(size_t, std::istream&, std::vector&); - }; - - //-------------------------------------------------------------------------------- - /** - * Partial specialization of above functor for primitive types. - */ - template - struct vector_loader - { - inline void load(size_t n, std::istream& in_stream, std::vector& v) - { - if (io_control.convert_from_sparse == CSR_01) { - - std::fill(v.begin(), v.end(), (T) 0); - - for (size_t i = 0; i != n; ++i) { - int index = 0; - in_stream >> index; - v[index] = (T) 1; - } + for (size_t i = 0; i != n; ++i) + if (!is_zero(v[i])) + out_stream << i << ' '; - } else if (io_control.bit_vector) { + } else if (io_control.bit_vector) { - for (size_t i = 0; i != n; ++i) { - float x = 0; - in_stream >> x; - if (x) - v[i] = 1; - else - v[i] = 0; + size_t k = 7; + for (size_t i = 0; i != v.size(); ++i) { + out_stream << (is_zero(v[i]) ? '0' : '1'); + if (i == k) { + out_stream << ' '; + k += 8; } - - } else { - for (size_t i = 0; i != n; ++i) - in_stream >> v[i]; } - } - }; - - // declartion of >> which is used in the following function. Avoid lookup error - template inline std::istream& operator>>(std::istream& in_stream, std::vector& v); - //-------------------------------------------------------------------------------- - /** - * Partial specialization for non-primitive types. - */ - template - struct vector_loader - { - inline void load(size_t n, std::istream& in_stream, std::vector& v) - { - for (size_t i = 0; i != n; ++i) - in_stream >> v[i]; - } - }; - - //-------------------------------------------------------------------------------- - /** - * Factory that will instantiate the right functor to call depending on whether - * T is a primitive type or not. - */ - template - inline void vector_load(size_t n, std::istream& in_stream, std::vector& v) - { - vector_loader::value > loader; - loader.load(n, in_stream, v); - } - - //-------------------------------------------------------------------------------- - template - struct vector_saver - { - inline void save(size_t n, std::ostream& out_stream, const std::vector& v); - }; - - //-------------------------------------------------------------------------------- - /** - * Partial specialization for primitive types. - */ - template - struct vector_saver - { - inline void save(size_t n, std::ostream& out_stream, const std::vector& v) - { - if (io_control.output_n_elts) - out_stream << n << ' '; - - if (io_control.abbr > 0) - n = std::min((size_t) io_control.abbr, n); - - if (io_control.convert_to_sparse == CSR_01) { - - for (size_t i = 0; i != n; ++i) - if (!is_zero(v[i])) - out_stream << i << ' '; - - } else if (io_control.bit_vector) { - - size_t k = 7; - for (size_t i = 0; i != v.size(); ++i) { - out_stream << (is_zero(v[i]) ? '0' : '1'); - if (i == k) { - out_stream << ' '; - k += 8; - } - } - } else { + } else { - for (size_t i = 0; i != n; ++i) - out_stream << v[i] << ' '; - } - - if (io_control.abbr > 0 && n < v.size()) { - size_t rest = v.size() - n; - out_stream << "[+" << rest << "/" << count_non_zeros(v) << "]"; - } - } - }; - - // declartion of << which is used in the following function. Avoid lookup error. - template inline std::ostream& operator<<(std::ostream& out_stream, const std::vector& v); - //-------------------------------------------------------------------------------- - /** - * Partial specialization for non-primitive types. - */ - template - struct vector_saver - { - inline void save(size_t n, std::ostream& out_stream, const std::vector& v) - { - if (io_control.output_n_elts) - out_stream << n << ' '; - - if (io_control.abbr > 0) - n = std::min((size_t) io_control.abbr, n); - - for (size_t i = 0; i != n; ++i) + for (size_t i = 0; i != n; ++i) out_stream << v[i] << ' '; - - if (io_control.abbr > 0 && n < v.size()) { - size_t rest = v.size() - n; - out_stream << "[+" << rest << "/" << count_non_zeros(v) << "]"; - } } - }; - - //-------------------------------------------------------------------------------- - /** - * Factory that will instantiate the right functor to call depending on whether - * T is a primitive type or not. - */ - template - inline void vector_save(size_t n, std::ostream& out_stream, const std::vector& v) - { - vector_saver::value> saver; - saver.save(n, out_stream, v); - } - - //-------------------------------------------------------------------------------- - /** - * Saves the size of the vector. - */ - template - inline std::ostream& operator<<(std::ostream& out_stream, const std::vector& v) - { - vector_save(v.size(), out_stream, v); - return out_stream; - } - - //-------------------------------------------------------------------------------- - /** - * Reads in size of the vector, and redimensions it, except if we are reading - * a sparse binary vector. - */ - template - inline std::istream& - operator>>(std::istream& in_stream, std::vector& v) - { - size_t n = 0; - in_stream >> n; - v.resize(n); - vector_load(n, in_stream, v); - return in_stream; - } - - //-------------------------------------------------------------------------------- - /** - * Doesn't save the size of the buffer itself. - */ - template - inline std::ostream& operator<<(std::ostream& out_stream, const Buffer& b) - { - vector_save(b.nnz, out_stream, static_cast&>(b)); - return out_stream; - } - - //-------------------------------------------------------------------------------- - /** - * Doesn't set the size of the buffer itself. - */ - template - inline std::istream& operator>>(std::istream& in_stream, Buffer& b) - { - in_stream >> b.nnz; - NTA_ASSERT(b.nnz <= b.size()); - vector_load(b.nnz, in_stream, static_cast&>(b)); - return in_stream; - } - //-------------------------------------------------------------------------------- - // std::set - //-------------------------------------------------------------------------------- - template - inline std::ostream& operator<<(std::ostream& out_stream, const std::set& m) - { - typename std::set::const_iterator - it = m.begin(), end = m.end(); - - while (it != end) { - out_stream << *it << ' '; - ++it; + if (io_control.abbr > 0 && n < v.size()) { + size_t rest = v.size() - n; + out_stream << "[+" << rest << "/" << count_non_zeros(v) << "]"; } - - return out_stream; } +}; - //-------------------------------------------------------------------------------- - // std::map - //-------------------------------------------------------------------------------- - template - inline std::ostream& operator<<(std::ostream& out_stream, const std::map& m) - { - out_stream << m.size() << " "; +// declartion of << which is used in the following function. Avoid lookup error. +template +inline std::ostream &operator<<(std::ostream &out_stream, + const std::vector &v); +//-------------------------------------------------------------------------------- +/** + * Partial specialization for non-primitive types. + */ +template struct vector_saver { + inline void save(size_t n, std::ostream &out_stream, + const std::vector &v) { + if (io_control.output_n_elts) + out_stream << n << ' '; - typename std::map::const_iterator - it = m.begin(), end = m.end(); + if (io_control.abbr > 0) + n = std::min((size_t)io_control.abbr, n); - while (it != end) { - out_stream << it->first << ' ' << it->second << ' '; - ++it; - } + for (size_t i = 0; i != n; ++i) + out_stream << v[i] << ' '; - return out_stream; - } - - //-------------------------------------------------------------------------------- - template - inline std::istream& operator>>(std::istream& in_stream, std::map& m) - { - int size = 0; - in_stream >> size; - - for (int i = 0; i != size; ++i) { - T1 k; T2 v; - in_stream >> k >> v; - m.insert(std::make_pair(k, v)); + if (io_control.abbr > 0 && n < v.size()) { + size_t rest = v.size() - n; + out_stream << "[+" << rest << "/" << count_non_zeros(v) << "]"; } - - return in_stream; } +}; - //-------------------------------------------------------------------------------- - // MISCELLANEOUS - //-------------------------------------------------------------------------------- - template - inline void show_all_differences(const std::vector& x, const std::vector& y) - { - std::vector diffs; - find_all_differences(x, y, diffs); - std::cout << diffs.size() << " differences: " << std::endl; - for (size_t i = 0; i != diffs.size(); ++i) - std::cout << "(at:" << diffs[i] - << " y=" << x[diffs[i]] - << ", ans=" << y[diffs[i]] << ")"; - std::cout << std::endl; - } - - //-------------------------------------------------------------------------------- +//-------------------------------------------------------------------------------- +/** + * Factory that will instantiate the right functor to call depending on whether + * T is a primitive type or not. + */ +template +inline void vector_save(size_t n, std::ostream &out_stream, + const std::vector &v) { + vector_saver::value> saver; + saver.save(n, out_stream, v); +} + +//-------------------------------------------------------------------------------- +/** + * Saves the size of the vector. + */ +template +inline std::ostream &operator<<(std::ostream &out_stream, + const std::vector &v) { + vector_save(v.size(), out_stream, v); + return out_stream; +} + +//-------------------------------------------------------------------------------- +/** + * Reads in size of the vector, and redimensions it, except if we are reading + * a sparse binary vector. + */ +template +inline std::istream &operator>>(std::istream &in_stream, std::vector &v) { + size_t n = 0; + in_stream >> n; + v.resize(n); + vector_load(n, in_stream, v); + return in_stream; +} + +//-------------------------------------------------------------------------------- +/** + * Doesn't save the size of the buffer itself. + */ +template +inline std::ostream &operator<<(std::ostream &out_stream, const Buffer &b) { + vector_save(b.nnz, out_stream, static_cast &>(b)); + return out_stream; +} + +//-------------------------------------------------------------------------------- +/** + * Doesn't set the size of the buffer itself. + */ +template +inline std::istream &operator>>(std::istream &in_stream, Buffer &b) { + in_stream >> b.nnz; + NTA_ASSERT(b.nnz <= b.size()); + vector_load(b.nnz, in_stream, static_cast &>(b)); + return in_stream; +} + +//-------------------------------------------------------------------------------- +// std::set +//-------------------------------------------------------------------------------- +template +inline std::ostream &operator<<(std::ostream &out_stream, + const std::set &m) { + typename std::set::const_iterator it = m.begin(), end = m.end(); + + while (it != end) { + out_stream << *it << ' '; + ++it; + } + + return out_stream; +} + +//-------------------------------------------------------------------------------- +// std::map +//-------------------------------------------------------------------------------- +template +inline std::ostream &operator<<(std::ostream &out_stream, + const std::map &m) { + out_stream << m.size() << " "; + + typename std::map::const_iterator it = m.begin(), end = m.end(); + + while (it != end) { + out_stream << it->first << ' ' << it->second << ' '; + ++it; + } + + return out_stream; +} + +//-------------------------------------------------------------------------------- +template +inline std::istream &operator>>(std::istream &in_stream, std::map &m) { + int size = 0; + in_stream >> size; + + for (int i = 0; i != size; ++i) { + T1 k; + T2 v; + in_stream >> k >> v; + m.insert(std::make_pair(k, v)); + } + + return in_stream; +} + +//-------------------------------------------------------------------------------- +// MISCELLANEOUS +//-------------------------------------------------------------------------------- +template +inline void show_all_differences(const std::vector &x, + const std::vector &y) { + std::vector diffs; + find_all_differences(x, y, diffs); + std::cout << diffs.size() << " differences: " << std::endl; + for (size_t i = 0; i != diffs.size(); ++i) + std::cout << "(at:" << diffs[i] << " y=" << x[diffs[i]] + << ", ans=" << y[diffs[i]] << ")"; + std::cout << std::endl; +} + +//-------------------------------------------------------------------------------- } // end namespace nupic #endif // NTA_STL_IO_HPP diff --git a/src/nupic/math/Topology.cpp b/src/nupic/math/Topology.cpp index 93d60f2b1d..17c77740d7 100644 --- a/src/nupic/math/Topology.cpp +++ b/src/nupic/math/Topology.cpp @@ -31,88 +31,69 @@ using std::vector; using namespace nupic; using namespace nupic::math::topology; -namespace nupic -{ - namespace math - { - namespace topology - { - - vector coordinatesFromIndex( - UInt index, - const vector& dimensions) - { - vector coordinates(dimensions.size(), 0); - - UInt shifted = index; - for (size_t i = dimensions.size() - 1; i > 0; i--) - { - coordinates[i] = shifted % dimensions[i]; - shifted = shifted / dimensions[i]; - } - - NTA_ASSERT(shifted < dimensions[0]); - coordinates[0] = shifted; - - return coordinates; - } - - UInt indexFromCoordinates(const vector& coordinates, - const vector& dimensions) - { - NTA_ASSERT(coordinates.size() == dimensions.size()); - - UInt index = 0; - for (size_t i = 0; i < dimensions.size(); i++) - { - NTA_ASSERT(coordinates[i] < dimensions[i]); - index *= dimensions[i]; - index += coordinates[i]; - } - - return index; - } - - } // end namespace topology - } // end namespace algorithms -} // end namespace nupic +namespace nupic { +namespace math { +namespace topology { + +vector coordinatesFromIndex(UInt index, const vector &dimensions) { + vector coordinates(dimensions.size(), 0); + + UInt shifted = index; + for (size_t i = dimensions.size() - 1; i > 0; i--) { + coordinates[i] = shifted % dimensions[i]; + shifted = shifted / dimensions[i]; + } + + NTA_ASSERT(shifted < dimensions[0]); + coordinates[0] = shifted; + + return coordinates; +} + +UInt indexFromCoordinates(const vector &coordinates, + const vector &dimensions) { + NTA_ASSERT(coordinates.size() == dimensions.size()); + + UInt index = 0; + for (size_t i = 0; i < dimensions.size(); i++) { + NTA_ASSERT(coordinates[i] < dimensions[i]); + index *= dimensions[i]; + index += coordinates[i]; + } + + return index; +} +} // end namespace topology +} // namespace math +} // end namespace nupic // ============================================================================ // NEIGHBORHOOD // ============================================================================ Neighborhood::Neighborhood(UInt centerIndex, UInt radius, - const vector& dimensions) - : centerPosition_(coordinatesFromIndex(centerIndex, dimensions)), - dimensions_(dimensions), - radius_(radius) -{ -} - -Neighborhood::Iterator::Iterator(const Neighborhood& neighborhood, bool end) - : neighborhood_(neighborhood), - offset_(neighborhood.dimensions_.size(), -neighborhood.radius_), - finished_(end) -{ + const vector &dimensions) + : centerPosition_(coordinatesFromIndex(centerIndex, dimensions)), + dimensions_(dimensions), radius_(radius) {} + +Neighborhood::Iterator::Iterator(const Neighborhood &neighborhood, bool end) + : neighborhood_(neighborhood), + offset_(neighborhood.dimensions_.size(), -neighborhood.radius_), + finished_(end) { // Choose the first offset that has positive resulting coordinates. - for (size_t i = 0; i < offset_.size(); i++) - { - offset_[i] = std::max(offset_[i], - -(Int)neighborhood_.centerPosition_[i]); + for (size_t i = 0; i < offset_.size(); i++) { + offset_[i] = std::max(offset_[i], -(Int)neighborhood_.centerPosition_[i]); } } -bool Neighborhood::Iterator::operator!=(const Iterator& other) const -{ +bool Neighborhood::Iterator::operator!=(const Iterator &other) const { return finished_ != other.finished_; } -UInt Neighborhood::Iterator::operator*() const -{ +UInt Neighborhood::Iterator::operator*() const { UInt index = 0; - for (size_t i = 0; i < neighborhood_.dimensions_.size(); i++) - { + for (size_t i = 0; i < neighborhood_.dimensions_.size(); i++) { const Int coordinate = neighborhood_.centerPosition_[i] + offset_[i]; NTA_ASSERT(coordinate >= 0); @@ -125,101 +106,74 @@ UInt Neighborhood::Iterator::operator*() const return index; } -const Neighborhood::Iterator& Neighborhood::Iterator::operator++() -{ +const Neighborhood::Iterator &Neighborhood::Iterator::operator++() { advance_(); return *this; } -void Neighborhood::Iterator::advance_() -{ +void Neighborhood::Iterator::advance_() { // When it overflows, we need to "carry the 1" to the next dimension. bool overflowed = true; - for (Int i = offset_.size() - 1; i >= 0; i--) - { + for (Int i = offset_.size() - 1; i >= 0; i--) { offset_[i]++; - overflowed = - offset_[i] > (Int)neighborhood_.radius_ || - (((Int)neighborhood_.centerPosition_[i] + offset_[i]) >= - (Int)neighborhood_.dimensions_[i]); + overflowed = offset_[i] > (Int)neighborhood_.radius_ || + (((Int)neighborhood_.centerPosition_[i] + offset_[i]) >= + (Int)neighborhood_.dimensions_[i]); - if (overflowed) - { + if (overflowed) { // Choose the first offset that has a positive resulting coordinate. offset_[i] = std::max(-(Int)neighborhood_.radius_, -(Int)neighborhood_.centerPosition_[i]); - } - else - { + } else { // There's no overflow. The remaining coordinates don't need to change. break; } } // When the final coordinate overflows, we're done. - if (overflowed) - { + if (overflowed) { finished_ = true; } } -Neighborhood::Iterator Neighborhood::begin() const -{ - return {*this, false}; -} - -Neighborhood::Iterator Neighborhood::end() const -{ - return {*this, true}; -} +Neighborhood::Iterator Neighborhood::begin() const { return {*this, false}; } +Neighborhood::Iterator Neighborhood::end() const { return {*this, true}; } // ============================================================================ // WRAPPING NEIGHBORHOOD // ============================================================================ -WrappingNeighborhood::WrappingNeighborhood( - UInt centerIndex, - UInt radius, - const vector& dimensions) - : centerPosition_(coordinatesFromIndex(centerIndex, dimensions)), - dimensions_(dimensions), - radius_(radius) -{ -} +WrappingNeighborhood::WrappingNeighborhood(UInt centerIndex, UInt radius, + const vector &dimensions) + : centerPosition_(coordinatesFromIndex(centerIndex, dimensions)), + dimensions_(dimensions), radius_(radius) {} WrappingNeighborhood::Iterator::Iterator( - const WrappingNeighborhood& neighborhood, bool end) - : neighborhood_(neighborhood), - offset_(neighborhood.dimensions_.size(), -neighborhood.radius_), - finished_(end) -{ -} + const WrappingNeighborhood &neighborhood, bool end) + : neighborhood_(neighborhood), + offset_(neighborhood.dimensions_.size(), -neighborhood.radius_), + finished_(end) {} -bool WrappingNeighborhood::Iterator::operator!=(const Iterator& other) const -{ +bool WrappingNeighborhood::Iterator::operator!=(const Iterator &other) const { return finished_ != other.finished_; } -UInt WrappingNeighborhood::Iterator::operator*() const -{ +UInt WrappingNeighborhood::Iterator::operator*() const { UInt index = 0; - for (size_t i = 0; i < neighborhood_.dimensions_.size(); i++) - { + for (size_t i = 0; i < neighborhood_.dimensions_.size(); i++) { Int coordinate = neighborhood_.centerPosition_[i] + offset_[i]; // With a large radius, it may have wrapped around multiple times, so use // `while`, not `if`. - while (coordinate < 0) - { + while (coordinate < 0) { coordinate += neighborhood_.dimensions_[i]; } - while (coordinate >= (Int)neighborhood_.dimensions_[i]) - { + while (coordinate >= (Int)neighborhood_.dimensions_[i]) { coordinate -= neighborhood_.dimensions_[i]; } @@ -230,53 +184,45 @@ UInt WrappingNeighborhood::Iterator::operator*() const return index; } -const WrappingNeighborhood::Iterator& WrappingNeighborhood::Iterator::operator++() -{ +const WrappingNeighborhood::Iterator &WrappingNeighborhood::Iterator:: +operator++() { advance_(); return *this; } -void WrappingNeighborhood::Iterator::advance_() -{ +void WrappingNeighborhood::Iterator::advance_() { // When it overflows, we need to "carry the 1" to the next dimension. bool overflowed = true; - for (Int i = offset_.size() - 1; i >= 0; i--) - { + for (Int i = offset_.size() - 1; i >= 0; i--) { offset_[i]++; // If the offset has moved by more than the dimension size, i.e. if // offset_[i] - (-radius) is greater than the dimension size, then we're // about to run into points that we've already seen. This happens when given // small dimensions, a large radius, and wrap-around. - overflowed = - offset_[i] > (Int)neighborhood_.radius_ || - offset_[i] + (Int)neighborhood_.radius_ >= (Int)neighborhood_.dimensions_[i]; + overflowed = offset_[i] > (Int)neighborhood_.radius_ || + offset_[i] + (Int)neighborhood_.radius_ >= + (Int)neighborhood_.dimensions_[i]; - if (overflowed) - { + if (overflowed) { offset_[i] = -neighborhood_.radius_; - } - else - { + } else { // There's no overflow. The remaining coordinates don't need to change. break; } } // When the final coordinate overflows, we're done. - if (overflowed) - { + if (overflowed) { finished_ = true; } } -WrappingNeighborhood::Iterator WrappingNeighborhood::begin() const -{ +WrappingNeighborhood::Iterator WrappingNeighborhood::begin() const { return {*this, /*end*/ false}; } -WrappingNeighborhood::Iterator WrappingNeighborhood::end() const -{ +WrappingNeighborhood::Iterator WrappingNeighborhood::end() const { return {*this, /*end*/ true}; } diff --git a/src/nupic/math/Topology.hpp b/src/nupic/math/Topology.hpp index 79c0c48af7..f1d5b6422f 100644 --- a/src/nupic/math/Topology.hpp +++ b/src/nupic/math/Topology.hpp @@ -31,187 +31,180 @@ #include -namespace nupic -{ - namespace math - { - namespace topology - { - - /** - * Translate an index into coordinates, using the given coordinate system. - * - * @param index - * The index of the point. The coordinates are expressed as a single index - * by using the dimensions as a mixed radix definition. For example, in - * dimensions 42x10, the point [1, 4] is index 1*420 + 4*10 = 460. - * - * @param dimensions - * The coordinate system. - * - * @returns - * A vector of coordinates of length dimensions.size(). - */ - std::vector coordinatesFromIndex( - UInt index, - const std::vector& dimensions); - - /** - * Translate coordinates into an index, using the given coordinate system. - * - * @param coordinates - * A vector of coordinates of length dimensions.size(). - * - * @param dimensions - * The coordinate system. - * - * @returns - * The index of the point. The coordinates are expressed as a single index - * by using the dimensions as a mixed radix definition. For example, in - * dimensions 42x10, the point [1, 4] is index 1*420 + 4*10 = 460. - */ - UInt indexFromCoordinates( - const std::vector& coordinates, - const std::vector& dimensions); - - /** - * A class that lets you iterate over all points within the neighborhood - * of a point. - * - * Usage: - * UInt center = 42; - * for (UInt neighbor : Neighborhood(center, 10, {100, 100})) - * { - * if (neighbor == center) - * { - * // Note that the center is included in the neighborhood! - * } - * else - * { - * // Do something with the neighbor. - * } - * } - * - * A point's neighborhood is the n-dimensional hypercube with sides - * ranging [center - radius, center + radius], inclusive. For example, - * if there are two dimensions and the radius is 3, the neighborhood is - * 6x6. Neighborhoods are truncated when they are near an edge. - * - * Dimensions aren't copied -- a reference is saved. Make sure the - * dimensions don't get overwritten while this Neighborhood instance - * exists. - * - * This is designed to be fast. It walks the list of points in the - * neighborhood without ever creating a list of points. - * - * This still could be faster. Because it handles an arbitrary number of - * dimensions, it has to allocate vectors. It would be faster to have a - * Neighborhood1D, Neighborhood2D, etc., so that all computation could - * occur on the stack, but this would put a burden on callers to handle - * different dimensions counts. Or it would require using polymorphism, - * using pointers/references and putting the Neighborhood on the heap, - * which defeats the purpose of avoiding the vector allocations. - * - * @param centerIndex - * The center of this neighborhood. The coordinates are expressed as a - * single index by using the dimensions as a mixed radix definition. For - * example, in dimensions 42x10, the point [1, 4] is index 1*420 + 4*10 = - * 460. - * - * @param radius - * The radius of this neighborhood about the centerIndex. - * - * @param dimensions - * The dimensions of the world outside this neighborhood. - * - * @returns - * An object which supports C++ range-based for loops. Each iteration of - * the loop returns a point in the neighborhood. Each point is expressed - * as a single index. - */ - class Neighborhood - { - public: - Neighborhood(UInt centerIndex, UInt radius, - const std::vector& dimensions); - - class Iterator { - public: - Iterator(const Neighborhood& neighborhood, bool end); - bool operator !=(const Iterator& other) const; - UInt operator*() const; - const Iterator& operator++(); - - private: - void advance_(); - - const Neighborhood& neighborhood_; - std::vector offset_; - bool finished_; - }; - - Iterator begin() const; - Iterator end() const; - - private: - const std::vector centerPosition_; - const std::vector& dimensions_; - const UInt radius_; - }; - - /** - * Like the Neighborhood class, except that the neighborhood isn't - * truncated when it's near an edge. It wraps around to the other side. - * - * @param centerIndex - * The center of this neighborhood. The coordinates are expressed as a - * single index by using the dimensions as a mixed radix definition. For - * example, in dimensions 42x10, the point [1, 4] is index 1*420 + 4*10 = - * 460. - * - * @param radius - * The radius of this neighborhood about the centerIndex. - * - * @param dimensions - * The dimensions of the world outside this neighborhood. - * - * @returns - * An object which supports C++ range-based for loops. Each iteration of - * the loop returns a point in the neighborhood. Each point is expressed - * as a single index. - */ - class WrappingNeighborhood - { - public: - WrappingNeighborhood(UInt centerIndex, UInt radius, - const std::vector& dimensions); - - class Iterator { - public: - Iterator(const WrappingNeighborhood& neighborhood, bool end); - bool operator !=(const Iterator& other) const; - UInt operator*() const; - const Iterator& operator++(); - - private: - void advance_(); - - const WrappingNeighborhood& neighborhood_; - std::vector offset_; - bool finished_; - }; - - Iterator begin() const; - Iterator end() const; - - private: - const std::vector centerPosition_; - const std::vector& dimensions_; - const UInt radius_; - }; - - } // end namespace topology - } // end namespace algorithms +namespace nupic { +namespace math { +namespace topology { + +/** + * Translate an index into coordinates, using the given coordinate system. + * + * @param index + * The index of the point. The coordinates are expressed as a single index + * by using the dimensions as a mixed radix definition. For example, in + * dimensions 42x10, the point [1, 4] is index 1*420 + 4*10 = 460. + * + * @param dimensions + * The coordinate system. + * + * @returns + * A vector of coordinates of length dimensions.size(). + */ +std::vector coordinatesFromIndex(UInt index, + const std::vector &dimensions); + +/** + * Translate coordinates into an index, using the given coordinate system. + * + * @param coordinates + * A vector of coordinates of length dimensions.size(). + * + * @param dimensions + * The coordinate system. + * + * @returns + * The index of the point. The coordinates are expressed as a single index + * by using the dimensions as a mixed radix definition. For example, in + * dimensions 42x10, the point [1, 4] is index 1*420 + 4*10 = 460. + */ +UInt indexFromCoordinates(const std::vector &coordinates, + const std::vector &dimensions); + +/** + * A class that lets you iterate over all points within the neighborhood + * of a point. + * + * Usage: + * UInt center = 42; + * for (UInt neighbor : Neighborhood(center, 10, {100, 100})) + * { + * if (neighbor == center) + * { + * // Note that the center is included in the neighborhood! + * } + * else + * { + * // Do something with the neighbor. + * } + * } + * + * A point's neighborhood is the n-dimensional hypercube with sides + * ranging [center - radius, center + radius], inclusive. For example, + * if there are two dimensions and the radius is 3, the neighborhood is + * 6x6. Neighborhoods are truncated when they are near an edge. + * + * Dimensions aren't copied -- a reference is saved. Make sure the + * dimensions don't get overwritten while this Neighborhood instance + * exists. + * + * This is designed to be fast. It walks the list of points in the + * neighborhood without ever creating a list of points. + * + * This still could be faster. Because it handles an arbitrary number of + * dimensions, it has to allocate vectors. It would be faster to have a + * Neighborhood1D, Neighborhood2D, etc., so that all computation could + * occur on the stack, but this would put a burden on callers to handle + * different dimensions counts. Or it would require using polymorphism, + * using pointers/references and putting the Neighborhood on the heap, + * which defeats the purpose of avoiding the vector allocations. + * + * @param centerIndex + * The center of this neighborhood. The coordinates are expressed as a + * single index by using the dimensions as a mixed radix definition. For + * example, in dimensions 42x10, the point [1, 4] is index 1*420 + 4*10 = + * 460. + * + * @param radius + * The radius of this neighborhood about the centerIndex. + * + * @param dimensions + * The dimensions of the world outside this neighborhood. + * + * @returns + * An object which supports C++ range-based for loops. Each iteration of + * the loop returns a point in the neighborhood. Each point is expressed + * as a single index. + */ +class Neighborhood { +public: + Neighborhood(UInt centerIndex, UInt radius, + const std::vector &dimensions); + + class Iterator { + public: + Iterator(const Neighborhood &neighborhood, bool end); + bool operator!=(const Iterator &other) const; + UInt operator*() const; + const Iterator &operator++(); + + private: + void advance_(); + + const Neighborhood &neighborhood_; + std::vector offset_; + bool finished_; + }; + + Iterator begin() const; + Iterator end() const; + +private: + const std::vector centerPosition_; + const std::vector &dimensions_; + const UInt radius_; +}; + +/** + * Like the Neighborhood class, except that the neighborhood isn't + * truncated when it's near an edge. It wraps around to the other side. + * + * @param centerIndex + * The center of this neighborhood. The coordinates are expressed as a + * single index by using the dimensions as a mixed radix definition. For + * example, in dimensions 42x10, the point [1, 4] is index 1*420 + 4*10 = + * 460. + * + * @param radius + * The radius of this neighborhood about the centerIndex. + * + * @param dimensions + * The dimensions of the world outside this neighborhood. + * + * @returns + * An object which supports C++ range-based for loops. Each iteration of + * the loop returns a point in the neighborhood. Each point is expressed + * as a single index. + */ +class WrappingNeighborhood { +public: + WrappingNeighborhood(UInt centerIndex, UInt radius, + const std::vector &dimensions); + + class Iterator { + public: + Iterator(const WrappingNeighborhood &neighborhood, bool end); + bool operator!=(const Iterator &other) const; + UInt operator*() const; + const Iterator &operator++(); + + private: + void advance_(); + + const WrappingNeighborhood &neighborhood_; + std::vector offset_; + bool finished_; + }; + + Iterator begin() const; + Iterator end() const; + +private: + const std::vector centerPosition_; + const std::vector &dimensions_; + const UInt radius_; +}; + +} // end namespace topology +} // namespace math } // end namespace nupic #endif // NTA_TOPOLOGY_HPP diff --git a/src/nupic/math/Types.hpp b/src/nupic/math/Types.hpp index 227488df9c..976e9d2f9c 100644 --- a/src/nupic/math/Types.hpp +++ b/src/nupic/math/Types.hpp @@ -27,10 +27,10 @@ #ifndef NTA_MATH_TYPES_HPP #define NTA_MATH_TYPES_HPP +#include // sort #include -#include #include -#include // sort +#include #include #include @@ -39,323 +39,279 @@ namespace nupic { - //-------------------------------------------------------------------------------- - // BYTE VECTOR - //-------------------------------------------------------------------------------- - /** - * This is a good compromise between speed and memory for the use cases we have. - * Going to a real vector of bits is slower when accessing the individual bits, - * but this vector of bytes can still be fed to the SSE with good results. - */ - struct ByteVector : public std::vector - { - inline ByteVector(size_t n =0) - : std::vector(n, (nupic::Byte)0) - {} - - /** - * Use these two functions when converting with a vector of int or float - * since the byte represenation of the elements in a byte vector is _not_ - * the same as the byte representation of ints and floats. - */ - template - inline ByteVector(It begin, size_t n) - : std::vector(n, 0) - { - for (size_t i = 0; i != this->size(); ++i) - (*this)[i] = *begin++ != 0; - } - - template - inline void toDense(It begin, It end) - { - for (size_t i = 0; i != this->size(); ++i) - *begin++ = (*this)[i] != 0; - } - }; - +//-------------------------------------------------------------------------------- +// BYTE VECTOR +//-------------------------------------------------------------------------------- +/** + * This is a good compromise between speed and memory for the use cases we have. + * Going to a real vector of bits is slower when accessing the individual bits, + * but this vector of bytes can still be fed to the SSE with good results. + */ +struct ByteVector : public std::vector { + inline ByteVector(size_t n = 0) + : std::vector(n, (nupic::Byte)0) {} - //-------------------------------------------------------------------------------- - // Buffer - //-------------------------------------------------------------------------------- /** - * Allocated once, but only the first n positions are valid (std::vector does that!) - * DON'T USE ANYMORE, but keeping it because a lot of code already depends on it. + * Use these two functions when converting with a vector of int or float + * since the byte represenation of the elements in a byte vector is _not_ + * the same as the byte representation of ints and floats. */ - template - struct Buffer : public std::vector - { - typedef size_t size_type; - typedef T value_type; - - size_type nnz; + template + inline ByteVector(It begin, size_t n) : std::vector(n, 0) { + for (size_t i = 0; i != this->size(); ++i) + (*this)[i] = *begin++ != 0; + } + + template inline void toDense(It begin, It end) { + for (size_t i = 0; i != this->size(); ++i) + *begin++ = (*this)[i] != 0; + } +}; + +//-------------------------------------------------------------------------------- +// Buffer +//-------------------------------------------------------------------------------- +/** + * Allocated once, but only the first n positions are valid (std::vector does + * that!) DON'T USE ANYMORE, but keeping it because a lot of code already + * depends on it. + */ +template struct Buffer : public std::vector { + typedef size_t size_type; + typedef T value_type; - inline Buffer(size_type _s =0) - : std::vector(_s), - nnz(0) - {} + size_type nnz; - inline void clear() - { - nnz = 0; - } + inline Buffer(size_type _s = 0) : std::vector(_s), nnz(0) {} - inline void adjust_nnz(size_t n) // call resize? - { - nnz = std::min(nnz, n); - } + inline void clear() { nnz = 0; } - inline bool empty() const - { - return nnz == 0; - } - - inline void push_back(const T& x) - { - (*this)[nnz++] = x; - } - - inline typename std::vector::iterator nnz_end() - { - return this->begin() + nnz; - } - - inline typename std::vector::const_iterator nnz_end() const - { - return this->begin() + nnz; - } - }; - - //-------------------------------------------------------------------------------- - // Direct access with fast erase - //-------------------------------------------------------------------------------- - // Records who has been set, so that resetting to zero is fast. Usage pattern - // is to clear the board, do a bunch of sets, look at the board (to test for - // membership for example), then reset the board in the next iteration. - // It trades memory for speed. T is adjustable to be bool (uses vector, - // 1 bit per element), or int, or ushort, or long, or even float. For - // a membership board (set), on darwin86, unsigned short is fastest on 8/11/2010. - // The clear() method provides a kind of incremental reset. - // Assumes the elements that are set are sparse, that there aren't many compared - // to the size of the board. - template - struct DirectAccess + inline void adjust_nnz(size_t n) // call resize? { - typedef I size_type; - typedef T value_type; - - std::vector board; - std::vector who; - - inline void resize(size_type m, size_type n =0) - { - //m = 4 * (m / 4); - board.resize(m, T()); - who.reserve(n == 0 ? m : n); - - assert(who.size() == 0); - } - - inline void set(size_type w, const T& v = T(1)) - { - assert(w < board.size()); - assert(v != T()); - assert(who.size() < who.capacity()); - - if (board[w] == T()) { // that if doesn't doest much at all (verified) - who.push_back(w); - //assert(std::set(who.begin(),who.end()).size() == who.size()); - } - - board[w] = v; - } - - inline T get(size_type w) const - { - assert(w < board.size()); - - return board[w]; - } - - // Only const operator because the non-const has annoying side-effects - // that are easily unintended - inline const T& operator[](size_type w) const - { - return board[w]; - } - - inline void increment(size_type w) - { - assert(w < board.size()); - - if (board[w] == T()) { - who.push_back(w); - //assert(std::set(who.begin(),who.end()).size() == who.size()); - } - - ++ board[w]; + nnz = std::min(nnz, n); + } + + inline bool empty() const { return nnz == 0; } + + inline void push_back(const T &x) { (*this)[nnz++] = x; } + + inline typename std::vector::iterator nnz_end() { + return this->begin() + nnz; + } + + inline typename std::vector::const_iterator nnz_end() const { + return this->begin() + nnz; + } +}; + +//-------------------------------------------------------------------------------- +// Direct access with fast erase +//-------------------------------------------------------------------------------- +// Records who has been set, so that resetting to zero is fast. Usage pattern +// is to clear the board, do a bunch of sets, look at the board (to test for +// membership for example), then reset the board in the next iteration. +// It trades memory for speed. T is adjustable to be bool (uses vector, +// 1 bit per element), or int, or ushort, or long, or even float. For +// a membership board (set), on darwin86, unsigned short is fastest on +// 8/11/2010. The clear() method provides a kind of incremental reset. Assumes +// the elements that are set are sparse, that there aren't many compared to the +// size of the board. +template struct DirectAccess { + typedef I size_type; + typedef T value_type; + + std::vector board; + std::vector who; + + inline void resize(size_type m, size_type n = 0) { + // m = 4 * (m / 4); + board.resize(m, T()); + who.reserve(n == 0 ? m : n); + + assert(who.size() == 0); + } + + inline void set(size_type w, const T &v = T(1)) { + assert(w < board.size()); + assert(v != T()); + assert(who.size() < who.capacity()); + + if (board[w] == T()) { // that if doesn't doest much at all (verified) + who.push_back(w); + // assert(std::set(who.begin(),who.end()).size() == + // who.size()); } - // If board[w] becomes T() again, we need to update who - inline void decrement(size_type w) - { - assert(w < board.size()); - - if (board[w] == T()) { - who.push_back(w); - //assert(std::set(who.begin(),who.end()).size() == who.size()); - } - - -- board[w]; - - // To make sure we keep the uniqueness invariant, - // might be costly if not very sparse? - if (board[w] == T()) { - size_type i = 0; - while (who[i] != w) - ++i; - std::swap(who[i], who[who.size()-1]); - who.pop_back(); - } - } + board[w] = v; + } - // v can be anything, < 0, == 0, or > 0 - // If board[w] becomes T() again, we need to update who - inline void update(size_type w, const value_type& v) - { - assert(w < board.size()); - - if (board[w] == T()) { - who.push_back(w); - //assert(std::set(who.begin(),who.end()).size() == who.size()); - } - - board[w] += v; - - // To make sure we keep the uniqueness invariant, - // might be costly if not very sparse? - if (board[w] == T()) { - size_type i = 0; - while (who[i] != w) - ++i; - std::swap(who[i], who[who.size()-1]); - who.pop_back(); - } - } - - // Clear by 4 is a little bit faster, but works only - // if T() takes exactly 4 bytes. - inline void clear() - { - size_type* w = &who[0], *w_end = w + who.size(); - //size_type* p = (size_type*) &board[0]; - while (w != w_end) - //p[*w++] = 0; - board[*w++] = T(); - who.resize(0); - } + inline T get(size_type w) const { + assert(w < board.size()); - // Keep only the value above a certain threshold. - // Resort the who array optionally. - // TODO: unit test more - inline void threshold(const T& t, bool sorted =false) - { - int n = who.size(); - int i = 0; + return board[w]; + } - while (i < n) - if (board[who[i]] < t) - std::swap(who[i], who[--n]); - else - ++i; + // Only const operator because the non-const has annoying side-effects + // that are easily unintended + inline const T &operator[](size_type w) const { return board[w]; } - who.resize(n); + inline void increment(size_type w) { + assert(w < board.size()); - if (sorted) - std::sort(who.begin(), who.end()); + if (board[w] == T()) { + who.push_back(w); + // assert(std::set(who.begin(),who.end()).size() == + // who.size()); } - }; - //-------------------------------------------------------------------------------- - // Avoids cost of clearing the board by using multiple colors. Clears only - // every 255 iterations. - // Doesn't keep list of who's on for fast iteration like DirecAccess does. - template - struct Indicator; + ++board[w]; + } - template - struct Indicator - { - typedef I size_type; - - std::vector board; - unsigned short color; + // If board[w] becomes T() again, we need to update who + inline void decrement(size_type w) { + assert(w < board.size()); - inline void resize(size_type m) - { - color = 0; - board.resize(m, color); + if (board[w] == T()) { + who.push_back(w); + // assert(std::set(who.begin(),who.end()).size() == + // who.size()); } - inline void set(size_type w) - { - NTA_ASSERT(w < board.size()); + --board[w]; - board[w] = color; + // To make sure we keep the uniqueness invariant, + // might be costly if not very sparse? + if (board[w] == T()) { + size_type i = 0; + while (who[i] != w) + ++i; + std::swap(who[i], who[who.size() - 1]); + who.pop_back(); } + } - inline bool is_on(size_type w) const - { - NTA_ASSERT(w < board.size()); + // v can be anything, < 0, == 0, or > 0 + // If board[w] becomes T() again, we need to update who + inline void update(size_type w, const value_type &v) { + assert(w < board.size()); - return board[w] == color; + if (board[w] == T()) { + who.push_back(w); + // assert(std::set(who.begin(),who.end()).size() == + // who.size()); } - inline bool operator[](size_type w) const - { - NTA_ASSERT(w < board.size()); + board[w] += v; - return is_on(w); + // To make sure we keep the uniqueness invariant, + // might be costly if not very sparse? + if (board[w] == T()) { + size_type i = 0; + while (who[i] != w) + ++i; + std::swap(who[i], who[who.size() - 1]); + who.pop_back(); } - - inline void clear() - { - if (color < std::numeric_limits::max()) - ++color; - else { - color = 0; - std::fill(board.begin(), board.end(), color); - } + } + + // Clear by 4 is a little bit faster, but works only + // if T() takes exactly 4 bytes. + inline void clear() { + size_type *w = &who[0], *w_end = w + who.size(); + // size_type* p = (size_type*) &board[0]; + while (w != w_end) + // p[*w++] = 0; + board[*w++] = T(); + who.resize(0); + } + + // Keep only the value above a certain threshold. + // Resort the who array optionally. + // TODO: unit test more + inline void threshold(const T &t, bool sorted = false) { + int n = who.size(); + int i = 0; + + while (i < n) + if (board[who[i]] < t) + std::swap(who[i], who[--n]); + else + ++i; + + who.resize(n); + + if (sorted) + std::sort(who.begin(), who.end()); + } +}; + +//-------------------------------------------------------------------------------- +// Avoids cost of clearing the board by using multiple colors. Clears only +// every 255 iterations. +// Doesn't keep list of who's on for fast iteration like DirecAccess does. +template struct Indicator; + +template struct Indicator { + typedef I size_type; + + std::vector board; + unsigned short color; + + inline void resize(size_type m) { + color = 0; + board.resize(m, color); + } + + inline void set(size_type w) { + NTA_ASSERT(w < board.size()); + + board[w] = color; + } + + inline bool is_on(size_type w) const { + NTA_ASSERT(w < board.size()); + + return board[w] == color; + } + + inline bool operator[](size_type w) const { + NTA_ASSERT(w < board.size()); + + return is_on(w); + } + + inline void clear() { + if (color < std::numeric_limits::max()) + ++color; + else { + color = 0; + std::fill(board.begin(), board.end(), color); } + } - template - inline void set_from_sparse(It begin, It end) - { - NTA_ASSERT(begin <= end); + template inline void set_from_sparse(It begin, It end) { + NTA_ASSERT(begin <= end); - this->clear(); - while (begin != end) - this->set(*begin++); - } - }; + this->clear(); + while (begin != end) + this->set(*begin++); + } +}; - //-------------------------------------------------------------------------------- - /** - * The first element of each pair is the index of a non-zero, and the second element - * is the value of the non-zero. - */ - template - struct SparseVector : public Buffer > - { - typedef T1 size_type; - typedef T2 value_type; +//-------------------------------------------------------------------------------- +/** + * The first element of each pair is the index of a non-zero, and the second + * element is the value of the non-zero. + */ +template +struct SparseVector : public Buffer> { + typedef T1 size_type; + typedef T2 value_type; - inline SparseVector(size_type s =0) - : Buffer >(s) - {} - }; + inline SparseVector(size_type s = 0) : Buffer>(s) {} +}; - //-------------------------------------------------------------------------------- +//-------------------------------------------------------------------------------- }; // end namespace nupic #endif // NTA_MATH_TYPES_HPP diff --git a/src/nupic/math/Utils.hpp b/src/nupic/math/Utils.hpp index 7b4171f4fc..b7ef5a6c88 100644 --- a/src/nupic/math/Utils.hpp +++ b/src/nupic/math/Utils.hpp @@ -20,27 +20,27 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Definitions for various utility functions */ #ifndef NTA_UTILS_HPP #define NTA_UTILS_HPP +#include #include -#include +#include #include -#include -#include #include #include -#include -#include -#include #include #include +#include +#include +#include #include -#include +#include +#include #include @@ -48,333 +48,337 @@ #include namespace nupic { - - //-------------------------------------------------------------------------------- - /** - * Computes the amount of padding required to align two adjacent blocks of memory. - * If the first block has 17 bytes, and the second is a "vector" of 4 elements - * of 4 bytes each, we need to align the start of the "vector" on a 4 bytes - * boundary. The amount of padding required after the 17 bytes of the first - * block is: 3 bytes, and 3 = 4 - 17 % 4, that is: - * padding = second elem size - first total size % second elem size. - * - * Special case: if the first block of memory ends on a boundary of the second - * block, no padding required. Example, first block has 16 bytes and second vector of - * 4 bytes each: 16 % 4 = 0. - */ - template - inline const SizeType padding(const SizeType& s1, const SizeType& s2) - { - if(s2) { - SizeType extra = s1 % s2; - return extra == 0 ? 0 : s2 - extra; - } - else return 0; - } + +//-------------------------------------------------------------------------------- +/** + * Computes the amount of padding required to align two adjacent blocks of + * memory. If the first block has 17 bytes, and the second is a "vector" of 4 + * elements of 4 bytes each, we need to align the start of the "vector" on a 4 + * bytes boundary. The amount of padding required after the 17 bytes of the + * first block is: 3 bytes, and 3 = 4 - 17 % 4, that is: padding = second elem + * size - first total size % second elem size. + * + * Special case: if the first block of memory ends on a boundary of the second + * block, no padding required. Example, first block has 16 bytes and second + * vector of 4 bytes each: 16 % 4 = 0. + */ +template +inline const SizeType padding(const SizeType &s1, const SizeType &s2) { + if (s2) { + SizeType extra = s1 % s2; + return extra == 0 ? 0 : s2 - extra; + } else + return 0; +} /* - the following code is known to cause -Wstrict-aliasing warning, so silence it here + the following code is known to cause -Wstrict-aliasing warning, so silence it + here */ #if !defined(NTA_OS_WINDOWS) - #pragma GCC diagnostic push - #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" #endif - inline bool isSystemLittleEndian() - { static const char test[2] = { 1, 0 }; return (*(short *) test) == 1; } +inline bool isSystemLittleEndian() { + static const char test[2] = {1, 0}; + return (*(short *)test) == 1; +} #if !defined(NTA_OS_WINDOWS) - #pragma GCC diagnostic pop // return back to defaults +#pragma GCC diagnostic pop // return back to defaults #endif - template - inline void swapBytesInPlace(T *pxIn, Size n) - { - union SwapType { T x; unsigned char b[sizeof(T)]; }; - SwapType *px = reinterpret_cast(pxIn); - SwapType *pxend = px + n; - const int stop = sizeof(T) / 2; - for(; px!=pxend; ++px) { - for(int j=0; jb[j], px->b[sizeof(T)-j-1]); - } - } - - template - inline void swapBytes(T *pxOut, Size n, const T *pxIn) - { - NTA_ASSERT(pxOut != pxIn) << "Use swapBytesInPlace() instead."; - NTA_ASSERT(!(((pxOut > pxIn) && (pxOut < (pxIn+n))) || - ((pxIn > pxOut) && (pxIn < (pxOut+n))))) << "Overlapping ranges not supported."; - - union SwapType { T x; unsigned char b[sizeof(T)]; }; - const SwapType *px0 = reinterpret_cast(pxIn); - const SwapType *pxend = px0 + n; - SwapType *px1 = reinterpret_cast(pxOut); - for(; px0!=pxend; ++px0, ++px1) { - for(int j=0; jb[j] = px0->b[sizeof(T)-j-1]; - } +template inline void swapBytesInPlace(T *pxIn, Size n) { + union SwapType { + T x; + unsigned char b[sizeof(T)]; + }; + SwapType *px = reinterpret_cast(pxIn); + SwapType *pxend = px + n; + const int stop = sizeof(T) / 2; + for (; px != pxend; ++px) { + for (int j = 0; j < stop; ++j) + std::swap(px->b[j], px->b[sizeof(T) - j - 1]); } +} - /** - * Calculates sizeof() types named by string names of types in nupic/types/types. - * Throws if the requested type cannot be found. - * - * Supported type names include: - * bool, - * char, wchar_t, - * NTA_Char, NTA_WChar, NTA_Byte, - * float, double, - * NTA_Real32, NTA_Real64, NTA_Real, - * int, size_t, - * NTA_Int32, NTA_UInt32, NTA_Int64, NTA_UInt64, NTA_Size - * - * @param name (string) Name of type to calculate sizeof() for. - * @param isNumeric (bool&) set to true on exit if type name is a number. - * @retval Number of bytes per element of the specified type. - */ - extern size_t GetTypeSize(const std::string &name, bool& isNumeric); - - /** - * Calculates sizeof() types named by string names of types in nupic/types/types. - * Throws if the requested type cannot be found. - * - * Supported type names include: - * bool, - * char, wchar_t, - * NTA_Char, NTA_WChar, NTA_Byte, - * float, double, - * NTA_Real32, NTA_Real64, NTA_Real, - * int, size_t, - * NTA_Int32, NTA_UInt32, NTA_Int64, NTA_UInt64, NTA_Size - * - * @param name (string) Name of type to calculate sizeof() for. - * @param isNumeric (bool&) set to true on exit if type name is a number. - * @retval Number of bytes per element of the specified type. - */ - extern size_t GetTypeSize(NTA_BasicType type, bool& isNumeric); - - /** - * Return a string representation of an NTA_BasicType - * - * @param type the NTA_BasicType enum - * @retval name of the type as a string - */ - extern std::string GetTypeName(NTA_BasicType type); - - /** - * Utility routine used by PrintVariableArray to print array of a certain type - */ - template - inline void utilsPrintArray_(std::ostream& out, const void* theBeginP, - const void* theEndP) - { - const T* beginP = (T*)theBeginP; - const T* endP = (T*)theEndP; - - for ( ; beginP != endP; ++beginP) - out << *beginP << " "; - } +template inline void swapBytes(T *pxOut, Size n, const T *pxIn) { + NTA_ASSERT(pxOut != pxIn) << "Use swapBytesInPlace() instead."; + NTA_ASSERT(!(((pxOut > pxIn) && (pxOut < (pxIn + n))) || + ((pxIn > pxOut) && (pxIn < (pxOut + n))))) + << "Overlapping ranges not supported."; - /** - * Utility routine for setting an array in memory of a certain type from a stream - * - * @param in the stream with values to put into the array - * @param theBeginP pointer to start of array in memory - * @param theEndP pointer to end of array in memory - * @retval true if successfully set all values - * - */ - template - inline void utilsSetArray_(std::istream& in, void* theBeginP, void* theEndP) - { - T* beginP = (T*)theBeginP; - T* endP = (T*)theEndP; - - for ( ; beginP != endP && in.good(); ++beginP) - in >> *beginP; - if (beginP != endP && !in.eof()) - NTA_THROW << "UtilsSetArray() - error reading stream of values"; + union SwapType { + T x; + unsigned char b[sizeof(T)]; + }; + const SwapType *px0 = reinterpret_cast(pxIn); + const SwapType *pxend = px0 + n; + SwapType *px1 = reinterpret_cast(pxOut); + for (; px0 != pxend; ++px0, ++px1) { + for (int j = 0; j < sizeof(T); ++j) + px1->b[j] = px0->b[sizeof(T) - j - 1]; } +} + +/** + * Calculates sizeof() types named by string names of types in + * nupic/types/types. Throws if the requested type cannot be found. + * + * Supported type names include: + * bool, + * char, wchar_t, + * NTA_Char, NTA_WChar, NTA_Byte, + * float, double, + * NTA_Real32, NTA_Real64, NTA_Real, + * int, size_t, + * NTA_Int32, NTA_UInt32, NTA_Int64, NTA_UInt64, NTA_Size + * + * @param name (string) Name of type to calculate sizeof() for. + * @param isNumeric (bool&) set to true on exit if type name is a number. + * @retval Number of bytes per element of the specified type. + */ +extern size_t GetTypeSize(const std::string &name, bool &isNumeric); + +/** + * Calculates sizeof() types named by string names of types in + * nupic/types/types. Throws if the requested type cannot be found. + * + * Supported type names include: + * bool, + * char, wchar_t, + * NTA_Char, NTA_WChar, NTA_Byte, + * float, double, + * NTA_Real32, NTA_Real64, NTA_Real, + * int, size_t, + * NTA_Int32, NTA_UInt32, NTA_Int64, NTA_UInt64, NTA_Size + * + * @param name (string) Name of type to calculate sizeof() for. + * @param isNumeric (bool&) set to true on exit if type name is a number. + * @retval Number of bytes per element of the specified type. + */ +extern size_t GetTypeSize(NTA_BasicType type, bool &isNumeric); + +/** + * Return a string representation of an NTA_BasicType + * + * @param type the NTA_BasicType enum + * @retval name of the type as a string + */ +extern std::string GetTypeName(NTA_BasicType type); + +/** + * Utility routine used by PrintVariableArray to print array of a certain type + */ +template +inline void utilsPrintArray_(std::ostream &out, const void *theBeginP, + const void *theEndP) { + const T *beginP = (T *)theBeginP; + const T *endP = (T *)theEndP; + + for (; beginP != endP; ++beginP) + out << *beginP << " "; +} + +/** + * Utility routine for setting an array in memory of a certain type from a + * stream + * + * @param in the stream with values to put into the array + * @param theBeginP pointer to start of array in memory + * @param theEndP pointer to end of array in memory + * @retval true if successfully set all values + * + */ +template +inline void utilsSetArray_(std::istream &in, void *theBeginP, void *theEndP) { + T *beginP = (T *)theBeginP; + T *endP = (T *)theEndP; + + for (; beginP != endP && in.good(); ++beginP) + in >> *beginP; + if (beginP != endP && !in.eof()) + NTA_THROW << "UtilsSetArray() - error reading stream of values"; +} - /** - * Streams the contents of a variable array cast as the given type. - * - * This is used by the NodeProcessor when returing the value of an node's outputs in - * response to the "nodeOPrint" supervisor command, and also when returning the value of - * a node's output or parameters to the tools in response to a watch request. - * - * The caller must pass in either a dataType, elemSize, or both. If both are specified, - * then this routine will assert that the elemSize agrees with the given dataType. If - * dataType is not specified, then this routine will pick a most likely dataType given - * the elemSize. - * - * @param outStream [std::ostream] the stream to print to - * @param beginP [Byte*] pointer to the start of the variable - * @param endP [Byte*] pointer to first byte past the end of the variable - * @param dataType [std::string] the data type to print as (optional) - * @retval [std::string] the actual type the variable was printed as. This will - * always be dataType, unless the dataType was unrecognized. - * - * @b Exceptions: - * @li None. - */ - extern std::string PrintVariableArray (std::ostream& outStream, const Byte* beginP, - const Byte* endP, const std::string& dataType=""); - - /** - * Sets the contents of a variable array cast as the given type. - * - * This is used by the NodeProcessor when setting the value of an node's outputs in - * response to the "nodeOSet" supervisor command. - * - * @param inStream [std::istream] the stream to fetch the values from - * @param beginP [Byte*] pointer to the start of the variable - * @param endP [Byte*] pointer to first byte past the end of the variable - * @param dataType [std::string] the data type to set as - * - * @b Exceptions: - * @li None. - */ - extern void SetVariableArray (std::istream& inStream, Byte* beginP, - Byte* endP, const std::string& dataType); - - //-------------------------------------------------------------------------------- - // Defines, used as code generators, to make the code more readable - -#define NO_DEFAULTS(X) private: X(); X(const X&); X& operator=(const X&); - - /** - * Puts Y in current scope - * Iterates on whole Z, which must have begin() and end() - */ -#define LOOP(X, Y, Z) \ - X::iterator Y; \ - X::iterator Y##beginXX = (Z).begin(); \ - X::iterator Y##endXX = (Z).end(); \ - for (Y = Y##beginXX; Y != Y##endXX; ++Y) - - /** - * Puts Y in current scope - * Z must have begin() - * Iterates on partial Z, between Z.begin() and Z.begin() + L - */ -#define PARTIAL_LOOP(X, Y, Z, L) \ - X::iterator Y; \ - X::iterator Y##beginXX = (Z).begin(); \ - X::iterator Y##endXX = (Z).begin() + (L); \ - for (Y = Y##beginXX; Y != Y##endXX; ++Y) - - /** - * Puts Y in current scope - * Iterates on whole Z, with a const_iterator - */ -#define CONST_LOOP(X, Y, Z) \ - X::const_iterator Y; \ - X::const_iterator Y##beginXX = (Z).begin(); \ - X::const_iterator Y##endXX = (Z).end(); \ - for (Y = Y##beginXX; Y != Y##endXX; ++Y) - - /** - * Puts Y in current scope - * Iterates from Y to Z by steps of 1 - */ -#define ITER(X, Y, Z) \ - Size X##minXX = (Y), X##maxXX = (Z); \ - for (Size X = X##minXX; X < X##maxXX; ++X) - - /** - * Puts Y1 and Y2 in current scope - * Iterates X1 from Y1 to Z1 and X2 from Y2 to Z2 - * X2 is the inner index - */ -#define ITER2(X1, X2, Y1, Y2, Z1, Z2) \ - UInt X1##minXX = (Y1), X1##maxXX = (Z1), \ - X2##minXX = (Y2), X2##maxXX = (Z2); \ - for (Size X1 = X1##minXX; X1 < X1##maxXX; ++X1) \ - for (Size X2 = X2##minXX; X2 < X2##maxXX; ++X2) - - /** - * Iterates with a single index, from 0 to M. - */ -#define ITER_1(M) \ - for (UInt i = 0; i < M; ++i) - - /** - * Iterates over 2 indices, from 0 to M, and 0 to N. - */ -#define ITER_2(M, N) \ - for (UInt i = 0; i < M; ++i) \ - for (UInt j = 0; j < N; ++j) - - /** - * Iterates over 3 indices, from 0 to M, 0 to N, and 0 to P. - */ -#define ITER_3(M, N, P) \ - for (UInt i = 0; i < M; ++i) \ - for (UInt j = 0; j < N; ++j) \ +/** + * Streams the contents of a variable array cast as the given type. + * + * This is used by the NodeProcessor when returing the value of an node's + * outputs in response to the "nodeOPrint" supervisor command, and also when + * returning the value of a node's output or parameters to the tools in response + * to a watch request. + * + * The caller must pass in either a dataType, elemSize, or both. If both are + * specified, then this routine will assert that the elemSize agrees with the + * given dataType. If dataType is not specified, then this routine will pick a + * most likely dataType given the elemSize. + * + * @param outStream [std::ostream] the stream to print to + * @param beginP [Byte*] pointer to the start of the variable + * @param endP [Byte*] pointer to first byte past the end of the variable + * @param dataType [std::string] the data type to print as (optional) + * @retval [std::string] the actual type the variable was printed as. This will + * always be dataType, unless the dataType was unrecognized. + * + * @b Exceptions: + * @li None. + */ +extern std::string PrintVariableArray(std::ostream &outStream, + const Byte *beginP, const Byte *endP, + const std::string &dataType = ""); + +/** + * Sets the contents of a variable array cast as the given type. + * + * This is used by the NodeProcessor when setting the value of an node's outputs + * in response to the "nodeOSet" supervisor command. + * + * @param inStream [std::istream] the stream to fetch the values from + * @param beginP [Byte*] pointer to the start of the variable + * @param endP [Byte*] pointer to first byte past the end of the variable + * @param dataType [std::string] the data type to set as + * + * @b Exceptions: + * @li None. + */ +extern void SetVariableArray(std::istream &inStream, Byte *beginP, Byte *endP, + const std::string &dataType); + +//-------------------------------------------------------------------------------- +// Defines, used as code generators, to make the code more readable + +#define NO_DEFAULTS(X) \ +private: \ + X(); \ + X(const X &); \ + X &operator=(const X &); + +/** + * Puts Y in current scope + * Iterates on whole Z, which must have begin() and end() + */ +#define LOOP(X, Y, Z) \ + X::iterator Y; \ + X::iterator Y##beginXX = (Z).begin(); \ + X::iterator Y##endXX = (Z).end(); \ + for (Y = Y##beginXX; Y != Y##endXX; ++Y) + +/** + * Puts Y in current scope + * Z must have begin() + * Iterates on partial Z, between Z.begin() and Z.begin() + L + */ +#define PARTIAL_LOOP(X, Y, Z, L) \ + X::iterator Y; \ + X::iterator Y##beginXX = (Z).begin(); \ + X::iterator Y##endXX = (Z).begin() + (L); \ + for (Y = Y##beginXX; Y != Y##endXX; ++Y) + +/** + * Puts Y in current scope + * Iterates on whole Z, with a const_iterator + */ +#define CONST_LOOP(X, Y, Z) \ + X::const_iterator Y; \ + X::const_iterator Y##beginXX = (Z).begin(); \ + X::const_iterator Y##endXX = (Z).end(); \ + for (Y = Y##beginXX; Y != Y##endXX; ++Y) + +/** + * Puts Y in current scope + * Iterates from Y to Z by steps of 1 + */ +#define ITER(X, Y, Z) \ + Size X##minXX = (Y), X##maxXX = (Z); \ + for (Size X = X##minXX; X < X##maxXX; ++X) + +/** + * Puts Y1 and Y2 in current scope + * Iterates X1 from Y1 to Z1 and X2 from Y2 to Z2 + * X2 is the inner index + */ +#define ITER2(X1, X2, Y1, Y2, Z1, Z2) \ + UInt X1##minXX = (Y1), X1##maxXX = (Z1), X2##minXX = (Y2), X2##maxXX = (Z2); \ + for (Size X1 = X1##minXX; X1 < X1##maxXX; ++X1) \ + for (Size X2 = X2##minXX; X2 < X2##maxXX; ++X2) + +/** + * Iterates with a single index, from 0 to M. + */ +#define ITER_1(M) for (UInt i = 0; i < M; ++i) + +/** + * Iterates over 2 indices, from 0 to M, and 0 to N. + */ +#define ITER_2(M, N) \ + for (UInt i = 0; i < M; ++i) \ + for (UInt j = 0; j < N; ++j) + +/** + * Iterates over 3 indices, from 0 to M, 0 to N, and 0 to P. + */ +#define ITER_3(M, N, P) \ + for (UInt i = 0; i < M; ++i) \ + for (UInt j = 0; j < N; ++j) \ for (UInt k = 0; k < P; ++k) - - /** - * Iterates over 4 indices, from 0 to M, 0 to N, 0 to P and 0 to Q. - */ -#define ITER_4(M, N, P, Q) \ - for (UInt i = 0; i < M; ++i) \ - for (UInt j = 0; j < N; ++j) \ - for (UInt k = 0; k < P; ++k) \ + +/** + * Iterates over 4 indices, from 0 to M, 0 to N, 0 to P and 0 to Q. + */ +#define ITER_4(M, N, P, Q) \ + for (UInt i = 0; i < M; ++i) \ + for (UInt j = 0; j < N; ++j) \ + for (UInt k = 0; k < P; ++k) \ for (UInt l = 0; l < Q; ++l) - /** - * Iterates over 5 indices. - */ -#define ITER_5(M, N, P, Q, R) \ - for (UInt i = 0; i < M; ++i) \ - for (UInt j = 0; j < N; ++j) \ - for (UInt k = 0; k < P; ++k) \ - for (UInt l = 0; l < Q; ++l) \ +/** + * Iterates over 5 indices. + */ +#define ITER_5(M, N, P, Q, R) \ + for (UInt i = 0; i < M; ++i) \ + for (UInt j = 0; j < N; ++j) \ + for (UInt k = 0; k < P; ++k) \ + for (UInt l = 0; l < Q; ++l) \ for (UInt m = 0; m < R; ++m) - /** - * Iterates over 6 indices. - */ -#define ITER_6(M, N, P, Q, R, S) \ - for (UInt i = 0; i < M; ++i) \ - for (UInt j = 0; j < N; ++j) \ - for (UInt k = 0; k < P; ++k) \ - for (UInt l = 0; l < Q; ++l) \ - for (UInt m = 0; m < R; ++m) \ +/** + * Iterates over 6 indices. + */ +#define ITER_6(M, N, P, Q, R, S) \ + for (UInt i = 0; i < M; ++i) \ + for (UInt j = 0; j < N; ++j) \ + for (UInt k = 0; k < P; ++k) \ + for (UInt l = 0; l < Q; ++l) \ + for (UInt m = 0; m < R; ++m) \ for (UInt n = 0; n < S; ++n) - - /** - * Function object that takes a single argument, a pair (or at least - * a class with the same interface as pair), and returns the pair's - * first element. This is not part of the C++ standard, but usually - * provided by implementations of STL. - */ - template - struct select1st - : public std::unary_function - { - inline const typename Pair::first_type& operator()(const Pair& x) const - { - return x.first; - } - }; - - /** - * Function object that takes a single argument, a pair (or at least - * a class with the same interface as pair), and returns the pair's - * second element. This is not part of the C++ standard, but usually - * provided by implementations of STL. - */ - template - struct select2nd - : public std::unary_function - { - inline const typename Pair::second_type& operator()(const Pair& x) const - { - return x.second; - } - }; -}; // namespace std +/** + * Function object that takes a single argument, a pair (or at least + * a class with the same interface as pair), and returns the pair's + * first element. This is not part of the C++ standard, but usually + * provided by implementations of STL. + */ +template +struct select1st : public std::unary_function { + inline const typename Pair::first_type &operator()(const Pair &x) const { + return x.first; + } +}; + +/** + * Function object that takes a single argument, a pair (or at least + * a class with the same interface as pair), and returns the pair's + * second element. This is not part of the C++ standard, but usually + * provided by implementations of STL. + */ +template +struct select2nd + : public std::unary_function { + inline const typename Pair::second_type &operator()(const Pair &x) const { + return x.second; + } +}; -#endif // NTA_UTILS_HPP +}; // namespace nupic +#endif // NTA_UTILS_HPP diff --git a/src/nupic/ntypes/Array.hpp b/src/nupic/ntypes/Array.hpp index dd1c8b973d..9e51d69ea5 100644 --- a/src/nupic/ntypes/Array.hpp +++ b/src/nupic/ntypes/Array.hpp @@ -23,7 +23,7 @@ // --- // // Definitions for the Array class -// +// // It is a sub-class of ArrayBase that owns its buffer // // --- @@ -34,34 +34,27 @@ #include #include -namespace nupic -{ - class Array : public ArrayBase - { - public: - Array(NTA_BasicType type, void * buffer, size_t count) : - ArrayBase(type, buffer, count) - { - } - - explicit Array(NTA_BasicType type) : ArrayBase(type) - { - } +namespace nupic { +class Array : public ArrayBase { +public: + Array(NTA_BasicType type, void *buffer, size_t count) + : ArrayBase(type, buffer, count) {} - //Array(const Array & other) : ArrayBase(other) - //{ - //} + explicit Array(NTA_BasicType type) : ArrayBase(type) {} - void invariant() - { - if (!own_) - NTA_THROW << "Array must own its buffer"; - } - private: - // Hide base class method (invalid for Array) - void setBuffer(void * buffer, size_t count); - }; -} + // Array(const Array & other) : ArrayBase(other) + //{ + //} -#endif + void invariant() { + if (!own_) + NTA_THROW << "Array must own its buffer"; + } +private: + // Hide base class method (invalid for Array) + void setBuffer(void *buffer, size_t count); +}; +} // namespace nupic + +#endif diff --git a/src/nupic/ntypes/ArrayBase.cpp b/src/nupic/ntypes/ArrayBase.cpp index d959e14d28..0aeec16b00 100644 --- a/src/nupic/ntypes/ArrayBase.cpp +++ b/src/nupic/ntypes/ArrayBase.cpp @@ -27,9 +27,9 @@ #include // for ostream #include // for size_t -#include -#include #include +#include +#include #include using namespace nupic; @@ -39,12 +39,11 @@ using namespace nupic; * NuPIC always copies data into this buffer * Caller frees buffer when no longer needed. */ -ArrayBase::ArrayBase(NTA_BasicType type, void* buffer, size_t count) : - buffer_((char*)buffer), count_(count), type_(type), own_(false) -{ - if(!BasicType::isValid(type)) - { - NTA_THROW << "Invalid NTA_BasicType " << type << " used in array constructor"; +ArrayBase::ArrayBase(NTA_BasicType type, void *buffer, size_t count) + : buffer_((char *)buffer), count_(count), type_(type), own_(false) { + if (!BasicType::isValid(type)) { + NTA_THROW << "Invalid NTA_BasicType " << type + << " used in array constructor"; } } @@ -53,12 +52,11 @@ ArrayBase::ArrayBase(NTA_BasicType type, void* buffer, size_t count) : * Nupic will either provide a buffer via setBuffer or * ask the ArrayBase to allocate a buffer via allocateBuffer. */ -ArrayBase::ArrayBase(NTA_BasicType type) : - buffer_(nullptr), count_(0), type_(type), own_(false) -{ - if(!BasicType::isValid(type)) - { - NTA_THROW << "Invalid NTA_BasicType " << type << " used in array constructor"; +ArrayBase::ArrayBase(NTA_BasicType type) + : buffer_(nullptr), count_(0), type_(type), own_(false) { + if (!BasicType::isValid(type)) { + NTA_THROW << "Invalid NTA_BasicType " << type + << " used in array constructor"; } } @@ -66,46 +64,36 @@ ArrayBase::ArrayBase(NTA_BasicType type) : * The destructor calls releaseBuffer() to make sure the ArrayBase * doesn't leak. */ -ArrayBase::~ArrayBase() -{ - releaseBuffer(); -} +ArrayBase::~ArrayBase() { releaseBuffer(); } /** * Ask ArrayBase to allocate its buffer */ -void -ArrayBase::allocateBuffer(size_t count) -{ - if (buffer_ != nullptr) - { - NTA_THROW << "allocateBuffer -- buffer already set. Use releaseBuffer first"; +void ArrayBase::allocateBuffer(size_t count) { + if (buffer_ != nullptr) { + NTA_THROW + << "allocateBuffer -- buffer already set. Use releaseBuffer first"; } count_ = count; - //Note that you can allocate a buffer of size zero. - //The C++ spec (5.3.4/7) requires such a new request to return - //a non-NULL value which is safe to delete. This allows us to - //disambiguate uninitialized ArrayBases and ArrayBases initialized with - //size zero. + // Note that you can allocate a buffer of size zero. + // The C++ spec (5.3.4/7) requires such a new request to return + // a non-NULL value which is safe to delete. This allows us to + // disambiguate uninitialized ArrayBases and ArrayBases initialized with + // size zero. buffer_ = new char[count_ * BasicType::getSize(type_)]; own_ = true; } -void -ArrayBase::setBuffer(void *buffer, size_t count) -{ - if (buffer_ != nullptr) - { +void ArrayBase::setBuffer(void *buffer, size_t count) { + if (buffer_ != nullptr) { NTA_THROW << "setBuffer -- buffer already set. Use releaseBuffer first"; } - buffer_ = (char*)buffer; + buffer_ = (char *)buffer; count_ = count; own_ = false; } -void -ArrayBase::releaseBuffer() -{ +void ArrayBase::releaseBuffer() { if (buffer_ == nullptr) return; if (own_) @@ -114,87 +102,65 @@ ArrayBase::releaseBuffer() count_ = 0; } -void* -ArrayBase::getBuffer() const -{ - return buffer_; -} +void *ArrayBase::getBuffer() const { return buffer_; } // number of elements of given type in the buffer -size_t -ArrayBase::getCount() const -{ - return count_; -}; - -NTA_BasicType -ArrayBase::getType() const -{ - return type_; -}; - - - -namespace nupic -{ - std::ostream& operator<<(std::ostream& outStream, const ArrayBase& a) - { - auto const inbuf = a.getBuffer(); - auto const numElements = a.getCount(); - auto const elementType = a.getType(); - - switch (elementType) - { - case NTA_BasicType_Byte: - ArrayBase::_templatedStreamBuffer(outStream, inbuf, +size_t ArrayBase::getCount() const { return count_; }; + +NTA_BasicType ArrayBase::getType() const { return type_; }; + +namespace nupic { +std::ostream &operator<<(std::ostream &outStream, const ArrayBase &a) { + auto const inbuf = a.getBuffer(); + auto const numElements = a.getCount(); + auto const elementType = a.getType(); + + switch (elementType) { + case NTA_BasicType_Byte: + ArrayBase::_templatedStreamBuffer(outStream, inbuf, numElements); + break; + case NTA_BasicType_Int16: + ArrayBase::_templatedStreamBuffer(outStream, inbuf, numElements); + break; + case NTA_BasicType_UInt16: + ArrayBase::_templatedStreamBuffer(outStream, inbuf, + numElements); + break; + case NTA_BasicType_Int32: + ArrayBase::_templatedStreamBuffer(outStream, inbuf, numElements); + break; + case NTA_BasicType_UInt32: + ArrayBase::_templatedStreamBuffer(outStream, inbuf, numElements); - break; - case NTA_BasicType_Int16: - ArrayBase::_templatedStreamBuffer(outStream, inbuf, - numElements); - break; - case NTA_BasicType_UInt16: - ArrayBase::_templatedStreamBuffer(outStream, inbuf, - numElements); - break; - case NTA_BasicType_Int32: - ArrayBase::_templatedStreamBuffer(outStream, inbuf, - numElements); - break; - case NTA_BasicType_UInt32: - ArrayBase::_templatedStreamBuffer(outStream, inbuf, - numElements); - break; - case NTA_BasicType_Int64: - ArrayBase::_templatedStreamBuffer(outStream, inbuf, - numElements); - break; - case NTA_BasicType_UInt64: - ArrayBase::_templatedStreamBuffer(outStream, inbuf, - numElements); - break; - case NTA_BasicType_Real32: - ArrayBase::_templatedStreamBuffer(outStream, inbuf, - numElements); - break; - case NTA_BasicType_Real64: - ArrayBase::_templatedStreamBuffer(outStream, inbuf, - numElements); - break; - case NTA_BasicType_Handle: - ArrayBase::_templatedStreamBuffer(outStream, inbuf, - numElements); - break; - case NTA_BasicType_Bool: - ArrayBase::_templatedStreamBuffer(outStream, inbuf, - numElements); - break; - default: - NTA_THROW << "Unexpected Element Type: " << elementType; - break; - } - - return outStream; + break; + case NTA_BasicType_Int64: + ArrayBase::_templatedStreamBuffer(outStream, inbuf, numElements); + break; + case NTA_BasicType_UInt64: + ArrayBase::_templatedStreamBuffer(outStream, inbuf, + numElements); + break; + case NTA_BasicType_Real32: + ArrayBase::_templatedStreamBuffer(outStream, inbuf, + numElements); + break; + case NTA_BasicType_Real64: + ArrayBase::_templatedStreamBuffer(outStream, inbuf, + numElements); + break; + case NTA_BasicType_Handle: + ArrayBase::_templatedStreamBuffer(outStream, inbuf, + numElements); + break; + case NTA_BasicType_Bool: + ArrayBase::_templatedStreamBuffer(outStream, inbuf, numElements); + break; + default: + NTA_THROW << "Unexpected Element Type: " << elementType; + break; } + return outStream; +} + } // namespace nupic diff --git a/src/nupic/ntypes/ArrayBase.hpp b/src/nupic/ntypes/ArrayBase.hpp index 009dfa6b40..fd1c51f197 100644 --- a/src/nupic/ntypes/ArrayBase.hpp +++ b/src/nupic/ntypes/ArrayBase.hpp @@ -22,15 +22,15 @@ /** @file * Definitions for the ArrayBase class - * - * An ArrayBase object contains a memory buffer that is used for - * implementing zero-copy and one-copy operations in NuPIC. - * An ArrayBase contains: - * - a pointer to a buffer - * - a length - * - a type - * - a flag indicating whether or not the object owns the buffer. - */ + * + * An ArrayBase object contains a memory buffer that is used for + * implementing zero-copy and one-copy operations in NuPIC. + * An ArrayBase contains: + * - a pointer to a buffer + * - a length + * - a type + * - a flag indicating whether or not the object owns the buffer. + */ #ifndef NTA_ARRAY_BASE_HPP #define NTA_ARRAY_BASE_HPP @@ -41,106 +41,92 @@ #include -namespace nupic -{ +namespace nupic { +/** + * An ArrayBase is used for passing arrays of data back and forth between + * a client application and NuPIC, minimizing copying. It facilitates + * both zero-copy and one-copy operations. + */ +class ArrayBase { +public: + /** + * Caller provides a buffer to use. + * NuPIC always copies data into this buffer + * Caller frees buffer when no longer needed. + */ + ArrayBase(NTA_BasicType type, void *buffer, size_t count); + + /** + * Caller does not provide a buffer -- + * Nupic will either provide a buffer via setBuffer or + * ask the ArrayBase to allocate a buffer via allocateBuffer. + */ + explicit ArrayBase(NTA_BasicType type); + + /** + * The destructor ensures the array doesn't leak its buffer (if + * it owns it). + */ + virtual ~ArrayBase(); + /** - * An ArrayBase is used for passing arrays of data back and forth between - * a client application and NuPIC, minimizing copying. It facilitates - * both zero-copy and one-copy operations. + * Ask ArrayBase to allocate its buffer */ - class ArrayBase - { - public: - /** - * Caller provides a buffer to use. - * NuPIC always copies data into this buffer - * Caller frees buffer when no longer needed. - */ - ArrayBase(NTA_BasicType type, void* buffer, size_t count); - - /** - * Caller does not provide a buffer -- - * Nupic will either provide a buffer via setBuffer or - * ask the ArrayBase to allocate a buffer via allocateBuffer. - */ - explicit ArrayBase(NTA_BasicType type); - - /** - * The destructor ensures the array doesn't leak its buffer (if - * it owns it). - */ - virtual ~ArrayBase(); - - - /** - * Ask ArrayBase to allocate its buffer - */ - void - allocateBuffer(size_t count); - - void - setBuffer(void *buffer, size_t count); - - void - releaseBuffer(); - - void* - getBuffer() const; - - // number of elements of given type in the buffer - size_t - getCount() const; - - NTA_BasicType - getType() const; - - protected: - // buffer_ is typed so that we can use new/delete - // cast to/from void* as necessary - char* buffer_; - size_t count_; - NTA_BasicType type_; - bool own_; - - private: - /** - * Element-type-specific templated function for streaming elements to - * ostream. Elements are comma+space-separated and enclosed in braces. - * - * @param outStream output stream - * @param inbuf input buffer - * @param numElements number of elements to use from the beginning of buffer - */ - template - static void _templatedStreamBuffer(std::ostream& outStream, - const void* inbuf, - size_t numElements) - { - outStream << "("; - - // Stream the elements - auto it = (const SourceElementT*)inbuf; - auto const end = it + numElements; - if (it < end) - { - for (; it < end - 1; ++it) - { - outStream << *it << ", "; - } - - outStream << *it; // final element without the comma + void allocateBuffer(size_t count); + + void setBuffer(void *buffer, size_t count); + + void releaseBuffer(); + + void *getBuffer() const; + + // number of elements of given type in the buffer + size_t getCount() const; + + NTA_BasicType getType() const; + +protected: + // buffer_ is typed so that we can use new/delete + // cast to/from void* as necessary + char *buffer_; + size_t count_; + NTA_BasicType type_; + bool own_; + +private: + /** + * Element-type-specific templated function for streaming elements to + * ostream. Elements are comma+space-separated and enclosed in braces. + * + * @param outStream output stream + * @param inbuf input buffer + * @param numElements number of elements to use from the beginning of buffer + */ + template + static void _templatedStreamBuffer(std::ostream &outStream, const void *inbuf, + size_t numElements) { + outStream << "("; + + // Stream the elements + auto it = (const SourceElementT *)inbuf; + auto const end = it + numElements; + if (it < end) { + for (; it < end - 1; ++it) { + outStream << *it << ", "; } - outStream << ")"; + outStream << *it; // final element without the comma } - friend std::ostream& operator<<(std::ostream&, const ArrayBase&); - }; + outStream << ")"; + } - // Serialization for diagnostic purposes - std::ostream& operator<<(std::ostream&, const ArrayBase&); + friend std::ostream &operator<<(std::ostream &, const ArrayBase &); +}; -} +// Serialization for diagnostic purposes +std::ostream &operator<<(std::ostream &, const ArrayBase &); -#endif +} // namespace nupic +#endif diff --git a/src/nupic/ntypes/ArrayRef.hpp b/src/nupic/ntypes/ArrayRef.hpp index 07dbd51ea2..2d606ee51d 100644 --- a/src/nupic/ntypes/ArrayRef.hpp +++ b/src/nupic/ntypes/ArrayRef.hpp @@ -23,7 +23,7 @@ // --- // // Definitions for the ArrayRef class -// +// // It is a sub-class of ArrayBase that doesn't own its buffer // // --- @@ -34,34 +34,26 @@ #include #include -namespace nupic -{ - class ArrayRef : public ArrayBase - { - public: - ArrayRef(NTA_BasicType type, void * buffer, size_t count) : ArrayBase(type) - { - setBuffer(buffer, count); - } - - explicit ArrayRef(NTA_BasicType type) : ArrayBase(type) - { - } +namespace nupic { +class ArrayRef : public ArrayBase { +public: + ArrayRef(NTA_BasicType type, void *buffer, size_t count) : ArrayBase(type) { + setBuffer(buffer, count); + } - ArrayRef(const ArrayRef & other) : ArrayBase(other) - { - } - - void invariant() - { - if (own_) - NTA_THROW << "ArrayRef mmust not own its buffer"; - } - private: - // Hide base class method (invalid for ArrayRef) - void allocateBuffer(void * buffer, size_t count); - }; -} + explicit ArrayRef(NTA_BasicType type) : ArrayBase(type) {} -#endif + ArrayRef(const ArrayRef &other) : ArrayBase(other) {} + + void invariant() { + if (own_) + NTA_THROW << "ArrayRef mmust not own its buffer"; + } +private: + // Hide base class method (invalid for ArrayRef) + void allocateBuffer(void *buffer, size_t count); +}; +} // namespace nupic + +#endif diff --git a/src/nupic/ntypes/Buffer.cpp b/src/nupic/ntypes/Buffer.cpp index 06dd20feba..a929c8ccd8 100644 --- a/src/nupic/ntypes/Buffer.cpp +++ b/src/nupic/ntypes/Buffer.cpp @@ -20,804 +20,696 @@ * --------------------------------------------------------------------- */ -/** @file -*/ +/** @file + */ +#include +#include #include #include #include -#include -#include -namespace nupic -{ - - // ----------------------------------------- - // - // R E A D B U F F E R - // - // ----------------------------------------- - - NTA_Size staticReadBufferGetSize(NTA_ReadBufferHandle handle) - { - NTA_CHECK(handle != nullptr); - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->getSize(); - } +namespace nupic { - const NTA_Byte * staticGetData(NTA_ReadBufferHandle handle) - { - NTA_CHECK(handle != nullptr); - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->getData(); - } +// ----------------------------------------- +// +// R E A D B U F F E R +// +// ----------------------------------------- - void staticReset(NTA_ReadBufferHandle handle) - { - NTA_CHECK(handle != nullptr); - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->reset(); - } - - static NTA_Int32 staticReadByte(NTA_ReadBufferHandle handle, NTA_Byte * value) - { - if (!handle || !value) - return -1; - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->read(*value); - } +NTA_Size staticReadBufferGetSize(NTA_ReadBufferHandle handle) { + NTA_CHECK(handle != nullptr); - NTA_Int32 staticReadByteArray(NTA_ReadBufferHandle handle, NTA_Byte * value, NTA_Size * size) - { - if (!handle || !value || !size || *size <= 0) - return -1; - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->read(value, *size); - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->getSize(); +} +const NTA_Byte *staticGetData(NTA_ReadBufferHandle handle) { + NTA_CHECK(handle != nullptr); - NTA_Int32 staticReadString(NTA_ReadBufferHandle handle, - NTA_Byte ** value, - NTA_UInt32 * size, - NTA_Byte *(*fAlloc)(NTA_UInt32), - void (*fDealloc)(NTA_Byte *) - ) - { - if (!handle || !value) { - return -1; - } - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->readString(*value, *size, fAlloc, fDealloc); - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->getData(); +} +void staticReset(NTA_ReadBufferHandle handle) { + NTA_CHECK(handle != nullptr); - static NTA_Int32 staticReadUInt32(NTA_ReadBufferHandle handle, NTA_UInt32 * value) - { - if (!handle || !value) - return -1; - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->read(*value); - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->reset(); +} - NTA_Int32 staticReadUInt32Array(NTA_ReadBufferHandle handle, NTA_UInt32 * value, NTA_Size size) - { - if (!handle || !value || size <= 0) - return -1; - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->read(value, size); - } - - static NTA_Int32 staticReadInt32(NTA_ReadBufferHandle handle, NTA_Int32 * value) - { - if (!handle || !value) - return -1; - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->read(*value); - } +static NTA_Int32 staticReadByte(NTA_ReadBufferHandle handle, NTA_Byte *value) { + if (!handle || !value) + return -1; - NTA_Int32 staticReadInt32Array(NTA_ReadBufferHandle handle, NTA_Int32 * value, NTA_Size size) - { - if (!handle || !value || size <= 0) - return -1; - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->read(value, size); - } - - static NTA_Int32 staticReadUInt64(NTA_ReadBufferHandle handle, NTA_UInt64 * value) - { - if (!handle || !value) - return -1; - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->read(*value); - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->read(*value); +} - NTA_Int32 staticReadUInt64Array(NTA_ReadBufferHandle handle, NTA_UInt64 * value, NTA_Size size) - { - if (!handle || !value || size <= 0) - return -1; - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->read(value, size); - } +NTA_Int32 staticReadByteArray(NTA_ReadBufferHandle handle, NTA_Byte *value, + NTA_Size *size) { + if (!handle || !value || !size || *size <= 0) + return -1; - static NTA_Int32 staticReadInt64(NTA_ReadBufferHandle handle, NTA_Int64 * value) - { - if (!handle || !value) - return -1; - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->read(*value); - } - - NTA_Int32 staticReadInt64Array(NTA_ReadBufferHandle handle, NTA_Int64 * value, NTA_Size size) - { - if (!handle || !value || size <= 0) - return -1; - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->read(value, size); - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->read(value, *size); +} - static NTA_Int32 staticReadReal32(NTA_ReadBufferHandle handle, NTA_Real32 * value) - { - if (!handle || !value) - return -1; - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->read(*value); - } - - NTA_Int32 staticReadReal32Array(NTA_ReadBufferHandle handle, NTA_Real32 * value, NTA_Size size) - { - if (!handle || !value || size <= 0) - return -1; - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->read(value, size); +NTA_Int32 staticReadString(NTA_ReadBufferHandle handle, NTA_Byte **value, + NTA_UInt32 *size, NTA_Byte *(*fAlloc)(NTA_UInt32), + void (*fDealloc)(NTA_Byte *)) { + if (!handle || !value) { + return -1; } - static NTA_Int32 staticReadReal64(NTA_ReadBufferHandle handle, NTA_Real64 * value) - { - if (!handle || !value) - return -1; - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->read(*value); - } - - NTA_Int32 staticReadReal64Array(NTA_ReadBufferHandle handle, NTA_Real64 * value, NTA_Size size) - { - if (!handle || !value || size <= 0) - return -1; - - ReadBuffer * rb = reinterpret_cast(handle); - return rb->read(value, size); - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->readString(*value, *size, fAlloc, fDealloc); +} - ReadBuffer::ReadBuffer(const char * bytes, Size size, bool copy) : - bytes_(copy ? new Byte[size] : nullptr), - memStream_(copy ? bytes_.get() : bytes, size) - { - // Copy the buffer to the internal bytes_ array (because - // MemStream needs persistent external storage if copy==true - if (copy) - ::memcpy(bytes_.get(), bytes, size); - - // Turn on exceptions for memStream_ - memStream_.exceptions(std::ostream::failbit | std::ostream::badbit); - - // Initialize the NTA_Readbuffer struct - handle = reinterpret_cast(this); - NTA_ReadBuffer::reset = staticReset; - NTA_ReadBuffer::getSize = staticReadBufferGetSize; - NTA_ReadBuffer::getData = staticGetData; - - readByte = staticReadByte; - readByteArray = staticReadByteArray; - readAsString = staticReadString; - - readInt32 = staticReadInt32; - readInt32Array = staticReadInt32Array; - readUInt32 = staticReadUInt32; - readUInt32Array = staticReadUInt32Array; - - readInt64 = staticReadInt64; - readInt64Array = staticReadInt64Array; - readUInt64 = staticReadUInt64; - readUInt64Array = staticReadUInt64Array; - - readReal32 = staticReadReal32; - readReal32Array = staticReadReal32Array; - readReal64 = staticReadReal64; - readReal64Array = staticReadReal64Array; - } - - ReadBuffer::ReadBuffer(const ReadBuffer & other) - { - assign(other); - } - - ReadBuffer & ReadBuffer::operator=(const ReadBuffer & other) - { - assign(other); - return *this; - } - - void ReadBuffer::assign(const ReadBuffer & other) - { - handle = reinterpret_cast(this); - NTA_ReadBuffer::reset = staticReset; - NTA_ReadBuffer::getSize = staticReadBufferGetSize; - NTA_ReadBuffer::getData = staticGetData; - - readByte = staticReadByte; - readByteArray = staticReadByteArray; - readAsString = staticReadString; - - readInt32 = staticReadInt32; - readInt32Array = staticReadInt32Array; - readUInt32 = staticReadUInt32; - readUInt32Array = staticReadUInt32Array; - - readInt64 = staticReadInt64; - readInt64Array = staticReadInt64Array; - readUInt64 = staticReadUInt64; - readUInt64Array = staticReadUInt64Array; - - readReal32 = staticReadReal32; - readReal32Array = staticReadReal32Array; - readReal64 = staticReadReal64; - readReal64Array = staticReadReal64Array; - - bytes_ = other.bytes_; - memStream_.str(bytes_.get(), other.getSize()); - } +static NTA_Int32 staticReadUInt32(NTA_ReadBufferHandle handle, + NTA_UInt32 *value) { + if (!handle || !value) + return -1; - void ReadBuffer::reset() const - { - IMemStream::memStreamBufType_ * s = static_cast(memStream_.rdbuf()); - s->setg(bytes_.get(), bytes_.get(), bytes_.get()+memStream_.pcount()); - memStream_.clear(); - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->read(*value); +} - Size ReadBuffer::getSize() const - { - return (Size)memStream_.pcount(); - } +NTA_Int32 staticReadUInt32Array(NTA_ReadBufferHandle handle, NTA_UInt32 *value, + NTA_Size size) { + if (!handle || !value || size <= 0) + return -1; - const char * ReadBuffer::getData() const - { - return memStream_.str(); - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->read(value, size); +} - Int32 ReadBuffer::read(Byte & value) const - { - return readT(value); - } +static NTA_Int32 staticReadInt32(NTA_ReadBufferHandle handle, + NTA_Int32 *value) { + if (!handle || !value) + return -1; - Int32 ReadBuffer::read(Byte * bytes, Size & size) const - { - ReadBuffer * r = const_cast(this); - try - { - size = r->memStream_.readsome(bytes, size); - return 0; - } - catch (...) - { - size = 0; - return -1; - } - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->read(*value); +} - Int32 ReadBuffer::read(Int32 & value) const - { - return readT(value); - } - - Int32 ReadBuffer::read(Int32 * value, Size size) const - { - return readT(value, size); - } +NTA_Int32 staticReadInt32Array(NTA_ReadBufferHandle handle, NTA_Int32 *value, + NTA_Size size) { + if (!handle || !value || size <= 0) + return -1; - Int32 ReadBuffer::read(UInt32 & value) const - { - return readT(value); - } - - Int32 ReadBuffer::read(UInt32 * value, Size size) const - { - return readT(value, size); - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->read(value, size); +} - Int32 ReadBuffer::read(Int64 & value) const - { - return readT(value); - } - - Int32 ReadBuffer::read(Int64 * value, Size size) const - { - return readT(value, size); - } +static NTA_Int32 staticReadUInt64(NTA_ReadBufferHandle handle, + NTA_UInt64 *value) { + if (!handle || !value) + return -1; - Int32 ReadBuffer::read(UInt64 & value) const - { - return readT(value); - } - - Int32 ReadBuffer::read(UInt64 * value, Size size) const - { - return readT(value, size); - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->read(*value); +} - Int32 ReadBuffer::read(Real32 & value) const - { - return readT(value); - } +NTA_Int32 staticReadUInt64Array(NTA_ReadBufferHandle handle, NTA_UInt64 *value, + NTA_Size size) { + if (!handle || !value || size <= 0) + return -1; - Int32 ReadBuffer::read(Real32 * value, Size size) const - { - return readT(value, size); - } - - Int32 ReadBuffer::read(Real64 & value) const - { - return readT(value); - } - - Int32 ReadBuffer::read(Real64 * value, Size size) const - { - return readT(value, size); - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->read(value, size); +} - Int32 ReadBuffer::read(bool & value) const - { - return readT(value); - } +static NTA_Int32 staticReadInt64(NTA_ReadBufferHandle handle, + NTA_Int64 *value) { + if (!handle || !value) + return -1; - Int32 ReadBuffer::read(bool * value, Size size) const - { - return readT(value, size); - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->read(*value); +} - inline Int32 findWithLeadingWhitespace(const ReadBuffer &r, char c, int maxSearch) - { - char dummy; - Int32 result; - for(int i=0; i(handle); + return rb->read(value, size); +} - typedef NTA_Byte *(*fp_alloc)(NTA_UInt32); - typedef void (*fp_dealloc)(NTA_Byte *); - - Int32 ReadBuffer::readString( - NTA_Byte * &value, - NTA_UInt32 &size, - fp_alloc fAlloc, - fp_dealloc fDealloc - ) const - { - NTA_ASSERT(fDealloc || !fAlloc); // Assume new/delete if unspecified. - value = nullptr; - size = 0; - Int32 result = findWithLeadingWhitespace(*this, "", 1); - } +static NTA_Int32 staticReadReal32(NTA_ReadBufferHandle handle, + NTA_Real32 *value) { + if (!handle || !value) + return -1; - // ------------------------------------------ - // - // R E A D B U F F E R I T E R A T O R - // - // -----------------------------------------= - static const NTA_ReadBuffer * staticNext(NTA_ReadBufferIteratorHandle handle) - { - NTA_CHECK(handle != nullptr); - - ReadBufferIterator * rbi = static_cast(reinterpret_cast(handle)); - return static_cast(rbi->next()); - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->read(*value); +} - static void staticReset(NTA_ReadBufferIteratorHandle handle) - { - NTA_CHECK(handle != nullptr); - - ReadBufferIterator * rbi = static_cast(reinterpret_cast(handle)); - return rbi->reset(); - } +NTA_Int32 staticReadReal32Array(NTA_ReadBufferHandle handle, NTA_Real32 *value, + NTA_Size size) { + if (!handle || !value || size <= 0) + return -1; - ReadBufferIterator::ReadBufferIterator(ReadBufferVec & rbv) : - readBufferVec_(rbv), - index_(0) - { - // Initialize the NTA_ReadbufferIterator struct - NTA_ReadBufferIterator::handle = reinterpret_cast(static_cast(this)); - NTA_ReadBufferIterator::next = staticNext; - NTA_ReadBufferIterator::reset = staticReset; - } - - const IReadBuffer * ReadBufferIterator::next() - { - if (index_ == readBufferVec_.size()) - return nullptr; - - return readBufferVec_[index_++]; - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->read(value, size); +} - void ReadBufferIterator::reset() - { - index_ = 0; - } - // ----------------------------------------- - // - // W R I T E B U F F E R - // - // ----------------------------------------- - NTA_Int32 staticWriteUInt32(NTA_WriteBufferHandle handle, NTA_UInt32 value) - { - NTA_CHECK(handle != nullptr); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->write(value); - } - - NTA_Int32 staticWriteUInt32Array(NTA_WriteBufferHandle handle, const NTA_UInt32 * value, NTA_Size size) - { - NTA_CHECK(handle != nullptr); - NTA_CHECK(value != nullptr); - NTA_CHECK(size > 0); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->write(value, size); - } +static NTA_Int32 staticReadReal64(NTA_ReadBufferHandle handle, + NTA_Real64 *value) { + if (!handle || !value) + return -1; - NTA_Int32 staticWriteInt32(NTA_WriteBufferHandle handle, NTA_Int32 value) - { - NTA_CHECK(handle != nullptr); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->write(value); - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->read(*value); +} - NTA_Int32 staticWriteInt32Array(NTA_WriteBufferHandle handle, const NTA_Int32 * value, NTA_Size size) - { - NTA_CHECK(handle != nullptr); - NTA_CHECK(value != nullptr); - NTA_CHECK(size > 0); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->write(value, size); - } +NTA_Int32 staticReadReal64Array(NTA_ReadBufferHandle handle, NTA_Real64 *value, + NTA_Size size) { + if (!handle || !value || size <= 0) + return -1; - NTA_Int32 staticWriteInt64(NTA_WriteBufferHandle handle, NTA_Int64 value) - { - NTA_CHECK(handle != nullptr); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->write(value); - } - - NTA_Int32 staticWriteInt64Array(NTA_WriteBufferHandle handle, const NTA_Int64 * value, NTA_Size size) - { - NTA_CHECK(handle != nullptr); - NTA_CHECK(value != nullptr); - NTA_CHECK(size > 0); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->write(value, size); - } + ReadBuffer *rb = reinterpret_cast(handle); + return rb->read(value, size); +} - NTA_Int32 staticWriteUInt64(NTA_WriteBufferHandle handle, NTA_UInt64 value) - { - NTA_CHECK(handle != nullptr); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->write(value); - } - - NTA_Int32 staticWriteUInt64Array(NTA_WriteBufferHandle handle, const NTA_UInt64 * value, NTA_Size size) - { - NTA_CHECK(handle != nullptr); - NTA_CHECK(value != nullptr); - NTA_CHECK(size > 0); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->write(value, size); - } +ReadBuffer::ReadBuffer(const char *bytes, Size size, bool copy) + : bytes_(copy ? new Byte[size] : nullptr), + memStream_(copy ? bytes_.get() : bytes, size) { + // Copy the buffer to the internal bytes_ array (because + // MemStream needs persistent external storage if copy==true + if (copy) + ::memcpy(bytes_.get(), bytes, size); + + // Turn on exceptions for memStream_ + memStream_.exceptions(std::ostream::failbit | std::ostream::badbit); + + // Initialize the NTA_Readbuffer struct + handle = reinterpret_cast(this); + NTA_ReadBuffer::reset = staticReset; + NTA_ReadBuffer::getSize = staticReadBufferGetSize; + NTA_ReadBuffer::getData = staticGetData; + + readByte = staticReadByte; + readByteArray = staticReadByteArray; + readAsString = staticReadString; + + readInt32 = staticReadInt32; + readInt32Array = staticReadInt32Array; + readUInt32 = staticReadUInt32; + readUInt32Array = staticReadUInt32Array; + + readInt64 = staticReadInt64; + readInt64Array = staticReadInt64Array; + readUInt64 = staticReadUInt64; + readUInt64Array = staticReadUInt64Array; + + readReal32 = staticReadReal32; + readReal32Array = staticReadReal32Array; + readReal64 = staticReadReal64; + readReal64Array = staticReadReal64Array; +} - NTA_Int32 staticWriteReal32(NTA_WriteBufferHandle handle, NTA_Real32 value) - { - NTA_CHECK(handle != nullptr); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->write(value); - } - - NTA_Int32 staticWriteReal32Array(NTA_WriteBufferHandle handle, const NTA_Real32 * value, NTA_Size size) - { - NTA_CHECK(handle != nullptr); - NTA_CHECK(value != nullptr); - NTA_CHECK(size > 0); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->write(value, size); - } +ReadBuffer::ReadBuffer(const ReadBuffer &other) { assign(other); } - NTA_Int32 staticWriteReal64(NTA_WriteBufferHandle handle, NTA_Real64 value) - { - NTA_CHECK(handle != nullptr); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->write(value); - } +ReadBuffer &ReadBuffer::operator=(const ReadBuffer &other) { + assign(other); + return *this; +} - NTA_Int32 staticWriteReal64Array(NTA_WriteBufferHandle handle, const NTA_Real64 * value, NTA_Size size) - { - NTA_CHECK(handle != nullptr); - NTA_CHECK(value != nullptr); - NTA_CHECK(size > 0); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->write(value, size); - } +void ReadBuffer::assign(const ReadBuffer &other) { + handle = reinterpret_cast(this); + NTA_ReadBuffer::reset = staticReset; + NTA_ReadBuffer::getSize = staticReadBufferGetSize; + NTA_ReadBuffer::getData = staticGetData; + + readByte = staticReadByte; + readByteArray = staticReadByteArray; + readAsString = staticReadString; + + readInt32 = staticReadInt32; + readInt32Array = staticReadInt32Array; + readUInt32 = staticReadUInt32; + readUInt32Array = staticReadUInt32Array; + + readInt64 = staticReadInt64; + readInt64Array = staticReadInt64Array; + readUInt64 = staticReadUInt64; + readUInt64Array = staticReadUInt64Array; + + readReal32 = staticReadReal32; + readReal32Array = staticReadReal32Array; + readReal64 = staticReadReal64; + readReal64Array = staticReadReal64Array; + + bytes_ = other.bytes_; + memStream_.str(bytes_.get(), other.getSize()); +} - NTA_Int32 staticWriteByte(NTA_WriteBufferHandle handle, NTA_Byte value) - { - NTA_CHECK(handle != nullptr); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->write(value); - } +void ReadBuffer::reset() const { + IMemStream::memStreamBufType_ *s = + static_cast(memStream_.rdbuf()); + s->setg(bytes_.get(), bytes_.get(), bytes_.get() + memStream_.pcount()); + memStream_.clear(); +} - NTA_Int32 staticWriteByteArray(NTA_WriteBufferHandle handle, const NTA_Byte * value, NTA_Size size) - { - NTA_CHECK(handle != nullptr); - NTA_CHECK(value != nullptr); - NTA_CHECK(size > 0); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->write(value, size); - } +Size ReadBuffer::getSize() const { return (Size)memStream_.pcount(); } - NTA_Int32 staticWriteString(NTA_WriteBufferHandle handle, const NTA_Byte * value, NTA_Size size) - { - NTA_CHECK(handle != nullptr); - NTA_CHECK(value != nullptr); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->writeString(value, size); - } +const char *ReadBuffer::getData() const { return memStream_.str(); } - const Byte * staticGetData(NTA_WriteBufferHandle handle) - { - NTA_CHECK(handle != nullptr); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->getData(); - } - - NTA_Size staticWriteBufferGetSize(NTA_WriteBufferHandle handle) - { - NTA_CHECK(handle != nullptr); - - WriteBuffer * wb = reinterpret_cast(handle); - return wb->getSize(); - } +Int32 ReadBuffer::read(Byte &value) const { return readT(value); } - WriteBuffer::WriteBuffer() - { - handle = reinterpret_cast(this); - NTA_WriteBuffer::getData = staticGetData; - NTA_WriteBuffer::getSize = staticWriteBufferGetSize; - - writeByte = staticWriteByte; - writeByteArray = staticWriteByteArray; - writeAsString = staticWriteString; - - writeInt32 = staticWriteInt32; - writeInt32Array = staticWriteInt32Array; - writeUInt32 = staticWriteUInt32; - writeUInt32Array = staticWriteUInt32Array; - - writeInt64 = staticWriteInt64; - writeInt64Array = staticWriteInt64Array; - writeUInt64 = staticWriteUInt64; - writeUInt64Array = staticWriteUInt64Array; - - writeReal32 = staticWriteReal32; - writeReal32Array = staticWriteReal32Array; - writeReal64 = staticWriteReal64; - writeReal64Array = staticWriteReal64Array; - - OMemStream::exceptions(std::ostream::failbit | std::ostream::badbit); +Int32 ReadBuffer::read(Byte *bytes, Size &size) const { + ReadBuffer *r = const_cast(this); + try { + size = r->memStream_.readsome(bytes, size); + return 0; + } catch (...) { + size = 0; + return -1; } +} - WriteBuffer::~WriteBuffer() - { - } +Int32 ReadBuffer::read(Int32 &value) const { return readT(value); } - Int32 WriteBuffer::write(Byte value) - { - return writeT(value); - } - - Int32 WriteBuffer::write(const Byte * bytes, Size size) - { - try - { - OMemStream::write(bytes, (std::streamsize)size); +Int32 ReadBuffer::read(Int32 *value, Size size) const { + return readT(value, size); +} + +Int32 ReadBuffer::read(UInt32 &value) const { return readT(value); } + +Int32 ReadBuffer::read(UInt32 *value, Size size) const { + return readT(value, size); +} + +Int32 ReadBuffer::read(Int64 &value) const { return readT(value); } + +Int32 ReadBuffer::read(Int64 *value, Size size) const { + return readT(value, size); +} + +Int32 ReadBuffer::read(UInt64 &value) const { return readT(value); } + +Int32 ReadBuffer::read(UInt64 *value, Size size) const { + return readT(value, size); +} + +Int32 ReadBuffer::read(Real32 &value) const { return readT(value); } + +Int32 ReadBuffer::read(Real32 *value, Size size) const { + return readT(value, size); +} + +Int32 ReadBuffer::read(Real64 &value) const { return readT(value); } + +Int32 ReadBuffer::read(Real64 *value, Size size) const { + return readT(value, size); +} + +Int32 ReadBuffer::read(bool &value) const { return readT(value); } + +Int32 ReadBuffer::read(bool *value, Size size) const { + return readT(value, size); +} + +inline Int32 findWithLeadingWhitespace(const ReadBuffer &r, char c, + int maxSearch) { + char dummy; + Int32 result; + for (int i = 0; i < maxSearch; ++i) { + dummy = 0; + result = r.readT(dummy); + if (result != 0) + return result; + if (dummy == c) return 0; - } - catch (...) - { + else if (!::isspace(dummy)) { return -1; - } + } else + NTA_CHECK(::isspace(dummy)); } + return -1; +} - Int32 WriteBuffer::write(Int32 value) - { - return writeT(value); - } +inline Int32 findWithLeadingWhitespace(const ReadBuffer &r, const char *s, + int maxSearch) { + Int32 result = 0; + while (*s) { + result = findWithLeadingWhitespace(r, *s, maxSearch); + if (result != 0) + return result; + ++s; + maxSearch = 1; + } + return 0; +} - Int32 WriteBuffer::write(const Int32 * value, Size size) - { - return writeT(value, size); - } +typedef NTA_Byte *(*fp_alloc)(NTA_UInt32); +typedef void (*fp_dealloc)(NTA_Byte *); - Int32 WriteBuffer::write(UInt32 value) - { - return writeT(value); +Int32 ReadBuffer::readString(NTA_Byte *&value, NTA_UInt32 &size, + fp_alloc fAlloc, fp_dealloc fDealloc) const { + NTA_ASSERT(fDealloc || !fAlloc); // Assume new/delete if unspecified. + value = nullptr; + size = 0; + Int32 result = findWithLeadingWhitespace(*this, "", 1); +} - Int32 WriteBuffer::write(const UInt32 * value, Size size) - { - return writeT(value, size); - } - - Int32 WriteBuffer::write(Int64 value) - { - return writeT(value); - } +// ------------------------------------------ +// +// R E A D B U F F E R I T E R A T O R +// +// -----------------------------------------= +static const NTA_ReadBuffer *staticNext(NTA_ReadBufferIteratorHandle handle) { + NTA_CHECK(handle != nullptr); + + ReadBufferIterator *rbi = static_cast( + reinterpret_cast(handle)); + return static_cast(rbi->next()); +} - Int32 WriteBuffer::write(const Int64 * value, Size size) - { - return writeT(value, size); - } - - Int32 WriteBuffer::write(UInt64 value) - { - return writeT(value); - } - - Int32 WriteBuffer::write(const UInt64 * value, Size size) - { - return writeT(value, size); - } - - Int32 WriteBuffer::write(Real32 value) - { - return writeT(value); - } +static void staticReset(NTA_ReadBufferIteratorHandle handle) { + NTA_CHECK(handle != nullptr); - Int32 WriteBuffer::write(const Real32 * value, Size size) - { - return writeT(value, size); - } - - Int32 WriteBuffer::write(Real64 value) - { - return writeT(value); - } + ReadBufferIterator *rbi = static_cast( + reinterpret_cast(handle)); + return rbi->reset(); +} - Int32 WriteBuffer::write(const Real64 * value, Size size) - { - return writeT(value, size); - } +ReadBufferIterator::ReadBufferIterator(ReadBufferVec &rbv) + : readBufferVec_(rbv), index_(0) { + // Initialize the NTA_ReadbufferIterator struct + NTA_ReadBufferIterator::handle = + reinterpret_cast( + static_cast(this)); + NTA_ReadBufferIterator::next = staticNext; + NTA_ReadBufferIterator::reset = staticReset; +} - Int32 WriteBuffer::write(bool value) - { - return writeT(value); - } +const IReadBuffer *ReadBufferIterator::next() { + if (index_ == readBufferVec_.size()) + return nullptr; + + return readBufferVec_[index_++]; +} + +void ReadBufferIterator::reset() { index_ = 0; } +// ----------------------------------------- +// +// W R I T E B U F F E R +// +// ----------------------------------------- +NTA_Int32 staticWriteUInt32(NTA_WriteBufferHandle handle, NTA_UInt32 value) { + NTA_CHECK(handle != nullptr); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->write(value); +} + +NTA_Int32 staticWriteUInt32Array(NTA_WriteBufferHandle handle, + const NTA_UInt32 *value, NTA_Size size) { + NTA_CHECK(handle != nullptr); + NTA_CHECK(value != nullptr); + NTA_CHECK(size > 0); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->write(value, size); +} + +NTA_Int32 staticWriteInt32(NTA_WriteBufferHandle handle, NTA_Int32 value) { + NTA_CHECK(handle != nullptr); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->write(value); +} + +NTA_Int32 staticWriteInt32Array(NTA_WriteBufferHandle handle, + const NTA_Int32 *value, NTA_Size size) { + NTA_CHECK(handle != nullptr); + NTA_CHECK(value != nullptr); + NTA_CHECK(size > 0); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->write(value, size); +} + +NTA_Int32 staticWriteInt64(NTA_WriteBufferHandle handle, NTA_Int64 value) { + NTA_CHECK(handle != nullptr); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->write(value); +} + +NTA_Int32 staticWriteInt64Array(NTA_WriteBufferHandle handle, + const NTA_Int64 *value, NTA_Size size) { + NTA_CHECK(handle != nullptr); + NTA_CHECK(value != nullptr); + NTA_CHECK(size > 0); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->write(value, size); +} + +NTA_Int32 staticWriteUInt64(NTA_WriteBufferHandle handle, NTA_UInt64 value) { + NTA_CHECK(handle != nullptr); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->write(value); +} + +NTA_Int32 staticWriteUInt64Array(NTA_WriteBufferHandle handle, + const NTA_UInt64 *value, NTA_Size size) { + NTA_CHECK(handle != nullptr); + NTA_CHECK(value != nullptr); + NTA_CHECK(size > 0); - Int32 WriteBuffer::write(const bool * value, Size size) - { - return writeT(value, size); + WriteBuffer *wb = reinterpret_cast(handle); + return wb->write(value, size); +} + +NTA_Int32 staticWriteReal32(NTA_WriteBufferHandle handle, NTA_Real32 value) { + NTA_CHECK(handle != nullptr); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->write(value); +} + +NTA_Int32 staticWriteReal32Array(NTA_WriteBufferHandle handle, + const NTA_Real32 *value, NTA_Size size) { + NTA_CHECK(handle != nullptr); + NTA_CHECK(value != nullptr); + NTA_CHECK(size > 0); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->write(value, size); +} + +NTA_Int32 staticWriteReal64(NTA_WriteBufferHandle handle, NTA_Real64 value) { + NTA_CHECK(handle != nullptr); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->write(value); +} + +NTA_Int32 staticWriteReal64Array(NTA_WriteBufferHandle handle, + const NTA_Real64 *value, NTA_Size size) { + NTA_CHECK(handle != nullptr); + NTA_CHECK(value != nullptr); + NTA_CHECK(size > 0); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->write(value, size); +} + +NTA_Int32 staticWriteByte(NTA_WriteBufferHandle handle, NTA_Byte value) { + NTA_CHECK(handle != nullptr); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->write(value); +} + +NTA_Int32 staticWriteByteArray(NTA_WriteBufferHandle handle, + const NTA_Byte *value, NTA_Size size) { + NTA_CHECK(handle != nullptr); + NTA_CHECK(value != nullptr); + NTA_CHECK(size > 0); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->write(value, size); +} + +NTA_Int32 staticWriteString(NTA_WriteBufferHandle handle, const NTA_Byte *value, + NTA_Size size) { + NTA_CHECK(handle != nullptr); + NTA_CHECK(value != nullptr); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->writeString(value, size); +} + +const Byte *staticGetData(NTA_WriteBufferHandle handle) { + NTA_CHECK(handle != nullptr); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->getData(); +} + +NTA_Size staticWriteBufferGetSize(NTA_WriteBufferHandle handle) { + NTA_CHECK(handle != nullptr); + + WriteBuffer *wb = reinterpret_cast(handle); + return wb->getSize(); +} + +WriteBuffer::WriteBuffer() { + handle = reinterpret_cast(this); + NTA_WriteBuffer::getData = staticGetData; + NTA_WriteBuffer::getSize = staticWriteBufferGetSize; + + writeByte = staticWriteByte; + writeByteArray = staticWriteByteArray; + writeAsString = staticWriteString; + + writeInt32 = staticWriteInt32; + writeInt32Array = staticWriteInt32Array; + writeUInt32 = staticWriteUInt32; + writeUInt32Array = staticWriteUInt32Array; + + writeInt64 = staticWriteInt64; + writeInt64Array = staticWriteInt64Array; + writeUInt64 = staticWriteUInt64; + writeUInt64Array = staticWriteUInt64Array; + + writeReal32 = staticWriteReal32; + writeReal32Array = staticWriteReal32Array; + writeReal64 = staticWriteReal64; + writeReal64Array = staticWriteReal64Array; + + OMemStream::exceptions(std::ostream::failbit | std::ostream::badbit); +} + +WriteBuffer::~WriteBuffer() {} + +Int32 WriteBuffer::write(Byte value) { return writeT(value); } + +Int32 WriteBuffer::write(const Byte *bytes, Size size) { + try { + OMemStream::write(bytes, (std::streamsize)size); + return 0; + } catch (...) { + return -1; } +} - NTA_Int32 WriteBuffer::writeString(const NTA_Byte *value, NTA_Size size) - { - NTA_Int32 result = write("", 4); +Int32 WriteBuffer::write(Int32 value) { return writeT(value); } + +Int32 WriteBuffer::write(const Int32 *value, Size size) { + return writeT(value, size); +} + +Int32 WriteBuffer::write(UInt32 value) { return writeT(value); } + +Int32 WriteBuffer::write(const UInt32 *value, Size size) { + return writeT(value, size); +} + +Int32 WriteBuffer::write(Int64 value) { return writeT(value); } + +Int32 WriteBuffer::write(const Int64 *value, Size size) { + return writeT(value, size); +} + +Int32 WriteBuffer::write(UInt64 value) { return writeT(value); } + +Int32 WriteBuffer::write(const UInt64 *value, Size size) { + return writeT(value, size); +} + +Int32 WriteBuffer::write(Real32 value) { return writeT(value); } + +Int32 WriteBuffer::write(const Real32 *value, Size size) { + return writeT(value, size); +} + +Int32 WriteBuffer::write(Real64 value) { return writeT(value); } + +Int32 WriteBuffer::write(const Real64 *value, Size size) { + return writeT(value, size); +} + +Int32 WriteBuffer::write(bool value) { return writeT(value); } + +Int32 WriteBuffer::write(const bool *value, Size size) { + return writeT(value, size); +} + +NTA_Int32 WriteBuffer::writeString(const NTA_Byte *value, NTA_Size size) { + NTA_Int32 result = write("", 4); + return result; +} - Size WriteBuffer::getSize() - { - return (Size)OMemStream::pcount(); +const Byte *WriteBuffer::getData() { + try { + return OMemStream::str(); + } catch (...) { + return nullptr; } } +Size WriteBuffer::getSize() { return (Size)OMemStream::pcount(); } +} // namespace nupic diff --git a/src/nupic/ntypes/Buffer.hpp b/src/nupic/ntypes/Buffer.hpp index 25c289ace1..9fd8a53d98 100644 --- a/src/nupic/ntypes/Buffer.hpp +++ b/src/nupic/ntypes/Buffer.hpp @@ -20,8 +20,8 @@ * --------------------------------------------------------------------- */ -/** @file -*/ +/** @file + */ #ifndef NTA_BUFFER_HPP #define NTA_BUFFER_HPP @@ -35,220 +35,190 @@ #include #include -namespace nupic -{ - typedef std::vector ReadBufferVec; - - /** - * ReadBuffer is a class that stores arbitrary binary data in memory. - * It has a very simple interface that allows linear writing. - * You can reset it to the beginning but no random seeking. - * Very simple. It implements the IReadBuffer interface and - * NTA_ReadBuffer interface. - * - * @b Responsibility: - * Provide efficient write access of arbitrary binary data - * from the buffer. The interface is simple enough so it can be - * easilly ported to C (so no streams) - * - * @b Rationale: - * Several methods of the plugin API require an arbitrary binary - * data store. This is it. The interface is intentionally simple - * so it can used for the C plugin API. - * - * @b Resource/Ownerships: - * A vector of bytes that represent the binary data - * an IMemeStream to internally stream the data. - * - * @b Invariants: - * index_ must be in the range [0, element count). - * When the buffer is empty it should be 0. - * - * @b Notes: - * see IReadBuffer documentation in nupic/plugin/object_model.hpp for - * further details - */ - class ReadBuffer : - public IReadBuffer, - public NTA_ReadBuffer - { - public: - ReadBuffer(const Byte * value, Size size, bool copy=true); - ReadBuffer(const ReadBuffer &); - ReadBuffer & operator=(const ReadBuffer &); - void assign(const ReadBuffer &); - void reset() const override; - Size getSize() const override; - const Byte * getData() const override; - - Int32 read(Byte & value) const override; - Int32 read(Byte * value, Size & size) const override; - Int32 read(Int32 & value) const override; - Int32 read(Int32 * value, Size size) const override; - Int32 read(UInt32 & value) const override; - Int32 read(UInt32 * value, Size size) const override; - Int32 read(Int64 & value) const override; - Int32 read(Int64 * value, Size size) const override; - Int32 read(UInt64 & value) const override; - Int32 read(UInt64 * value, Size size) const override; - Int32 read(Real32 & value) const override; - Int32 read(Real32 * value, Size size) const override; - Int32 read(Real64 & value) const override; - Int32 read(Real64 * value, Size size) const override; - Int32 read(bool & value) const override; - Int32 read(bool * value, Size size) const override; - Int32 readString( - NTA_Byte * &value, - NTA_UInt32 &size, - NTA_Byte *(*fAlloc)(NTA_UInt32 size)=nullptr, - void (*fDealloc)(NTA_Byte *)=nullptr - ) const override; - - template - Int32 readT(T & value) const - { - ReadBuffer * r = const_cast(this); +namespace nupic { +typedef std::vector ReadBufferVec; + +/** + * ReadBuffer is a class that stores arbitrary binary data in memory. + * It has a very simple interface that allows linear writing. + * You can reset it to the beginning but no random seeking. + * Very simple. It implements the IReadBuffer interface and + * NTA_ReadBuffer interface. + * + * @b Responsibility: + * Provide efficient write access of arbitrary binary data + * from the buffer. The interface is simple enough so it can be + * easilly ported to C (so no streams) + * + * @b Rationale: + * Several methods of the plugin API require an arbitrary binary + * data store. This is it. The interface is intentionally simple + * so it can used for the C plugin API. + * + * @b Resource/Ownerships: + * A vector of bytes that represent the binary data + * an IMemeStream to internally stream the data. + * + * @b Invariants: + * index_ must be in the range [0, element count). + * When the buffer is empty it should be 0. + * + * @b Notes: + * see IReadBuffer documentation in nupic/plugin/object_model.hpp for + * further details + */ +class ReadBuffer : public IReadBuffer, public NTA_ReadBuffer { +public: + ReadBuffer(const Byte *value, Size size, bool copy = true); + ReadBuffer(const ReadBuffer &); + ReadBuffer &operator=(const ReadBuffer &); + void assign(const ReadBuffer &); + void reset() const override; + Size getSize() const override; + const Byte *getData() const override; + + Int32 read(Byte &value) const override; + Int32 read(Byte *value, Size &size) const override; + Int32 read(Int32 &value) const override; + Int32 read(Int32 *value, Size size) const override; + Int32 read(UInt32 &value) const override; + Int32 read(UInt32 *value, Size size) const override; + Int32 read(Int64 &value) const override; + Int32 read(Int64 *value, Size size) const override; + Int32 read(UInt64 &value) const override; + Int32 read(UInt64 *value, Size size) const override; + Int32 read(Real32 &value) const override; + Int32 read(Real32 *value, Size size) const override; + Int32 read(Real64 &value) const override; + Int32 read(Real64 *value, Size size) const override; + Int32 read(bool &value) const override; + Int32 read(bool *value, Size size) const override; + Int32 readString(NTA_Byte *&value, NTA_UInt32 &size, + NTA_Byte *(*fAlloc)(NTA_UInt32 size) = nullptr, + void (*fDealloc)(NTA_Byte *) = nullptr) const override; + + template Int32 readT(T &value) const { + ReadBuffer *r = const_cast(this); + if (memStream_.eof()) + return 1; + + try { + r->memStream_ >> value; + return 0; + } catch (...) { if (memStream_.eof()) return 1; - - try - { - r->memStream_ >> value; - return 0; - } - catch (...) - { - if (memStream_.eof()) - return 1; - else - return -1; - } - } - - template - Int32 readT(T * value, Size size) const - { - ReadBuffer * r = const_cast(this); - try - { - for (Size i = 0; i < size; ++i) - r->read(value[i]); - return 0; - } - catch (...) - { + else return -1; - } } - private: - boost::shared_array bytes_; - mutable IMemStream memStream_; - }; - - class ReadBufferIterator : - public IReadBufferIterator, - public NTA_ReadBufferIterator - { - public: - ReadBufferIterator(ReadBufferVec & rbv); - const IReadBuffer * next() override; - void reset() override; - private: - ReadBufferVec & readBufferVec_; - Size index_; - }; - - /** - * WriteBuffer is a class that stores arbitrary binary data in memory. - * It has a very simple interface that allows linear writing. - * You can get the entire buffer using getData(). - * Very simple. It implements the IWriteBuffer interface and - * NTA_WriteBuffer interface. - * - * @b Responsibility: - * Provide efficient write access of arbitrary binary data - * to the buffer. The interface is simple enough so it can be - * easilly ported to C (so no streams) - * - * @b Rationale: - * Several methods of the plugin API require an arbitrary binary - * data store. This is it. The interface is intentionally simple - * so it can used for the C plugin API. - * - * @b Resource/Ownerships: - * The OMemeStream private base class manages the actual data - * - * @b Invariants: - * index_ must be in the range [0, element count). - * When the buffer is empty it should be 0. - * - * @b Notes: - * see IWriteBuffer documentation in nupic/plugin/object_model.hpp for - * further details - */ - class WriteBuffer : - public IWriteBuffer, - public NTA_WriteBuffer, - private OMemStream - { - public: - WriteBuffer(); - virtual ~WriteBuffer(); - Int32 write(Byte value) override; - Int32 write(const Byte * value, Size size) override; - Int32 write(Int32 value) override; - Int32 write(const Int32 * value, Size size) override; - Int32 write(UInt32 value) override; - Int32 write(const UInt32 * value, Size size) override; - Int32 write(Int64 value) override; - Int32 write(const Int64 * value, Size size) override; - Int32 write(UInt64 value) override; - Int32 write(const UInt64 * value, Size size) override; - Int32 write(Real32 value) override; - Int32 write(const Real32 * value, Size size) override; - Int32 write(Real64 value) override; - Int32 write(const Real64 * value, Size size) override; - Int32 write(bool value) override; - Int32 write(const bool * value, Size size) override; - Int32 writeString(const Byte * value, Size size) override; - - Size getSize() override; - const Byte * getData() override; - - template - Int32 writeT(T value, const char *sep=" ") - { - try - { - if (sep && (getSize() > 0)) - *this << ' '; - *this << value; - return 0; - } - catch (...) - { - return -1; - } + } + + template Int32 readT(T *value, Size size) const { + ReadBuffer *r = const_cast(this); + try { + for (Size i = 0; i < size; ++i) + r->read(value[i]); + return 0; + } catch (...) { + return -1; } - - template - Int32 writeT(const T * value, Size size) - { - try - { - for (Size i = 0; i < size; ++i) - { - const T & val = value[i]; - write(val); - } - return 0; - } - catch (...) - { - return -1; + } + +private: + boost::shared_array bytes_; + mutable IMemStream memStream_; +}; + +class ReadBufferIterator : public IReadBufferIterator, + public NTA_ReadBufferIterator { +public: + ReadBufferIterator(ReadBufferVec &rbv); + const IReadBuffer *next() override; + void reset() override; + +private: + ReadBufferVec &readBufferVec_; + Size index_; +}; + +/** + * WriteBuffer is a class that stores arbitrary binary data in memory. + * It has a very simple interface that allows linear writing. + * You can get the entire buffer using getData(). + * Very simple. It implements the IWriteBuffer interface and + * NTA_WriteBuffer interface. + * + * @b Responsibility: + * Provide efficient write access of arbitrary binary data + * to the buffer. The interface is simple enough so it can be + * easilly ported to C (so no streams) + * + * @b Rationale: + * Several methods of the plugin API require an arbitrary binary + * data store. This is it. The interface is intentionally simple + * so it can used for the C plugin API. + * + * @b Resource/Ownerships: + * The OMemeStream private base class manages the actual data + * + * @b Invariants: + * index_ must be in the range [0, element count). + * When the buffer is empty it should be 0. + * + * @b Notes: + * see IWriteBuffer documentation in nupic/plugin/object_model.hpp for + * further details + */ +class WriteBuffer : public IWriteBuffer, + public NTA_WriteBuffer, + private OMemStream { +public: + WriteBuffer(); + virtual ~WriteBuffer(); + Int32 write(Byte value) override; + Int32 write(const Byte *value, Size size) override; + Int32 write(Int32 value) override; + Int32 write(const Int32 *value, Size size) override; + Int32 write(UInt32 value) override; + Int32 write(const UInt32 *value, Size size) override; + Int32 write(Int64 value) override; + Int32 write(const Int64 *value, Size size) override; + Int32 write(UInt64 value) override; + Int32 write(const UInt64 *value, Size size) override; + Int32 write(Real32 value) override; + Int32 write(const Real32 *value, Size size) override; + Int32 write(Real64 value) override; + Int32 write(const Real64 *value, Size size) override; + Int32 write(bool value) override; + Int32 write(const bool *value, Size size) override; + Int32 writeString(const Byte *value, Size size) override; + + Size getSize() override; + const Byte *getData() override; + + template Int32 writeT(T value, const char *sep = " ") { + try { + if (sep && (getSize() > 0)) + *this << ' '; + *this << value; + return 0; + } catch (...) { + return -1; + } + } + + template Int32 writeT(const T *value, Size size) { + try { + for (Size i = 0; i < size; ++i) { + const T &val = value[i]; + write(val); } + return 0; + } catch (...) { + return -1; } - }; -} + } +}; +} // namespace nupic #endif // NTA_BUFFER_HPP diff --git a/src/nupic/ntypes/BundleIO.cpp b/src/nupic/ntypes/BundleIO.cpp index 70319b3db0..0e68452176 100644 --- a/src/nupic/ntypes/BundleIO.cpp +++ b/src/nupic/ntypes/BundleIO.cpp @@ -25,92 +25,78 @@ #include #include -namespace nupic -{ - BundleIO::BundleIO(const std::string& bundlePath, const std::string& label, - std::string regionName, bool isInput) : - isInput_(isInput), - bundlePath_(bundlePath), - regionName_(std::move(regionName)), - ostream_(nullptr), - istream_(nullptr) - { - if (! Path::exists(bundlePath_)) - NTA_THROW << "Network bundle " << bundlePath << " does not exist"; - - filePrefix_ = Path::join(bundlePath, label + "-"); - } - - BundleIO::~BundleIO() - { - if (istream_) - { - if (istream_->is_open()) - istream_->close(); - delete istream_; - istream_ = nullptr; - } - if (ostream_) - { - if (ostream_->is_open()) - ostream_->close(); - delete ostream_; - ostream_ = nullptr; - } - } +namespace nupic { +BundleIO::BundleIO(const std::string &bundlePath, const std::string &label, + std::string regionName, bool isInput) + : isInput_(isInput), bundlePath_(bundlePath), + regionName_(std::move(regionName)), ostream_(nullptr), istream_(nullptr) { + if (!Path::exists(bundlePath_)) + NTA_THROW << "Network bundle " << bundlePath << " does not exist"; - std::ofstream& BundleIO::getOutputStream(const std::string& name) const - { - NTA_CHECK(!isInput_); - - checkStreams_(); - - ostream_ = new OFStream(getPath(name).c_str(), std::ios::out | std::ios::binary); - if (!ostream_->is_open()) - { - NTA_THROW << "getOutputStream - Unable to open bundle file " << name - << " for region " << regionName_ << " in network bundle " - << bundlePath_; - } - - return *ostream_; + filePrefix_ = Path::join(bundlePath, label + "-"); +} + +BundleIO::~BundleIO() { + if (istream_) { + if (istream_->is_open()) + istream_->close(); + delete istream_; + istream_ = nullptr; } - - std::ifstream& BundleIO::getInputStream(const std::string& name) const - { - NTA_CHECK(isInput_); - - checkStreams_(); - - istream_ = new IFStream(getPath(name).c_str(), std::ios::in | std::ios::binary); - if (!istream_->is_open()) - { - NTA_THROW << "getInputStream - Unable to open bundle file " << name - << " for region " << regionName_ << " in network bundle " - << bundlePath_; - } - - return *istream_; + if (ostream_) { + if (ostream_->is_open()) + ostream_->close(); + delete ostream_; + ostream_ = nullptr; } - - std::string BundleIO::getPath(const std::string& name) const - { - return filePrefix_ + name; +} + +std::ofstream &BundleIO::getOutputStream(const std::string &name) const { + NTA_CHECK(!isInput_); + + checkStreams_(); + + ostream_ = + new OFStream(getPath(name).c_str(), std::ios::out | std::ios::binary); + if (!ostream_->is_open()) { + NTA_THROW << "getOutputStream - Unable to open bundle file " << name + << " for region " << regionName_ << " in network bundle " + << bundlePath_; } - - - // Before a request for a new stream, - // there should be no open streams. - void BundleIO::checkStreams_() const - { - // Catch implementation errors and make it easier to - // support direct serialization to/from archives - if (isInput_ && istream_ != nullptr && istream_->is_open()) - NTA_THROW << "Internal Error: istream_ has not been closed"; - - if (!isInput_ && ostream_ != nullptr && ostream_->is_open()) - NTA_THROW << "Internal Error: ostream_ has not been closed"; + + return *ostream_; +} + +std::ifstream &BundleIO::getInputStream(const std::string &name) const { + NTA_CHECK(isInput_); + + checkStreams_(); + + istream_ = + new IFStream(getPath(name).c_str(), std::ios::in | std::ios::binary); + if (!istream_->is_open()) { + NTA_THROW << "getInputStream - Unable to open bundle file " << name + << " for region " << regionName_ << " in network bundle " + << bundlePath_; } -} // namespace nupic + return *istream_; +} + +std::string BundleIO::getPath(const std::string &name) const { + return filePrefix_ + name; +} +// Before a request for a new stream, +// there should be no open streams. +void BundleIO::checkStreams_() const { + // Catch implementation errors and make it easier to + // support direct serialization to/from archives + if (isInput_ && istream_ != nullptr && istream_->is_open()) + NTA_THROW << "Internal Error: istream_ has not been closed"; + + if (!isInput_ && ostream_ != nullptr && ostream_->is_open()) + NTA_THROW << "Internal Error: ostream_ has not been closed"; +} + +} // namespace nupic diff --git a/src/nupic/ntypes/BundleIO.hpp b/src/nupic/ntypes/BundleIO.hpp index 9e2567c030..62fa5ad5d1 100644 --- a/src/nupic/ntypes/BundleIO.hpp +++ b/src/nupic/ntypes/BundleIO.hpp @@ -23,55 +23,52 @@ #ifndef NTA_BUNDLEIO_HPP #define NTA_BUNDLEIO_HPP -#include #include +#include -namespace nupic -{ - class BundleIO - { - public: - BundleIO(const std::string& bundlePath, const std::string& label, - std::string regionName, bool isInput); - - ~BundleIO(); +namespace nupic { +class BundleIO { +public: + BundleIO(const std::string &bundlePath, const std::string &label, + std::string regionName, bool isInput); - // These are {o,i}fstream instead of {o,i}stream so that - // the node can explicitly close() them. - std::ofstream& getOutputStream(const std::string& name) const; + ~BundleIO(); - std::ifstream& getInputStream(const std::string& name) const; + // These are {o,i}fstream instead of {o,i}stream so that + // the node can explicitly close() them. + std::ofstream &getOutputStream(const std::string &name) const; - std::string getPath(const std::string& name) const; + std::ifstream &getInputStream(const std::string &name) const; - private: + std::string getPath(const std::string &name) const; - // Before a request for a new stream, - // there should be no open streams. - void checkStreams_() const; +private: + // Before a request for a new stream, + // there should be no open streams. + void checkStreams_() const; - // Should never read and write at the same time -- this helps - // to enforce. - bool isInput_; + // Should never read and write at the same time -- this helps + // to enforce. + bool isInput_; - // We only need the file prefix, but store the bundle path - // for error messages - std::string bundlePath_; + // We only need the file prefix, but store the bundle path + // for error messages + std::string bundlePath_; - // Store the whole prefix instead of just the label - std::string filePrefix_; + // Store the whole prefix instead of just the label + std::string filePrefix_; - // Store the region name for debugging - std::string regionName_; + // Store the region name for debugging + std::string regionName_; - // We own the streams -- helps with finding errors - // and with enforcing one-stream-at-a-time - // These are mutable because the bundle doesn't conceptually - // change when you serialize/deserialize. - mutable std::ofstream *ostream_; - mutable std::ifstream *istream_; + // We own the streams -- helps with finding errors + // and with enforcing one-stream-at-a-time + // These are mutable because the bundle doesn't conceptually + // change when you serialize/deserialize. + mutable std::ofstream *ostream_; + mutable std::ifstream *istream_; - }; // class BundleIO +}; // class BundleIO } // namespace nupic #endif // NTA_BUNDLEIO_HPP diff --git a/src/nupic/ntypes/Collection.cpp b/src/nupic/ntypes/Collection.cpp index 6c109de459..79ec6089f9 100644 --- a/src/nupic/ntypes/Collection.cpp +++ b/src/nupic/ntypes/Collection.cpp @@ -25,57 +25,44 @@ #include #include -namespace nupic -{ +namespace nupic { /* - * Implementation of the templated Collection class. + * Implementation of the templated Collection class. * This code is used to create explicit instantiations - * of the Collection class. + * of the Collection class. * It is not compiled into the types library because - * we instantiate for classes outside of the types library. - * For example, Collection is built in the - * net library where OutputSpec is defined. + * we instantiate for classes outside of the types library. + * For example, Collection is built in the + * net library where OutputSpec is defined. * See nupic/engine/Collections.cpp, which is where the * Collection classes are instantiated. */ -template -Collection::Collection() -{ -} - -template -Collection::~Collection() -{ -} - -template -size_t Collection::getCount() const -{ +template Collection::Collection() {} + +template Collection::~Collection() {} + +template size_t Collection::getCount() const { return vec_.size(); } -template const -std::pair& Collection::getByIndex(size_t index) const -{ +template +const std::pair &Collection::getByIndex(size_t index) const { NTA_CHECK(index < vec_.size()); return vec_[index]; } -template -std::pair& Collection::getByIndex(size_t index) -{ +template +std::pair &Collection::getByIndex(size_t index) { NTA_CHECK(index < vec_.size()); return vec_[index]; } -template -bool Collection::contains(const std::string & name) const -{ +template +bool Collection::contains(const std::string &name) const { typename CollectionStorage::const_iterator i; - for (i = vec_.begin(); i != vec_.end(); i++) - { + for (i = vec_.begin(); i != vec_.end(); i++) { if (i->first == name) return true; } @@ -83,26 +70,21 @@ bool Collection::contains(const std::string & name) const } template -T Collection::getByName(const std::string & name) const -{ +T Collection::getByName(const std::string &name) const { typename CollectionStorage::const_iterator i; - for (i = vec_.begin(); i != vec_.end(); i++) - { + for (i = vec_.begin(); i != vec_.end(); i++) { if (i->first == name) return i->second; - } + } NTA_THROW << "No item named: " << name; } template -void Collection::add(const std::string & name, const T & item) -{ +void Collection::add(const std::string &name, const T &item) { // make sure we don't already have something with this name typename CollectionStorage::const_iterator i; - for (i = vec_.begin(); i != vec_.end(); i++) - { - if (i->first == name) - { + for (i = vec_.begin(); i != vec_.end(); i++) { + if (i->first == name) { NTA_THROW << "Unable to add item '" << name << "' to collection " << "because it already exists"; } @@ -112,13 +94,9 @@ void Collection::add(const std::string & name, const T & item) vec_.push_back(std::make_pair(name, item)); } - -template -void Collection::remove(const std::string & name) -{ +template void Collection::remove(const std::string &name) { typename CollectionStorage::iterator i; - for (i = vec_.begin(); i != vec_.end(); i++) - { + for (i = vec_.begin(); i != vec_.end(); i++) { if (i->first == name) break; } @@ -128,5 +106,4 @@ void Collection::remove(const std::string & name) vec_.erase(i); } -} - +} // namespace nupic diff --git a/src/nupic/ntypes/Collection.hpp b/src/nupic/ntypes/Collection.hpp index 11e96e63d4..eefc44c5ca 100644 --- a/src/nupic/ntypes/Collection.hpp +++ b/src/nupic/ntypes/Collection.hpp @@ -26,49 +26,44 @@ #include #include -namespace nupic -{ - // A collection is a templated class that contains items of type t. - // It supports lookup by name and by index. The items are stored in a map - // and copies are also stored in a vector (it's Ok to use pointers). - // You can add items using the add() method. - // - template - class Collection - { - public: - Collection(); - virtual ~Collection(); - - size_t getCount() const; +namespace nupic { +// A collection is a templated class that contains items of type t. +// It supports lookup by name and by index. The items are stored in a map +// and copies are also stored in a vector (it's Ok to use pointers). +// You can add items using the add() method. +// +template class Collection { +public: + Collection(); + virtual ~Collection(); + + size_t getCount() const; - // This method provides access by index to the contents of the collection - // The indices are in insertion order. - // + // This method provides access by index to the contents of the collection + // The indices are in insertion order. + // - const std::pair& getByIndex(size_t index) const; - - bool contains(const std::string & name) const; + const std::pair &getByIndex(size_t index) const; - T getByName(const std::string & name) const; + bool contains(const std::string &name) const; - // TODO: move add/remove to a ModifiableCollection subclass - // This method should be internal but is currently tested - // in net_test.py in test_node_spec - void add(const std::string & name, const T & item); + T getByName(const std::string &name) const; - void remove(const std::string& name); + // TODO: move add/remove to a ModifiableCollection subclass + // This method should be internal but is currently tested + // in net_test.py in test_node_spec + void add(const std::string &name, const T &item); + void remove(const std::string &name); #ifdef NTA_INTERNAL - std::pair& getByIndex(size_t index); + std::pair &getByIndex(size_t index); #endif - private: - typedef std::vector > CollectionStorage; - CollectionStorage vec_; - }; -} +private: + typedef std::vector> CollectionStorage; + CollectionStorage vec_; +}; +} // namespace nupic #endif - diff --git a/src/nupic/ntypes/Dimensions.cpp b/src/nupic/ntypes/Dimensions.cpp index 8771346e50..3332298b63 100644 --- a/src/nupic/ntypes/Dimensions.cpp +++ b/src/nupic/ntypes/Dimensions.cpp @@ -21,37 +21,31 @@ * --------------------------------------------------------------------- */ -#include #include +#include #include using namespace nupic; -Dimensions::Dimensions() {}; +Dimensions::Dimensions(){}; -Dimensions::Dimensions(std::vector v) : std::vector(std::move(v)) {}; +Dimensions::Dimensions(std::vector v) + : std::vector(std::move(v)){}; -Dimensions::Dimensions(size_t x) -{ - push_back(x); -} +Dimensions::Dimensions(size_t x) { push_back(x); } -Dimensions::Dimensions(size_t x, size_t y) -{ +Dimensions::Dimensions(size_t x, size_t y) { push_back(x); push_back(y); } -Dimensions::Dimensions(size_t x, size_t y, size_t z) -{ +Dimensions::Dimensions(size_t x, size_t y, size_t z) { push_back(x); push_back(y); push_back(z); } -size_t -Dimensions::getCount() const -{ +size_t Dimensions::getCount() const { if (isUnspecified() || isDontcare()) NTA_THROW << "Attempt to get count from dimensions " << toString(); size_t count = 1; @@ -62,54 +56,32 @@ Dimensions::getCount() const return count; } -size_t -Dimensions::getDimensionCount() const -{ - return size(); -} +size_t Dimensions::getDimensionCount() const { return size(); } - -size_t -Dimensions::getDimension(size_t index) const -{ - if (index >= size()) - { - NTA_THROW << "Bad request for dimension " << index - << " on " << toString(); +size_t Dimensions::getDimension(size_t index) const { + if (index >= size()) { + NTA_THROW << "Bad request for dimension " << index << " on " << toString(); } return at(index); } -bool -Dimensions::isDontcare() const -{ - return (size() == 1 && at(0) == 0); -} +bool Dimensions::isDontcare() const { return (size() == 1 && at(0) == 0); } -bool -Dimensions::isUnspecified() const -{ - return size() == 0; -} +bool Dimensions::isUnspecified() const { return size() == 0; } -bool -Dimensions::isOnes() const -{ +bool Dimensions::isOnes() const { if (size() == 0) return false; - for (size_t i = 0; i < size(); i++) - { + for (size_t i = 0; i < size(); i++) { if (at(i) != 1) return false; } return true; } -bool -Dimensions::isValid() const -{ +bool Dimensions::isValid() const { if (isDontcare() || isUnspecified()) return true; @@ -118,35 +90,24 @@ Dimensions::isValid() const return false; return true; } - -bool -Dimensions::isSpecified() const -{ + +bool Dimensions::isSpecified() const { return isValid() && !isUnspecified() && !isDontcare(); } - - - // internal helper method -static std::string vecToString(std::vector vec) -{ +static std::string vecToString(std::vector vec) { std::stringstream ss; - for (size_t i = 0; i < vec.size(); i++) - { + for (size_t i = 0; i < vec.size(); i++) { ss << vec[i]; - if (i != vec.size()-1) + if (i != vec.size() - 1) ss << " "; } return ss.str(); } - -std::string -Dimensions::toString(bool humanReadable) const -{ - if (humanReadable) - { +std::string Dimensions::toString(bool humanReadable) const { + if (humanReadable) { if (isUnspecified()) return "[unspecified]"; if (isDontcare()) @@ -162,11 +123,8 @@ Dimensions::toString(bool humanReadable) const return s; } -size_t -Dimensions::getIndex(const Coordinate& coordinate) const -{ - if(coordinate.size() != size()) - { +size_t Dimensions::getIndex(const Coordinate &coordinate) const { + if (coordinate.size() != size()) { NTA_THROW << "Invalid coordinate [" << vecToString(coordinate) << "] for Dimensions " << toString(); } @@ -176,53 +134,43 @@ Dimensions::getIndex(const Coordinate& coordinate) const // We need to return an index based on x major ordering. We can't simply use // an unsigned or size_t because vector<>::size_type varies between - // implementations (it is only required to be an unsigned type, not a + // implementations (it is only required to be an unsigned type, not a // specific bit-depth). - for(Coordinate::size_type dim = 0; - dim != coordinate.size(); - dim++) - { + for (Coordinate::size_type dim = 0; dim != coordinate.size(); dim++) { size_t thisdim = at(dim); - if (coordinate[dim] >= thisdim) - { - NTA_THROW << "Invalid coordinate index " << dim - << " of " << coordinate[dim] - << " is too large for region dimensions " << toString(); + if (coordinate[dim] >= thisdim) { + NTA_THROW << "Invalid coordinate index " << dim << " of " + << coordinate[dim] << " is too large for region dimensions " + << toString(); } index += factor * coordinate[dim]; - factor *= thisdim; + factor *= thisdim; } return index; } -Coordinate -Dimensions::getCoordinate(const size_t index) const -{ +Coordinate Dimensions::getCoordinate(const size_t index) const { Coordinate coordinate; size_t x = index; size_t product = 1; - for(size_type i = 0; i < size(); i++) - { + for (size_type i = 0; i < size(); i++) { product *= at(i); } - for(size_type i = size()-1; i != (size_type)-1; i--) - { - product/=at(i); - coordinate.insert(coordinate.begin(), x/product); - x%=product; + for (size_type i = size() - 1; i != (size_type)-1; i--) { + product /= at(i); + coordinate.insert(coordinate.begin(), x / product); + x %= product; } return coordinate; } -void -Dimensions::promote(size_t newDimensionality) -{ - if (! isOnes()) - { - NTA_THROW << "Dimensions::promote -- must be all ones for Dimensions " << toString(); +void Dimensions::promote(size_t newDimensionality) { + if (!isOnes()) { + NTA_THROW << "Dimensions::promote -- must be all ones for Dimensions " + << toString(); } if (size() == newDimensionality) return; @@ -232,31 +180,24 @@ Dimensions::promote(size_t newDimensionality) push_back(1); } -bool -Dimensions::operator == (const Dimensions& dims2) const -{ +bool Dimensions::operator==(const Dimensions &dims2) const { if ((std::vector)(*this) == (std::vector)dims2) return true; if (isOnes() && dims2.isOnes()) return true; - + return false; } -bool -Dimensions::operator != (const Dimensions& dims2) const -{ - return ! operator==(dims2); +bool Dimensions::operator!=(const Dimensions &dims2) const { + return !operator==(dims2); } - -namespace nupic -{ - std::ostream& operator<<(std::ostream& f, const Dimensions& d) - { - // temporary -- this might be hard to de-serialize - f << d.toString(/* humanReadable: */ false); - return f; - } +namespace nupic { +std::ostream &operator<<(std::ostream &f, const Dimensions &d) { + // temporary -- this might be hard to de-serialize + f << d.toString(/* humanReadable: */ false); + return f; } +} // namespace nupic diff --git a/src/nupic/ntypes/Dimensions.hpp b/src/nupic/ntypes/Dimensions.hpp index 6919b4b18e..944c8d141c 100644 --- a/src/nupic/ntypes/Dimensions.hpp +++ b/src/nupic/ntypes/Dimensions.hpp @@ -21,370 +21,354 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Interface for the Dimensions class */ #ifndef NTA_DIMENSIONS_HPP #define NTA_DIMENSIONS_HPP -#include #include +#include + +namespace nupic { +/** + * @typedef Coordinate + * + * A Coordinate is the location of a single cell in an n-dimensional + * grid described by a Dimensions object. + * + * It's a direct @c typedef, so it has the exactly the same interface as + * @c std::vector . A value with the index of `i` in the vector + * represents the location of the cell along the `i`th dimension. + * + * @note It must have the same number of dimensions as its corresponding + * Dimensions object. + * + * @internal + * + * Because a vector of a basic type can be directly wrapped + * by swig, we do not need a separate class. + * + * @endinternal + */ +typedef std::vector Coordinate; -namespace nupic -{ +/** + * Represents the dimensions of a Region. + * + * A Dimensions object is an n-dimensional grid, consists of many cells, and + * each dimension has a size, i.e. how many cells can there be along this + * dimension. + * + * A node within a Region is represented by a cell of a n-dimensional grid, + * identified by a Coordinate. + * + * It's implemented by a @c vector of @c size_t plus a few methods for + * convenience and for wrapping. + * + * @nosubgrouping + * + */ +class Dimensions : public std::vector { +public: /** - * @typedef Coordinate - * - * A Coordinate is the location of a single cell in an n-dimensional - * grid described by a Dimensions object. * - * It's a direct @c typedef, so it has the exactly the same interface as - * @c std::vector . A value with the index of `i` in the vector - * represents the location of the cell along the `i`th dimension. + * @name Constructors * - * @note It must have the same number of dimensions as its corresponding - * Dimensions object. + * @{ * - * @internal - * - * Because a vector of a basic type can be directly wrapped - * by swig, we do not need a separate class. + */ + + /** + * Create a new Dimensions object. + * + * @note Default dimensions are unspecified, see isUnspecified() * - * @endinternal */ - typedef std::vector Coordinate; + Dimensions(); /** - * Represents the dimensions of a Region. + * Create a new Dimensions object from a @c std::vector. * - * A Dimensions object is an n-dimensional grid, consists of many cells, and - * each dimension has a size, i.e. how many cells can there be along this dimension. - * - * A node within a Region is represented by a cell of a n-dimensional grid, - * identified by a Coordinate. + * @param v + * A @c std::vector of @c size_t, the value with the index of @a n + * is the size of the @a n th dimension * - * It's implemented by a @c vector of @c size_t plus a few methods for - * convenience and for wrapping. + */ + Dimensions(std::vector v); + + /** Create a new 1-dimension Dimensions object. + + * @param x + * The size of the 1st dimension * - * @nosubgrouping + */ + Dimensions(size_t x); + + /** + * Create a new 2-dimension Dimensions. * + * @param x + * The size of the 1st dimension + * @param y + * The size of the 2nd dimension */ - class Dimensions : public std::vector - { - public: - /** - * - * @name Constructors - * - * @{ - * - */ - - /** - * Create a new Dimensions object. - * - * @note Default dimensions are unspecified, see isUnspecified() - * - */ - Dimensions(); - - /** - * Create a new Dimensions object from a @c std::vector. - * - * @param v - * A @c std::vector of @c size_t, the value with the index of @a n - * is the size of the @a n th dimension - * - */ - Dimensions(std::vector v); - - /** Create a new 1-dimension Dimensions object. - - * @param x - * The size of the 1st dimension - * - */ - Dimensions(size_t x); - - /** - * Create a new 2-dimension Dimensions. - * - * @param x - * The size of the 1st dimension - * @param y - * The size of the 2nd dimension - */ - Dimensions(size_t x, size_t y); - - /** - * Create a new 3-dimension Dimensions. - * - * @param x - * The size of the 1st dimension - * @param y - * The size of the 2nd dimension - * @param z - * The size of the 3rd dimension - */ - Dimensions(size_t x, size_t y, size_t z); - - /** - * - * @} - * - * @name Properties - * - * @{ - * - */ - - /** - * Get the count of cells in the grid, which is the product of the sizes of - * the dimensions. - * - * @returns - * The count of cells in the grid. - */ - size_t - getCount() const; - - /** - * - * Get the number of dimensions. - * - * @returns number of dimensions - * - */ - size_t - getDimensionCount() const; - - /** - * Get the size of a dimension. - * - * @param index - * The index of the dimension - * - * @returns - * The size of the dimension with the index of @a index - * - * @note Do not confuse @a index with "linear index" as in getIndex() - */ - size_t - getDimension(size_t index) const; - - /** - * - * @} - * - * @name Boolean properties - * - * There are two "special" values for dimensions: - * - * * Dimensions of `[]` (`dims.size()==0`) means "not yet known" aka - * "unspecified", see isUnspecified() - * * Dimensions of `[0]` (`dims.size()==1 && dims[0] == 0`) means - * "don't care", see isDontcare() - * - * @{ - * - */ - - /** - * Tells whether the Dimensions object is "unspecified". - * - * @returns - * Whether the Dimensions object is "unspecified" - * - * @see isSpecified() - */ - bool - isUnspecified() const; - - /** - * - * Tells whether the Dimensions object is "don't care". - * - * @returns - * Whether the Dimensions object is "don't care" - */ - bool - isDontcare() const; - - /** - * Tells whether the Dimensions object is "specified". - * - * A "specified" Dimensions object satisfies all following conditions: - * - * * "valid" - * * NOT "unspecified" - * * NOT "don't care" - * - * @returns - * Whether the Dimensions object is "specified" - * - * @note It's not the opposite of isUnspecified()! - */ - bool - isSpecified() const; - - /** - * Tells whether the sizes of all dimensions are 1. - * - * @returns - * Whether the sizes of all dimensions are 1, e.g. [1], [1 1], [1 1 1], etc. - */ - bool - isOnes() const; - - /** - * Tells whether Dimensions is "valid". - * - * A Dimensions object is valid if it specifies actual dimensions, i.e. all - * dimensions have a size greater than 0, or is a special value - * ("unspecified"/"don't care"). - * - * A Dimensions object is invalid if any dimensions are 0 (except for "don't care") - * - * @returns - * Whether Dimensions is "valid" - */ - bool - isValid() const; - - /** - * - * @} - * - * @name Coordinate<->index mapping - * - * Coordinate<->index mapping is in lower-major order, i.e. - * for Region with dimensions `[2,3]`: - * - * [0,0] -> index 0 - * [1,0] -> index 1 - * [0,1] -> index 2 - * [1,1] -> index 3 - * [0,2] -> index 4 - * [1,2] -> index 5 - * - * @{ - * - */ - - /** - * Convert a Coordinate to a linear index (in lower-major order). - * - * @param coordinate - * The coordinate to be converted - * - * @returns - * The linear index corresponding to @a coordinate - */ - size_t - getIndex(const Coordinate& coordinate) const; - - /** - * Convert a linear index (in lower-major order) to a Coordinate. - * - * @param index - * The linear index to be converted - * - * @returns - * The Coordinate corresponding to @a index - */ - Coordinate - getCoordinate(const size_t index) const; - - /** - * - * @} - * - * @name Misc - * - * @{ - * - */ - - /** - * - * Convert the Dimensions object to string representation. - * - * In most cases, we want a human-readable string, but for - * serialization we want only the actual dimension values - * - * @param humanReadable - * The default is @c true, make the string human-readable, - * set to @c false for serialization - * - * @returns - * The string representation of the Dimensions object - */ - std::string - toString(bool humanReadable=true) const; - - /** - * Promote the Dimensions object to a new dimensionality. - * - * @param newDimensionality - * The new dimensionality to promote to, it can be greater than, - * smaller than or equal to current dimensionality - * - * @note The sizes of all dimensions must be 1( i.e. isOnes() returns true), - * or an exception will be thrown. - */ - void - promote(size_t newDimensionality); - - /** - * The equivalence operator. - * - * Two Dimensions objects will be considered equivalent, if any of the - * following satisfies: - * - * * They have the same number of dimensions and the same size for every - * dimension. - * * Both of them have the size of 1 for everything dimensions, despite of - * how many dimensions they have, i.e. isOnes() returns @c true for both - * of them. Some linking scenarios require us to treat [1] equivalent to [1 1] etc. - * - * @param dims2 - * The Dimensions object being compared - * - * @returns - * Whether this Dimensions object is equivalent to @a dims2. - * - */ - bool - operator == (const Dimensions& dims2) const; - - /** - * The in-equivalence operator, the opposite of operator==(). - * - * @param dims2 - * The Dimensions object being compared - * - * @returns - * Whether this Dimensions object is not equivalent to @a dims2. - */ - bool - operator != (const Dimensions& dims2) const; - - /** - * - * @} - * - */ + Dimensions(size_t x, size_t y); + /** + * Create a new 3-dimension Dimensions. + * + * @param x + * The size of the 1st dimension + * @param y + * The size of the 2nd dimension + * @param z + * The size of the 3rd dimension + */ + Dimensions(size_t x, size_t y, size_t z); -#ifdef NTA_INTERNAL - friend std::ostream& operator<<(std::ostream& f, const Dimensions&); -#endif + /** + * + * @} + * + * @name Properties + * + * @{ + * + */ + /** + * Get the count of cells in the grid, which is the product of the sizes of + * the dimensions. + * + * @returns + * The count of cells in the grid. + */ + size_t getCount() const; + + /** + * + * Get the number of dimensions. + * + * @returns number of dimensions + * + */ + size_t getDimensionCount() const; + + /** + * Get the size of a dimension. + * + * @param index + * The index of the dimension + * + * @returns + * The size of the dimension with the index of @a index + * + * @note Do not confuse @a index with "linear index" as in getIndex() + */ + size_t getDimension(size_t index) const; + + /** + * + * @} + * + * @name Boolean properties + * + * There are two "special" values for dimensions: + * + * * Dimensions of `[]` (`dims.size()==0`) means "not yet known" aka + * "unspecified", see isUnspecified() + * * Dimensions of `[0]` (`dims.size()==1 && dims[0] == 0`) means + * "don't care", see isDontcare() + * + * @{ + * + */ + + /** + * Tells whether the Dimensions object is "unspecified". + * + * @returns + * Whether the Dimensions object is "unspecified" + * + * @see isSpecified() + */ + bool isUnspecified() const; - }; + /** + * + * Tells whether the Dimensions object is "don't care". + * + * @returns + * Whether the Dimensions object is "don't care" + */ + bool isDontcare() const; -} + /** + * Tells whether the Dimensions object is "specified". + * + * A "specified" Dimensions object satisfies all following conditions: + * + * * "valid" + * * NOT "unspecified" + * * NOT "don't care" + * + * @returns + * Whether the Dimensions object is "specified" + * + * @note It's not the opposite of isUnspecified()! + */ + bool isSpecified() const; + + /** + * Tells whether the sizes of all dimensions are 1. + * + * @returns + * Whether the sizes of all dimensions are 1, e.g. [1], [1 1], [1 1 1], + * etc. + */ + bool isOnes() const; + + /** + * Tells whether Dimensions is "valid". + * + * A Dimensions object is valid if it specifies actual dimensions, i.e. all + * dimensions have a size greater than 0, or is a special value + * ("unspecified"/"don't care"). + * + * A Dimensions object is invalid if any dimensions are 0 (except for "don't + * care") + * + * @returns + * Whether Dimensions is "valid" + */ + bool isValid() const; + + /** + * + * @} + * + * @name Coordinate<->index mapping + * + * Coordinate<->index mapping is in lower-major order, i.e. + * for Region with dimensions `[2,3]`: + * + * [0,0] -> index 0 + * [1,0] -> index 1 + * [0,1] -> index 2 + * [1,1] -> index 3 + * [0,2] -> index 4 + * [1,2] -> index 5 + * + * @{ + * + */ + + /** + * Convert a Coordinate to a linear index (in lower-major order). + * + * @param coordinate + * The coordinate to be converted + * + * @returns + * The linear index corresponding to @a coordinate + */ + size_t getIndex(const Coordinate &coordinate) const; + + /** + * Convert a linear index (in lower-major order) to a Coordinate. + * + * @param index + * The linear index to be converted + * + * @returns + * The Coordinate corresponding to @a index + */ + Coordinate getCoordinate(const size_t index) const; + + /** + * + * @} + * + * @name Misc + * + * @{ + * + */ + + /** + * + * Convert the Dimensions object to string representation. + * + * In most cases, we want a human-readable string, but for + * serialization we want only the actual dimension values + * + * @param humanReadable + * The default is @c true, make the string human-readable, + * set to @c false for serialization + * + * @returns + * The string representation of the Dimensions object + */ + std::string toString(bool humanReadable = true) const; + + /** + * Promote the Dimensions object to a new dimensionality. + * + * @param newDimensionality + * The new dimensionality to promote to, it can be greater than, + * smaller than or equal to current dimensionality + * + * @note The sizes of all dimensions must be 1( i.e. isOnes() returns true), + * or an exception will be thrown. + */ + void promote(size_t newDimensionality); + + /** + * The equivalence operator. + * + * Two Dimensions objects will be considered equivalent, if any of the + * following satisfies: + * + * * They have the same number of dimensions and the same size for every + * dimension. + * * Both of them have the size of 1 for everything dimensions, despite of + * how many dimensions they have, i.e. isOnes() returns @c true for both + * of them. Some linking scenarios require us to treat [1] equivalent to [1 1] + * etc. + * + * @param dims2 + * The Dimensions object being compared + * + * @returns + * Whether this Dimensions object is equivalent to @a dims2. + * + */ + bool operator==(const Dimensions &dims2) const; + + /** + * The in-equivalence operator, the opposite of operator==(). + * + * @param dims2 + * The Dimensions object being compared + * + * @returns + * Whether this Dimensions object is not equivalent to @a dims2. + */ + bool operator!=(const Dimensions &dims2) const; + + /** + * + * @} + * + */ + +#ifdef NTA_INTERNAL + friend std::ostream &operator<<(std::ostream &f, const Dimensions &); +#endif +}; +} // namespace nupic #endif // NTA_DIMENSIONS_HPP diff --git a/src/nupic/ntypes/MemParser.cpp b/src/nupic/ntypes/MemParser.cpp index c2dcb3b510..c73c6a7f0d 100644 --- a/src/nupic/ntypes/MemParser.cpp +++ b/src/nupic/ntypes/MemParser.cpp @@ -20,28 +20,24 @@ * --------------------------------------------------------------------- */ -/** @file - * - */ +/** @file + * + */ -#include "nupic/utils/Log.hpp" #include "nupic/ntypes/MemParser.hpp" #include "nupic/ntypes/MemStream.hpp" -#include +#include "nupic/utils/Log.hpp" #include - - +#include using namespace std; namespace nupic { - + //////////////////////////////////////////////////////////////////////////// // MemParser constructor ////////////////////////////////////////////////////////////////////////////// -MemParser::MemParser(std::istream& in, UInt32 bytes) -{ - if (bytes == 0) - { +MemParser::MemParser(std::istream &in, UInt32 bytes) { + if (bytes == 0) { // ----------------------------------------------------------------------------- // Read all available bytes from the stream // ----------------------------------------------------------------------------- @@ -50,155 +46,139 @@ MemParser::MemParser(std::istream& in, UInt32 bytes) auto chunkP = new char[chunkSize]; while (!in.eof()) { in.read(chunkP, chunkSize); - NTA_CHECK (in.good() || in.eof()) + NTA_CHECK(in.good() || in.eof()) << "MemParser::MemParser() - error reading data from stream"; - data.append (chunkP, in.gcount()); + data.append(chunkP, in.gcount()); } - + bytes_ = (UInt32)data.size(); - bufP_ = new char[bytes_+1]; - NTA_CHECK (bufP_ != nullptr) << "MemParser::MemParser() - out of memory"; - ::memmove ((void*)bufP_, data.data(), bytes_); - ((char*)bufP_)[bytes_] = 0; - + bufP_ = new char[bytes_ + 1]; + NTA_CHECK(bufP_ != nullptr) << "MemParser::MemParser() - out of memory"; + ::memmove((void *)bufP_, data.data(), bytes_); + ((char *)bufP_)[bytes_] = 0; + delete[] chunkP; - } - else - { + } else { // ----------------------------------------------------------------------------- // Read given # of bytes from the stream // ----------------------------------------------------------------------------- bytes_ = bytes; - bufP_ = new char[bytes_+1]; - NTA_CHECK (bufP_ != nullptr) << "MemParser::MemParser() - out of memory"; - - in.read ((char*)bufP_, bytes); - ((char*)bufP_)[bytes] = 0; - NTA_CHECK (in.good()) << "MemParser::MemParser() - error reading data from stream"; + bufP_ = new char[bytes_ + 1]; + NTA_CHECK(bufP_ != nullptr) << "MemParser::MemParser() - out of memory"; + + in.read((char *)bufP_, bytes); + ((char *)bufP_)[bytes] = 0; + NTA_CHECK(in.good()) + << "MemParser::MemParser() - error reading data from stream"; } // Setup start and end pointers startP_ = bufP_; - endP_ = startP_ + bytes_; -} + endP_ = startP_ + bytes_; +} //////////////////////////////////////////////////////////////////////////// // Destructor ////////////////////////////////////////////////////////////////////////////// -MemParser::~MemParser() -{ - delete[] bufP_; -} - +MemParser::~MemParser() { delete[] bufP_; } //////////////////////////////////////////////////////////////////////////// // Read an unsigned integer number out ////////////////////////////////////////////////////////////////////////////// -void MemParser::get(unsigned long& val) -{ - const char* prefix = "MemParser::get(unsigned long&) - "; - char* endP; - - NTA_CHECK (startP_ < endP_) << prefix << "EOF"; - - val = ::strtoul (startP_, &endP, 0); - - NTA_CHECK (endP != startP_ && endP <= endP_) << prefix - << "parse error, not a valid integer"; - +void MemParser::get(unsigned long &val) { + const char *prefix = "MemParser::get(unsigned long&) - "; + char *endP; + + NTA_CHECK(startP_ < endP_) << prefix << "EOF"; + + val = ::strtoul(startP_, &endP, 0); + + NTA_CHECK(endP != startP_ && endP <= endP_) + << prefix << "parse error, not a valid integer"; + startP_ = endP; -} +} //////////////////////////////////////////////////////////////////////////// // Read an unsigned long long number out ////////////////////////////////////////////////////////////////////////////// -void MemParser::get(unsigned long long& val) -{ - const char* prefix = "MemParser::get(unsigned long long&) - "; - char* endP; - - NTA_CHECK (startP_ < endP_) << prefix << "EOF"; - - val = ::strtoul (startP_, &endP, 0); - - NTA_CHECK (endP != startP_ && endP <= endP_) << prefix - << "parse error, not a valid integer"; - +void MemParser::get(unsigned long long &val) { + const char *prefix = "MemParser::get(unsigned long long&) - "; + char *endP; + + NTA_CHECK(startP_ < endP_) << prefix << "EOF"; + + val = ::strtoul(startP_, &endP, 0); + + NTA_CHECK(endP != startP_ && endP <= endP_) + << prefix << "parse error, not a valid integer"; + startP_ = endP; -} +} //////////////////////////////////////////////////////////////////////////// // Read an signed integer number out ////////////////////////////////////////////////////////////////////////////// -void MemParser::get(long& val) -{ - const char* prefix = "MemParser::get(long&) - "; - char* endP; - - NTA_CHECK (startP_ < endP_) << prefix << "EOF"; - - val = ::strtol (startP_, &endP, 0); - - NTA_CHECK (endP != startP_ && endP <= endP_) << prefix - << "parse error, not a valid integer"; - +void MemParser::get(long &val) { + const char *prefix = "MemParser::get(long&) - "; + char *endP; + + NTA_CHECK(startP_ < endP_) << prefix << "EOF"; + + val = ::strtol(startP_, &endP, 0); + + NTA_CHECK(endP != startP_ && endP <= endP_) + << prefix << "parse error, not a valid integer"; + startP_ = endP; -} +} //////////////////////////////////////////////////////////////////////////// // Read a double-precision float out ////////////////////////////////////////////////////////////////////////////// -void MemParser::get(double& val) -{ - const char* prefix = "MemParser::get(double&) - "; - char* endP; - - NTA_CHECK (startP_ < endP_) << prefix << "EOF"; - - val = ::strtod (startP_, &endP); - - NTA_CHECK (endP != startP_ && endP <= endP_) << prefix - << "parse error, not a valid floating point value"; - +void MemParser::get(double &val) { + const char *prefix = "MemParser::get(double&) - "; + char *endP; + + NTA_CHECK(startP_ < endP_) << prefix << "EOF"; + + val = ::strtod(startP_, &endP); + + NTA_CHECK(endP != startP_ && endP <= endP_) + << prefix << "parse error, not a valid floating point value"; + startP_ = endP; -} +} //////////////////////////////////////////////////////////////////////////// // Read a single-precision float out ////////////////////////////////////////////////////////////////////////////// -void MemParser::get(float& val) -{ +void MemParser::get(float &val) { double f; get(f); val = (float)f; -} - +} //////////////////////////////////////////////////////////////////////////// // Read string out ////////////////////////////////////////////////////////////////////////////// -void MemParser::get(std::string& val) -{ - const char* prefix = "MemParser::get(string&) - "; - +void MemParser::get(std::string &val) { + const char *prefix = "MemParser::get(string&) - "; + // First, skip leading white space - const char* cP = startP_; + const char *cP = startP_; while (cP < endP_) { char c = *cP; if (c != 0 && c != ' ' && c != '\t' && c != '\n' && c != '\r') break; cP++; } - NTA_CHECK (cP < endP_) << prefix << "EOF"; - - size_t len = strcspn(cP, " \t\n\r"); - NTA_CHECK (len > 0) << prefix - << "parse error, not a valid string"; - - val.assign (cP, len); - startP_ = cP + len; -} + NTA_CHECK(cP < endP_) << prefix << "EOF"; + size_t len = strcspn(cP, " \t\n\r"); + NTA_CHECK(len > 0) << prefix << "parse error, not a valid string"; + val.assign(cP, len); + startP_ = cP + len; +} - } // namespace nupic diff --git a/src/nupic/ntypes/MemParser.hpp b/src/nupic/ntypes/MemParser.hpp index f9b721a030..48644c7924 100644 --- a/src/nupic/ntypes/MemParser.hpp +++ b/src/nupic/ntypes/MemParser.hpp @@ -25,46 +25,44 @@ #ifndef NTA_MEM_PARSER2_HPP #define NTA_MEM_PARSER2_HPP -#include #include +#include #include "nupic/types/Types.hpp" - namespace nupic { - + //////////////////////////////////////////////////////////////////////////// -/// Class for parsing numbers and strings out of a memory buffer. +/// Class for parsing numbers and strings out of a memory buffer. /// /// This provides a significant performance advantage over using the standard -/// C++ stream input operators operating on a stringstream in memory. +/// C++ stream input operators operating on a stringstream in memory. /// /// @b Responsibility /// - provide high level parsing functions for extracing numbers and strings /// from a memory buffer /// /// @b Resources/Ownerships: -/// - Owns a memory buffer that it allocates in it's constructor. +/// - Owns a memory buffer that it allocates in it's constructor. /// /// @b Notes: -/// To use this class, you pass in an input stream and a total # of bytes to the -/// constructor. The constructor will then read that number of bytes from the stream -/// into an internal buffer maintained by the MemParser object. Subsequent calls to -/// MemParser::get() will extract numbers/strings from the internal buffer. +/// To use this class, you pass in an input stream and a total # of bytes to +/// the constructor. The constructor will then read that number of bytes from +/// the stream into an internal buffer maintained by the MemParser object. +/// Subsequent calls to MemParser::get() will extract numbers/strings from the +/// internal buffer. /// ////////////////////////////////////////////////////////////////////////////// -class MemParser -{ +class MemParser { public: - ///////////////////////////////////////////////////////////////////////////////////// /// Constructor /// /// @param in The input stream to get characters from. - /// @param bytes The number of bytes to extract from the stream for parsing. - /// 0 means extract all bytes + /// @param bytes The number of bytes to extract from the stream for parsing. + /// 0 means extract all bytes /////////////////////////////////////////////////////////////////////////////////// - MemParser(std::istream& in, UInt32 bytes=0); + MemParser(std::istream &in, UInt32 bytes = 0); ///////////////////////////////////////////////////////////////////////////////////// /// Destructor @@ -72,112 +70,99 @@ class MemParser /// Free the MemParser object /////////////////////////////////////////////////////////////////////////////////// virtual ~MemParser(); - + ///////////////////////////////////////////////////////////////////////////////////// /// Read an unsigned integer out of the stream /// /////////////////////////////////////////////////////////////////////////////////// - void get(unsigned long& val); + void get(unsigned long &val); ///////////////////////////////////////////////////////////////////////////////////// /// Read an unsigned long long out of the stream /// /////////////////////////////////////////////////////////////////////////////////// - void get(unsigned long long& val); - + void get(unsigned long long &val); + ///////////////////////////////////////////////////////////////////////////////////// /// Read an signed integer out of the stream /// /////////////////////////////////////////////////////////////////////////////////// - void get(long& val); - + void get(long &val); + ///////////////////////////////////////////////////////////////////////////////////// /// Read a double precision floating point number out of the stream /// /////////////////////////////////////////////////////////////////////////////////// - void get(double& val); - + void get(double &val); + ///////////////////////////////////////////////////////////////////////////////////// /// Read a double precision floating point number out of the stream /// /////////////////////////////////////////////////////////////////////////////////// - void get(float& val); + void get(float &val); ///////////////////////////////////////////////////////////////////////////////////// /// Read a string out of the stream /// /////////////////////////////////////////////////////////////////////////////////// - void get(std::string& val); - - + void get(std::string &val); + ///////////////////////////////////////////////////////////////////////////////////// /// >> operator's /////////////////////////////////////////////////////////////////////////////////// - friend MemParser& operator>>(MemParser& in, unsigned long& val) - { + friend MemParser &operator>>(MemParser &in, unsigned long &val) { in.get(val); return in; } - friend MemParser& operator>>(MemParser& in, unsigned long long& val) - { + friend MemParser &operator>>(MemParser &in, unsigned long long &val) { in.get(val); return in; } - friend MemParser& operator>>(MemParser& in, long& val) - { + friend MemParser &operator>>(MemParser &in, long &val) { in.get(val); return in; } - friend MemParser& operator>>(MemParser& in, unsigned int& val) - { + friend MemParser &operator>>(MemParser &in, unsigned int &val) { unsigned long lval; in.get(lval); val = lval; return in; } - friend MemParser& operator>>(MemParser& in, int& val) - { + friend MemParser &operator>>(MemParser &in, int &val) { long lval; in.get(lval); val = lval; return in; } - friend MemParser& operator>>(MemParser& in, double& val) - { + friend MemParser &operator>>(MemParser &in, double &val) { in.get(val); return in; } - friend MemParser& operator>>(MemParser& in, float& val) - { + friend MemParser &operator>>(MemParser &in, float &val) { in.get(val); return in; } - friend MemParser& operator>>(MemParser& in, std::string& val) - { + friend MemParser &operator>>(MemParser &in, std::string &val) { in.get(val); return in; } - private: - std::string str_; - const char* bufP_; - UInt32 bytes_; - - const char* startP_; - const char* endP_; + std::string str_; + const char *bufP_; + UInt32 bytes_; + const char *startP_; + const char *endP_; }; - - } // namespace nupic #endif // NTA_MEM_PARSER2_HPP diff --git a/src/nupic/ntypes/MemStream.hpp b/src/nupic/ntypes/MemStream.hpp index 0e90166e60..c91ddd3281 100644 --- a/src/nupic/ntypes/MemStream.hpp +++ b/src/nupic/ntypes/MemStream.hpp @@ -20,27 +20,26 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Definitions for the MemStream classes - * + * * These classes implement a stream that uses memory buffer for reading/writing. - * It is more efficient than using stringstream because it doesn't require making a - * copy of the data when setting up an input stream for reading or getting the - * contents of an output stream after it has been written to. + * It is more efficient than using stringstream because it doesn't require + * making a copy of the data when setting up an input stream for reading or + * getting the contents of an output stream after it has been written to. * - * These classes operate much like the c++ std deprecated strstream class, and have the same - * member functions for getting to the buffered data (str(), pcount()) to make them - * drop-in replacements. + * These classes operate much like the c++ std deprecated strstream class, and + * have the same member functions for getting to the buffered data (str(), + * pcount()) to make them drop-in replacements. */ - #ifndef NTA_MEM_STREAM_HPP #define NTA_MEM_STREAM_HPP -#include #include //EOF -#include #include +#include +#include namespace nupic { @@ -50,309 +49,295 @@ void dbgbreak(); /// BasicIMemStreamBuf /// /// @b Responsibility -/// -/// The basic input stream buffer used by BasicIMemStream. -/// +/// +/// The basic input stream buffer used by BasicIMemStream. +/// /// @b Description /// -/// This class simply sets up the buffer for the streambuf to be the caller's buffer. -/// +/// This class simply sets up the buffer for the streambuf to be the caller's +/// buffer. +/// /// @b Resource @b Ownership /// -/// None. This class does not take ownership of the input buffer. It is the caller's -/// responsibility to free it after this class is destroyed. +/// None. This class does not take ownership of the input buffer. It is the +/// caller's +/// responsibility to free it after this class is destroyed. /// ///////////////////////////////////////////////////////////////////////////////////// -template -class BasicIMemStreamBuf : public std::basic_streambuf -{ +template +class BasicIMemStreamBuf : public std::basic_streambuf { public: - BasicIMemStreamBuf (const charT* bufP, size_t bufSize) - { - setg ((charT*)bufP, (charT*)bufP, (charT*)bufP+bufSize); + BasicIMemStreamBuf(const charT *bufP, size_t bufSize) { + setg((charT *)bufP, (charT *)bufP, (charT *)bufP + bufSize); bufP_ = bufP; size_ = bufSize; } - + ////////////////////////////////////////////////////////////////////////// /// Set the input buffer for this stream ///////////////////////////////////////////////////////////////////////// - void str(const charT* bufP, size_t bufSize) - { - setg ((charT*)bufP, (charT*)bufP, (charT*)bufP+bufSize); + void str(const charT *bufP, size_t bufSize) { + setg((charT *)bufP, (charT *)bufP, (charT *)bufP + bufSize); bufP_ = bufP; size_ = bufSize; } - + /////////////////////////////////////////////////////////////////////////////////// /// Return a pointer to beginning of the memory stream buffer /// /// @retval pointer to input buffer maintained by this class //////////////////////////////////////////////////////////////////////////////////// - const charT* str() {return bufP_;} + const charT *str() { return bufP_; } /////////////////////////////////////////////////////////////////////////////////// /// Return size of the input data /// /// @retval size of the input data in the buffer //////////////////////////////////////////////////////////////////////////////////// - size_t pcount() {return size_;} + size_t pcount() { return size_; } /////////////////////////////////////////////////////////////////////////////////// /// Return size of the input data /// /// @retval size of the input data in the buffer //////////////////////////////////////////////////////////////////////////////////// - void setg(charT * p1, charT * p2, charT * p3) {std::basic_streambuf::setg(p1,p2,p3);} -private: - const charT* bufP_; - size_t size_; - -}; // end class BasicIMemStreamBuf + void setg(charT *p1, charT *p2, charT *p3) { + std::basic_streambuf::setg(p1, p2, p3); + } +private: + const charT *bufP_; + size_t size_; +}; // end class BasicIMemStreamBuf /////////////////////////////////////////////////////////////////////////////////////// /// BasicOMemStreamBuf /// /// @b Responsibility -/// +/// /// The basic output stream buffer used by BasicOMemStream -/// +/// /// @b Description /// /// This class uses an internal basic_string to manage the storage of characters -/// written to the streambuf. It sets the streambuf's buffer to be the capacity of -/// the string and grows the capacity of the string when overflow() is called. -/// +/// written to the streambuf. It sets the streambuf's buffer to be the capacity +/// of the string and grows the capacity of the string when overflow() is +/// called. +/// /// @b Resource @b Ownership /// -/// The internal buffer used to hold the stream data. +/// The internal buffer used to hold the stream data. /// ///////////////////////////////////////////////////////////////////////////////////// -template -class BasicOMemStreamBuf : public std::basic_streambuf -{ -public: - typedef std::basic_string stringType_; - typedef typename traitsT::int_type int_type; - +template +class BasicOMemStreamBuf : public std::basic_streambuf { +public: + typedef std::basic_string stringType_; + typedef typename traitsT::int_type int_type; + private: stringType_ data_; static const size_t growByMin_ = 512; public: - BasicOMemStreamBuf () - { - data_.reserve (growByMin_); - char* bufP = (char*)(data_.data()); - this->setp (bufP, bufP + data_.capacity()); + BasicOMemStreamBuf() { + data_.reserve(growByMin_); + char *bufP = (char *)(data_.data()); + this->setp(bufP, bufP + data_.capacity()); } - - virtual int_type overflow (int_type c) - { - size_t growBy; - - if (c == EOF) return c; - + + virtual int_type overflow(int_type c) { + size_t growBy; + + if (c == EOF) + return c; + // Remember the current size size_t curSize = pcount(); - + // Grow by 12% (1/8th) of the current size, or the minGrowBy growBy = curSize >> 3; if (growBy < growByMin_) growBy = growByMin_; - + // Grow the buffer. We do this by allocating a tmp string of the new size // and using assign() to transfer the existing data into it. We can't simply - // call reserve() on the existing string because it's size() is not set and because - // of this, reserve() won't copy existing characters into the newly grown buffer for us. + // call reserve() on the existing string because it's size() is not set and + // because of this, reserve() won't copy existing characters into the newly + // grown buffer for us. stringType_ tmp; try { - tmp.reserve (data_.capacity() + growBy); + tmp.reserve(data_.capacity() + growBy); } catch (...) { - NTA_THROW << "MemStream::write() - request of " << data_.capacity() + growBy - << " bytes (" << data_.capacity() + growBy + NTA_THROW << "MemStream::write() - request of " + << data_.capacity() + growBy << " bytes (" + << data_.capacity() + growBy << ") exceeds the maximum allowable memory block size."; - // Note: above used std::hex but Visual C++ Express can't handle that for some reason + // Note: above used std::hex but Visual C++ Express can't handle that for + // some reason } - tmp.assign (data_.data(), curSize); - data_.swap (tmp); - - char* bufP = (char*)(data_.data()); - this->setp (bufP, bufP + data_.capacity()); - + tmp.assign(data_.data(), curSize); + data_.swap(tmp); + + char *bufP = (char *)(data_.data()); + this->setp(bufP, bufP + data_.capacity()); + // Restore the current write position after setting the new buffer - this->pbump ((int)curSize); + this->pbump((int)curSize); // Store the new character in the bigger buffer - return this->sputc (c); + return this->sputc(c); } - + /////////////////////////////////////////////////////////////////////////////////// /// Return a pointer to the output data /// - /// This call does not transfer ownership of the output buffer to the caller. The - /// output buffer pointer is only valid until the next write operation to the stream. + /// This call does not transfer ownership of the output buffer to the caller. + /// The output buffer pointer is only valid until the next write operation to + /// the stream. /// /// @retval pointer to output buffer maintained by this class //////////////////////////////////////////////////////////////////////////////////// - inline const charT* str() {return data_.data();} + inline const charT *str() { return data_.data(); } /////////////////////////////////////////////////////////////////////////////////// /// Return size of the output data /// /// @retval size of the output data in the buffer //////////////////////////////////////////////////////////////////////////////////// - inline size_t pcount() {return this->pptr() - this->pbase(); } + inline size_t pcount() { return this->pptr() - this->pbase(); } }; // end class BasicOMemStreamBuf - - /////////////////////////////////////////////////////////////////////////////////////// /// BasicIMemStream /// /// @b Responsibility -/// -/// An input stream which allows the caller to specify which data buffer to use. -/// +/// +/// An input stream which allows the caller to specify which data buffer to use. +/// /// @b Description /// /// The caller constructs the input stream by passing in a buffer and size. All -/// input operations from the stream then extract data from this buffer. The stream -/// does *not* take ownership of the buffer. It is the caller's responsibility to free -/// the buffer after deleting this class. -/// +/// input operations from the stream then extract data from this buffer. The +/// stream does *not* take ownership of the buffer. It is the caller's +/// responsibility to free the buffer after deleting this class. +/// /// @b Resource @b Ownership /// /// This class does not take ownership of the input buffer. It is the caller's -/// responsibility to free it after the class is destroyed. +/// responsibility to free it after the class is destroyed. /// ///////////////////////////////////////////////////////////////////////////////////// -template -class BasicIMemStream : public std::basic_istream -{ -public: - typedef BasicIMemStreamBuf memStreamBufType_; - +template +class BasicIMemStream : public std::basic_istream { +public: + typedef BasicIMemStreamBuf memStreamBufType_; + private: - memStreamBufType_ streamBuf_; - + memStreamBufType_ streamBuf_; + public: - BasicIMemStream(const charT* bufP=0, size_t bufSize=0) : std::istream(&streamBuf_), - streamBuf_(bufP, bufSize) - { - this->rdbuf (&streamBuf_); + BasicIMemStream(const charT *bufP = 0, size_t bufSize = 0) + : std::istream(&streamBuf_), streamBuf_(bufP, bufSize) { + this->rdbuf(&streamBuf_); } - + ////////////////////////////////////////////////////////////////////////// /// Set the input buffer for this stream ///////////////////////////////////////////////////////////////////////// - void str(const charT* bufP, size_t bufSize) - { - streamBuf_.str (bufP, bufSize); - } - + void str(const charT *bufP, size_t bufSize) { streamBuf_.str(bufP, bufSize); } + /////////////////////////////////////////////////////////////////////////////////// /// Return a pointer to beginning of the memory stream buffer /// /// @retval pointer to input buffer maintained by this class //////////////////////////////////////////////////////////////////////////////////// - const charT* str() {return streamBuf_.str();} + const charT *str() { return streamBuf_.str(); } /////////////////////////////////////////////////////////////////////////////////// /// Return size of the input data /// /// @retval size of the input data in the buffer //////////////////////////////////////////////////////////////////////////////////// - size_t pcount() {return streamBuf_.pcount();} - -}; // end class BasicIMemStream + size_t pcount() { return streamBuf_.pcount(); } +}; // end class BasicIMemStream /////////////////////////////////////////////////////////////////////////////////////// /// BasicOMemStream /// /// @b Responsibility -/// -/// An output stream that appends data to an internal dynamically grown buffer. -/// +/// +/// An output stream that appends data to an internal dynamically grown buffer. +/// /// @b Description /// -/// At any time, the caller can get a pointer to the internal buffer and it's current size -/// through the str() and pcount() member functions. This information is valid until the -/// next write operation to the stream. -/// +/// At any time, the caller can get a pointer to the internal buffer and it's +/// current size through the str() and pcount() member functions. This +/// information is valid until the next write operation to the stream. +/// /// @b Resource @b Ownership /// -/// The internal buffer used to hold the stream data. +/// The internal buffer used to hold the stream data. /// ///////////////////////////////////////////////////////////////////////////////////// -template -class BasicOMemStream : public std::basic_ostream -{ -public: - typedef BasicOMemStreamBuf memStreamBufType_; - +template +class BasicOMemStream : public std::basic_ostream { +public: + typedef BasicOMemStreamBuf memStreamBufType_; + private: - memStreamBufType_ streamBuf_; - + memStreamBufType_ streamBuf_; + public: - BasicOMemStream() : std::ostream(&streamBuf_), - streamBuf_() - { - this->rdbuf (&streamBuf_); + BasicOMemStream() : std::ostream(&streamBuf_), streamBuf_() { + this->rdbuf(&streamBuf_); } - + /////////////////////////////////////////////////////////////////////////////////// - /// freeze - does nothing in this class, provided only so this class can be a + /// freeze - does nothing in this class, provided only so this class can be a /// drop-in replacement for strstream //////////////////////////////////////////////////////////////////////////////////// - void freeze (bool f) {} + void freeze(bool f) {} /////////////////////////////////////////////////////////////////////////////////// /// Return a pointer to the output data /// - /// This call does not transfer ownership of the output buffer to the caller. The - /// output buffer pointer is only valid until the next write operation to the stream. + /// This call does not transfer ownership of the output buffer to the caller. + /// The output buffer pointer is only valid until the next write operation to + /// the stream. /// /// @retval pointer to output buffer maintained by this class //////////////////////////////////////////////////////////////////////////////////// - const charT* str() {return streamBuf_.str();} + const charT *str() { return streamBuf_.str(); } /////////////////////////////////////////////////////////////////////////////////// /// Return size of the output data /// /// @retval size of the output data in the buffer //////////////////////////////////////////////////////////////////////////////////// - size_t pcount() {return streamBuf_.pcount();} - -}; // end class BasicOMemStream + size_t pcount() { return streamBuf_.pcount(); } +}; // end class BasicOMemStream /////////////////////////////////////////////////////////////////////////////////////// // Convenience typedefs /////////////////////////////////////////////////////////////////////////////////////// -typedef BasicIMemStream,std::allocator > - IMemStream; -typedef BasicIMemStream,std::allocator > - WIMemStream; - -typedef BasicOMemStream,std::allocator > - OMemStream; -typedef BasicOMemStream,std::allocator > - WOMemStream; - +typedef BasicIMemStream, std::allocator> + IMemStream; +typedef BasicIMemStream, + std::allocator> + WIMemStream; + +typedef BasicOMemStream, std::allocator> + OMemStream; +typedef BasicOMemStream, + std::allocator> + WOMemStream; } // end namespace nupic - - - - #endif // NTA_MEM_STREAM_HPP - - - diff --git a/src/nupic/ntypes/NodeSet.hpp b/src/nupic/ntypes/NodeSet.hpp index c97ccba657..4216134859 100644 --- a/src/nupic/ntypes/NodeSet.hpp +++ b/src/nupic/ntypes/NodeSet.hpp @@ -27,86 +27,66 @@ #include -namespace nupic -{ - /** - * A NodeSet represents the set of currently-enabled nodes in a Region - * It is just a set of indexes, with the ability to add/remove an index, and the - * ability to iterate through enabled nodes. - * - * There are many ways to represent such a set, and the best way to represent - * it depends on what nodes are typically enabled through this mechanism. - * In NuPIC 1 we used an IndexRangeList, which is a list of index ranges. - * This is natural, for example, in the original pictures app, where in - * training level N+1 we would enable a square patch of nodes at level N. - * (which is a list of ranges). In the NuPIC 1 API such ranges were initially - * specified with ranges ("1-4"). With new algorithms and new training paradigms - * I think we may always enable nodes individually. - * - * So for NuPIC 2 we're starting with the simplest possible solution (a set) and - * might switch to something else (e.g. a range list) if needed. - * - * TODO: split into hpp/cpp - */ - class NodeSet - { - public: - NodeSet(size_t nnodes) : nnodes_(nnodes) - { - set_.clear(); - } - - typedef std::set::const_iterator const_iterator; - - const_iterator begin() const - { - return set_.begin(); - }; +namespace nupic { +/** + * A NodeSet represents the set of currently-enabled nodes in a Region + * It is just a set of indexes, with the ability to add/remove an index, and the + * ability to iterate through enabled nodes. + * + * There are many ways to represent such a set, and the best way to represent + * it depends on what nodes are typically enabled through this mechanism. + * In NuPIC 1 we used an IndexRangeList, which is a list of index ranges. + * This is natural, for example, in the original pictures app, where in + * training level N+1 we would enable a square patch of nodes at level N. + * (which is a list of ranges). In the NuPIC 1 API such ranges were initially + * specified with ranges ("1-4"). With new algorithms and new training paradigms + * I think we may always enable nodes individually. + * + * So for NuPIC 2 we're starting with the simplest possible solution (a set) and + * might switch to something else (e.g. a range list) if needed. + * + * TODO: split into hpp/cpp + */ +class NodeSet { +public: + NodeSet(size_t nnodes) : nnodes_(nnodes) { set_.clear(); } - const_iterator end() const - { - return set_.end(); - } - - void allOn() - { - for (size_t i = 0; i < nnodes_; i++) - { - set_.insert(i); - } - } + typedef std::set::const_iterator const_iterator; - void allOff() - { - set_.clear(); - } + const_iterator begin() const { return set_.begin(); }; + + const_iterator end() const { return set_.end(); } - void add(size_t index) - { - if (index > nnodes_) - { - NTA_THROW << "Attempt to enable node with index " << index << " which is larger than the number of nodes " << nnodes_; - } - set_.insert(index); + void allOn() { + for (size_t i = 0; i < nnodes_; i++) { + set_.insert(i); } + } - void remove(size_t index) - { - auto f = set_.find(index); - if (f == set_.end()) - return; - set_.erase(f); + void allOff() { set_.clear(); } + + void add(size_t index) { + if (index > nnodes_) { + NTA_THROW << "Attempt to enable node with index " << index + << " which is larger than the number of nodes " << nnodes_; } + set_.insert(index); + } - private: - typedef std::set::iterator iterator; - NodeSet(); - size_t nnodes_; - std::set set_; - }; - -} // namespace nupic + void remove(size_t index) { + auto f = set_.find(index); + if (f == set_.end()) + return; + set_.erase(f); + } +private: + typedef std::set::iterator iterator; + NodeSet(); + size_t nnodes_; + std::set set_; +}; +} // namespace nupic #endif // NTA_NODESET_HPP diff --git a/src/nupic/ntypes/ObjectModel.h b/src/nupic/ntypes/ObjectModel.h index 0a6c677f95..ee64dc3cfd 100644 --- a/src/nupic/ntypes/ObjectModel.h +++ b/src/nupic/ntypes/ObjectModel.h @@ -20,18 +20,17 @@ * ---------------------------------------------------------------------- */ -/** @file -* -* -* -*/ +/** @file + * + * + * + */ /* * TEMPORARY for NUPIC 2 development * Included because included by object_model.hpp */ - #ifndef NTA_OBJECT_MODEL_H #define NTA_OBJECT_MODEL_H @@ -49,37 +48,40 @@ extern "C" { * 1. Defines the C API for the runtime engine object model. * * @b Rationale: - * The plugin API supports C as a lowest common denominator. Plugins that publish a C API - * need to access the runtime services through a C API. This is it. + * The plugin API supports C as a lowest common denominator. Plugins that + * publish a C API need to access the runtime services through a C API. This is + * it. * * @b Resource/Ownerships: * None. just an interface * * @b General API information - * - * An important goal of the C object model is to immitate the C++ object model - * (nupic/plugin/object_model.hpp). The C++ object model is implemented as a set of - * interfaces (pure abstract classes). The C object model has a one to one mapping - * to each interface. Each C interface consists of a struct with function - * pointers that correspond to the virtual table of the C++ interface. - * The implicit this pointer of a C++ interface is emulated with - * an explicit opaque handle that you must pass to every C function in the + * + * An important goal of the C object model is to immitate the C++ object model + * (nupic/plugin/object_model.hpp). The C++ object model is implemented as a + * set of interfaces (pure abstract classes). The C object model has a one to + * one mapping to each interface. Each C interface consists of a struct with + * function pointers that correspond to the virtual table of the C++ interface. + * The implicit this pointer of a C++ interface is emulated with + * an explicit opaque handle that you must pass to every C function in the * C object model API. * - * The reason the 1-1 mapping is so important is that the concrete runtime object model - * is composed of objects that implement both interfaces and expose a dual C/C++ facade. - * Having a 1-1 mapping between the interfaces allows reusing the same implementation - * with minimal forwarding inside each runtime object. + * The reason the 1-1 mapping is so important is that the concrete runtime + * object model is composed of objects that implement both interfaces and expose + * a dual C/C++ facade. Having a 1-1 mapping between the interfaces allows + * reusing the same implementation with minimal forwarding inside each runtime + * object. + * + * The naming convention of mapped interfaces is that the C++ + * nupic::I corresponds to the NTA_. The nupic + * namespace is translated to an NTA_ prefix and the I is dropped. * - * The naming convention of mapped interfaces is that the C++ nupic::I corresponds to the - * NTA_. The nupic namespace is translated to an NTA_ prefix and the I is dropped. - * * @b Invariants: - * + * * @b Notes: - * 1. There is a compatible C++ object model [object_model.hpp] that documents in detail - * every interface and struct. Please check out this file for detailed documentation on - * the corresponding C runtime object interface. + * 1. There is a compatible C++ object model [object_model.hpp] that documents + * in detail every interface and struct. Please check out this file for detailed + * documentation on the corresponding C runtime object interface. */ /** ---------------------- @@ -93,39 +95,44 @@ extern "C" { * know the internal format or you can read it as a byte * array. The internal representation is stringified * so it works properly on different platforms. - */ -typedef struct _NTA_ReadBufferHandle { char c; } * NTA_ReadBufferHandle; -typedef struct NTA_ReadBuffer -{ + */ +typedef struct _NTA_ReadBufferHandle { + char c; +} * NTA_ReadBufferHandle; +typedef struct NTA_ReadBuffer { /* functions */ void (*reset)(NTA_ReadBufferHandle handle); NTA_Size (*getSize)(NTA_ReadBufferHandle handle); - const NTA_Byte * (*getData)(NTA_ReadBufferHandle handle); - NTA_Int32 (*readByte)(NTA_ReadBufferHandle handle, NTA_Byte * value); - NTA_Int32 (*readByteArray)(NTA_ReadBufferHandle handle, NTA_Byte * value, NTA_Size * size); - NTA_Int32 (*readAsString)(NTA_ReadBufferHandle handle, - NTA_Byte ** value, - NTA_UInt32 * size, - NTA_Byte *(*fAlloc)(NTA_UInt32), - void (*fDealloc)(NTA_Byte *) - ); - - NTA_Int32 (*readInt32)(NTA_ReadBufferHandle handle, NTA_Int32 * value); - NTA_Int32 (*readInt32Array)(NTA_ReadBufferHandle handle, NTA_Int32 * value, NTA_Size size); - NTA_Int32 (*readUInt32)(NTA_ReadBufferHandle handle, NTA_UInt32 * value); - NTA_Int32 (*readUInt32Array)(NTA_ReadBufferHandle handle, NTA_UInt32 * value, NTA_Size size); - NTA_Int32 (*readInt64)(NTA_ReadBufferHandle handle, NTA_Int64 * value); - NTA_Int32 (*readInt64Array)(NTA_ReadBufferHandle handle, NTA_Int64 * value, NTA_Size size); - NTA_Int32 (*readUInt64)(NTA_ReadBufferHandle handle, NTA_UInt64 * value); - NTA_Int32 (*readUInt64Array)(NTA_ReadBufferHandle handle, NTA_UInt64 * value, NTA_Size size); - NTA_Int32 (*readReal32)(NTA_ReadBufferHandle handle, NTA_Real32 * value); - NTA_Int32 (*readReal32Array)(NTA_ReadBufferHandle handle, NTA_Real32 * value, NTA_Size size); - NTA_Int32 (*readReal64)(NTA_ReadBufferHandle handle, NTA_Real64 * value); - NTA_Int32 (*readReal64Array)(NTA_ReadBufferHandle handle, NTA_Real64 * value, NTA_Size size); - + const NTA_Byte *(*getData)(NTA_ReadBufferHandle handle); + NTA_Int32 (*readByte)(NTA_ReadBufferHandle handle, NTA_Byte *value); + NTA_Int32 (*readByteArray)(NTA_ReadBufferHandle handle, NTA_Byte *value, + NTA_Size *size); + NTA_Int32 (*readAsString)(NTA_ReadBufferHandle handle, NTA_Byte **value, + NTA_UInt32 *size, NTA_Byte *(*fAlloc)(NTA_UInt32), + void (*fDealloc)(NTA_Byte *)); + + NTA_Int32 (*readInt32)(NTA_ReadBufferHandle handle, NTA_Int32 *value); + NTA_Int32 (*readInt32Array)(NTA_ReadBufferHandle handle, NTA_Int32 *value, + NTA_Size size); + NTA_Int32 (*readUInt32)(NTA_ReadBufferHandle handle, NTA_UInt32 *value); + NTA_Int32 (*readUInt32Array)(NTA_ReadBufferHandle handle, NTA_UInt32 *value, + NTA_Size size); + NTA_Int32 (*readInt64)(NTA_ReadBufferHandle handle, NTA_Int64 *value); + NTA_Int32 (*readInt64Array)(NTA_ReadBufferHandle handle, NTA_Int64 *value, + NTA_Size size); + NTA_Int32 (*readUInt64)(NTA_ReadBufferHandle handle, NTA_UInt64 *value); + NTA_Int32 (*readUInt64Array)(NTA_ReadBufferHandle handle, NTA_UInt64 *value, + NTA_Size size); + NTA_Int32 (*readReal32)(NTA_ReadBufferHandle handle, NTA_Real32 *value); + NTA_Int32 (*readReal32Array)(NTA_ReadBufferHandle handle, NTA_Real32 *value, + NTA_Size size); + NTA_Int32 (*readReal64)(NTA_ReadBufferHandle handle, NTA_Real64 *value); + NTA_Int32 (*readReal64Array)(NTA_ReadBufferHandle handle, NTA_Real64 *value, + NTA_Size size); + /* data members */ NTA_ReadBufferHandle handle; - + } NTA_ReadBuffer; /** --------------------------------------- @@ -137,17 +144,19 @@ typedef struct NTA_ReadBuffer * This struct represents an iterator over a collection * of read buffers. . * It has a next() function to get the next buffer in the colection - * and a reset() function that sets the internal pointer to the first range again. + * and a reset() function that sets the internal pointer to the first range + * again. */ -typedef struct _NTA_ReadBufferIteratorHandle { char c; } * NTA_ReadBufferIteratorHandle; -typedef struct NTA_ReadBufferIterator -{ - /* functions */ - void (*reset)(NTA_ReadBufferIteratorHandle handle); - const NTA_ReadBuffer * (*next)(NTA_ReadBufferIteratorHandle handle); - - /* data members */ - NTA_ReadBufferIteratorHandle handle; +typedef struct _NTA_ReadBufferIteratorHandle { + char c; +} * NTA_ReadBufferIteratorHandle; +typedef struct NTA_ReadBufferIterator { + /* functions */ + void (*reset)(NTA_ReadBufferIteratorHandle handle); + const NTA_ReadBuffer *(*next)(NTA_ReadBufferIteratorHandle handle); + + /* data members */ + NTA_ReadBufferIteratorHandle handle; } NTA_ReadBufferIterator; /** ------------------------ @@ -161,37 +170,46 @@ typedef struct NTA_ReadBufferIterator * know the internal format or you can write it as a byte * array. The internal representation is stringified * so it works properly on different platforms. - */ -typedef struct _NTA_WriteBufferHandle { char c; } * NTA_WriteBufferHandle; -typedef struct NTA_WriteBuffer -{ + */ +typedef struct _NTA_WriteBufferHandle { + char c; +} * NTA_WriteBufferHandle; +typedef struct NTA_WriteBuffer { /* functions */ NTA_Size (*getSize)(NTA_WriteBufferHandle handle); - const NTA_Byte * (*getData)(NTA_WriteBufferHandle handle); + const NTA_Byte *(*getData)(NTA_WriteBufferHandle handle); NTA_Int32 (*writeByte)(NTA_WriteBufferHandle handle, NTA_Byte value); - NTA_Int32 (*writeByteArray)(NTA_WriteBufferHandle handle, const NTA_Byte * value, NTA_Size size); - NTA_Int32 (*writeAsString)(NTA_WriteBufferHandle handle, const NTA_Byte * value, NTA_Size size); + NTA_Int32 (*writeByteArray)(NTA_WriteBufferHandle handle, + const NTA_Byte *value, NTA_Size size); + NTA_Int32 (*writeAsString)(NTA_WriteBufferHandle handle, + const NTA_Byte *value, NTA_Size size); NTA_Int32 (*writeInt32)(NTA_WriteBufferHandle handle, NTA_Int32 value); - NTA_Int32 (*writeInt32Array)(NTA_WriteBufferHandle handle, const NTA_Int32 * value, NTA_Size size); + NTA_Int32 (*writeInt32Array)(NTA_WriteBufferHandle handle, + const NTA_Int32 *value, NTA_Size size); NTA_Int32 (*writeUInt32)(NTA_WriteBufferHandle handle, NTA_UInt32 value); - NTA_Int32 (*writeUInt32Array)(NTA_WriteBufferHandle handle, const NTA_UInt32 * value, NTA_Size size); + NTA_Int32 (*writeUInt32Array)(NTA_WriteBufferHandle handle, + const NTA_UInt32 *value, NTA_Size size); NTA_Int32 (*writeInt64)(NTA_WriteBufferHandle handle, NTA_Int64 value); - NTA_Int32 (*writeInt64Array)(NTA_WriteBufferHandle handle, const NTA_Int64 * value, NTA_Size size); + NTA_Int32 (*writeInt64Array)(NTA_WriteBufferHandle handle, + const NTA_Int64 *value, NTA_Size size); NTA_Int32 (*writeUInt64)(NTA_WriteBufferHandle handle, NTA_UInt64 value); - NTA_Int32 (*writeUInt64Array)(NTA_WriteBufferHandle handle, const NTA_UInt64 * value, NTA_Size size); + NTA_Int32 (*writeUInt64Array)(NTA_WriteBufferHandle handle, + const NTA_UInt64 *value, NTA_Size size); NTA_Int32 (*writeReal32)(NTA_WriteBufferHandle handle, NTA_Real32 value); - NTA_Int32 (*writeReal32Array)(NTA_WriteBufferHandle handle, const NTA_Real32 * value, NTA_Size size); + NTA_Int32 (*writeReal32Array)(NTA_WriteBufferHandle handle, + const NTA_Real32 *value, NTA_Size size); NTA_Int32 (*writeReal64)(NTA_WriteBufferHandle handle, NTA_Real64 value); - NTA_Int32 (*writeReal64Array)(NTA_WriteBufferHandle handle, const NTA_Real64 * value, NTA_Size size); - + NTA_Int32 (*writeReal64Array)(NTA_WriteBufferHandle handle, + const NTA_Real64 *value, NTA_Size size); + /* data members */ NTA_WriteBufferHandle handle; - + } NTA_WriteBuffer; /** ----------------------- * - * I N P U T R A N G E + * I N P U T R A N G E * * ------------------------ * @@ -201,44 +219,44 @@ typedef struct NTA_WriteBuffer * (elementSize) and number of elements in the range * (elementCount). */ -typedef struct _NTA_InputRangeHandle { char c; } * NTA_InputRangeHandle; -typedef struct NTA_InputRange -{ +typedef struct _NTA_InputRangeHandle { + char c; +} * NTA_InputRangeHandle; +typedef struct NTA_InputRange { /* functions */ - const NTA_Byte * (*begin)(NTA_InputRangeHandle handle); - const NTA_Byte * (*end)(NTA_InputRangeHandle handle); + const NTA_Byte *(*begin)(NTA_InputRangeHandle handle); + const NTA_Byte *(*end)(NTA_InputRangeHandle handle); NTA_Size (*getElementCount)(NTA_InputRangeHandle handle); NTA_Size (*getElementSize)(NTA_InputRangeHandle handle); /* data members */ NTA_InputRangeHandle handle; - -} NTA_InputRange; +} NTA_InputRange; /** ----------------------------------------- * - * I N P U T R A N G E M A P E N T R Y + * I N P U T R A N G E M A P E N T R Y * * ------------------------------------------ * * This struct represents a single entry in an input range map * It stores a name, a list of input ranges and the number * of input ranges in this entry - */ -typedef struct _NTA_InputRangeMapEntryHandle { char c; } * NTA_InputRangeMapEntryHandle; -typedef struct NTA_InputRangeMapEntry -{ + */ +typedef struct _NTA_InputRangeMapEntryHandle { + char c; +} * NTA_InputRangeMapEntryHandle; +typedef struct NTA_InputRangeMapEntry { /* functions */ void (*reset)(NTA_InputRangeMapEntryHandle handle); - const NTA_InputRange * (*next)(NTA_InputRangeMapEntryHandle handle); + const NTA_InputRange *(*next)(NTA_InputRangeMapEntryHandle handle); /* data members */ - const NTA_Byte * name; + const NTA_Byte *name; NTA_InputRangeMapEntryHandle handle; - -} NTA_InputRangeMapEntry; +} NTA_InputRangeMapEntry; /** ----------------------------- * @@ -250,36 +268,36 @@ typedef struct NTA_InputRangeMapEntry * It contains entries and provides iterator-like accessor * as well as lookup by name accessor. */ -typedef struct _NTA_InputRangeMapHandle { char c; } * NTA_InputRangeMapHandle; -typedef struct NTA_InputRangeMap -{ +typedef struct _NTA_InputRangeMapHandle { + char c; +} * NTA_InputRangeMapHandle; +typedef struct NTA_InputRangeMap { /* functions */ void (*reset)(NTA_InputRangeMapHandle handle); - const NTA_InputRangeMapEntry * (*next)(NTA_InputRangeMapHandle handle); - const NTA_InputRangeMapEntry * (*lookup)(NTA_InputRangeMapHandle handle, const NTA_Byte * name); + const NTA_InputRangeMapEntry *(*next)(NTA_InputRangeMapHandle handle); + const NTA_InputRangeMapEntry *(*lookup)(NTA_InputRangeMapHandle handle, + const NTA_Byte *name); /* data members */ NTA_InputRangeMapHandle handle; - + } NTA_InputRangeMap; /** ----------------------------------------- * - * I N D E X R A N G E + * I N D E X R A N G E * * ------------------------------------------ * * This struct represents a chunk of an input range. - * This is used to represent each internal link of a multi-node. - */ -typedef struct NTA_IndexRange -{ + * This is used to represent each internal link of a multi-node. + */ +typedef struct NTA_IndexRange { /* data members */ - NTA_UInt32 begin; // begin offset - NTA_UInt32 size; // number of elements - -} NTA_IndexRange; + NTA_UInt32 begin; // begin offset + NTA_UInt32 size; // number of elements +} NTA_IndexRange; /** ----------------------------------------- * @@ -288,19 +306,18 @@ typedef struct NTA_IndexRange * ------------------------------------------ * * This struct represents a list of NTA_IndexRanges. It encapsulates all - * the connections for a specific baby node in a multi-node. - */ -typedef struct NTA_IndexRangeList -{ + * the connections for a specific baby node in a multi-node. + */ +typedef struct NTA_IndexRangeList { /* data members */ - NTA_Size rangeCount; // number of elements in the ranges array - NTA_IndexRange * ranges; // array of rangeCount NTA_IndexRange's - + NTA_Size rangeCount; // number of elements in the ranges array + NTA_IndexRange *ranges; // array of rangeCount NTA_IndexRange's + } NTA_IndexRangeList; /** ----------------------- * - * O U T P U T R A N G E + * O U T P U T R A N G E * * ------------------------ * @@ -310,36 +327,35 @@ typedef struct NTA_IndexRangeList * (elementSize) and number of elements in the range * (elementCount). */ -typedef struct _NTA_OutputRangeHandle { char c; } * NTA_OutputRangeHandle; -typedef struct NTA_OutputRange -{ +typedef struct _NTA_OutputRangeHandle { + char c; +} * NTA_OutputRangeHandle; +typedef struct NTA_OutputRange { /* functions */ - NTA_Byte * (*begin)(NTA_OutputRangeHandle handle); - NTA_Byte * (*end)(NTA_OutputRangeHandle handle); + NTA_Byte *(*begin)(NTA_OutputRangeHandle handle); + NTA_Byte *(*end)(NTA_OutputRangeHandle handle); NTA_Size (*getElementCount)(NTA_OutputRangeHandle handle); NTA_Size (*getElementSize)(NTA_OutputRangeHandle handle); /* data memebers */ NTA_OutputRangeHandle handle; - + } NTA_OutputRange; /** ------------------------------------------- * - * O U T P U T R A N G E M A P E N T R Y + * O U T P U T R A N G E M A P E N T R Y * * -------------------------------------------- * * This struct represents a single entry in an output range map * It stores a name and an output range. */ -typedef struct NTA_OutputRangeMapEntry -{ - const NTA_Byte * name; - NTA_OutputRange * range; - -} NTA_OutputRangeMapEntry; +typedef struct NTA_OutputRangeMapEntry { + const NTA_Byte *name; + NTA_OutputRange *range; +} NTA_OutputRangeMapEntry; /** ------------------------------- * @@ -351,22 +367,24 @@ typedef struct NTA_OutputRangeMapEntry * It contains entries and provides iterator-like accessor * as well as lookup by name accessor. */ -typedef struct _NTA_OutputRangeMapHandle { char c; } * NTA_OutputRangeMapHandle; -typedef struct NTA_OutputRangeMap -{ +typedef struct _NTA_OutputRangeMapHandle { + char c; +} * NTA_OutputRangeMapHandle; +typedef struct NTA_OutputRangeMap { /* functions */ void (*reset)(NTA_OutputRangeMapHandle handle); - NTA_OutputRangeMapEntry * (*next)(NTA_OutputRangeMapHandle handle); - NTA_OutputRange * (*lookup)(NTA_OutputRangeMapHandle handle, const NTA_Byte * name); + NTA_OutputRangeMapEntry *(*next)(NTA_OutputRangeMapHandle handle); + NTA_OutputRange *(*lookup)(NTA_OutputRangeMapHandle handle, + const NTA_Byte *name); /* data members */ NTA_OutputRangeMapHandle handle; - + } NTA_OutputRangeMap; /** -------------------------------------- * - * P A R A M E T E R M A P E N T R Y + * P A R A M E T E R M A P E N T R Y * * --------------------------------------- * @@ -374,13 +392,11 @@ typedef struct NTA_OutputRangeMap * It stores a name, a list of output ranges and the number * of output ranges in this entry */ -typedef struct NTA_ParameterMapEntry -{ - const NTA_Byte * name; - const NTA_ReadBuffer * value; - -} NTA_ParameterMapEntry; +typedef struct NTA_ParameterMapEntry { + const NTA_Byte *name; + const NTA_ReadBuffer *value; +} NTA_ParameterMapEntry; /** --------------------------- * @@ -392,70 +408,70 @@ typedef struct NTA_ParameterMapEntry * It contains various parameters and provides iterator-like accessor * as well as lookup by name accessor. */ -typedef struct _NTA_ParameterMapHandle { char c; } * NTA_ParameterMapHandle; -typedef struct NTA_ParameterMap -{ +typedef struct _NTA_ParameterMapHandle { + char c; +} * NTA_ParameterMapHandle; +typedef struct NTA_ParameterMap { /* functions */ void (*reset)(NTA_ParameterMapHandle handle); - const NTA_ParameterMapEntry * (*next)(NTA_ParameterMapHandle handle); - const NTA_ReadBuffer * (*lookup)(NTA_ParameterMapHandle handle, const NTA_Byte * name); + const NTA_ParameterMapEntry *(*next)(NTA_ParameterMapHandle handle); + const NTA_ReadBuffer *(*lookup)(NTA_ParameterMapHandle handle, + const NTA_Byte *name); /* data members */ NTA_ParameterMapHandle handle; - -} NTA_ParameterMap; +} NTA_ParameterMap; /** ------------------------------- * - * I N P U T + * I N P U T * * -------------------------------- * * This struct represents a flattened input of a node. */ -typedef struct _NTA_InputHandle { char c; } * NTA_InputHandle; -typedef struct NTA_Input -{ +typedef struct _NTA_InputHandle { + char c; +} * NTA_InputHandle; +typedef struct NTA_Input { /* functions */ - const NTA_Byte * (*begin)(NTA_InputHandle handle, NTA_Int32 nodeIdx, - const NTA_Byte* sentinelP); - const NTA_Byte * (*end)(NTA_InputHandle handle, NTA_Int32 nodeIdx); + const NTA_Byte *(*begin)(NTA_InputHandle handle, NTA_Int32 nodeIdx, + const NTA_Byte *sentinelP); + const NTA_Byte *(*end)(NTA_InputHandle handle, NTA_Int32 nodeIdx); NTA_Size (*getElementCount)(NTA_InputHandle handle, NTA_Int32 nodeIdx); NTA_Size (*getElementSize)(NTA_InputHandle handle); - - NTA_Size * (*getLinkBoundaries)(NTA_InputHandle handle, NTA_Int32 nodeIdx); + + NTA_Size *(*getLinkBoundaries)(NTA_InputHandle handle, NTA_Int32 nodeIdx); NTA_Size (*getLinkCount)(NTA_InputHandle handle, NTA_Int32 nodeIdx); /* data members */ NTA_InputHandle handle; - -} NTA_Input; +} NTA_Input; /** ------------------------------- * - * O U T P U T + * O U T P U T * * -------------------------------- * * This struct represents a flattened output of a node. */ -typedef struct _NTA_OutputHandle { char c; } * NTA_OutputHandle; -typedef struct NTA_Output -{ +typedef struct _NTA_OutputHandle { + char c; +} * NTA_OutputHandle; +typedef struct NTA_Output { /* functions */ - NTA_Byte * (*begin)(NTA_OutputHandle handle, NTA_Int32 nodeIdx); - NTA_Byte * (*end)(NTA_OutputHandle handle, NTA_Int32 nodeIdx); + NTA_Byte *(*begin)(NTA_OutputHandle handle, NTA_Int32 nodeIdx); + NTA_Byte *(*end)(NTA_OutputHandle handle, NTA_Int32 nodeIdx); NTA_Size (*getElementCount)(NTA_OutputHandle handle, NTA_Int32 nodeIdx); NTA_Size (*getElementSize)(NTA_OutputHandle handle); - + /* data members */ NTA_OutputHandle handle; - -} NTA_Output; - +} NTA_Output; /** --------------------------- * @@ -463,31 +479,33 @@ typedef struct NTA_Output * * ---------------------------- * - * This struct contains all the initial information + * This struct contains all the initial information * that a node needs: inputs, outputs, parameters and state */ -typedef struct _NTA_NodeInfoHandle { char c; } * NTA_NodeInfoHandle; -typedef struct _NTA_NodeInfo -{ +typedef struct _NTA_NodeInfoHandle { + char c; +} * NTA_NodeInfoHandle; +typedef struct _NTA_NodeInfo { /* functions */ NTA_UInt64 (*getID)(NTA_NodeInfoHandle handle); - const NTA_Byte * (*getType)(NTA_NodeInfoHandle handle); + const NTA_Byte *(*getType)(NTA_NodeInfoHandle handle); NTA_LogLevel (*getLogLevel)(NTA_NodeInfoHandle handle); - NTA_Input * (*getInput)(NTA_NodeInfoHandle handle, const NTA_Byte* varName); - NTA_Output * (*getOutput)(NTA_NodeInfoHandle handle, const NTA_Byte* varName); - NTA_InputRangeMap * (*getInputs)(NTA_NodeInfoHandle handle); - NTA_OutputRangeMap * (*getOutputs)(NTA_NodeInfoHandle handle); - NTA_ParameterMap * (*getParameters)(NTA_NodeInfoHandle handle); - NTA_ReadBuffer * (*getState)(NTA_NodeInfoHandle handle); + NTA_Input *(*getInput)(NTA_NodeInfoHandle handle, const NTA_Byte *varName); + NTA_Output *(*getOutput)(NTA_NodeInfoHandle handle, const NTA_Byte *varName); + NTA_InputRangeMap *(*getInputs)(NTA_NodeInfoHandle handle); + NTA_OutputRangeMap *(*getOutputs)(NTA_NodeInfoHandle handle); + NTA_ParameterMap *(*getParameters)(NTA_NodeInfoHandle handle); + NTA_ReadBuffer *(*getState)(NTA_NodeInfoHandle handle); NTA_Size (*getMNNodeCount)(NTA_NodeInfoHandle handle); - const NTA_IndexRangeList * (*getMNInputLists)(NTA_NodeInfoHandle handle, const NTA_Byte* varName); - const NTA_Size * (*getMNOutputSizes)(NTA_NodeInfoHandle handle, const NTA_Byte* varName); + const NTA_IndexRangeList *(*getMNInputLists)(NTA_NodeInfoHandle handle, + const NTA_Byte *varName); + const NTA_Size *(*getMNOutputSizes)(NTA_NodeInfoHandle handle, + const NTA_Byte *varName); /* data members */ NTA_NodeInfoHandle handle; - -} NTA_NodeInfo; +} NTA_NodeInfo; /**------------------------------ * @@ -495,45 +513,46 @@ typedef struct _NTA_NodeInfo * * ------------------------------- * - * This struct contains the additional initial information - * that a multi-node needs: number of baby nodes, index ranges of baby nodes inputs - * and the output size of each baby node + * This struct contains the additional initial information + * that a multi-node needs: number of baby nodes, index ranges of baby nodes + * inputs and the output size of each baby node */ -typedef struct _NTA_MultiNodeInfoHandle { char c; } * NTA_MultiNodeInfoHandle; -typedef struct _NTA_MultiNodeInfo -{ +typedef struct _NTA_MultiNodeInfoHandle { + char c; +} * NTA_MultiNodeInfoHandle; +typedef struct _NTA_MultiNodeInfo { /* functions */ NTA_Size (*getNodeCount)(NTA_MultiNodeInfoHandle handle); - const NTA_IndexRangeList * (*getInputList)(NTA_MultiNodeInfoHandle handle, const NTA_Byte* varName); - const NTA_Size * (*getOutputSizes)(NTA_NodeInfoHandle handle, const NTA_Byte* varName); + const NTA_IndexRangeList *(*getInputList)(NTA_MultiNodeInfoHandle handle, + const NTA_Byte *varName); + const NTA_Size *(*getOutputSizes)(NTA_NodeInfoHandle handle, + const NTA_Byte *varName); /* data members */ NTA_MultiNodeInfoHandle handle; - -} NTA_MultiNodeInfo; - +} NTA_MultiNodeInfo; /** ----------------------------------------- * - * I N P U T S I Z E M A P E N T R Y + * I N P U T S I Z E M A P E N T R Y * * ------------------------------------------ * * This struct represents a single entry in an input size map * It stores a name and a list of input sizes (one per each input range) - */ -typedef struct _NTA_InputSizeMapEntryHandle { char c; } * NTA_InputSizeMapEntryHandle; -typedef struct NTA_InputSizeMapEntry -{ + */ +typedef struct _NTA_InputSizeMapEntryHandle { + char c; +} * NTA_InputSizeMapEntryHandle; +typedef struct NTA_InputSizeMapEntry { /* data members */ - const NTA_Byte * name; - NTA_UInt32 count; - NTA_UInt32 * sizes; - -} NTA_InputSizeMapEntry; + const NTA_Byte *name; + NTA_UInt32 count; + NTA_UInt32 *sizes; +} NTA_InputSizeMapEntry; /** ----------------------------- * @@ -545,35 +564,35 @@ typedef struct NTA_InputSizeMapEntry * It contains entries and provides iterator-like accessor * as well as lookup by name accessor. */ -typedef struct _NTA_InputSizeMapHandle { char c; } * NTA_InputSizeMapHandle; -typedef struct NTA_InputSizeMap -{ +typedef struct _NTA_InputSizeMapHandle { + char c; +} * NTA_InputSizeMapHandle; +typedef struct NTA_InputSizeMap { /* functions */ void (*reset)(NTA_InputSizeMapHandle handle); - const NTA_InputSizeMapEntry * (*next)(NTA_InputSizeMapHandle handle); - const NTA_InputSizeMapEntry * (*lookup)(NTA_InputSizeMapHandle handle, const NTA_Byte * name); + const NTA_InputSizeMapEntry *(*next)(NTA_InputSizeMapHandle handle); + const NTA_InputSizeMapEntry *(*lookup)(NTA_InputSizeMapHandle handle, + const NTA_Byte *name); /* data members */ NTA_InputSizeMapHandle handle; - + } NTA_InputSizeMap; /** ------------------------------------------- * - * O U T P U T S I Z E M A P E N T R Y + * O U T P U T S I Z E M A P E N T R Y * * -------------------------------------------- * * This struct represents a single entry in an output size map * It stores an output name and the size of this output. */ -typedef struct NTA_OutputSizeMapEntry -{ - const NTA_Byte * name; - NTA_UInt32 size; - -} NTA_OutputSizeMapEntry; +typedef struct NTA_OutputSizeMapEntry { + const NTA_Byte *name; + NTA_UInt32 size; +} NTA_OutputSizeMapEntry; /** ------------------------------- * @@ -585,17 +604,18 @@ typedef struct NTA_OutputSizeMapEntry * It contains entries and provides iterator-like accessor * as well as lookup by name accessor. */ -typedef struct _NTA_OutputSizeMapHandle { char c; } * NTA_OutputSizeMapHandle; -typedef struct NTA_OutputSizeMap -{ +typedef struct _NTA_OutputSizeMapHandle { + char c; +} * NTA_OutputSizeMapHandle; +typedef struct NTA_OutputSizeMap { /* functions */ void (*reset)(NTA_OutputSizeMapHandle handle); - NTA_OutputSizeMapEntry * (*next)(NTA_OutputSizeMapHandle handle); - NTA_UInt32 (*lookup)(NTA_OutputSizeMapHandle handle, const NTA_Byte * name); + NTA_OutputSizeMapEntry *(*next)(NTA_OutputSizeMapHandle handle); + NTA_UInt32 (*lookup)(NTA_OutputSizeMapHandle handle, const NTA_Byte *name); /* data members */ NTA_OutputSizeMapHandle handle; - + } NTA_OutputSizeMap; /** ------------------------------------ @@ -604,28 +624,28 @@ typedef struct NTA_OutputSizeMap * * ------------------------------------- * - * This struct contains all the Information that + * This struct contains all the Information that * NTA_CreateInitialState needs: input sizes, output sizes * and a map of the initial parameters. */ -typedef struct _NTA_InitialStateInfoHandle { char c; } * NTA_InitialStateInfoHandle; -typedef struct _NTA_InitialStateInfo -{ +typedef struct _NTA_InitialStateInfoHandle { + char c; +} * NTA_InitialStateInfoHandle; +typedef struct _NTA_InitialStateInfo { /* functions */ - const NTA_Byte * (*getNodeType)(NTA_InitialStateInfoHandle handle); - const NTA_InputSizeMap * (*getInputSizes)(NTA_InitialStateInfoHandle handle); - const NTA_OutputSizeMap * (*getOutputSizes)(NTA_InitialStateInfoHandle handle); - const NTA_ParameterMap * (*getParameters)(NTA_InitialStateInfoHandle handle); - const NTA_MultiNodeInfo * (*getMultiNodeInfo)(NTA_InitialStateInfoHandle handle); + const NTA_Byte *(*getNodeType)(NTA_InitialStateInfoHandle handle); + const NTA_InputSizeMap *(*getInputSizes)(NTA_InitialStateInfoHandle handle); + const NTA_OutputSizeMap *(*getOutputSizes)(NTA_InitialStateInfoHandle handle); + const NTA_ParameterMap *(*getParameters)(NTA_InitialStateInfoHandle handle); + const NTA_MultiNodeInfo *(*getMultiNodeInfo)( + NTA_InitialStateInfoHandle handle); /* data members */ NTA_InitialStateInfoHandle handle; - -} NTA_InitialStateInfo; +} NTA_InitialStateInfo; -#ifdef __cplusplus +#ifdef __cplusplus } #endif #endif /* NTA_OBJECT_MODEL_H */ - diff --git a/src/nupic/ntypes/ObjectModel.hpp b/src/nupic/ntypes/ObjectModel.hpp index 56c7eac3a2..d5c45781e9 100644 --- a/src/nupic/ntypes/ObjectModel.hpp +++ b/src/nupic/ntypes/ObjectModel.hpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Interfaces for C++ runtime objects used by INode */ @@ -29,1230 +29,1204 @@ * Included because IWriteBuffer/IReadBuffer */ - #ifndef NTA_OBJECT_MODEL_HPP #define NTA_OBJECT_MODEL_HPP -#include #include -#include +#include #include +#include + +namespace nupic { +//--------------------------------------------------------------------------- +// +// I R E A D B U F F E R +// +//--------------------------------------------------------------------------- +/** + * @b Responsibility: + * Interface for reading values from a binary buffer + */ +//--------------------------------------------------------------------------- +struct IReadBuffer { + /** + * Virtual destructor. Required for abstract classes. + */ + virtual ~IReadBuffer() {} + + /** + * Reset the internal pointer to point to the beginning of the buffer. + * + * This is necessary if you want to read from the same buffer multiple + * times. + */ + virtual void reset() const = 0; + + /** + * Returns the size in bytes of the buffer's contents + * + * This is useful if you want to copy the entire buffer + * as a byte array. + * + * @retval number of bytes in the buffer + */ + virtual Size getSize() const = 0; + + /** + * Returns a pointer to the buffer's contents + * + * This is useful if you want to access the bytes directly. + * The returned buffer is not related in any way to the + * internal advancing pointer used in the read() methods. + * + * @retval pointer to beginning of the buffer + */ + virtual const Byte *getData() const = 0; + + /** + * Read a single byte into 'value' and advance the internal pointer. + * + * @param value the output byte. + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(Byte &value) const = 0; + + /** + * Read 'size' bytes into the 'value' array and advance + * the internal pointer. + * If the buffer contains less than 'size' bytes it will read + * as much as possible and write the number of bytes actually read + * into the 'size' argument. + * + * @param value the output buffer. Must not be NULL + * @param size the size of the output buffer. Must be >0. Receives + * the actual number of bytes read if success or 0 + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(Byte *value, Size &size) const = 0; + + /** + * Read a string into the string object provided in 'value'. + * The string object should be empty before begin passed in, + * or the result will be undefined. + * The string must have been written to the buffer using the + * IWriteBuffer::write(const std::string &) interface. + * The value of the string upon completion will be undefined on failure. + * Note that reading and writing a string is slightly different from + * reading or writing an arbitrary binary structure in the + * the following ways: + * * Reading/writing a 0-length string is a sensible operation. + * * The length of the string is almost never known ahead of time. + * + * @retval value A reference to a character array pointer (initially null). + * This array will point to a new buffer allocated with the + * provided allocator upon success. + * The caller is responsible for this memory. + * @retval size A reference to a size that will be filled in with the + * string length. + * @param fAlloc A function pointer that will be called to + * perform necessary allocation of value buffer. + * @param fDealloc A function pointer that will be called if a failure + * occurs after the value array has been allocated to + * cleanup allocated memory. + * @return 0 for success, -1 for failure + */ + virtual NTA_Int32 readString(NTA_Byte *&value, NTA_UInt32 &size, + NTA_Byte *(fAlloc)(NTA_UInt32 size), + void(fDealloc)(NTA_Byte *)) const = 0; + + /** + * Read a single integer (32 bits) into 'value' + * and advance the internal pointer. + * + * @param value the output integer. + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(Int32 &value) const = 0; + + /** + * Read 'size' Int32 elements into the 'value' array and advance + * the internal pointer. + * + * If the remaining buffer isn't long enough to contain 'size' elements, read + * as much as possible. + * + * @param value the output buffer. Must not be NULL + * @param size the size of the output buffer. Must be >0. + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(Int32 *value, Size size) const = 0; + + /** + * Read a single unsigned integer (32 bits) into 'value' + * and advance the internal pointer. + * + * @param value the output unsigned integer. + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(UInt32 &value) const = 0; + + /** + * Read 'size' UInt32 elements into the 'value' array and advance + * the internal pointer. + * + * If the remaining buffer isn't long enough to contain 'size' elements, read + * as much as possible. + * + * @param value the output buffer. Must not be NULL + * @param size the size of the output buffer. Must be >0. + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(UInt32 *value, Size size) const = 0; + + /** + * Read a single integer (64 bits) into 'value' + * and advance the internal pointer. + * + * @param value the output integer. + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(Int64 &value) const = 0; + + /** + * Read 'size' Int64 elements into the 'value' array and advance + * the internal pointer. + * + * If the remaining buffer isn't long enough to contain 'size' elements, read + * as much as possible. + * + * @param value the output buffer. Must not be NULL + * @param size the size of the output buffer. Must be >0. + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(Int64 *value, Size size) const = 0; + + /** + * Read a single unsigned integer (64 bits) into 'value' + * and advance the internal pointer. + * + * @param value the output unsigned integer. + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(UInt64 &value) const = 0; + + /** + * Read 'size' UInt64 elements into the 'value' array and advance + * the internal pointer. + * + * If the remaining buffer isn't long enough to contain 'size' elements, read + * as much as possible. + * + * @param value the output buffer. Must not be NULL + * @param size the size of the output buffer. Must be >0. + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(UInt64 *value, Size size) const = 0; + + /** + * Read a single 32-bit real number (float) + * into 'value' and advance the internal pointer. + * + * @param value the output real number. + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(Real32 &value) const = 0; + + /** + * Read 'size' Real32 elements into the 'value' array and advance + * the internal pointer. + * + * If the remaining buffer isn't long enough to contain 'size' elements, read + * as much as possible. + * + * @param value the output buffer. Must not be NULL + * @param size the size of the output buffer. Must be >0. + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(Real32 *value, Size size) const = 0; + + /** + * Read a single 64-bit real number (double) + * into 'value' and advance the internal pointer. + * + * @param value the output real number. + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(Real64 &value) const = 0; + + /** + * Read 'size' Real64 elements into the 'value' array and advance + * the internal pointer. + * + * If the remaining buffer isn't long enough to contain 'size' elements, read + * as much as possible. + * + * @param value the output buffer. Must not be NULL + * @param size the size of the output buffer. Must be >0. + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(Real64 *value, Size size) const = 0; + + /** + * Read a single bool (size is compiler-defined) into 'value' and advance the + * internal pointer. + * + * @param value the output bool. + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(bool &value) const = 0; + + /** + * Read 'size' bool elements into the 'value' array and advance + * the internal pointer. + * + * If the remaining buffer isn't long enough to contain 'size' elements, read + * as much as possible. + * + * @param value the output buffer. Must not be NULL + * @param size the size of the output buffer. Must be >0. + * @retval 0 for success, -1 for failure, 1 for EOF + */ + virtual Int32 read(bool *value, Size size) const = 0; +}; + +//--------------------------------------------------------------------------- +// +// I R E A D B U F F E R I T E R A T O R +// +//--------------------------------------------------------------------------- +/** + * @b Responsibility: + * Interface for iterating over a collection of IReadBuffer objects + */ +//--------------------------------------------------------------------------- +struct IReadBufferIterator { + /** + * Virtual destructor. Required for abstract classes. + */ + virtual ~IReadBufferIterator() {} + + /** + * Reset the internal pointer to point to the beginning of the iterator. + * + * The following next() will return the first IReadBuffer in the collection + * or NULL if the collection is empty. Multiple consecutive calls are allowed + * but have no effect. + */ + virtual void reset() = 0; + + /** + * Get the next buffer in the collection + * + * This method returns the buffer pointed to by the internal pointer and + * advances the pointer to the next buffer in the collection. If the + * collection is empty or previous call to next() returned the last buffer, + * further calls to next() will return NULL. + * + * @retval [IReadBuffer *] next buffer or NULL + */ + virtual const IReadBuffer *next() = 0; +}; + +//--------------------------------------------------------------------------- +// +// I W R I T E B U F F E R +// +//--------------------------------------------------------------------------- +/** + * @b Responsibility: + * Interface for writing values to a binary buffer + */ +//--------------------------------------------------------------------------- +struct IWriteBuffer { + /** + * Virtual destructor. Required for abstract classes. + */ + virtual ~IWriteBuffer() {} + + /** + * Write a single byte into + * the internal buffer. + * + * @param value the input byte. + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(Byte value) = 0; + + /** + * Write a byte array into + * the internal buffer. + * + * @param value the input array. + * @param size how many bytes to write + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(const Byte *value, Size size) = 0; + + /** + * Write the contents of a string into the stream. + * The string may be of 0 length and may contain any 8-bit characters + * (the string must be 1-byte encoded). + * Note that reading and writing a string is slightly different from + * reading or writing an arbitrary binary structure in the + * the following ways: + * * Reading/writing a 0-length string is a sensible operation. + * + * @param value the input array. + * @retval 0 for success, -1 for failure + */ + virtual Int32 writeString(const Byte *value, Size size) = 0; + + /** + * Write a single integer (32 bits) into + * the internal buffer. + * + * @param value the input integer. + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(Int32 value) = 0; + + /** + * Write array of Int32 elements into + * the internal buffer. + * + * @param value the input array. + * @param size how many bytes to write + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(const Int32 *value, Size size) = 0; + + /** + * Write a single unsigned integer (32 bits) into + * the internal buffer. + * + * @param value the input unsigned integer. + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(UInt32 value) = 0; + + /** + * Write array of UInt32 elements into + * the internal buffer. + * + * @param value the input array. + * @param size how many bytes to write + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(const UInt32 *value, Size size) = 0; + + /** + * Write a single integer (64 bits) into + * the internal buffer. + * + * @param value the input integer. + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(Int64 value) = 0; + + /** + * Write array of Int64 elements into + * the internal buffer. + * + * @param value the input array. + * @param size how many bytes to write + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(const Int64 *value, Size size) = 0; + + /** + * Write a single unsigned integer (64 bits) into + * the internal buffer. + * + * @param value the input unsigned integer. + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(UInt64 value) = 0; + + /** + * Write array of UInt64 elements into + * the internal buffer. + * + * @param value the input array. + * @param size how many bytes to write + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(const UInt64 *value, Size size) = 0; + + /** + * Write a single precision real (32 bits) into + * the internal buffer. + * + * @param value the input real number. + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(Real32 value) = 0; + + /** + * Write array of Real32 elements into + * the internal buffer. + * + * @param value the input array. + * @param size how many bytes to write + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(const Real32 *value, Size size) = 0; + + /** + * Write a double precision real (64 bits) into + * the internal buffer. + * + * @param value the input real number. + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(Real64 value) = 0; -namespace nupic -{ - //--------------------------------------------------------------------------- - // - // I R E A D B U F F E R - // - //--------------------------------------------------------------------------- - /** - * @b Responsibility: - * Interface for reading values from a binary buffer - */ - //--------------------------------------------------------------------------- - struct IReadBuffer - { - /** - * Virtual destructor. Required for abstract classes. - */ - virtual ~IReadBuffer() {} - - /** - * Reset the internal pointer to point to the beginning of the buffer. - * - * This is necessary if you want to read from the same buffer multiple - * times. - */ - virtual void reset() const = 0; - - /** - * Returns the size in bytes of the buffer's contents - * - * This is useful if you want to copy the entire buffer - * as a byte array. - * - * @retval number of bytes in the buffer - */ - virtual Size getSize() const = 0; - - /** - * Returns a pointer to the buffer's contents - * - * This is useful if you want to access the bytes directly. - * The returned buffer is not related in any way to the - * internal advancing pointer used in the read() methods. - * - * @retval pointer to beginning of the buffer - */ - virtual const Byte * getData() const = 0; - - /** - * Read a single byte into 'value' and advance the internal pointer. - * - * @param value the output byte. - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(Byte & value) const = 0; - - /** - * Read 'size' bytes into the 'value' array and advance - * the internal pointer. - * If the buffer contains less than 'size' bytes it will read - * as much as possible and write the number of bytes actually read - * into the 'size' argument. - * - * @param value the output buffer. Must not be NULL - * @param size the size of the output buffer. Must be >0. Receives - * the actual number of bytes read if success or 0 - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(Byte * value, Size & size) const = 0; - - /** - * Read a string into the string object provided in 'value'. - * The string object should be empty before begin passed in, - * or the result will be undefined. - * The string must have been written to the buffer using the - * IWriteBuffer::write(const std::string &) interface. - * The value of the string upon completion will be undefined on failure. - * Note that reading and writing a string is slightly different from - * reading or writing an arbitrary binary structure in the - * the following ways: - * * Reading/writing a 0-length string is a sensible operation. - * * The length of the string is almost never known ahead of time. - * - * @retval value A reference to a character array pointer (initially null). - * This array will point to a new buffer allocated with the - * provided allocator upon success. - * The caller is responsible for this memory. - * @retval size A reference to a size that will be filled in with the - * string length. - * @param fAlloc A function pointer that will be called to - * perform necessary allocation of value buffer. - * @param fDealloc A function pointer that will be called if a failure - * occurs after the value array has been allocated to - * cleanup allocated memory. - * @return 0 for success, -1 for failure - */ - virtual NTA_Int32 readString( - NTA_Byte * &value, - NTA_UInt32 &size, - NTA_Byte *(fAlloc)(NTA_UInt32 size), - void (fDealloc)(NTA_Byte *) - ) const = 0; - - /** - * Read a single integer (32 bits) into 'value' - * and advance the internal pointer. - * - * @param value the output integer. - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(Int32 & value) const = 0; - - /** - * Read 'size' Int32 elements into the 'value' array and advance - * the internal pointer. - * - * If the remaining buffer isn't long enough to contain 'size' elements, read - * as much as possible. - * - * @param value the output buffer. Must not be NULL - * @param size the size of the output buffer. Must be >0. - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(Int32 * value, Size size) const = 0; - - /** - * Read a single unsigned integer (32 bits) into 'value' - * and advance the internal pointer. - * - * @param value the output unsigned integer. - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(UInt32 & value) const = 0; - - /** - * Read 'size' UInt32 elements into the 'value' array and advance - * the internal pointer. - * - * If the remaining buffer isn't long enough to contain 'size' elements, read - * as much as possible. - * - * @param value the output buffer. Must not be NULL - * @param size the size of the output buffer. Must be >0. - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(UInt32 * value, Size size) const = 0; - - /** - * Read a single integer (64 bits) into 'value' - * and advance the internal pointer. - * - * @param value the output integer. - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(Int64 & value) const = 0; - - /** - * Read 'size' Int64 elements into the 'value' array and advance - * the internal pointer. - * - * If the remaining buffer isn't long enough to contain 'size' elements, read - * as much as possible. - * - * @param value the output buffer. Must not be NULL - * @param size the size of the output buffer. Must be >0. - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(Int64 * value, Size size) const = 0; - - /** - * Read a single unsigned integer (64 bits) into 'value' - * and advance the internal pointer. - * - * @param value the output unsigned integer. - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(UInt64 & value) const = 0; - - /** - * Read 'size' UInt64 elements into the 'value' array and advance - * the internal pointer. - * - * If the remaining buffer isn't long enough to contain 'size' elements, read - * as much as possible. - * - * @param value the output buffer. Must not be NULL - * @param size the size of the output buffer. Must be >0. - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(UInt64 * value, Size size) const = 0; - - - /** - * Read a single 32-bit real number (float) - * into 'value' and advance the internal pointer. - * - * @param value the output real number. - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(Real32 & value) const = 0; - - /** - * Read 'size' Real32 elements into the 'value' array and advance - * the internal pointer. - * - * If the remaining buffer isn't long enough to contain 'size' elements, read - * as much as possible. - * - * @param value the output buffer. Must not be NULL - * @param size the size of the output buffer. Must be >0. - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(Real32 * value, Size size) const = 0; - - /** - * Read a single 64-bit real number (double) - * into 'value' and advance the internal pointer. - * - * @param value the output real number. - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(Real64 & value) const = 0; - - /** - * Read 'size' Real64 elements into the 'value' array and advance - * the internal pointer. - * - * If the remaining buffer isn't long enough to contain 'size' elements, read - * as much as possible. - * - * @param value the output buffer. Must not be NULL - * @param size the size of the output buffer. Must be >0. - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(Real64 * value, Size size) const = 0; - - /** - * Read a single bool (size is compiler-defined) into 'value' and advance the - * internal pointer. - * - * @param value the output bool. - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(bool & value) const = 0; - - /** - * Read 'size' bool elements into the 'value' array and advance - * the internal pointer. - * - * If the remaining buffer isn't long enough to contain 'size' elements, read - * as much as possible. - * - * @param value the output buffer. Must not be NULL - * @param size the size of the output buffer. Must be >0. - * @retval 0 for success, -1 for failure, 1 for EOF - */ - virtual Int32 read(bool * value, Size size) const = 0; - }; - - //--------------------------------------------------------------------------- - // - // I R E A D B U F F E R I T E R A T O R - // - //--------------------------------------------------------------------------- - /** - * @b Responsibility: - * Interface for iterating over a collection of IReadBuffer objects - */ - //--------------------------------------------------------------------------- - struct IReadBufferIterator - { - /** - * Virtual destructor. Required for abstract classes. - */ - virtual ~IReadBufferIterator() {} - - /** - * Reset the internal pointer to point to the beginning of the iterator. - * - * The following next() will return the first IReadBuffer in the collection - * or NULL if the collection is empty. Multiple consecutive calls are allowed - * but have no effect. - */ - virtual void reset() = 0; - - /** - * Get the next buffer in the collection - * - * This method returns the buffer pointed to by the internal pointer and advances - * the pointer to the next buffer in the collection. If the collection is empty - * or previous call to next() returned the last buffer, further calls to next() - * will return NULL. - * - * @retval [IReadBuffer *] next buffer or NULL - */ - virtual const IReadBuffer * next() = 0; - }; - - - //--------------------------------------------------------------------------- - // - // I W R I T E B U F F E R - // - //--------------------------------------------------------------------------- - /** - * @b Responsibility: - * Interface for writing values to a binary buffer - */ - //--------------------------------------------------------------------------- - struct IWriteBuffer - { - /** - * Virtual destructor. Required for abstract classes. - */ - virtual ~IWriteBuffer() {} - - /** - * Write a single byte into - * the internal buffer. - * - * @param value the input byte. - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(Byte value) = 0; - - /** - * Write a byte array into - * the internal buffer. - * - * @param value the input array. - * @param size how many bytes to write - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(const Byte * value, Size size) = 0; - - /** - * Write the contents of a string into the stream. - * The string may be of 0 length and may contain any 8-bit characters - * (the string must be 1-byte encoded). - * Note that reading and writing a string is slightly different from - * reading or writing an arbitrary binary structure in the - * the following ways: - * * Reading/writing a 0-length string is a sensible operation. - * - * @param value the input array. - * @retval 0 for success, -1 for failure - */ - virtual Int32 writeString(const Byte * value, Size size) = 0; - - /** - * Write a single integer (32 bits) into - * the internal buffer. - * - * @param value the input integer. - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(Int32 value) = 0; - - /** - * Write array of Int32 elements into - * the internal buffer. - * - * @param value the input array. - * @param size how many bytes to write - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(const Int32 * value, Size size) = 0; - - /** - * Write a single unsigned integer (32 bits) into - * the internal buffer. - * - * @param value the input unsigned integer. - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(UInt32 value) = 0; - - /** - * Write array of UInt32 elements into - * the internal buffer. - * - * @param value the input array. - * @param size how many bytes to write - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(const UInt32 * value, Size size) = 0; - - /** - * Write a single integer (64 bits) into - * the internal buffer. - * - * @param value the input integer. - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(Int64 value) = 0; - - /** - * Write array of Int64 elements into - * the internal buffer. - * - * @param value the input array. - * @param size how many bytes to write - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(const Int64 * value, Size size) = 0; - - /** - * Write a single unsigned integer (64 bits) into - * the internal buffer. - * - * @param value the input unsigned integer. - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(UInt64 value) = 0; - - /** - * Write array of UInt64 elements into - * the internal buffer. - * - * @param value the input array. - * @param size how many bytes to write - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(const UInt64 * value, Size size) = 0; - - /** - * Write a single precision real (32 bits) into - * the internal buffer. - * - * @param value the input real number. - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(Real32 value) = 0; - - /** - * Write array of Real32 elements into - * the internal buffer. - * - * @param value the input array. - * @param size how many bytes to write - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(const Real32 * value, Size size) = 0; - - /** - * Write a double precision real (64 bits) into - * the internal buffer. - * - * @param value the input real number. - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(Real64 value) = 0; - - /** - * Write array of Real64 elements into - * the internal buffer. - * - * @param value the input array. - * @param size how many bytes to write - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(const Real64 * value, Size size) = 0; - - /** - * Write a bool (size is compiler-defined) into the internal buffer. - * - * @param value the input bool. - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(bool value) = 0; - - /** - * Write array of bool elements into the internal buffer. - * - * @param value the input array. - * @param size how many bytes to write - * @retval 0 for success, -1 for failure - */ - virtual Int32 write(const bool * value, Size size) = 0; - - /** - * Get the size in bytes of the contents of the internal - * buffer. - * - * @retval [Size] size in bytes of the buffer contents - */ - virtual Size getSize() = 0; - - /** - * A pointer to the internal buffer. - * - * The buffer is guarantueed to be contiguous. - * - * @retval [Byte *] internal buffer - */ - virtual const Byte * getData() = 0; - }; - - - //--------------------------------------------------------------------------- - // - // I R A N G E - // - //--------------------------------------------------------------------------- - /** - * - * @b Responsibility - * A base interface that defines the common operations for input - * and output ranges. Exposes the numer of elements (elementCount) - * and the size of each element in bytes (elementSize). - * - * @b Rationale - * Plain old reuse. both IInputRange and IOutputRange are derived - * from IRange. - * - */ - //--------------------------------------------------------------------------- - struct IRange - { - /** - * Virtual destructor. Required for abstract classes. - */ - virtual ~IRange() {} - - /** - * Get the number of elements in a range - * - * @retval [Size] number of elements - */ - virtual Size getElementCount() const = 0; - - /** - * Get the size of a single element in a range - * - * All elements in a range have the same size - * - * @retval [Size] size in bytes of a range element. - */ - virtual Size getElementSize() const = 0; - }; - - //--------------------------------------------------------------------------- - // - // I I N P U T R A N G E - // - //--------------------------------------------------------------------------- - /** - * - * @b Responsibility - * The input range interface. Provides access to a couple of iterator-like - * pointers to the beginning and end of the raange. The lag argument - * controls the offset (from which buffer to extract the pointers). - * @b Note - * begin() and end() return a const Byte *. It is the responsibility of the - * caller to cast it to te correct type. The memory is not suppposed to - * be modified only read. - */ - //--------------------------------------------------------------------------- - struct IInputRange : public IRange - { - /** - * Get the beginning pointer to the range's byte array - * - * @retval [Byte *] pointer to internal byte array. - */ - virtual const Byte * begin() const = 0; - - /** - * Get the end pointer to the range's byte array - * - * The end pointer is pointing to the byte immediately - * following the last byte in the internal byte array - * - * @retval [Byte *] the end pointer of the internal byte array. - */ - virtual const Byte * end() const = 0; - }; - - //--------------------------------------------------------------------------- - // - // I O U T P U T R A N G E - // - //--------------------------------------------------------------------------- - /** - * @b Responsibility - * The output range interface. Provides access to a couple of iterator-like - * pointers to the beginning and end of the raange. - * - * @b Note - * begin() and end() return a Byte *. It is the responsibility of the - * caller to cast it to te correct type. The memory can be written to of course. - */ - //--------------------------------------------------------------------------- - struct IOutputRange : public IRange - { - /** - * Get the beginning pointer to the range's byte array - * - * @retval [Byte *] pointer to internal byte array. - */ - virtual Byte * begin() = 0; - - /** - * Get the end pointer to the range's byte array - * - * The end pointer is pointing to the byte immediately - * following the last byte in the internal byte array - * - * @retval [Byte *] the end pointer of the internal byte array. - */ - virtual Byte * end() = 0; - }; - - - //--------------------------------------------------------------------------- - // - // I I N P U T R A N G E M A P E N T R Y - // - //--------------------------------------------------------------------------- - /** - * @b Responsibility - * The input range map entry interface. Each entry has a name and an input - * range iterator. That means that a whole collection on input ranges are - * accessible via the same name. - */ - //--------------------------------------------------------------------------- - struct IInputRangeMapEntry - { - /** - * Virtual destructor. Required for abstract classes. - */ - virtual ~IInputRangeMapEntry() {} - - /** - * Reset the internal pointer to point to the beginning of the input range iterator. - * - * The following next() will return the first IReadBuffer in the collection - * or NULL if the collection is empty. Multiple consecutive calls are allowed - * but have no effect. - */ - virtual void reset() const = 0; - - /** - * Get the next input range in the map entry - * - * This method returns the input range pointed to by the internal pointer and advances - * the pointer to the next input range in the map entry. If the collection is empty - * or previous call to next() returned the last input range, further calls to next() - * will return NULL. - * - * @retval [const IInputRange *] next input range or NULL - */ - virtual const IInputRange * next() const = 0; - - /** - * The name of the input range - */ - const Byte * name; - }; - - - //--------------------------------------------------------------------------- - // - // I I N P U T R A N G E M A P - // - //--------------------------------------------------------------------------- - /** - * @b Responsibility - * The input range map interface. Stores a collection of IInputMapRangeEntry - * objects. It provides lookup by name as well an iterator - * to iterate over all the entries. - */ - //--------------------------------------------------------------------------- - struct IInputRangeMap - { - /** - * Virtual destructor. Required for abstract classes. - */ - virtual ~IInputRangeMap() {} - - /** - * Reset the internal pointer to point to the first InputRangeMap entry. - * - * The following next() will return the first entry in the map or NULL - * if the collection is empty. Multiple consecutive calls are allowed - * but have no effect. - */ - virtual void reset() const = 0; - - /** - * Get the next InputRangeMap entry - * - * This method returns the InputRangeMap entry pointed to by the internal pointer - * and advances the pointer to the next entry. If the collection is empty - * or previous call to next() returned the last entry, further calls to next() - * will return NULL. - * - * @retval [const IInputRangeMapEntry *] next map entry or NULL - */ - virtual const IInputRangeMapEntry * next() const = 0; - - /** - * Get an InputRangeMap entry by name - * - * This method returns the InputRangeMap entry whose name matches the input name - * or NULL if an entry with this name is not in the map. lookup() calls - * don't affect the internal iterator pointer. - * - * @param [const Byte *] entry name to lookup. - * @retval [const IInputRangeMapEntry *] map entry or NULL - */ - virtual const IInputRangeMapEntry * lookup(const Byte * name) const = 0; - }; - - //--------------------------------------------------------------------------- - // - // I O U T P U T R A N G E M A P E N T R Y - // - //--------------------------------------------------------------------------- - /** - * @b Responsibility - * The output range map entry is just a named output range - */ - //--------------------------------------------------------------------------- - struct IOutputRangeMapEntry - { - /** - * The name of the output range - */ - const Byte * name; - - /** - * The output range - */ - IOutputRange * range; - }; - - - //--------------------------------------------------------------------------- - // - // I O U T P U T R A N G E M A P - // - //--------------------------------------------------------------------------- - /** - * @b Responsibility - * The output range map interface. Stores pairs of output [name, range]. - * It provides lookup by name as well as iterator-like - * methods (begin(), end()) to iterate over all entries - */ - //--------------------------------------------------------------------------- - struct IOutputRangeMap - { - /** - * Virtual destructor. Required for abstract classes. - */ - virtual ~IOutputRangeMap() {} - - /** - * Reset the internal pointer to point to the first OutputRangeMap entry. - * - * The following next() will return the first entry in the map or NULL - * if the collection is empty. Multiple consecutive calls are allowed - * but have no effect. - */ - virtual void reset() = 0; - - /** - * Get the next OutputRangeMap entry - * - * This method returns the OutputRangeMap entry pointed to by the internal pointer - * and advances the pointer to the next entry. If the collection is empty - * or previous call to next() returned the last entry, further calls to next() - * will return NULL. - * - * @retval [const IOutputRangeMapEntry *] next map entry or NULL - */ - virtual IOutputRangeMapEntry * next() = 0; - - /** - * Get an OutputRangeMap entry by name - * - * This method returns the OutputRangeMap entry whose name matches the - * requested name or NULL if an entry with this name is not in the map. - * lookup() calls don't affect the internal iterator pointer. - * - * @param [const Byte *] entry name to lookup. - * @retval [const IOutputRangeMapEntry *] map entry or NULL - */ - virtual IOutputRange * lookup(const Byte * name) = 0; - }; - - //--------------------------------------------------------------------------- - // - // I I N P U T - // - //--------------------------------------------------------------------------- - /** - * @b Responsibility - * The flattened input accessor. Provides easy access to the flattened input - * of a node or a specific baby node within a multi-node. - */ - //--------------------------------------------------------------------------- - struct IInput - { - enum {allNodes = -1}; - - /** - * Virtual destructor. Required for abstract classes. - */ - virtual ~IInput() {} - - /** - * Get the beginning pointer to the input's byte array - * - * @param nodeIdx [Int32] baby node index, or allNodes for entire input - * @param sentinelP [Byte*] pointer to default value to insert for elements of - * the node input that are outside the actual input bounds. - * @retval [Byte *] pointer to internal byte array. - */ - virtual const Byte * begin(Int32 nodeIdx=allNodes, const Byte* sentinelP=nullptr) = 0; - - /** - * Get the end pointer to the input's byte array - * - * The end pointer is pointing to the byte immediately - * following the last byte in the internal byte array - * - * @param nodeIdx [Int32] baby node index, or allNodes for entire input - * @retval [Byte *] the end pointer of the internal byte array. - */ - virtual const Byte * end(Int32 nodeIdx=allNodes) = 0; - - /** - * Get the number of elements in an input - * - * @param nodeIdx [Int32] baby node index, or allNodes for entire input - * @retval [Size] number of elements - */ - virtual Size getElementCount(Int32 nodeIdx=allNodes) = 0; - - /** - * Get the size of a single element in an input - * - * All elements in a range have the same size - * - * @retval [Size] size in bytes of a range element. - */ - virtual Size getElementSize() = 0; - - /** - * Get the number of links into a specific node - * - * @param nodeIdx [Int32] baby node index, or allNodes for entire input - * @retval [Size] number of links - */ - virtual Size getLinkCount(Int32 nodeIdx=allNodes) = 0; - - /** - * Get pointer to the link boundaries - * - * @param nodeIdx [Int32] baby node index, or allNodes for entire input - * @retval [Size] pointer to array of link boundaries - */ - virtual Size * getLinkBoundaries(Int32 nodeIdx=allNodes) = 0; - - }; - - - //--------------------------------------------------------------------------- - // - // I O U T P U T - // - //--------------------------------------------------------------------------- - /** - * @b Responsibility - * The easy output accessor. Provides easy access to the output - * of a node or a specific baby node within a multi-node. - */ - //--------------------------------------------------------------------------- - struct IOutput - { - enum {allNodes = -1}; - - /** - * Virtual destructor. Required for abstract classes. - */ - virtual ~IOutput() {} - - /** - * Get the beginning pointer to the input's byte array - * - * @param nodeIdx [Int32] baby node index, or allNodes for entire input - * @retval [Byte *] pointer to internal byte array. - */ - virtual Byte * begin(Int32 nodeIdx=allNodes) = 0; - - /** - * Get the end pointer to the input's byte array - * - * The end pointer is pointing to the byte immediately - * following the last byte in the internal byte array - * - * @param nodeIdx [Int32] baby node index, or allNodes for entire input - * @retval [Byte *] the end pointer of the internal byte array. - */ - virtual Byte * end(Int32 nodeIdx=allNodes) = 0; - - /** - * Get the number of elements in a range - * - * @param nodeIdx [Int32] baby node index, or allNodes for entire input - * @retval [Size] number of elements - */ - virtual Size getElementCount(Int32 nodeIdx=allNodes) = 0; - - /** - * Get the size of a single element in a range - * - * All elements in a range have the same size - * - * @retval [Size] size in bytes of a range element. - */ - virtual Size getElementSize() = 0; - - }; - - //--------------------------------------------------------------------------- - /** - * @b Responsibility - * Aggregate all the information a node needs for initialization. It includes - * id, name, logLevel, inputs, outputs and state. This struct is passed to - * INode::init() during the initialization of nodes. Note that multi-nodes - * (nodes that represent multiple "baby" nodes require additional information - * that is provided by the IMultiNodeInfo interface (see bellow). - */ - //--------------------------------------------------------------------------- - struct INodeInfo - { - /** - * Virtual destructor. Required for abstract classes. - */ - virtual ~INodeInfo() {} - - /** - * Return the type of the node - * - * A node is created dynamically by the runtime engine based on - * its type (used when registering a node type with the plugin manager). - * Making this type available to the node via INodeInfo - * saves the node from storing its type internally. It also ensures - * that there will not be a conflict between the registered type of - * a node and what the node thinks its type is. - * - * @retval [const Byte *] node type - */ - virtual const Byte * getType() = 0; - - /** - * Return the current log level of the node - * - * The node should take into account the log level every time - * it is about to log something. The log level can be modified - * externally by the user. The node should exercise judgment - * and call INodeInfo->getLogLevel() frequently - * (before every log statement or at the beginning of each compute()) - * - * @retval [const Byte *] node type - */ - virtual LogLevel getLogLevel() = 0; - - /** - * Return an object used to access the flattened input of a node. - * - * This method can be used to get easy access to a flattened version of any node input. - * It is much easier to use than the more primitive getInputs() call which can potentially - * return multiple input ranges that comprise the input. - * - * In addition, the IInput object allows you to easily get a pointer to the portion of - * the flattened input that corresponds to any particular baby node. - * - * @retval [IInput *] flattened node input object - */ - virtual IInput * getInput(const NTA_Byte* varName) = 0; - - /** - * Return an object used to access the output of a node. - * - * This method can be used to get easy access to any node output. - * It is much easier to use than the more primitive getOutputs() call for multi-nodes - * since it allows you to easily get a pointer to the portion of the output that - * corresponds to any particular baby node. - * - * @retval [IOutput *] node output object - */ - virtual IOutput * getOutput(const NTA_Byte* varName) = 0; - - - /** - * Return the inputs of the node - * - * The inputs are garantueed to be persistent over the lifetime - * of the node. That means that the node may call getInputs() - * multiple times and will always get the same answer. - * The contents of the inputs may change of course between - * calls to compute(), but the number of inputs, names and all - * other structural properties (including the memory area) - * are all fixed. - * - * @retval [IInputRangeMap &] node inputs - */ - virtual IInputRangeMap & getInputs() = 0; - - /** - * Return the outputs of the node - * - * The outputs are garantueed to be persistent over the lifetime - * of the node. That means that the node may call getOutputs() - * multiple times and will always get the same answer. - * The contents of the outputs may change of course as the node - * modifies them in compute(), but the number of outputs, names - * and all other structural properties (including the memory area) - * are all fixed. - * - * @retval [IOutputRangeMap &] node outputs - */ - virtual IOutputRangeMap & getOutputs() = 0; - - /** - * Return the serialized state of the node - * - * The state is used to initialize a node on the runtime side - * to an initial state. The initial state is created on the tools - * side and stored in serialized form in the network file. - * - * @retval [IReadBuffer &] initial node state - */ - virtual IReadBuffer & getState() = 0; - - /** - * Return the number of baby nodes in a multi-node. - * - * This method is only used for multi-nodes. - * It returns the number of baby nodes for this multi-node. - * - * @retval [NTA_Size] number of baby nodes in this multi-node. - */ - virtual NTA_Size getMNNodeCount() = 0; - - /** - * Return the Multi-node input list for a given input variable in a multi-node - * - * This method is only used for multi-nodes. - * It returns a pointer to an array of NTA_IndexRangeLists's, one per baby node. Each - * NTA_IndexRangeList contains a count and an array of NTA_IndexRange's. - * Each NTA_IndexRange has an offset and size, specifying the offset within the - * input variable 'varName', and the number of elements. - * - * @retval [NTA_IndexRangeList *] array of NTA_IndexRangeLists, one for each baby node - */ - virtual const NTA_IndexRangeList * getMNInputLists(const NTA_Byte* varName) = 0; - - /** - * Return the Multi-node output sizes for a given output variable of a multi-node - * - * This method is only used for multi-nodes. - * It returns a pointer to an array of sizes, one per baby node of the multi-node. - * - * @retval [IReadBuffer &] serialized SparseMatrix01 - */ - virtual const NTA_Size * getMNOutputSizes(const NTA_Byte* varName) = 0; - }; - - - //--------------------------------------------------------------------------- - // - // I I N P U T S I Z E M A P - // - //--------------------------------------------------------------------------- - /** - * This interface provides access to the input sizes of a node. - * It contains entries and provides iterator-like accessor - * as well as lookup by name accessor. - */ - //--------------------------------------------------------------------------- - struct IInputSizeMap - { - virtual ~IInputSizeMap() {} - virtual void reset()= 0; - virtual const NTA_InputSizeMapEntry * next() = 0; - virtual const NTA_InputSizeMapEntry * lookup(const Byte * name) = 0; - }; - - - //--------------------------------------------------------------------------- - // - // I O U T P U T S I Z E M A P - // - //--------------------------------------------------------------------------- - /** - * This interface provides access to the output sizes of a node. - * It contains entries and provides iterator-like accessor - * as well as lookup by name accessor. - */ - //--------------------------------------------------------------------------- - struct IOutputSizeMap - { - virtual ~IOutputSizeMap() {} - virtual void reset()= 0; - virtual const NTA_OutputSizeMapEntry * next() = 0; - virtual const NTA_OutputSizeMapEntry * lookup(const Byte * name) = 0; - }; - - - //--------------------------------------------------------------------------- - // - // I P A R A M E T E R M A P E N T R Y - // - //--------------------------------------------------------------------------- - /** - * @b Responsibility - * The parameter map entry interface. Each entry has a name and read buffer - * that contains the value of the parameter. - */ - //--------------------------------------------------------------------------- - struct IParameterMapEntry - { - /** - * The parameter name - */ - const Byte * name; - - /** - * The parameter value - */ - const IReadBuffer * value; - }; - - //--------------------------------------------------------------------------- - // - // I P A R A M E T E R M A P - // - //--------------------------------------------------------------------------- - /** - * @b Responsibility - * The parameter map interface. Stores pairs of [name, parameter] - * It provides lookup by name as well as iterator-like - * methods (begin(), end()) to iterate over all entries - */ - //--------------------------------------------------------------------------- - struct IParameterMap - { - /** - * Virtual destructor. Required for abstract classes. - */ - virtual ~IParameterMap() {} - - /** - * Reset the internal pointer to point to the first parameter entry. - * - * The following next() will return the first entry in the map or NULL - * if the map is empty. Multiple consecutive calls are allowed - * but have no effect. - */ - virtual void reset() const = 0; - - /** - * Get the next parameter entry - * - * This method returns the parameter entry pointed to by the internal pointer - * and advances the pointer to the next entry. If the collection is empty - * or previous call to next() returned the last entry, further calls to next() - * will return NULL. - * - * @retval [const IParameterMapEntry *] next map entry or NULL - */ - virtual const IParameterMapEntry * next() const = 0; - - /** - * Get a parameter by name - * - * This method returns the parameter whose name matches the - * requested name or NULL if an entry with this name is not in the map. - * lookup() calls don't affect the internal iterator pointer. - * - * @param [const Byte *] parameter name to lookup. - * @retval [const IReadBuffer *] map entry or NULL - */ - virtual const IReadBuffer * lookup(const Byte * name) const = 0; - }; - - - /** ------------------------------------ - * - * I N I T I A L S T A T E I N F O - * - * ------------------------------------- - * - * This struct contains all the Information that - * NTA_CreateInitialState needs: input sizes, output sizes - * and a map of the initial parameters. - */ - struct IInitialStateInfo - { - virtual ~IInitialStateInfo() {} - virtual const Byte * getNodeType() = 0; - virtual const IInputSizeMap & getInputSizes() = 0; - virtual const IOutputSizeMap & getOutputSizes() = 0; - virtual const IParameterMap & getParameters() = 0; - }; + /** + * Write array of Real64 elements into + * the internal buffer. + * + * @param value the input array. + * @param size how many bytes to write + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(const Real64 *value, Size size) = 0; + + /** + * Write a bool (size is compiler-defined) into the internal buffer. + * + * @param value the input bool. + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(bool value) = 0; + + /** + * Write array of bool elements into the internal buffer. + * + * @param value the input array. + * @param size how many bytes to write + * @retval 0 for success, -1 for failure + */ + virtual Int32 write(const bool *value, Size size) = 0; + + /** + * Get the size in bytes of the contents of the internal + * buffer. + * + * @retval [Size] size in bytes of the buffer contents + */ + virtual Size getSize() = 0; + + /** + * A pointer to the internal buffer. + * + * The buffer is guarantueed to be contiguous. + * + * @retval [Byte *] internal buffer + */ + virtual const Byte *getData() = 0; +}; + +//--------------------------------------------------------------------------- +// +// I R A N G E +// +//--------------------------------------------------------------------------- +/** + * + * @b Responsibility + * A base interface that defines the common operations for input + * and output ranges. Exposes the numer of elements (elementCount) + * and the size of each element in bytes (elementSize). + * + * @b Rationale + * Plain old reuse. both IInputRange and IOutputRange are derived + * from IRange. + * + */ +//--------------------------------------------------------------------------- +struct IRange { + /** + * Virtual destructor. Required for abstract classes. + */ + virtual ~IRange() {} + + /** + * Get the number of elements in a range + * + * @retval [Size] number of elements + */ + virtual Size getElementCount() const = 0; + + /** + * Get the size of a single element in a range + * + * All elements in a range have the same size + * + * @retval [Size] size in bytes of a range element. + */ + virtual Size getElementSize() const = 0; +}; + +//--------------------------------------------------------------------------- +// +// I I N P U T R A N G E +// +//--------------------------------------------------------------------------- +/** + * + * @b Responsibility + * The input range interface. Provides access to a couple of iterator-like + * pointers to the beginning and end of the raange. The lag argument + * controls the offset (from which buffer to extract the pointers). + * @b Note + * begin() and end() return a const Byte *. It is the responsibility of the + * caller to cast it to te correct type. The memory is not suppposed to + * be modified only read. + */ +//--------------------------------------------------------------------------- +struct IInputRange : public IRange { + /** + * Get the beginning pointer to the range's byte array + * + * @retval [Byte *] pointer to internal byte array. + */ + virtual const Byte *begin() const = 0; + + /** + * Get the end pointer to the range's byte array + * + * The end pointer is pointing to the byte immediately + * following the last byte in the internal byte array + * + * @retval [Byte *] the end pointer of the internal byte array. + */ + virtual const Byte *end() const = 0; +}; + +//--------------------------------------------------------------------------- +// +// I O U T P U T R A N G E +// +//--------------------------------------------------------------------------- +/** + * @b Responsibility + * The output range interface. Provides access to a couple of iterator-like + * pointers to the beginning and end of the raange. + * + * @b Note + * begin() and end() return a Byte *. It is the responsibility of the + * caller to cast it to te correct type. The memory can be written to of + * course. + */ +//--------------------------------------------------------------------------- +struct IOutputRange : public IRange { + /** + * Get the beginning pointer to the range's byte array + * + * @retval [Byte *] pointer to internal byte array. + */ + virtual Byte *begin() = 0; + + /** + * Get the end pointer to the range's byte array + * + * The end pointer is pointing to the byte immediately + * following the last byte in the internal byte array + * + * @retval [Byte *] the end pointer of the internal byte array. + */ + virtual Byte *end() = 0; +}; + +//--------------------------------------------------------------------------- +// +// I I N P U T R A N G E M A P E N T R Y +// +//--------------------------------------------------------------------------- +/** + * @b Responsibility + * The input range map entry interface. Each entry has a name and an input + * range iterator. That means that a whole collection on input ranges are + * accessible via the same name. + */ +//--------------------------------------------------------------------------- +struct IInputRangeMapEntry { + /** + * Virtual destructor. Required for abstract classes. + */ + virtual ~IInputRangeMapEntry() {} + + /** + * Reset the internal pointer to point to the beginning of the input range + * iterator. + * + * The following next() will return the first IReadBuffer in the collection + * or NULL if the collection is empty. Multiple consecutive calls are allowed + * but have no effect. + */ + virtual void reset() const = 0; + + /** + * Get the next input range in the map entry + * + * This method returns the input range pointed to by the internal pointer and + * advances the pointer to the next input range in the map entry. If the + * collection is empty or previous call to next() returned the last input + * range, further calls to next() will return NULL. + * + * @retval [const IInputRange *] next input range or NULL + */ + virtual const IInputRange *next() const = 0; + + /** + * The name of the input range + */ + const Byte *name; +}; + +//--------------------------------------------------------------------------- +// +// I I N P U T R A N G E M A P +// +//--------------------------------------------------------------------------- +/** + * @b Responsibility + * The input range map interface. Stores a collection of IInputMapRangeEntry + * objects. It provides lookup by name as well an iterator + * to iterate over all the entries. + */ +//--------------------------------------------------------------------------- +struct IInputRangeMap { + /** + * Virtual destructor. Required for abstract classes. + */ + virtual ~IInputRangeMap() {} + + /** + * Reset the internal pointer to point to the first InputRangeMap entry. + * + * The following next() will return the first entry in the map or NULL + * if the collection is empty. Multiple consecutive calls are allowed + * but have no effect. + */ + virtual void reset() const = 0; + + /** + * Get the next InputRangeMap entry + * + * This method returns the InputRangeMap entry pointed to by the internal + * pointer and advances the pointer to the next entry. If the collection is + * empty or previous call to next() returned the last entry, further calls to + * next() will return NULL. + * + * @retval [const IInputRangeMapEntry *] next map entry or NULL + */ + virtual const IInputRangeMapEntry *next() const = 0; + + /** + * Get an InputRangeMap entry by name + * + * This method returns the InputRangeMap entry whose name matches the input + * name or NULL if an entry with this name is not in the map. lookup() calls + * don't affect the internal iterator pointer. + * + * @param [const Byte *] entry name to lookup. + * @retval [const IInputRangeMapEntry *] map entry or NULL + */ + virtual const IInputRangeMapEntry *lookup(const Byte *name) const = 0; +}; + +//--------------------------------------------------------------------------- +// +// I O U T P U T R A N G E M A P E N T R Y +// +//--------------------------------------------------------------------------- +/** + * @b Responsibility + * The output range map entry is just a named output range + */ +//--------------------------------------------------------------------------- +struct IOutputRangeMapEntry { + /** + * The name of the output range + */ + const Byte *name; + + /** + * The output range + */ + IOutputRange *range; +}; + +//--------------------------------------------------------------------------- +// +// I O U T P U T R A N G E M A P +// +//--------------------------------------------------------------------------- +/** + * @b Responsibility + * The output range map interface. Stores pairs of output [name, range]. + * It provides lookup by name as well as iterator-like + * methods (begin(), end()) to iterate over all entries + */ +//--------------------------------------------------------------------------- +struct IOutputRangeMap { + /** + * Virtual destructor. Required for abstract classes. + */ + virtual ~IOutputRangeMap() {} + + /** + * Reset the internal pointer to point to the first OutputRangeMap entry. + * + * The following next() will return the first entry in the map or NULL + * if the collection is empty. Multiple consecutive calls are allowed + * but have no effect. + */ + virtual void reset() = 0; + + /** + * Get the next OutputRangeMap entry + * + * This method returns the OutputRangeMap entry pointed to by the internal + * pointer and advances the pointer to the next entry. If the collection is + * empty or previous call to next() returned the last entry, further calls to + * next() will return NULL. + * + * @retval [const IOutputRangeMapEntry *] next map entry or NULL + */ + virtual IOutputRangeMapEntry *next() = 0; + + /** + * Get an OutputRangeMap entry by name + * + * This method returns the OutputRangeMap entry whose name matches the + * requested name or NULL if an entry with this name is not in the map. + * lookup() calls don't affect the internal iterator pointer. + * + * @param [const Byte *] entry name to lookup. + * @retval [const IOutputRangeMapEntry *] map entry or NULL + */ + virtual IOutputRange *lookup(const Byte *name) = 0; +}; + +//--------------------------------------------------------------------------- +// +// I I N P U T +// +//--------------------------------------------------------------------------- +/** + * @b Responsibility + * The flattened input accessor. Provides easy access to the flattened input + * of a node or a specific baby node within a multi-node. + */ +//--------------------------------------------------------------------------- +struct IInput { + enum { allNodes = -1 }; + + /** + * Virtual destructor. Required for abstract classes. + */ + virtual ~IInput() {} + + /** + * Get the beginning pointer to the input's byte array + * + * @param nodeIdx [Int32] baby node index, or allNodes for entire input + * @param sentinelP [Byte*] pointer to default value to insert for elements + * of the node input that are outside the actual input bounds. + * @retval [Byte *] pointer to internal byte array. + */ + virtual const Byte *begin(Int32 nodeIdx = allNodes, + const Byte *sentinelP = nullptr) = 0; + + /** + * Get the end pointer to the input's byte array + * + * The end pointer is pointing to the byte immediately + * following the last byte in the internal byte array + * + * @param nodeIdx [Int32] baby node index, or allNodes for entire input + * @retval [Byte *] the end pointer of the internal byte array. + */ + virtual const Byte *end(Int32 nodeIdx = allNodes) = 0; + + /** + * Get the number of elements in an input + * + * @param nodeIdx [Int32] baby node index, or allNodes for entire input + * @retval [Size] number of elements + */ + virtual Size getElementCount(Int32 nodeIdx = allNodes) = 0; + + /** + * Get the size of a single element in an input + * + * All elements in a range have the same size + * + * @retval [Size] size in bytes of a range element. + */ + virtual Size getElementSize() = 0; + + /** + * Get the number of links into a specific node + * + * @param nodeIdx [Int32] baby node index, or allNodes for entire input + * @retval [Size] number of links + */ + virtual Size getLinkCount(Int32 nodeIdx = allNodes) = 0; + + /** + * Get pointer to the link boundaries + * + * @param nodeIdx [Int32] baby node index, or allNodes for entire input + * @retval [Size] pointer to array of link boundaries + */ + virtual Size *getLinkBoundaries(Int32 nodeIdx = allNodes) = 0; +}; + +//--------------------------------------------------------------------------- +// +// I O U T P U T +// +//--------------------------------------------------------------------------- +/** + * @b Responsibility + * The easy output accessor. Provides easy access to the output + * of a node or a specific baby node within a multi-node. + */ +//--------------------------------------------------------------------------- +struct IOutput { + enum { allNodes = -1 }; + + /** + * Virtual destructor. Required for abstract classes. + */ + virtual ~IOutput() {} + + /** + * Get the beginning pointer to the input's byte array + * + * @param nodeIdx [Int32] baby node index, or allNodes for entire input + * @retval [Byte *] pointer to internal byte array. + */ + virtual Byte *begin(Int32 nodeIdx = allNodes) = 0; + + /** + * Get the end pointer to the input's byte array + * + * The end pointer is pointing to the byte immediately + * following the last byte in the internal byte array + * + * @param nodeIdx [Int32] baby node index, or allNodes for entire input + * @retval [Byte *] the end pointer of the internal byte array. + */ + virtual Byte *end(Int32 nodeIdx = allNodes) = 0; + + /** + * Get the number of elements in a range + * + * @param nodeIdx [Int32] baby node index, or allNodes for entire input + * @retval [Size] number of elements + */ + virtual Size getElementCount(Int32 nodeIdx = allNodes) = 0; + + /** + * Get the size of a single element in a range + * + * All elements in a range have the same size + * + * @retval [Size] size in bytes of a range element. + */ + virtual Size getElementSize() = 0; +}; + +//--------------------------------------------------------------------------- +/** + * @b Responsibility + * Aggregate all the information a node needs for initialization. It includes + * id, name, logLevel, inputs, outputs and state. This struct is passed to + * INode::init() during the initialization of nodes. Note that multi-nodes + * (nodes that represent multiple "baby" nodes require additional information + * that is provided by the IMultiNodeInfo interface (see bellow). + */ +//--------------------------------------------------------------------------- +struct INodeInfo { + /** + * Virtual destructor. Required for abstract classes. + */ + virtual ~INodeInfo() {} + + /** + * Return the type of the node + * + * A node is created dynamically by the runtime engine based on + * its type (used when registering a node type with the plugin manager). + * Making this type available to the node via INodeInfo + * saves the node from storing its type internally. It also ensures + * that there will not be a conflict between the registered type of + * a node and what the node thinks its type is. + * + * @retval [const Byte *] node type + */ + virtual const Byte *getType() = 0; + + /** + * Return the current log level of the node + * + * The node should take into account the log level every time + * it is about to log something. The log level can be modified + * externally by the user. The node should exercise judgment + * and call INodeInfo->getLogLevel() frequently + * (before every log statement or at the beginning of each compute()) + * + * @retval [const Byte *] node type + */ + virtual LogLevel getLogLevel() = 0; + + /** + * Return an object used to access the flattened input of a node. + * + * This method can be used to get easy access to a flattened version of any + * node input. It is much easier to use than the more primitive getInputs() + * call which can potentially return multiple input ranges that comprise the + * input. + * + * In addition, the IInput object allows you to easily get a pointer to the + * portion of the flattened input that corresponds to any particular baby + * node. + * + * @retval [IInput *] flattened node input object + */ + virtual IInput *getInput(const NTA_Byte *varName) = 0; + + /** + * Return an object used to access the output of a node. + * + * This method can be used to get easy access to any node output. + * It is much easier to use than the more primitive getOutputs() call for + * multi-nodes since it allows you to easily get a pointer to the portion of + * the output that corresponds to any particular baby node. + * + * @retval [IOutput *] node output object + */ + virtual IOutput *getOutput(const NTA_Byte *varName) = 0; + + /** + * Return the inputs of the node + * + * The inputs are garantueed to be persistent over the lifetime + * of the node. That means that the node may call getInputs() + * multiple times and will always get the same answer. + * The contents of the inputs may change of course between + * calls to compute(), but the number of inputs, names and all + * other structural properties (including the memory area) + * are all fixed. + * + * @retval [IInputRangeMap &] node inputs + */ + virtual IInputRangeMap &getInputs() = 0; + + /** + * Return the outputs of the node + * + * The outputs are garantueed to be persistent over the lifetime + * of the node. That means that the node may call getOutputs() + * multiple times and will always get the same answer. + * The contents of the outputs may change of course as the node + * modifies them in compute(), but the number of outputs, names + * and all other structural properties (including the memory area) + * are all fixed. + * + * @retval [IOutputRangeMap &] node outputs + */ + virtual IOutputRangeMap &getOutputs() = 0; + + /** + * Return the serialized state of the node + * + * The state is used to initialize a node on the runtime side + * to an initial state. The initial state is created on the tools + * side and stored in serialized form in the network file. + * + * @retval [IReadBuffer &] initial node state + */ + virtual IReadBuffer &getState() = 0; + + /** + * Return the number of baby nodes in a multi-node. + * + * This method is only used for multi-nodes. + * It returns the number of baby nodes for this multi-node. + * + * @retval [NTA_Size] number of baby nodes in this multi-node. + */ + virtual NTA_Size getMNNodeCount() = 0; + + /** + * Return the Multi-node input list for a given input variable in a multi-node + * + * This method is only used for multi-nodes. + * It returns a pointer to an array of NTA_IndexRangeLists's, one per baby + * node. Each NTA_IndexRangeList contains a count and an array of + * NTA_IndexRange's. Each NTA_IndexRange has an offset and size, specifying + * the offset within the input variable 'varName', and the number of elements. + * + * @retval [NTA_IndexRangeList *] array of NTA_IndexRangeLists, one for each + * baby node + */ + virtual const NTA_IndexRangeList * + getMNInputLists(const NTA_Byte *varName) = 0; + + /** + * Return the Multi-node output sizes for a given output variable of a + * multi-node + * + * This method is only used for multi-nodes. + * It returns a pointer to an array of sizes, one per baby node of the + * multi-node. + * + * @retval [IReadBuffer &] serialized SparseMatrix01 + */ + virtual const NTA_Size *getMNOutputSizes(const NTA_Byte *varName) = 0; +}; + +//--------------------------------------------------------------------------- +// +// I I N P U T S I Z E M A P +// +//--------------------------------------------------------------------------- +/** + * This interface provides access to the input sizes of a node. + * It contains entries and provides iterator-like accessor + * as well as lookup by name accessor. + */ +//--------------------------------------------------------------------------- +struct IInputSizeMap { + virtual ~IInputSizeMap() {} + virtual void reset() = 0; + virtual const NTA_InputSizeMapEntry *next() = 0; + virtual const NTA_InputSizeMapEntry *lookup(const Byte *name) = 0; +}; + +//--------------------------------------------------------------------------- +// +// I O U T P U T S I Z E M A P +// +//--------------------------------------------------------------------------- +/** + * This interface provides access to the output sizes of a node. + * It contains entries and provides iterator-like accessor + * as well as lookup by name accessor. + */ +//--------------------------------------------------------------------------- +struct IOutputSizeMap { + virtual ~IOutputSizeMap() {} + virtual void reset() = 0; + virtual const NTA_OutputSizeMapEntry *next() = 0; + virtual const NTA_OutputSizeMapEntry *lookup(const Byte *name) = 0; +}; + +//--------------------------------------------------------------------------- +// +// I P A R A M E T E R M A P E N T R Y +// +//--------------------------------------------------------------------------- +/** + * @b Responsibility + * The parameter map entry interface. Each entry has a name and read buffer + * that contains the value of the parameter. + */ +//--------------------------------------------------------------------------- +struct IParameterMapEntry { + /** + * The parameter name + */ + const Byte *name; + + /** + * The parameter value + */ + const IReadBuffer *value; +}; + +//--------------------------------------------------------------------------- +// +// I P A R A M E T E R M A P +// +//--------------------------------------------------------------------------- +/** + * @b Responsibility + * The parameter map interface. Stores pairs of [name, parameter] + * It provides lookup by name as well as iterator-like + * methods (begin(), end()) to iterate over all entries + */ +//--------------------------------------------------------------------------- +struct IParameterMap { + /** + * Virtual destructor. Required for abstract classes. + */ + virtual ~IParameterMap() {} + + /** + * Reset the internal pointer to point to the first parameter entry. + * + * The following next() will return the first entry in the map or NULL + * if the map is empty. Multiple consecutive calls are allowed + * but have no effect. + */ + virtual void reset() const = 0; + + /** + * Get the next parameter entry + * + * This method returns the parameter entry pointed to by the internal pointer + * and advances the pointer to the next entry. If the collection is empty + * or previous call to next() returned the last entry, further calls to next() + * will return NULL. + * + * @retval [const IParameterMapEntry *] next map entry or NULL + */ + virtual const IParameterMapEntry *next() const = 0; + + /** + * Get a parameter by name + * + * This method returns the parameter whose name matches the + * requested name or NULL if an entry with this name is not in the map. + * lookup() calls don't affect the internal iterator pointer. + * + * @param [const Byte *] parameter name to lookup. + * @retval [const IReadBuffer *] map entry or NULL + */ + virtual const IReadBuffer *lookup(const Byte *name) const = 0; +}; + +/** ------------------------------------ + * + * I N I T I A L S T A T E I N F O + * + * ------------------------------------- + * + * This struct contains all the Information that + * NTA_CreateInitialState needs: input sizes, output sizes + * and a map of the initial parameters. + */ +struct IInitialStateInfo { + virtual ~IInitialStateInfo() {} + virtual const Byte *getNodeType() = 0; + virtual const IInputSizeMap &getInputSizes() = 0; + virtual const IOutputSizeMap &getOutputSizes() = 0; + virtual const IParameterMap &getParameters() = 0; +}; // //--------------------------------------------------------------------------- // /** // * @b Responsibility -// * Aggregate the additional information a multi-node needs for initialization. -// * It includes the number of baby nodes and index ranges of each baby node into +// * Aggregate the additional information a multi-node needs for +// initialization. +// * It includes the number of baby nodes and index ranges of each baby node +// into // * the multi-node inputs and outputs. This struct is passed to // * INode::init() during the initialization of nodes. // */ @@ -1263,53 +1237,59 @@ namespace nupic // * Virtual destructor. Required for abstract classes. // */ // virtual ~IMultiNodeInfo() {} -// +// // /** -// * Return the number of baby nodes in a multi-node. +// * Return the number of baby nodes in a multi-node. // * -// * This method is only used for multi-nodes. -// * It returns the number of baby nodes for this multi-node. +// * This method is only used for multi-nodes. +// * It returns the number of baby nodes for this multi-node. // * -// * @retval [NTA_Size] number of baby nodes in this multi-node. -// */ +// * @retval [NTA_Size] number of baby nodes in this multi-node. +// */ // virtual NTA_Size getNodeCount() = 0; -// +// // /** -// * Return the Multi-node input list for a given input variable in a multi-node +// * Return the Multi-node input list for a given input variable in a +// multi-node // * -// * This method is only used for multi-nodes. -// * It returns a pointer to an array of NTA_IndexRangeLists's, one per baby node. Each -// * NTA_IndexRangeList contains a count and an array of NTA_IndexRange's. -// * Each NTA_IndexRange has an offset and size, specifying the offset within the -// * input variable 'varName', and the number of elements. +// * This method is only used for multi-nodes. +// * It returns a pointer to an array of NTA_IndexRangeLists's, one per baby +// node. Each +// * NTA_IndexRangeList contains a count and an array of NTA_IndexRange's. +// * Each NTA_IndexRange has an offset and size, specifying the offset within +// the +// * input variable 'varName', and the number of elements. // * -// * @retval [NTA_IndexRangeList *] array of NTA_IndexRangeLists, one for each baby node -// */ -// virtual const NTA_IndexRangeList * getInputList(const NTA_Byte* varName) = 0; -// +// * @retval [NTA_IndexRangeList *] array of NTA_IndexRangeLists, one for +// each baby node +// */ +// virtual const NTA_IndexRangeList * getInputList(const NTA_Byte* varName) = +// 0; +// // /** -// * Return the Multi-node output sizes for a given output variable of a multi-node +// * Return the Multi-node output sizes for a given output variable of a +// multi-node // * -// * This method is only used for multi-nodes. -// * It returns a pointer to an array of sizes, one per baby node of the multi-node. +// * This method is only used for multi-nodes. +// * It returns a pointer to an array of sizes, one per baby node of the +// multi-node. // * // * @retval [IReadBuffer &] serialized SparseMatrix01 -// */ +// */ // virtual const NTA_Size * getOutputSizes(const NTA_Byte* varName) = 0; // }; -inline NTA_Byte *_ReadString_alloc(NTA_UInt32 size) - { return new NTA_Byte[size]; } -inline void _ReadString_dealloc(NTA_Byte *p) - { delete[] p; } +inline NTA_Byte *_ReadString_alloc(NTA_UInt32 size) { + return new NTA_Byte[size]; +} +inline void _ReadString_dealloc(NTA_Byte *p) { delete[] p; } -inline std::string ReadStringFromBuffer(const IReadBuffer &buf) -{ +inline std::string ReadStringFromBuffer(const IReadBuffer &buf) { NTA_Byte *value = nullptr; NTA_UInt32 size = 0; - NTA_Int32 result = buf.readString(value, size, - _ReadString_alloc, _ReadString_dealloc); - if(result != 0) + NTA_Int32 result = + buf.readString(value, size, _ReadString_alloc, _ReadString_dealloc); + if (result != 0) throw std::runtime_error("Failed to read string from stream."); std::string toReturn(value, size); // Real fps must be provided to use delete here. @@ -1317,11 +1297,6 @@ inline std::string ReadStringFromBuffer(const IReadBuffer &buf) return toReturn; } - - } // end namespace nupic #endif // NTA_OBJECT_MODEL_HPP - - - diff --git a/src/nupic/ntypes/Scalar.cpp b/src/nupic/ntypes/Scalar.cpp index c9e7efedf7..b4e041313c 100644 --- a/src/nupic/ntypes/Scalar.cpp +++ b/src/nupic/ntypes/Scalar.cpp @@ -20,11 +20,11 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Implementation of the Scalar class - * + * * A Scalar object is an instance of an NTA_BasicType -- essentially a union - * It is used internally in the conversion of YAML strings to C++ objects. + * It is used internally in the conversion of YAML strings to C++ objects. */ #include @@ -32,79 +32,59 @@ using namespace nupic; -Scalar::Scalar(NTA_BasicType theTypeParam) -{ +Scalar::Scalar(NTA_BasicType theTypeParam) { theType_ = theTypeParam; value.uint64 = 0; } -NTA_BasicType -Scalar::getType() -{ - return theType_; -} +NTA_BasicType Scalar::getType() { return theType_; } -// gcc 4.2 complains about the template specializations +// gcc 4.2 complains about the template specializations // in a different namespace if we don't include this namespace nupic { - - template <> Handle Scalar::getValue() const - { - NTA_CHECK(theType_ == NTA_BasicType_Handle); - return value.handle; - } - template <> Byte Scalar::getValue() const - { - NTA_CHECK(theType_ == NTA_BasicType_Byte); - return value.byte; - } - template <> UInt16 Scalar::getValue() const - { - NTA_CHECK(theType_ == NTA_BasicType_UInt16); - return value.uint16; - } - template <> Int16 Scalar::getValue() const - { - NTA_CHECK(theType_ == NTA_BasicType_Int16); - return value.int16; - } - template <> UInt32 Scalar::getValue() const - { - NTA_CHECK(theType_ == NTA_BasicType_UInt32); - return value.uint32; - } - template <> Int32 Scalar::getValue() const - { - NTA_CHECK(theType_ == NTA_BasicType_Int32); - return value.int32; - } - template <> UInt64 Scalar::getValue() const - { - NTA_CHECK(theType_ == NTA_BasicType_UInt64); - return value.uint64; - } - template <> Int64 Scalar::getValue() const - { - NTA_CHECK(theType_ == NTA_BasicType_Int64); - return value.int64; - } - template <> Real32 Scalar::getValue() const - { - NTA_CHECK(theType_ == NTA_BasicType_Real32); - return value.real32; - } - template <> Real64 Scalar::getValue() const - { - NTA_CHECK(theType_ == NTA_BasicType_Real64); - return value.real64; - } - template <> bool Scalar::getValue() const - { - NTA_CHECK(theType_ == NTA_BasicType_Bool); - return value.boolean; - } -} - - - +template <> Handle Scalar::getValue() const { + NTA_CHECK(theType_ == NTA_BasicType_Handle); + return value.handle; +} +template <> Byte Scalar::getValue() const { + NTA_CHECK(theType_ == NTA_BasicType_Byte); + return value.byte; +} +template <> UInt16 Scalar::getValue() const { + NTA_CHECK(theType_ == NTA_BasicType_UInt16); + return value.uint16; +} +template <> Int16 Scalar::getValue() const { + NTA_CHECK(theType_ == NTA_BasicType_Int16); + return value.int16; +} +template <> UInt32 Scalar::getValue() const { + NTA_CHECK(theType_ == NTA_BasicType_UInt32); + return value.uint32; +} +template <> Int32 Scalar::getValue() const { + NTA_CHECK(theType_ == NTA_BasicType_Int32); + return value.int32; +} +template <> UInt64 Scalar::getValue() const { + NTA_CHECK(theType_ == NTA_BasicType_UInt64); + return value.uint64; +} +template <> Int64 Scalar::getValue() const { + NTA_CHECK(theType_ == NTA_BasicType_Int64); + return value.int64; +} +template <> Real32 Scalar::getValue() const { + NTA_CHECK(theType_ == NTA_BasicType_Real32); + return value.real32; +} +template <> Real64 Scalar::getValue() const { + NTA_CHECK(theType_ == NTA_BasicType_Real64); + return value.real64; +} +template <> bool Scalar::getValue() const { + NTA_CHECK(theType_ == NTA_BasicType_Bool); + return value.boolean; +} +} // namespace nupic diff --git a/src/nupic/ntypes/Scalar.hpp b/src/nupic/ntypes/Scalar.hpp index eb02365ce0..897aa795e5 100644 --- a/src/nupic/ntypes/Scalar.hpp +++ b/src/nupic/ntypes/Scalar.hpp @@ -20,11 +20,11 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Definitions for the Scalar class - * + * * A Scalar object is an instance of an NTA_BasicType -- essentially a union - * It is used internally in the conversion of YAML strings to C++ objects. + * It is used internally in the conversion of YAML strings to C++ objects. */ #ifndef NTA_SCALAR_HPP @@ -34,41 +34,33 @@ #include // temporary, while implementation is in hpp #include -namespace nupic -{ - class Scalar - { - public: - Scalar(NTA_BasicType theTypeParam); - - NTA_BasicType getType(); - - template T getValue() const; - +namespace nupic { +class Scalar { +public: + Scalar(NTA_BasicType theTypeParam); - union { - NTA_Handle handle; - NTA_Byte byte; - NTA_Int16 int16; - NTA_UInt16 uint16; - NTA_Int32 int32; - NTA_UInt32 uint32; - NTA_Int64 int64; - NTA_UInt64 uint64; - NTA_Real32 real32; - NTA_Real64 real64; - bool boolean; - } value; + NTA_BasicType getType(); + template T getValue() const; - private: - NTA_BasicType theType_; + union { + NTA_Handle handle; + NTA_Byte byte; + NTA_Int16 int16; + NTA_UInt16 uint16; + NTA_Int32 int32; + NTA_UInt32 uint32; + NTA_Int64 int64; + NTA_UInt64 uint64; + NTA_Real32 real32; + NTA_Real64 real64; + bool boolean; + } value; - }; +private: + NTA_BasicType theType_; +}; -} +} // namespace nupic #endif // NTA_SCALAR_HPP - - - diff --git a/src/nupic/ntypes/Value.cpp b/src/nupic/ntypes/Value.cpp index 102ec58195..2abd237489 100644 --- a/src/nupic/ntypes/Value.cpp +++ b/src/nupic/ntypes/Value.cpp @@ -20,56 +20,38 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Implementation of the Value class */ - #include #include - using namespace nupic; - - -Value::Value(boost::shared_ptr& s) -{ +Value::Value(boost::shared_ptr &s) { category_ = scalarCategory; scalar_ = s; } -Value::Value(boost::shared_ptr& a) -{ +Value::Value(boost::shared_ptr &a) { category_ = arrayCategory; array_ = a; } -Value::Value(boost::shared_ptr& s) -{ +Value::Value(boost::shared_ptr &s) { category_ = stringCategory; string_ = s; } -bool Value::isScalar() const -{ - return category_ == scalarCategory; -} +bool Value::isScalar() const { return category_ == scalarCategory; } -bool Value::isArray() const -{ - return category_ == arrayCategory; -} +bool Value::isArray() const { return category_ == arrayCategory; } -bool Value::isString() const -{ - return category_ == stringCategory; -} +bool Value::isString() const { return category_ == stringCategory; } -NTA_BasicType Value::getType() const -{ - switch (category_) - { +NTA_BasicType Value::getType() const { + switch (category_) { case scalarCategory: return scalar_->getType(); break; @@ -83,57 +65,50 @@ NTA_BasicType Value::getType() const } } -boost::shared_ptr Value::getScalar() const -{ +boost::shared_ptr Value::getScalar() const { NTA_CHECK(category_ == scalarCategory); return scalar_; } -boost::shared_ptr Value::getArray() const -{ +boost::shared_ptr Value::getArray() const { NTA_CHECK(category_ == arrayCategory); return array_; } -boost::shared_ptr Value::getString() const -{ +boost::shared_ptr Value::getString() const { NTA_CHECK(category_ == stringCategory); return string_; } -template T Value::getScalarT() const -{ +template T Value::getScalarT() const { NTA_CHECK(category_ == scalarCategory); - if (BasicType::getType() != scalar_->getType()) - { - NTA_THROW << "Attempt to access scalar of type " - << BasicType::getName(scalar_->getType()) - << " as type " << BasicType::getName(); + if (BasicType::getType() != scalar_->getType()) { + NTA_THROW << "Attempt to access scalar of type " + << BasicType::getName(scalar_->getType()) << " as type " + << BasicType::getName(); } return scalar_->getValue(); } -const std::string Value::getDescription() const -{ - switch(category_) - { +const std::string Value::getDescription() const { + switch (category_) { case stringCategory: return std::string("string") + " (" + *string_ + ")"; break; case scalarCategory: - return std::string("Scalar of type ") + BasicType::getName(scalar_->getType()); + return std::string("Scalar of type ") + + BasicType::getName(scalar_->getType()); break; case arrayCategory: - return std::string("Array of type ") + BasicType::getName(array_->getType()); + return std::string("Array of type ") + + BasicType::getName(array_->getType()); break; } return "NOT REACHED"; } -void ValueMap::add(const std::string& key, const Value& value) -{ - if (map_.find(key) != map_.end()) - { +void ValueMap::add(const std::string &key, const Value &value) { + if (map_.find(key) != map_.end()) { NTA_THROW << "Key '" << key << "' specified twice"; } auto vp = new Value(value); @@ -141,166 +116,122 @@ void ValueMap::add(const std::string& key, const Value& value) map_.insert(std::make_pair(key, vp)); } +Value::Category Value::getCategory() const { return category_; } +ValueMap::const_iterator ValueMap::begin() const { return map_.begin(); } -Value::Category Value::getCategory() const -{ - return category_; -} - - -ValueMap::const_iterator ValueMap::begin() const -{ - return map_.begin(); -} - -ValueMap::const_iterator ValueMap::end() const -{ - return map_.end(); -} - - - - +ValueMap::const_iterator ValueMap::end() const { return map_.end(); } // specializations of getValue() // gcc 4.2 complains if they are not inside the namespace declaration -namespace nupic -{ - template Byte Value::getScalarT() const; - template Int16 Value::getScalarT() const; - template Int32 Value::getScalarT() const; - template Int64 Value::getScalarT() const; - template UInt16 Value::getScalarT() const; - template UInt32 Value::getScalarT() const; - template UInt64 Value::getScalarT() const; - template Real32 Value::getScalarT() const; - template Real64 Value::getScalarT() const; - template Handle Value::getScalarT() const; - template bool Value::getScalarT() const; -} - -ValueMap::ValueMap() -{ -}; - -ValueMap::~ValueMap() -{ - for (auto & elem : map_) - { +namespace nupic { +template Byte Value::getScalarT() const; +template Int16 Value::getScalarT() const; +template Int32 Value::getScalarT() const; +template Int64 Value::getScalarT() const; +template UInt16 Value::getScalarT() const; +template UInt32 Value::getScalarT() const; +template UInt64 Value::getScalarT() const; +template Real32 Value::getScalarT() const; +template Real64 Value::getScalarT() const; +template Handle Value::getScalarT() const; +template bool Value::getScalarT() const; +} // namespace nupic + +ValueMap::ValueMap(){}; + +ValueMap::~ValueMap() { + for (auto &elem : map_) { delete elem.second; elem.second = nullptr; } map_.clear(); } -ValueMap::ValueMap(const ValueMap& rhs) -{ - for (auto & elem : map_) - { +ValueMap::ValueMap(const ValueMap &rhs) { + for (auto &elem : map_) { delete elem.second; elem.second = nullptr; } map_.clear(); - for(const auto & rh : rhs) - { + for (const auto &rh : rhs) { auto vp = new Value(*(rh.second)); map_.insert(std::make_pair(rh.first, vp)); } } -void ValueMap::dump() const -{ +void ValueMap::dump() const { NTA_DEBUG << "===== Value Map:"; - for (const auto & elem : map_) - { + for (const auto &elem : map_) { std::string key = elem.first; - Value* value = elem.second; + Value *value = elem.second; NTA_DEBUG << "key: " << key - << " datatype: " << BasicType::getName(value->getType()) + << " datatype: " << BasicType::getName(value->getType()) << " category: " << value->getCategory(); } NTA_DEBUG << "===== End of Value Map"; } - -bool ValueMap::contains(const std::string& key) const -{ +bool ValueMap::contains(const std::string &key) const { return (map_.find(key) != map_.end()); } - -Value& ValueMap::getValue(const std::string& key) const -{ +Value &ValueMap::getValue(const std::string &key) const { auto item = map_.find(key); - if (item == map_.end()) - { + if (item == map_.end()) { NTA_THROW << "No value '" << key << "' found in Value Map"; } return *(item->second); } - -template T ValueMap::getScalarT(const std::string& key, T defaultValue) const -{ +template +T ValueMap::getScalarT(const std::string &key, T defaultValue) const { auto item = map_.find(key); - if (item == map_.end()) - { + if (item == map_.end()) { return defaultValue; - } - else - { + } else { return getScalarT(key); } } - - -template T ValueMap::getScalarT(const std::string& key) const -{ +template T ValueMap::getScalarT(const std::string &key) const { boost::shared_ptr s = getScalar(key); - if (s->getType() != BasicType::getType()) - { - NTA_THROW << "Invalid attempt to access parameter '" << key - << "' of type " << BasicType::getName(s->getType()) - << " as a scalar of type " << BasicType::getName(); + if (s->getType() != BasicType::getType()) { + NTA_THROW << "Invalid attempt to access parameter '" << key << "' of type " + << BasicType::getName(s->getType()) << " as a scalar of type " + << BasicType::getName(); } return s->getValue(); } -boost::shared_ptr ValueMap::getArray(const std::string& key) const -{ - Value& v = getValue(key); - if (! v.isArray()) - { - NTA_THROW << "Attempt to access element '" << key +boost::shared_ptr ValueMap::getArray(const std::string &key) const { + Value &v = getValue(key); + if (!v.isArray()) { + NTA_THROW << "Attempt to access element '" << key << "' of value map as an array but it is a '" << v.getDescription(); } return v.getArray(); } -boost::shared_ptr ValueMap::getScalar(const std::string& key) const -{ - Value& v = getValue(key); - if (! v.isScalar()) - { - NTA_THROW << "Attempt to access element '" << key +boost::shared_ptr ValueMap::getScalar(const std::string &key) const { + Value &v = getValue(key); + if (!v.isScalar()) { + NTA_THROW << "Attempt to access element '" << key << "' of value map as an array but it is a '" << v.getDescription(); } return v.getScalar(); } -boost::shared_ptr ValueMap::getString(const std::string& key) const -{ - Value& v = getValue(key); - if (! v.isString()) - { - NTA_THROW << "Attempt to access element '" << key +boost::shared_ptr +ValueMap::getString(const std::string &key) const { + Value &v = getValue(key); + if (!v.isString()) { + NTA_THROW << "Attempt to access element '" << key << "' of value map as a string but it is a '" << v.getDescription(); } @@ -308,29 +239,39 @@ boost::shared_ptr ValueMap::getString(const std::string& key) const } // explicit instantiations of getScalarT -namespace nupic -{ - template Byte ValueMap::getScalarT(const std::string& key, Byte defaultValue) const; - template UInt16 ValueMap::getScalarT(const std::string& key, UInt16 defaultValue) const; - template Int16 ValueMap::getScalarT(const std::string& key, Int16 defaultValue) const; - template UInt32 ValueMap::getScalarT(const std::string& key, UInt32 defaultValue) const; - template Int32 ValueMap::getScalarT(const std::string& key, Int32 defaultValue) const; - template UInt64 ValueMap::getScalarT(const std::string& key, UInt64 defaultValue) const; - template Int64 ValueMap::getScalarT(const std::string& key, Int64 defaultValue) const; - template Real32 ValueMap::getScalarT(const std::string& key, Real32 defaultValue) const; - template Real64 ValueMap::getScalarT(const std::string& key, Real64 defaultValue) const; - template Handle ValueMap::getScalarT(const std::string& key, Handle defaultValue) const; - template bool ValueMap::getScalarT(const std::string& key, bool defaultValue) const; - - template Byte ValueMap::getScalarT(const std::string& key) const; - template UInt16 ValueMap::getScalarT(const std::string& key) const; - template Int16 ValueMap::getScalarT(const std::string& key) const; - template UInt32 ValueMap::getScalarT(const std::string& key) const; - template Int32 ValueMap::getScalarT(const std::string& key) const; - template UInt64 ValueMap::getScalarT(const std::string& key) const; - template Int64 ValueMap::getScalarT(const std::string& key) const; - template Real32 ValueMap::getScalarT(const std::string& key) const; - template Real64 ValueMap::getScalarT(const std::string& key) const; - template Handle ValueMap::getScalarT(const std::string& key) const; - template bool ValueMap::getScalarT(const std::string& key) const; -} +namespace nupic { +template Byte ValueMap::getScalarT(const std::string &key, + Byte defaultValue) const; +template UInt16 ValueMap::getScalarT(const std::string &key, + UInt16 defaultValue) const; +template Int16 ValueMap::getScalarT(const std::string &key, + Int16 defaultValue) const; +template UInt32 ValueMap::getScalarT(const std::string &key, + UInt32 defaultValue) const; +template Int32 ValueMap::getScalarT(const std::string &key, + Int32 defaultValue) const; +template UInt64 ValueMap::getScalarT(const std::string &key, + UInt64 defaultValue) const; +template Int64 ValueMap::getScalarT(const std::string &key, + Int64 defaultValue) const; +template Real32 ValueMap::getScalarT(const std::string &key, + Real32 defaultValue) const; +template Real64 ValueMap::getScalarT(const std::string &key, + Real64 defaultValue) const; +template Handle ValueMap::getScalarT(const std::string &key, + Handle defaultValue) const; +template bool ValueMap::getScalarT(const std::string &key, + bool defaultValue) const; + +template Byte ValueMap::getScalarT(const std::string &key) const; +template UInt16 ValueMap::getScalarT(const std::string &key) const; +template Int16 ValueMap::getScalarT(const std::string &key) const; +template UInt32 ValueMap::getScalarT(const std::string &key) const; +template Int32 ValueMap::getScalarT(const std::string &key) const; +template UInt64 ValueMap::getScalarT(const std::string &key) const; +template Int64 ValueMap::getScalarT(const std::string &key) const; +template Real32 ValueMap::getScalarT(const std::string &key) const; +template Real64 ValueMap::getScalarT(const std::string &key) const; +template Handle ValueMap::getScalarT(const std::string &key) const; +template bool ValueMap::getScalarT(const std::string &key) const; +} // namespace nupic diff --git a/src/nupic/ntypes/Value.hpp b/src/nupic/ntypes/Value.hpp index 6576b536f3..5726bdd79e 100644 --- a/src/nupic/ntypes/Value.hpp +++ b/src/nupic/ntypes/Value.hpp @@ -20,127 +20,120 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Definitions for the Value class - * + * * A Value object holds a Scalar or an Array * A ValueMap is essentially a map - * It is used internally in the conversion of YAML strings to C++ objects. - * The API and implementation are geared towards clarify rather than performance, - * since it is expected to be used only during network construction. + * It is used internally in the conversion of YAML strings to C++ objects. + * The API and implementation are geared towards clarify rather than + * performance, since it is expected to be used only during network + * construction. */ #ifndef NTA_VALUE_HPP #define NTA_VALUE_HPP -#include -#include -#include +#include +#include #include +#include +#include +#include #include -#include -#include - -namespace nupic -{ - - /** - * The Value class is used to store construction parameters - * for regions and links. A YAML string specified by the user - * is parsed and converted into a set of Values. - * - * A Value is essentially a union of Scalar/Array/string. - * In turn, a Scalar is a union of NTA_BasicType types, - * and an Array is an array of such types. - * - * A string is similar to an Array of NTA_BasicType_Byte, but - * is handled differently, so it is separated in the API. - * - * The Value API uses boost::shared_ptr instead of directly - * using the underlying objects, to avoid copying, and because - * Array may not be copied. - */ - class Value - { - public: - Value(boost::shared_ptr& s); - Value(boost::shared_ptr& a); - Value(boost::shared_ptr& s); - enum Category {scalarCategory, arrayCategory, stringCategory}; +namespace nupic { - bool isArray() const; - bool isString() const; - bool isScalar() const; - Category getCategory() const; +/** + * The Value class is used to store construction parameters + * for regions and links. A YAML string specified by the user + * is parsed and converted into a set of Values. + * + * A Value is essentially a union of Scalar/Array/string. + * In turn, a Scalar is a union of NTA_BasicType types, + * and an Array is an array of such types. + * + * A string is similar to an Array of NTA_BasicType_Byte, but + * is handled differently, so it is separated in the API. + * + * The Value API uses boost::shared_ptr instead of directly + * using the underlying objects, to avoid copying, and because + * Array may not be copied. + */ +class Value { +public: + Value(boost::shared_ptr &s); + Value(boost::shared_ptr &a); + Value(boost::shared_ptr &s); - NTA_BasicType getType() const; + enum Category { scalarCategory, arrayCategory, stringCategory }; - boost::shared_ptr getScalar() const; + bool isArray() const; + bool isString() const; + bool isScalar() const; + Category getCategory() const; - boost::shared_ptr getArray() const; + NTA_BasicType getType() const; - boost::shared_ptr getString() const; + boost::shared_ptr getScalar() const; - template T getScalarT() const; - - const std::string getDescription() const; + boost::shared_ptr getArray() const; - private: - // Default constructor would not be useful - Value(); - Category category_; - boost::shared_ptr scalar_; - boost::shared_ptr array_; - boost::shared_ptr string_; - }; + boost::shared_ptr getString() const; + template T getScalarT() const; - class ValueMap - { - public: - ValueMap(); - ValueMap(const ValueMap& rhs); - ~ValueMap(); - void add(const std::string& key, const Value& value); + const std::string getDescription() const; - // map.find(key) != map.end() - bool contains(const std::string& key) const; +private: + // Default constructor would not be useful + Value(); + Category category_; + boost::shared_ptr scalar_; + boost::shared_ptr array_; + boost::shared_ptr string_; +}; - // map.find(key) + exception if not found - Value& getValue(const std::string& key) const; +class ValueMap { +public: + ValueMap(); + ValueMap(const ValueMap &rhs); + ~ValueMap(); + void add(const std::string &key, const Value &value); - // Method below are for convenience, bypassing the Value - boost::shared_ptr getArray(const std::string& key) const; - boost::shared_ptr getScalar(const std::string& key) const; - boost::shared_ptr getString(const std::string& key) const; + // map.find(key) != map.end() + bool contains(const std::string &key) const; - // More convenience methods, bypassing the Value and the contained Scalar + // map.find(key) + exception if not found + Value &getValue(const std::string &key) const; - // use default value if not specified in map - template T getScalarT(const std::string& key, T defaultValue) const; + // Method below are for convenience, bypassing the Value + boost::shared_ptr getArray(const std::string &key) const; + boost::shared_ptr getScalar(const std::string &key) const; + boost::shared_ptr getString(const std::string &key) const; - // raise exception if value is not specified in map - template T getScalarT(const std::string& key) const; + // More convenience methods, bypassing the Value and the contained Scalar - void dump() const; + // use default value if not specified in map + template + T getScalarT(const std::string &key, T defaultValue) const; - typedef std::map::const_iterator const_iterator; - const_iterator begin() const; - const_iterator end() const; + // raise exception if value is not specified in map + template T getScalarT(const std::string &key) const; - private: - // must be a Value* since Value doesn't have a default constructor - // We own all the items in the map and must delete them in our destructor - typedef std::map::iterator iterator; - std::map map_; - }; + void dump() const; + typedef std::map::const_iterator const_iterator; + const_iterator begin() const; + const_iterator end() const; -} +private: + // must be a Value* since Value doesn't have a default constructor + // We own all the items in the map and must delete them in our destructor + typedef std::map::iterator iterator; + std::map map_; +}; +} // namespace nupic #endif // NTA_VALUE_HPP - - - diff --git a/src/nupic/os/Directory.cpp b/src/nupic/os/Directory.cpp index 3509610bad..65fb5d57a5 100644 --- a/src/nupic/os/Directory.cpp +++ b/src/nupic/os/Directory.cpp @@ -20,311 +20,276 @@ * --------------------------------------------------------------------- */ -/** @file -*/ +/** @file + */ -#include #include +#include +#include #include -#include #include +#include #include -#include -#include +#include #if defined(NTA_OS_WINDOWS) - #include - #include +#include +#include #else - #include - #include +#include +#include #endif -namespace nupic -{ - namespace Directory - { - bool exists(const std::string & path) - { - return Path::exists(path); - } - - std::string getCWD() - { - #if defined(NTA_OS_WINDOWS) - wchar_t wcwd[APR_PATH_MAX]; - DWORD res = ::GetCurrentDirectoryW(APR_PATH_MAX, wcwd); - NTA_CHECK(res > 0) << "Couldn't get current working directory. OS msg: " - << OS::getErrorMessage(); - std::string cwd = Path::unicodeToUtf8(std::wstring(wcwd)); - return cwd; - #else - char cwd[APR_PATH_MAX]; - cwd[0] = '\0'; - char * res = ::getcwd(cwd, APR_PATH_MAX); - NTA_CHECK(res != nullptr) << "Couldn't get current working directory. OS num: " << errno; - return std::string(cwd); - #endif - } +namespace nupic { +namespace Directory { +bool exists(const std::string &path) { return Path::exists(path); } - bool empty(const std::string & path) - { - Entry dummy; - return Iterator(path).next(dummy) == nullptr; - } - - void setCWD(const std::string & path) - { - int res = 0; - #if defined(NTA_OS_WINDOWS) - std::wstring wpath(Path::utf8ToUnicode(path)); - res = ::SetCurrentDirectoryW(wpath.c_str()) ? 0 : -1; - #else - res = ::chdir(path.c_str()); - #endif - - NTA_CHECK(res == 0) << "setCWD: " << OS::getErrorMessage(); - } +std::string getCWD() { +#if defined(NTA_OS_WINDOWS) + wchar_t wcwd[APR_PATH_MAX]; + DWORD res = ::GetCurrentDirectoryW(APR_PATH_MAX, wcwd); + NTA_CHECK(res > 0) << "Couldn't get current working directory. OS msg: " + << OS::getErrorMessage(); + std::string cwd = Path::unicodeToUtf8(std::wstring(wcwd)); + return cwd; +#else + char cwd[APR_PATH_MAX]; + cwd[0] = '\0'; + char *res = ::getcwd(cwd, APR_PATH_MAX); + NTA_CHECK(res != nullptr) + << "Couldn't get current working directory. OS num: " << errno; + return std::string(cwd); +#endif +} - static bool removeEmptyDir(const std::string & path, bool noThrow) - { - int res = 0; - #if defined(NTA_OS_WINDOWS) - std::wstring wpath(Path::utf8ToUnicode(path)); - res = ::RemoveDirectoryW(wpath.c_str()) != FALSE ? 0 : -1; - #else - res = ::rmdir(path.c_str()); - #endif - if(!noThrow) { - NTA_CHECK(res == 0) << "removeEmptyDir: " << OS::getErrorMessage(); - } - return (res == 0); - } +bool empty(const std::string &path) { + Entry dummy; + return Iterator(path).next(dummy) == nullptr; +} - void copyTree(const std::string & source, const std::string & destination) - { - NTA_CHECK(Path::isDirectory(source)); - std::string baseSource(Path::getBasename(source)); - std::string dest(destination); - dest = Path::join(dest, baseSource); - if (!Path::exists(dest)) - Directory::create(dest, false, true); - NTA_CHECK(Path::isDirectory(dest)); - - Iterator i(source); - Entry e; - while (i.next(e)) - { - std::string fullSource(source); - fullSource = Path::join(fullSource, e.path); - Path::copy(fullSource, dest); - } - } +void setCWD(const std::string &path) { + int res = 0; +#if defined(NTA_OS_WINDOWS) + std::wstring wpath(Path::utf8ToUnicode(path)); + res = ::SetCurrentDirectoryW(wpath.c_str()) ? 0 : -1; +#else + res = ::chdir(path.c_str()); +#endif - - bool removeTree(const std::string & path, bool noThrow) - { - bool success = true; - NTA_CHECK(!path.empty()) << "Can't remove directory with no name"; - { - // The scope is necessary to make sure the destructor - // of the Iterator releases the directory so that - // removeEmptyDir() will succeed. - Iterator i(path); - Entry e; - while (i.next(e)) - { - Path fullPath = Path(path) + Path(e.path); - if (e.type == Entry::DIRECTORY) { - bool subResult = removeTree(std::string(fullPath), noThrow); - success = success && subResult; - } - else - { - apr_status_t st = ::apr_file_remove(fullPath, nullptr); - if(st != APR_SUCCESS) { - if(noThrow) success = false; - else { - NTA_THROW - << "Directory::removeTree() failed. " - << "Unable to remove the file'" << fullPath << "'. " - << "OS msg: " << OS::getErrorMessage(); - } - } + NTA_CHECK(res == 0) << "setCWD: " << OS::getErrorMessage(); +} + +static bool removeEmptyDir(const std::string &path, bool noThrow) { + int res = 0; +#if defined(NTA_OS_WINDOWS) + std::wstring wpath(Path::utf8ToUnicode(path)); + res = ::RemoveDirectoryW(wpath.c_str()) != FALSE ? 0 : -1; +#else + res = ::rmdir(path.c_str()); +#endif + if (!noThrow) { + NTA_CHECK(res == 0) << "removeEmptyDir: " << OS::getErrorMessage(); + } + return (res == 0); +} + +void copyTree(const std::string &source, const std::string &destination) { + NTA_CHECK(Path::isDirectory(source)); + std::string baseSource(Path::getBasename(source)); + std::string dest(destination); + dest = Path::join(dest, baseSource); + if (!Path::exists(dest)) + Directory::create(dest, false, true); + NTA_CHECK(Path::isDirectory(dest)); + + Iterator i(source); + Entry e; + while (i.next(e)) { + std::string fullSource(source); + fullSource = Path::join(fullSource, e.path); + Path::copy(fullSource, dest); + } +} + +bool removeTree(const std::string &path, bool noThrow) { + bool success = true; + NTA_CHECK(!path.empty()) << "Can't remove directory with no name"; + { + // The scope is necessary to make sure the destructor + // of the Iterator releases the directory so that + // removeEmptyDir() will succeed. + Iterator i(path); + Entry e; + while (i.next(e)) { + Path fullPath = Path(path) + Path(e.path); + if (e.type == Entry::DIRECTORY) { + bool subResult = removeTree(std::string(fullPath), noThrow); + success = success && subResult; + } else { + apr_status_t st = ::apr_file_remove(fullPath, nullptr); + if (st != APR_SUCCESS) { + if (noThrow) + success = false; + else { + NTA_THROW << "Directory::removeTree() failed. " + << "Unable to remove the file'" << fullPath << "'. " + << "OS msg: " << OS::getErrorMessage(); } } } - - bool subResult = removeEmptyDir(path, noThrow); - success = success && subResult; - // Check 3 times the directory is really gone - // (needed for unreliable file systems) - for (int i = 0; i < 3; ++i) - { - if (!Directory::exists(path)) - return success; - // sleep for a second - if (i < 2) - ::apr_sleep(1000 * 1000); - } - if(!noThrow) { - NTA_THROW << "Directory::removeTree() failed. " - << "Unable to remove empty dir: " - << "\"" << path << "\""; - } - return false; } - - // Create directory recursively (creates parent if doesn't exist) - // Helper function for create(.., recursive=true) - static std::string createRecursive(const std::string & path, bool otherAccess) - { - /// TODO: When the directory exists, confirm or update its permissions. - - NTA_CHECK(!path.empty()) << "Can't create directory with no name"; - std::string p = Path::makeAbsolute(path); - - if (Path::exists(p)) - { - if (! Path::isDirectory(p)) - { - NTA_THROW << "Directory::create -- path " << path << " already exists but is not a directory"; - } - // Empty string return terminates the recursive call because "" has no parent - return ""; - } + } - std::string result(p); - std::string parent = Path::getParent(p); - if (!Directory::exists(parent)) - { - result = createRecursive(parent, otherAccess); - } - - create(p, otherAccess, false); - return result; + bool subResult = removeEmptyDir(path, noThrow); + success = success && subResult; + // Check 3 times the directory is really gone + // (needed for unreliable file systems) + for (int i = 0; i < 3; ++i) { + if (!Directory::exists(path)) + return success; + // sleep for a second + if (i < 2) + ::apr_sleep(1000 * 1000); + } + if (!noThrow) { + NTA_THROW << "Directory::removeTree() failed. " + << "Unable to remove empty dir: " + << "\"" << path << "\""; + } + return false; +} + +// Create directory recursively (creates parent if doesn't exist) +// Helper function for create(.., recursive=true) +static std::string createRecursive(const std::string &path, bool otherAccess) { + /// TODO: When the directory exists, confirm or update its permissions. + + NTA_CHECK(!path.empty()) << "Can't create directory with no name"; + std::string p = Path::makeAbsolute(path); + + if (Path::exists(p)) { + if (!Path::isDirectory(p)) { + NTA_THROW << "Directory::create -- path " << path + << " already exists but is not a directory"; } + // Empty string return terminates the recursive call because "" has no + // parent + return ""; + } - void create(const std::string& path, bool otherAccess, bool recursive) - { - /// TODO: When the directory exists, confirm or update its permissions. + std::string result(p); + std::string parent = Path::getParent(p); + if (!Directory::exists(parent)) { + result = createRecursive(parent, otherAccess); + } - if (recursive) - { - createRecursive(path, otherAccess); - return; - } + create(p, otherAccess, false); + return result; +} - // non-recursive case - bool success = true; - #if defined(NTA_OS_WINDOWS) - std::wstring wPath = Path::utf8ToUnicode(path); - success = ::CreateDirectoryW(wPath.c_str(), NULL) != FALSE; - if (!success) - { - if (GetLastError() == ERROR_ALREADY_EXISTS) { - // Not a hard error, due to potential race conditions. - std::cerr << "Path '" << path << "' exists. " - "Possible race condition." - << std::endl; - success = Path::isDirectory(path); - } - } +void create(const std::string &path, bool otherAccess, bool recursive) { + /// TODO: When the directory exists, confirm or update its permissions. - #else - int permissions = S_IRWXU; - if(otherAccess) { - permissions |= (S_IRWXG | S_IROTH | S_IXOTH); - } - int res = ::mkdir(path.c_str(), permissions); - if(res != 0) { - if(errno == EEXIST) { - // Not a hard error, due to potential race conditions. - std::cerr << "Path '" << path << "' exists. " - "Possible race condition." - << std::endl; - success = Path::isDirectory(path); - } - else { - success = false; - } - } - else success = true; - #endif + if (recursive) { + createRecursive(path, otherAccess); + return; + } - if (!success) - { - NTA_THROW << "Directory::create -- failed to create directory \"" << path << "\".\n" - << "OS msg: " << OS::getErrorMessage(); - } + // non-recursive case + bool success = true; +#if defined(NTA_OS_WINDOWS) + std::wstring wPath = Path::utf8ToUnicode(path); + success = ::CreateDirectoryW(wPath.c_str(), NULL) != FALSE; + if (!success) { + if (GetLastError() == ERROR_ALREADY_EXISTS) { + // Not a hard error, due to potential race conditions. + std::cerr << "Path '" << path + << "' exists. " + "Possible race condition." + << std::endl; + success = Path::isDirectory(path); } - + } - Iterator::Iterator(const Path & path) - { - init(std::string(path)); - } - - Iterator::Iterator(const std::string & path) - { - init(path); +#else + int permissions = S_IRWXU; + if (otherAccess) { + permissions |= (S_IRWXG | S_IROTH | S_IXOTH); + } + int res = ::mkdir(path.c_str(), permissions); + if (res != 0) { + if (errno == EEXIST) { + // Not a hard error, due to potential race conditions. + std::cerr << "Path '" << path + << "' exists. " + "Possible race condition." + << std::endl; + success = Path::isDirectory(path); + } else { + success = false; } + } else + success = true; +#endif - void Iterator::init(const std::string & path) - { - apr_status_t res = ::apr_pool_create(&pool_, nullptr); - NTA_CHECK(res == 0) << "Can't create pool"; - std::string absolutePath = Path::makeAbsolute(path); - res = ::apr_dir_open(&handle_, absolutePath.c_str(), pool_); - NTA_CHECK(res == 0) << "Can't open directory " << path - << ". OS num: " << APR_TO_OS_ERROR(res); - } - - Iterator::~Iterator() - { - apr_status_t res = ::apr_dir_close(handle_); - ::apr_pool_destroy(pool_); - NTA_CHECK(res == 0) << "Couldn't close directory." - << " OS num: " << APR_TO_OS_ERROR(res); - } - - void Iterator::reset() - { - apr_status_t res = ::apr_dir_rewind(handle_); - NTA_CHECK(res == 0) - << "Couldn't reset directory iterator." - << " OS num: " << APR_TO_OS_ERROR(res); - } - - Entry * Iterator::next(Entry & e) - { - apr_int32_t wanted = APR_FINFO_LINK | APR_FINFO_NAME | APR_FINFO_TYPE; - apr_status_t res = ::apr_dir_read(&e, wanted, handle_); - - // No more entries - if (APR_STATUS_IS_ENOENT(res)) - return nullptr; - - if (res != 0) - { - NTA_CHECK(res == APR_INCOMPLETE) - << "Couldn't read next dir entry." - << " OS num: " << APR_TO_OS_ERROR(res); - NTA_CHECK(((e.valid & wanted) | APR_FINFO_LINK) == wanted) - << "Couldn't retrieve all fields. Valid mask=" << e.valid; - } - - - e.type = (e.filetype == APR_DIR) ? Directory::Entry::DIRECTORY - : Directory::Entry::FILE; - e.path = e.name; - - // Skip '.' and '..' directories - if (e.type == Directory::Entry::DIRECTORY && - (e.name == std::string(".") || e.name == std::string(".."))) - return next(e); - else - return &e; - } + if (!success) { + NTA_THROW << "Directory::create -- failed to create directory \"" << path + << "\".\n" + << "OS msg: " << OS::getErrorMessage(); + } +} + +Iterator::Iterator(const Path &path) { init(std::string(path)); } + +Iterator::Iterator(const std::string &path) { init(path); } + +void Iterator::init(const std::string &path) { + apr_status_t res = ::apr_pool_create(&pool_, nullptr); + NTA_CHECK(res == 0) << "Can't create pool"; + std::string absolutePath = Path::makeAbsolute(path); + res = ::apr_dir_open(&handle_, absolutePath.c_str(), pool_); + NTA_CHECK(res == 0) << "Can't open directory " << path + << ". OS num: " << APR_TO_OS_ERROR(res); +} + +Iterator::~Iterator() { + apr_status_t res = ::apr_dir_close(handle_); + ::apr_pool_destroy(pool_); + NTA_CHECK(res == 0) << "Couldn't close directory." + << " OS num: " << APR_TO_OS_ERROR(res); +} + +void Iterator::reset() { + apr_status_t res = ::apr_dir_rewind(handle_); + NTA_CHECK(res == 0) << "Couldn't reset directory iterator." + << " OS num: " << APR_TO_OS_ERROR(res); +} + +Entry *Iterator::next(Entry &e) { + apr_int32_t wanted = APR_FINFO_LINK | APR_FINFO_NAME | APR_FINFO_TYPE; + apr_status_t res = ::apr_dir_read(&e, wanted, handle_); + + // No more entries + if (APR_STATUS_IS_ENOENT(res)) + return nullptr; + + if (res != 0) { + NTA_CHECK(res == APR_INCOMPLETE) << "Couldn't read next dir entry." + << " OS num: " << APR_TO_OS_ERROR(res); + NTA_CHECK(((e.valid & wanted) | APR_FINFO_LINK) == wanted) + << "Couldn't retrieve all fields. Valid mask=" << e.valid; } + + e.type = (e.filetype == APR_DIR) ? Directory::Entry::DIRECTORY + : Directory::Entry::FILE; + e.path = e.name; + + // Skip '.' and '..' directories + if (e.type == Directory::Entry::DIRECTORY && + (e.name == std::string(".") || e.name == std::string(".."))) + return next(e); + else + return &e; } +} // namespace Directory +} // namespace nupic diff --git a/src/nupic/os/Directory.hpp b/src/nupic/os/Directory.hpp old mode 100755 new mode 100644 index 177329bfa6..58d8709c0c --- a/src/nupic/os/Directory.hpp +++ b/src/nupic/os/Directory.hpp @@ -22,7 +22,6 @@ /** @file */ - #ifndef NTA_DIRECTORY_HPP #define NTA_DIRECTORY_HPP @@ -34,77 +33,72 @@ //---------------------------------------------------------------------- -namespace nupic -{ - class Path; - - namespace Directory - { - // check if a directory exists - bool exists(const std::string & path); - - bool empty(const std::string & path); - - // get current working directory - std::string getCWD(); - - // set current working directories - void setCWD(const std::string & path); - - // Copy directory tree rooted in 'source' to 'destination' - void copyTree(const std::string & source, const std::string & destination); - - // Remove directory tree rooted in 'path' - bool removeTree(const std::string & path, bool noThrow=false); - - // Create directory 'path' including all parent directories if missing - // returns the first directory that was actually created. - // - // For example if path is /A/B/C/D - // if /A/B/C/D exists it returns "" - // if /A/B exists it returns /A/B/C - // if /A doesn't exist it returns /A/B/C/D - // - // Failures will throw an exception - void create(const std::string & path, bool otherAccess=false, bool recursive=false); - - std::string createTemporary(const std::string &templatePath); - - struct Entry : public apr_finfo_t - { - enum Type { FILE, DIRECTORY, LINK }; - - Type type; - std::string path; - }; - - class Iterator - { - public: - - Iterator(const Path & path); - Iterator(const std::string & path); - ~Iterator(); - - // Resets directory to start. Subsequent call to next() - // will retrieve the first entry - void reset(); - // get next directory entry - Entry * next(Entry & e); - - private: - Iterator(); - Iterator(const Iterator &); - - void init(const std::string & path); - private: - std::string path_; - apr_dir_t * handle_; - apr_pool_t * pool_; - }; - } -} +namespace nupic { +class Path; -#endif // NTA_DIRECTORY_HPP +namespace Directory { +// check if a directory exists +bool exists(const std::string &path); + +bool empty(const std::string &path); + +// get current working directory +std::string getCWD(); + +// set current working directories +void setCWD(const std::string &path); + +// Copy directory tree rooted in 'source' to 'destination' +void copyTree(const std::string &source, const std::string &destination); + +// Remove directory tree rooted in 'path' +bool removeTree(const std::string &path, bool noThrow = false); + +// Create directory 'path' including all parent directories if missing +// returns the first directory that was actually created. +// +// For example if path is /A/B/C/D +// if /A/B/C/D exists it returns "" +// if /A/B exists it returns /A/B/C +// if /A doesn't exist it returns /A/B/C/D +// +// Failures will throw an exception +void create(const std::string &path, bool otherAccess = false, + bool recursive = false); +std::string createTemporary(const std::string &templatePath); +struct Entry : public apr_finfo_t { + enum Type { FILE, DIRECTORY, LINK }; + + Type type; + std::string path; +}; + +class Iterator { +public: + Iterator(const Path &path); + Iterator(const std::string &path); + ~Iterator(); + + // Resets directory to start. Subsequent call to next() + // will retrieve the first entry + void reset(); + // get next directory entry + Entry *next(Entry &e); + +private: + Iterator(); + Iterator(const Iterator &); + + void init(const std::string &path); + +private: + std::string path_; + apr_dir_t *handle_; + apr_pool_t *pool_; +}; +} // namespace Directory +} // namespace nupic + +#endif // NTA_DIRECTORY_HPP diff --git a/src/nupic/os/DynamicLibrary.cpp b/src/nupic/os/DynamicLibrary.cpp index f5ea40e5ee..bc9bf29ddc 100644 --- a/src/nupic/os/DynamicLibrary.cpp +++ b/src/nupic/os/DynamicLibrary.cpp @@ -24,97 +24,87 @@ //---------------------------------------------------------------------- +#include #include #include #include -#include //---------------------------------------------------------------------- -namespace nupic -{ +namespace nupic { - DynamicLibrary::DynamicLibrary(void * handle) : handle_(handle) - { - } +DynamicLibrary::DynamicLibrary(void *handle) : handle_(handle) {} - DynamicLibrary::~DynamicLibrary() - { - #if defined(NTA_OS_WINDOWS) - ::FreeLibrary((HMODULE)handle_); - #else - ::dlclose(handle_); - #endif - } +DynamicLibrary::~DynamicLibrary() { +#if defined(NTA_OS_WINDOWS) + ::FreeLibrary((HMODULE)handle_); +#else + ::dlclose(handle_); +#endif +} - DynamicLibrary * DynamicLibrary::load(const std::string & name, std::string &errorString) - { - #if defined(NTA_OS_WINDOWS) - return load(name, 0, errorString); - #else - // LOCAL/NOW make more sense. In NuPIC 2 we currently need GLOBAL/LAZY - // See comments in RegionImplFactory.cpp - // return load(name, LOCAL | NOW, errorString); - return load(name, GLOBAL | LAZY, errorString); - #endif - } +DynamicLibrary *DynamicLibrary::load(const std::string &name, + std::string &errorString) { +#if defined(NTA_OS_WINDOWS) + return load(name, 0, errorString); +#else + // LOCAL/NOW make more sense. In NuPIC 2 we currently need GLOBAL/LAZY + // See comments in RegionImplFactory.cpp + // return load(name, LOCAL | NOW, errorString); + return load(name, GLOBAL | LAZY, errorString); +#endif +} - DynamicLibrary * DynamicLibrary::load(const std::string & name, UInt32 mode, std::string & errorString) - { - if (name.empty()) - { - errorString = "Empty path."; - return nullptr; - } +DynamicLibrary *DynamicLibrary::load(const std::string &name, UInt32 mode, + std::string &errorString) { + if (name.empty()) { + errorString = "Empty path."; + return nullptr; + } - //if (!Path::exists(name)) - //{ - // errorString = "Dynamic library doesn't exist."; - // return NULL; - //} - - void * handle = nullptr; - - #if defined(NTA_OS_WINDOWS) - #if !defined(NTA_COMPILER_GNU) - mode; // ignore on Windows - #endif - handle = ::LoadLibraryA(name.c_str()); - if (handle == NULL) - { - DWORD errorCode = ::GetLastError(); - std::stringstream ss; - ss << std::string("LoadLibrary(") << name - << std::string(") Failed. errorCode: ") - << errorCode; - errorString = ss.str(); - return NULL; - } - #else - handle = ::dlopen(name.c_str(), mode); - if (!handle) - { - std::string dlErrorString; - const char *zErrorString = ::dlerror(); - if (zErrorString) - dlErrorString = zErrorString; - errorString += "Failed to load \"" + name + '"'; - if(dlErrorString.size()) - errorString += ": " + dlErrorString; - return nullptr; - } + // if (!Path::exists(name)) + //{ + // errorString = "Dynamic library doesn't exist."; + // return NULL; + //} - #endif - return new DynamicLibrary(handle); + void *handle = nullptr; +#if defined(NTA_OS_WINDOWS) +#if !defined(NTA_COMPILER_GNU) + mode; // ignore on Windows +#endif + handle = ::LoadLibraryA(name.c_str()); + if (handle == NULL) { + DWORD errorCode = ::GetLastError(); + std::stringstream ss; + ss << std::string("LoadLibrary(") << name + << std::string(") Failed. errorCode: ") << errorCode; + errorString = ss.str(); + return NULL; } - - void * DynamicLibrary::getSymbol(const std::string & symbol) - { - #if defined(NTA_OS_WINDOWS) - return (void*)::GetProcAddress((HMODULE)handle_, symbol.c_str()); - #else - return ::dlsym(handle_, symbol.c_str()); - #endif +#else + handle = ::dlopen(name.c_str(), mode); + if (!handle) { + std::string dlErrorString; + const char *zErrorString = ::dlerror(); + if (zErrorString) + dlErrorString = zErrorString; + errorString += "Failed to load \"" + name + '"'; + if (dlErrorString.size()) + errorString += ": " + dlErrorString; + return nullptr; } + +#endif + return new DynamicLibrary(handle); +} + +void *DynamicLibrary::getSymbol(const std::string &symbol) { +#if defined(NTA_OS_WINDOWS) + return (void *)::GetProcAddress((HMODULE)handle_, symbol.c_str()); +#else + return ::dlsym(handle_, symbol.c_str()); +#endif } +} // namespace nupic diff --git a/src/nupic/os/DynamicLibrary.hpp b/src/nupic/os/DynamicLibrary.hpp index b7e251edb8..803a3e1369 100644 --- a/src/nupic/os/DynamicLibrary.hpp +++ b/src/nupic/os/DynamicLibrary.hpp @@ -22,120 +22,127 @@ /** @file */ - #ifndef NTA_DYNAMIC_LIBRARY_HPP #define NTA_DYNAMIC_LIBRARY_HPP //---------------------------------------------------------------------- #if defined(NTA_OS_WINDOWS) - #include +#include #else - #include +#include #endif -#include #include +#include //---------------------------------------------------------------------- -namespace nupic -{ - /** - * @b Responsibility: - * 1. Proivde a cross-platform dynamic library load/unload/getSymbol functionality +namespace nupic { +/** + * @b Responsibility: + * 1. Proivde a cross-platform dynamic library load/unload/getSymbol + functionality + * + * @b Rationale: + * Numenta needs to load code dynamically on multiple platforms. It makes + sense to + * encapsulate this core capability in a nice object-oriented C++ class. + + * @b Resource/Ownerships: + * 1. An opaque library handle (released automatically by destructor) + * + * @b Invariants: + * 1. handle_ is never NULL after construction. This invariant is guarantueed + by + * the class design. The handle_ variable is private. The constructor that + * sets it is private. The load() factory method is the only method that + invokes + * this constructor. The user has no chance to mess things up. The + destructor + * cleans up by unloading the library. + * + * @b Notes: + * The load() static factory method is overloaded to provide default loading + * or loading based on an integer flag. The reason I didn't use a default + * argument is that the flag is an implememntation detail. An alternative + * approach is to define an enum with various flags that will be + * platform-independent and will be interpreted in the specific + implementation. + * + * The error handling strategy is to return error NULLs and not to throw + exceptions. + * The reason is that it is a very generic low-level class that should not be + aware + * and depend on the runtime's error handling policy. It may be used in + different + * contexts like tools and utilities that may utilize a different error + handling + * strategy. It is also a common idiom to return NULL from a failed factory + method. + * + */ +class DynamicLibrary { +public: + enum Mode { +#if defined(NTA_OS_WINDOWS) + LAZY, + GLOBAL, + LOCAL, + NOW +#else + LAZY = RTLD_LAZY, + GLOBAL = RTLD_GLOBAL, + LOCAL = RTLD_LOCAL, + NOW = RTLD_NOW +#endif + }; + + /** + * Loads a dynamic library file, stores the handle in a heap-allocated + * DynamicLibrary instance and returns a pointer to it. Returns NULL + * if something goes wrong. * - * @b Rationale: - * Numenta needs to load code dynamically on multiple platforms. It makes sense to - * encapsulate this core capability in a nice object-oriented C++ class. - - * @b Resource/Ownerships: - * 1. An opaque library handle (released automatically by destructor) - * - * @b Invariants: - * 1. handle_ is never NULL after construction. This invariant is guarantueed by - * the class design. The handle_ variable is private. The constructor that - * sets it is private. The load() factory method is the only method that invokes - * this constructor. The user has no chance to mess things up. The destructor - * cleans up by unloading the library. - * - * @b Notes: - * The load() static factory method is overloaded to provide default loading - * or loading based on an integer flag. The reason I didn't use a default - * argument is that the flag is an implememntation detail. An alternative - * approach is to define an enum with various flags that will be - * platform-independent and will be interpreted in the specific implementation. + * @param path [std::string] the absolute path to the dynamic library file + * @param mode [UInt32] a bitmap of loading modes with platform-specific + * meaning + */ + static DynamicLibrary *load(const std::string &path, UInt32 mode, + std::string &errorString); + + /** + * Loads a dynamic library file, stores the handle in a heap-allocated + * DynamicLibrary instance and returns a pointer to it. Returns NULL + * if something goes wrong. * - * The error handling strategy is to return error NULLs and not to throw exceptions. - * The reason is that it is a very generic low-level class that should not be aware - * and depend on the runtime's error handling policy. It may be used in different - * contexts like tools and utilities that may utilize a different error handling - * strategy. It is also a common idiom to return NULL from a failed factory method. - * + * @param path [std::string] the absolute path to the dynamic library file + * @param errorString [std::string] error message if load failed + * @retval the DynamicLibrary pointer on success or NULL on failure */ - class DynamicLibrary - { - public: - enum Mode - { - #if defined(NTA_OS_WINDOWS) - LAZY, - GLOBAL, - LOCAL, - NOW - #else - LAZY = RTLD_LAZY, - GLOBAL = RTLD_GLOBAL, - LOCAL = RTLD_LOCAL, - NOW = RTLD_NOW - #endif - }; - - /** - * Loads a dynamic library file, stores the handle in a heap-allocated - * DynamicLibrary instance and returns a pointer to it. Returns NULL - * if something goes wrong. - * - * @param path [std::string] the absolute path to the dynamic library file - * @param mode [UInt32] a bitmap of loading modes with platform-specific meaning - */ - static DynamicLibrary * load(const std::string & path, - UInt32 mode, - std::string &errorString); - - /** - * Loads a dynamic library file, stores the handle in a heap-allocated - * DynamicLibrary instance and returns a pointer to it. Returns NULL - * if something goes wrong. - * - * @param path [std::string] the absolute path to the dynamic library file - * @param errorString [std::string] error message if load failed - * @retval the DynamicLibrary pointer on success or NULL on failure - */ - static DynamicLibrary * load(const std::string & path, - std::string &errorString); - ~DynamicLibrary(); - - /** - * Gets a symbols from a loaded dynamic library. - * Returns the symbol (usually a function pointer) as - * a void *. The caller is responsible for casting to the - * right type. Returns NULL if something goes wrong. - * - * @param name [std::string] the requested symbol name. - */ - void * getSymbol(const std::string & name); - - private: - DynamicLibrary(); - - DynamicLibrary(void * handle); - DynamicLibrary(const DynamicLibrary &); - - private: - void * handle_; - }; + static DynamicLibrary *load(const std::string &path, + std::string &errorString); + ~DynamicLibrary(); + + /** + * Gets a symbols from a loaded dynamic library. + * Returns the symbol (usually a function pointer) as + * a void *. The caller is responsible for casting to the + * right type. Returns NULL if something goes wrong. + * + * @param name [std::string] the requested symbol name. + */ + void *getSymbol(const std::string &name); + +private: + DynamicLibrary(); + + DynamicLibrary(void *handle); + DynamicLibrary(const DynamicLibrary &); + +private: + void *handle_; +}; -} +} // namespace nupic #endif // NTA_DYNAMIC_LIBRARY_HPP diff --git a/src/nupic/os/Env.cpp b/src/nupic/os/Env.cpp index b9003191c1..de51743e1e 100644 --- a/src/nupic/os/Env.cpp +++ b/src/nupic/os/Env.cpp @@ -20,38 +20,39 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file Environment Implementation */ -#include -#include -#include +#include // std::transform #include +#include #include // toupper -#include // std::transform +#include +#include using namespace nupic; -bool Env::get(const std::string& name, std::string& value) -{ +bool Env::get(const std::string &name, std::string &value) { // @todo remove apr initialization when we have global initialization apr_status_t status = apr_initialize(); if (status != APR_SUCCESS) { - NTA_THROW << "Env::get -- Unable to initialize APR" << " name = " << name; + NTA_THROW << "Env::get -- Unable to initialize APR" + << " name = " << name; return false; } - + // This is annoying. apr_env_get doesn't actually use the memory // pool it is given. But we have to set it up because the API - // requires it and might use it in the future. + // requires it and might use it in the future. apr_pool_t *poolP; status = apr_pool_create(&poolP, nullptr); if (status != APR_SUCCESS) { - NTA_THROW << "Env::get -- Unable to create a pool" << " name = " << name; + NTA_THROW << "Env::get -- Unable to create a pool" + << " name = " << name; return false; } - + char *cvalue; bool returnvalue = false; status = apr_env_get(&cvalue, name.c_str(), poolP); @@ -63,55 +64,54 @@ bool Env::get(const std::string& name, std::string& value) } apr_pool_destroy(poolP); return returnvalue; - } -void Env::set(const std::string& name, const std::string& value) -{ +void Env::set(const std::string &name, const std::string &value) { // @todo remove apr initialization when we have global initialization apr_status_t status = apr_initialize(); if (status != APR_SUCCESS) { - NTA_THROW << "Env::set -- Unable to initialize APR" << " name = " << name << - " value = " << value; + NTA_THROW << "Env::set -- Unable to initialize APR" + << " name = " << name << " value = " << value; // ok to return. Haven't created a pool yet return; } - + apr_pool_t *poolP; status = apr_pool_create(&poolP, nullptr); if (status != APR_SUCCESS) { - NTA_THROW << "Env::set -- Unable to create a pool." << " name = " << name << - " value = " << value; - // ok to return. Haven't created a pool yet. + NTA_THROW << "Env::set -- Unable to create a pool." + << " name = " << name << " value = " << value; + // ok to return. Haven't created a pool yet. return; } - + status = apr_env_set(name.c_str(), value.c_str(), poolP); if (status != APR_SUCCESS) { - NTA_THROW << "Env::set -- Unable to set variable " << name << " to " << value; - } - + NTA_THROW << "Env::set -- Unable to set variable " << name << " to " + << value; + } + apr_pool_destroy(poolP); return; - } -void Env::unset(const std::string& name) -{ +void Env::unset(const std::string &name) { // @todo remove apr initialization when we have global initialization apr_status_t status = apr_initialize(); if (status != APR_SUCCESS) { - NTA_THROW << "Env::unset -- Unable to initialize APR." << " name = " << name; + NTA_THROW << "Env::unset -- Unable to initialize APR." + << " name = " << name; return; } - + apr_pool_t *poolP; status = apr_pool_create(&poolP, nullptr); if (status != APR_SUCCESS) { - NTA_THROW << "Env::unset -- Unable to create a pool." << " name = " << name; + NTA_THROW << "Env::unset -- Unable to create a pool." + << " name = " << name; return; } - + status = apr_env_delete(name.c_str(), poolP); if (status != APR_SUCCESS) { // not a fatal error because may not exist @@ -119,52 +119,46 @@ void Env::unset(const std::string& name) } apr_pool_destroy(poolP); return; - } -char ** Env::environ_ = nullptr; +char **Env::environ_ = nullptr; #if defined(NTA_OS_DARWIN) - #include +#include #else - extern char **environ; +extern char **environ; #endif - -char **Env::getenv() -{ +char **Env::getenv() { if (environ_ != nullptr) return environ_; #if defined(NTA_OS_DARWIN) environ_ = *_NSGetEnviron(); -#else +#else environ_ = environ; #endif return environ_; } - -static std::string _getOptionEnvironmentVariable(const std::string& optionName) -{ - std::string result="NTA_"; +static std::string +_getOptionEnvironmentVariable(const std::string &optionName) { + std::string result = "NTA_"; result += optionName; std::transform(result.begin(), result.end(), result.begin(), toupper); return result; } - -bool Env::isOptionSet(const std::string& optionName) -{ +bool Env::isOptionSet(const std::string &optionName) { std::string envName = _getOptionEnvironmentVariable(optionName); std::string value; bool found = get(envName, value); return found; } -std::string Env::getOption(const std::string& optionName, std::string defaultValue) -{ +std::string Env::getOption(const std::string &optionName, + std::string defaultValue) { std::string envName = _getOptionEnvironmentVariable(optionName); std::string value; bool found = get(envName, value); @@ -173,4 +167,3 @@ std::string Env::getOption(const std::string& optionName, std::string defaultVal else return value; } - diff --git a/src/nupic/os/Env.hpp b/src/nupic/os/Env.hpp index b888a6aa04..28fffc51f7 100644 --- a/src/nupic/os/Env.hpp +++ b/src/nupic/os/Env.hpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file Environment Interface */ @@ -31,56 +31,54 @@ Environment Interface namespace nupic { - class Env { - public: +class Env { +public: + /** + * get the named environment variable from the environment. + * @param name Name of environment variable + * @param value Value of environment variable. Set only if variable is found + * @retval true if variable was found; false if not found. + * If false, then value parameter is not set + **/ + static bool get(const std::string &name, std::string &value); - /** - * get the named environment variable from the environment. - * @param name Name of environment variable - * @param value Value of environment variable. Set only if variable is found - * @retval true if variable was found; false if not found. - * If false, then value parameter is not set - **/ - static bool get(const std::string& name, std::string& value); - - /** - * Set the named environment variable. - * @param name Name of environment variable - * @param value Value to which environment variable is set - */ - static void set(const std::string& name, const std::string& value); + /** + * Set the named environment variable. + * @param name Name of environment variable + * @param value Value to which environment variable is set + */ + static void set(const std::string &name, const std::string &value); - /** - * Unset the named environment variable - * @param name Name of environment variable to unset - * If variable is not previously set, no error is returned. - */ - static void unset(const std::string& name); + /** + * Unset the named environment variable + * @param name Name of environment variable to unset + * If variable is not previously set, no error is returned. + */ + static void unset(const std::string &name); - /** - * Get the environment as an array of strings - */ - static char** getenv(); + /** + * Get the environment as an array of strings + */ + static char **getenv(); - /** - * An "option" is an environment variable of the form NTA_XXX. - * The canonical form for an option name is all uppercase characters. - * These are convenience routines for using options. They canonicalize - * the name and search the environment. - */ - static bool isOptionSet(const std::string& optionName); - - /** - * Get the value of the NTA_XXX environment variable. - */ - static std::string getOption(const std::string& optionName, std::string defaultValue=""); + /** + * An "option" is an environment variable of the form NTA_XXX. + * The canonical form for an option name is all uppercase characters. + * These are convenience routines for using options. They canonicalize + * the name and search the environment. + */ + static bool isOptionSet(const std::string &optionName); + /** + * Get the value of the NTA_XXX environment variable. + */ + static std::string getOption(const std::string &optionName, + std::string defaultValue = ""); - private: - static char** environ_; - }; - -} +private: + static char **environ_; +}; -#endif // NTA_ENV_HPP +} // namespace nupic +#endif // NTA_ENV_HPP diff --git a/src/nupic/os/FStream.cpp b/src/nupic/os/FStream.cpp index ea520f7376..434c46f88a 100644 --- a/src/nupic/os/FStream.cpp +++ b/src/nupic/os/FStream.cpp @@ -20,76 +20,70 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Definitions for the FStream classes - * - * These classes are versions of ifstream and ofstream that accept platform independent - * (i.e. windows or unix) utf-8 path specifiers for their constructor and open() methods. * - * The native ifstream and ofstream classes on unix already accept UTF-8, but on windows, - * we must convert the utf-8 path to unicode and then pass it to the 'w' version of - * ifstream or ofstream + * These classes are versions of ifstream and ofstream that accept platform + * independent (i.e. windows or unix) utf-8 path specifiers for their + * constructor and open() methods. + * + * The native ifstream and ofstream classes on unix already accept UTF-8, but on + * windows, we must convert the utf-8 path to unicode and then pass it to the + * 'w' version of ifstream or ofstream */ - -#include "nupic/utils/Log.hpp" -#include "nupic/os/Path.hpp" #include "nupic/os/FStream.hpp" #include "nupic/os/Directory.hpp" #include "nupic/os/Env.hpp" -#include +#include "nupic/os/Path.hpp" +#include "nupic/utils/Log.hpp" #include +#include #if defined(NTA_OS_WINDOWS) && !defined(NTA_COMPILER_GNU) - #include - #include +#include +#include #else - #include +#include #endif #include - using namespace nupic; ////////////////////////////////////////////////////////////////////////// /// Print out diagnostic information when a file open fails ///////////////////////////////////////////////////////////////////////// -void IFStream::diagnostics(const char* filename) -{ +void IFStream::diagnostics(const char *filename) { bool forceLog = false; - // We occasionally get error 116(ESTALE) "Stale NFS file handle" (TOO-402) when creating - // a file using OFStream::open() on a shared drive on unix systems. We found that - // if we perform a directory listing after encountering the error, that a retry - // immediately after is successful. So.... we log this information if we - // get errno==ESTALE OR if NTA_FILE_LOGGING is set. + // We occasionally get error 116(ESTALE) "Stale NFS file handle" (TOO-402) + // when creating + // a file using OFStream::open() on a shared drive on unix systems. We found + // that if we perform a directory listing after encountering the error, that + // a retry immediately after is successful. So.... we log this information if + // we get errno==ESTALE OR if NTA_FILE_LOGGING is set. #ifdef ESTALE if (errno == ESTALE) forceLog = true; #endif - + if (forceLog || ::getenv("NTA_FILE_LOGGING")) { NTA_DEBUG << "FStream::open() failed opening file " << filename - << "; errno = " << errno - << "; errmsg = " << strerror(errno) + << "; errno = " << errno << "; errmsg = " << strerror(errno) << "; cwd = " << Directory::getCWD(); - + Directory::Iterator di(Directory::getCWD()); Directory::Entry e; while (di.next(e)) { NTA_DEBUG << "FStream::open() ls: " << e.path; } } - } - - ////////////////////////////////////////////////////////////////////////// /// open the given file by name ///////////////////////////////////////////////////////////////////////// -void IFStream::open(const char * filename, ios_base::openmode mode) -{ +void IFStream::open(const char *filename, ios_base::openmode mode) { #if defined(NTA_OS_WINDOWS) && !defined(NTA_COMPILER_GNU) std::wstring pathW = Path::utf8ToUnicode(filename); std::ifstream::open(pathW.c_str(), mode); @@ -100,21 +94,19 @@ void IFStream::open(const char * filename, ios_base::openmode mode) // Check for error if (!is_open()) { IFStream::diagnostics(filename); - // On unix, running nfs, we occasionally get errors opening a file on an nfs drive - // and it seems that simply doing a retry makes it successful - #if !defined(NTA_OS_WINDOWS) - std::ifstream::clear(); - std::ifstream::open(filename, mode); - #endif +// On unix, running nfs, we occasionally get errors opening a file on an nfs +// drive and it seems that simply doing a retry makes it successful +#if !defined(NTA_OS_WINDOWS) + std::ifstream::clear(); + std::ifstream::open(filename, mode); +#endif } } - - + ////////////////////////////////////////////////////////////////////////// /// open the given file by name ///////////////////////////////////////////////////////////////////////// -void OFStream::open(const char * filename, ios_base::openmode mode) -{ +void OFStream::open(const char *filename, ios_base::openmode mode) { #if defined(NTA_OS_WINDOWS) && !defined(NTA_COMPILER_GNU) std::wstring pathW = Path::utf8ToUnicode(filename); std::ofstream::open(pathW.c_str(), mode); @@ -125,41 +117,40 @@ void OFStream::open(const char * filename, ios_base::openmode mode) // Check for error if (!is_open()) { IFStream::diagnostics(filename); - // On unix, running nfs, we occasionally get errors opening a file on an nfs drive - // and it seems that simply doing a retry makes it successful - #if !(defined(NTA_OS_WINDOWS) && !defined(NTA_COMPILER_GNU)) - std::ofstream::clear(); - std::ofstream::open(filename, mode); - #endif +// On unix, running nfs, we occasionally get errors opening a file on an nfs +// drive and it seems that simply doing a retry makes it successful +#if !(defined(NTA_OS_WINDOWS) && !defined(NTA_COMPILER_GNU)) + std::ofstream::clear(); + std::ofstream::open(filename, mode); +#endif } } - + void *ZLib::fopen(const std::string &filename, const std::string &mode, - std::string *errorMessage) -{ - if(mode.empty()) throw std::invalid_argument("Mode may not be empty."); + std::string *errorMessage) { + if (mode.empty()) + throw std::invalid_argument("Mode may not be empty."); #if defined(NTA_OS_WINDOWS) && !defined(NTA_COMPILER_GNU) std::wstring wfilename(Path::utf8ToUnicode(filename)); int cflags = _O_BINARY; int pflags = 0; - if(mode[0] == 'r') { + if (mode[0] == 'r') { cflags |= _O_RDONLY; - } - else if(mode[0] == 'w') { - cflags |= _O_TRUNC | _O_CREAT | _O_WRONLY; + } else if (mode[0] == 'w') { + cflags |= _O_TRUNC | _O_CREAT | _O_WRONLY; pflags |= _S_IREAD | _S_IWRITE; - } - else if(mode[0] == 'a') { + } else if (mode[0] == 'a') { cflags |= _O_APPEND | _O_CREAT | _O_WRONLY; pflags |= _S_IREAD | _S_IWRITE; + } else { + throw std::invalid_argument("Mode must start with 'r', 'w' or 'a'."); } - else { throw std::invalid_argument("Mode must start with 'r', 'w' or 'a'."); } int fd = _wopen(wfilename.c_str(), cflags, pflags); gzFile fs = gzdopen(fd, mode.c_str()); - if(fs == 0) { + if (fs == 0) { // TODO: Build an error message for Windows. } @@ -169,38 +160,48 @@ void *ZLib::fopen(const std::string &filename, const std::string &mode, int attempts = 0; const int maxAttempts = 1; int lastError = 0; - while(1) { + while (1) { fs = gzopen(filename.c_str(), mode.c_str()); - if(fs) break; + if (fs) + break; int error = errno; - if(error != lastError) { + if (error != lastError) { std::string message("Unknown error."); // lastError = error; - switch(error) { - case Z_STREAM_ERROR: message = "Zlib stream error."; break; - case Z_DATA_ERROR: message = "Zlib data error."; break; - case Z_MEM_ERROR: message = "Zlib memory error."; break; - case Z_BUF_ERROR: message = "Zlib buffer error."; break; - case Z_VERSION_ERROR: message = "Zlib version error."; break; - default: message = ::strerror(error); break; - } - if(errorMessage) { - *errorMessage = message; + switch (error) { + case Z_STREAM_ERROR: + message = "Zlib stream error."; + break; + case Z_DATA_ERROR: + message = "Zlib data error."; + break; + case Z_MEM_ERROR: + message = "Zlib memory error."; + break; + case Z_BUF_ERROR: + message = "Zlib buffer error."; + break; + case Z_VERSION_ERROR: + message = "Zlib version error."; + break; + default: + message = ::strerror(error); + break; } - else if(maxAttempts > 1) { // If we will try again, warn about failure. - std::cerr << "Warning: Failed to open file '" - << filename << "': " << message << std::endl; + if (errorMessage) { + *errorMessage = message; + } else if (maxAttempts > + 1) { // If we will try again, warn about failure. + std::cerr << "Warning: Failed to open file '" << filename + << "': " << message << std::endl; } - } - if((++attempts) >= maxAttempts) break; + } + if ((++attempts) >= maxAttempts) + break; ::usleep(10000); - } + } } #endif return fs; } - - - - diff --git a/src/nupic/os/FStream.hpp b/src/nupic/os/FStream.hpp index e9bcba466c..900b4b579c 100644 --- a/src/nupic/os/FStream.hpp +++ b/src/nupic/os/FStream.hpp @@ -20,18 +20,18 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Definitions for the FStream classes - * - * These classes are versions of ifstream and ofstream that accept platform independent - * (i.e. windows or unix) utf-8 path specifiers for their constructor and open() methods. * - * The native ifstream and ofstream classes on unix already accept UTF-8, but on windows, - * we must convert the utf-8 path to unicode and then pass it to the 'w' version of - * ifstream or ofstream + * These classes are versions of ifstream and ofstream that accept platform + * independent (i.e. windows or unix) utf-8 path specifiers for their + * constructor and open() methods. + * + * The native ifstream and ofstream classes on unix already accept UTF-8, but on + * windows, we must convert the utf-8 path to unicode and then pass it to the + * 'w' version of ifstream or ofstream */ - #ifndef NTA_F_STREAM_HPP #define NTA_F_STREAM_HPP @@ -43,40 +43,40 @@ namespace nupic { /// IFStream /// /// @b Responsibility -/// +/// /// Open a file for reading -/// +/// /// @b Description /// -/// This class overrides the open() and constructor methods of the standard ifstream to -/// handle utf-8 paths. +/// This class overrides the open() and constructor methods of the standard +/// ifstream to handle utf-8 paths. /// ///////////////////////////////////////////////////////////////////////////////////// -class IFStream : public std::ifstream -{ +class IFStream : public std::ifstream { public: ////////////////////////////////////////////////////////////////////////// /// Construct an OFStream ///////////////////////////////////////////////////////////////////////// - IFStream () : std::ifstream() {} + IFStream() : std::ifstream() {} /////////////////////////////////////////////////////////////////////////////////// - /// WARNING: the std library does not declare a virtual destructor for std::basic_ofstream - /// or std::basic_ifstream, which we sub-class. Therefore, the destructor for this class - /// will NOT be called and therefore it should not allocate any data members that need - /// to be deleted at destruction time. + /// WARNING: the std library does not declare a virtual destructor for + /// std::basic_ofstream + /// or std::basic_ifstream, which we sub-class. Therefore, the destructor for + /// this class will NOT be called and therefore it should not allocate any + /// data members that need to be deleted at destruction time. /////////////////////////////////////////////////////////////////////////////////// virtual ~IFStream() {} - + ////////////////////////////////////////////////////////////////////////// /// Construct an IFStream /// /// @param filename the name of the file to open /// @param mode the open mode ///////////////////////////////////////////////////////////////////////// - IFStream (const char * filename, ios_base::openmode mode = ios_base::in ) : std::ifstream() - { + IFStream(const char *filename, ios_base::openmode mode = ios_base::in) + : std::ifstream() { open(filename, mode); } @@ -86,55 +86,53 @@ class IFStream : public std::ifstream /// @param filename the name of the file to open /// @param mode the open mode ///////////////////////////////////////////////////////////////////////// - void open(const char * filename, ios_base::openmode mode = ios_base::in ); - + void open(const char *filename, ios_base::openmode mode = ios_base::in); + ////////////////////////////////////////////////////////////////////////// /// print out diagnostic information on a failed open ///////////////////////////////////////////////////////////////////////// - static void diagnostics(const char* filename); - -}; // end class IFStream + static void diagnostics(const char *filename); +}; // end class IFStream /////////////////////////////////////////////////////////////////////////////////////// /// OFStream /// /// @b Responsibility -/// +/// /// Open a file for writing -/// +/// /// @b Description /// -/// This class overrides the open() and constructor methods of the standard ofstream to -/// handle utf-8 paths. -/// +/// This class overrides the open() and constructor methods of the standard +/// ofstream to handle utf-8 paths. +/// ///////////////////////////////////////////////////////////////////////////////////// -class OFStream : public std::ofstream -{ +class OFStream : public std::ofstream { public: ////////////////////////////////////////////////////////////////////////// /// Construct an OFStream ///////////////////////////////////////////////////////////////////////// - OFStream () : std::ofstream() {} + OFStream() : std::ofstream() {} /////////////////////////////////////////////////////////////////////////////////// - /// WARNING: the std library does not declare a virtual destructor for std::basic_ofstream - /// or std::basic_ifstream, which we sub-class. Therefore, the destructor for this class - /// will NOT be called and therefore it should not allocate any data members that need - /// to be deleted at destruction time. + /// WARNING: the std library does not declare a virtual destructor for + /// std::basic_ofstream + /// or std::basic_ifstream, which we sub-class. Therefore, the destructor for + /// this class will NOT be called and therefore it should not allocate any + /// data members that need to be deleted at destruction time. /////////////////////////////////////////////////////////////////////////////////// virtual ~OFStream() {} - - + ////////////////////////////////////////////////////////////////////////// /// Construct an OFStream /// /// @param filename the name of the file to open /// @param mode the open mode ///////////////////////////////////////////////////////////////////////// - OFStream (const char * filename, ios_base::openmode mode = ios_base::out ) : std::ofstream() - { + OFStream(const char *filename, ios_base::openmode mode = ios_base::out) + : std::ofstream() { open(filename, mode); } @@ -144,27 +142,16 @@ class OFStream : public std::ofstream /// @param filename the name of the file to open /// @param mode the open mode ///////////////////////////////////////////////////////////////////////// - void open(const char * filename, ios_base::openmode mode = ios_base::out ); - - + void open(const char *filename, ios_base::openmode mode = ios_base::out); + }; // end class OFStream -class ZLib -{ +class ZLib { public: static void *fopen(const std::string &filename, const std::string &mode, - std::string *errorMessage=nullptr); + std::string *errorMessage = nullptr); }; - - } // end namespace nupic - - - - #endif // NTA_F_STREAM_HPP - - - diff --git a/src/nupic/os/OS.cpp b/src/nupic/os/OS.cpp index 940b3a9988..2e2a3b6977 100644 --- a/src/nupic/os/OS.cpp +++ b/src/nupic/os/OS.cpp @@ -20,49 +20,40 @@ * --------------------------------------------------------------------- */ - -/** @file +/** @file * Generic OS Implementations for the OS class */ -#include -#include +#include +#include +#include #include #include +#include +#include #include -#include -#include -#include - #if defined(NTA_OS_DARWIN) extern "C" { -#include #include +#include } #elif defined(NTA_OS_WINDOWS) -//We only run on XP/2003 and above +// We only run on XP/2003 and above #undef _WIN32_WINNT #define _WIN32_WINNT 0x0501 #include #endif - - using namespace nupic; - - - -void OS::getProcessMemoryUsage(size_t& realMem, size_t& virtualMem) -{ +void OS::getProcessMemoryUsage(size_t &realMem, size_t &virtualMem) { #if defined(NTA_OS_DARWIN) struct task_basic_info t_info; mach_msg_type_number_t t_info_count = TASK_BASIC_INFO_COUNT; - - if (KERN_SUCCESS != task_info(mach_task_self(), - TASK_BASIC_INFO, (task_info_t)&t_info, &t_info_count)) - { + + if (KERN_SUCCESS != task_info(mach_task_self(), TASK_BASIC_INFO, + (task_info_t)&t_info, &t_info_count)) { NTA_THROW << "getProcessMemoryUsage -- unable to get memory usage"; } realMem = t_info.resident_size; @@ -74,26 +65,23 @@ void OS::getProcessMemoryUsage(size_t& realMem, size_t& virtualMem) ::GetSystemInfo(&si); - PSAPI_WORKING_SET_INFORMATION * pWSI = NULL; + PSAPI_WORKING_SET_INFORMATION *pWSI = NULL; unsigned int pageCount = 2500; unsigned int size; - + unsigned int retries; - for(retries = 0; retries < 20; retries++) - { + for (retries = 0; retries < 20; retries++) { size = sizeof(PSAPI_WORKING_SET_INFORMATION) + - pageCount * sizeof(PSAPI_WORKING_SET_BLOCK); + pageCount * sizeof(PSAPI_WORKING_SET_BLOCK); - pWSI = (PSAPI_WORKING_SET_INFORMATION *) realloc((void *) pWSI, size); + pWSI = (PSAPI_WORKING_SET_INFORMATION *)realloc((void *)pWSI, size); - if(::QueryWorkingSet(hProcess, pWSI, size)) - { + if (::QueryWorkingSet(hProcess, pWSI, size)) { break; } - if(::GetLastError()!=ERROR_BAD_LENGTH) - { - free((void *) pWSI); + if (::GetLastError() != ERROR_BAD_LENGTH) { + free((void *)pWSI); ::CloseHandle(hProcess); NTA_THROW << "getProcessMemoryUsage -- unable to get memory usage"; } @@ -102,52 +90,47 @@ void OS::getProcessMemoryUsage(size_t& realMem, size_t& virtualMem) pageCount += pageCount >> 2; } - if(retries >= 20) - { - free((void *) pWSI); + if (retries >= 20) { + free((void *)pWSI); ::CloseHandle(hProcess); NTA_THROW << "getProcessMemoryUsage -- unable to get memory usage"; } unsigned int actualPages; - pWSI->NumberOfEntries > pageCount ? (actualPages = pageCount) : - (actualPages = pWSI->NumberOfEntries); + pWSI->NumberOfEntries > pageCount ? (actualPages = pageCount) + : (actualPages = pWSI->NumberOfEntries); unsigned int privateWorkingSet = 0; - for(unsigned int i = 0; i < actualPages; i++) - { - if(!pWSI->WorkingSetInfo[i].Shared) - { + for (unsigned int i = 0; i < actualPages; i++) { + if (!pWSI->WorkingSetInfo[i].Shared) { privateWorkingSet += si.dwPageSize; } } - //subtract off memory allocated for our pWSI + // subtract off memory allocated for our pWSI privateWorkingSet -= ((size / si.dwPageSize) + 1) * si.dwPageSize; - free((void *) pWSI); + free((void *)pWSI); PROCESS_MEMORY_COUNTERS_EX pmcEx; pmcEx.cb = sizeof(PROCESS_MEMORY_COUNTERS_EX); rc = ::GetProcessMemoryInfo( - hProcess, - reinterpret_cast(&pmcEx), - sizeof(PROCESS_MEMORY_COUNTERS_EX)); + hProcess, reinterpret_cast(&pmcEx), + sizeof(PROCESS_MEMORY_COUNTERS_EX)); - if (!rc) - { + if (!rc) { NTA_THROW << "getProcessMemoryUsage -- unable to get memory usage"; } - //Private usage corresponds to the total amount of private virtual memory + // Private usage corresponds to the total amount of private virtual memory virtualMem = pmcEx.PrivateUsage; - //The private working set corresponds to the unshared virtual memory in - //the processes' working set + // The private working set corresponds to the unshared virtual memory in + // the processes' working set realMem = privateWorkingSet; ::CloseHandle(hProcess); @@ -158,23 +141,19 @@ void OS::getProcessMemoryUsage(size_t& realMem, size_t& virtualMem) #endif } -std::string OS::executeCommand(std::string command) -{ +std::string OS::executeCommand(std::string command) { #if defined(NTA_OS_WINDOWS) && defined(NTA_COMPILER_MSVC) - FILE* pipe = _popen(&command[0], "r"); + FILE *pipe = _popen(&command[0], "r"); #else - FILE* pipe = popen(&command[0], "r"); + FILE *pipe = popen(&command[0], "r"); #endif - if (!pipe) - { + if (!pipe) { return "ERROR"; } char buffer[128]; std::string result = ""; - while(!feof(pipe)) - { - if(fgets(buffer, 128, pipe) != nullptr) - { + while (!feof(pipe)) { + if (fgets(buffer, 128, pipe) != nullptr) { result += buffer; } } diff --git a/src/nupic/os/OS.hpp b/src/nupic/os/OS.hpp index 64334926be..81b8987b4f 100644 --- a/src/nupic/os/OS.hpp +++ b/src/nupic/os/OS.hpp @@ -20,122 +20,117 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Interface for the OS class */ #ifndef NTA_OS_HPP #define NTA_OS_HPP +#include #include #include -#include #ifdef _MSC_VER - #pragma warning (disable: 4996) - // The POSIX name for this item is deprecated. Instead, use the ISO C++ - // conformant name: _getpid. +#pragma warning(disable : 4996) +// The POSIX name for this item is deprecated. Instead, use the ISO C++ +// conformant name: _getpid. #endif -namespace nupic -{ - /* - * removed for NuPIC 2: - * getHostname - * getUserNTADir - * setUserNTADir - * getProcessID - * getTempDir - * makeTempFilename - * sleep - * executeCommand - * genCryptoString - * verifyHostname - * isProcessAliveWin32 - * killWin32 - * getStackTrace - */ +namespace nupic { +/* + * removed for NuPIC 2: + * getHostname + * getUserNTADir + * setUserNTADir + * getProcessID + * getTempDir + * makeTempFilename + * sleep + * executeCommand + * genCryptoString + * verifyHostname + * isProcessAliveWin32 + * killWin32 + * getStackTrace + */ +/** + * @b Responsibility + * Operating system functionality. + * + * @b Description + * OS is a set of static methods that provide access to operating system + * functionality for Numenta apps. + */ +class OS { +public: /** - * @b Responsibility - * Operating system functionality. - * - * @b Description - * OS is a set of static methods that provide access to operating system functionality - * for Numenta apps. + * Get the last error string + * + * @retval Returns character string containing the last error message. */ + static std::string getErrorMessage(); - class OS - { - public: - /** - * Get the last error string - * - * @retval Returns character string containing the last error message. - */ - static std::string getErrorMessage(); - - /** - * - * - * @return An OS/system library error code. - */ - static int getLastErrorCode(); - - /** - * Get an OS-level error message associated with an error code. - * - * If no error code is specified, gets the error message associated - * with the last error code. - * - * @param An error code, usually reported by getLastErrorCode(). - * - * @return An error message string. - */ - static std::string getErrorMessageFromErrorCode( - int errorCode=getLastErrorCode()); + /** + * + * + * @return An OS/system library error code. + */ + static int getLastErrorCode(); - /** - * Get the user's home directory - * - * The home directory is determined by common environment variables - * on different platforms. - * - * @retval Returns character string containing the user's home directory. - */ - static std::string getHomeDir(); + /** + * Get an OS-level error message associated with an error code. + * + * If no error code is specified, gets the error message associated + * with the last error code. + * + * @param An error code, usually reported by getLastErrorCode(). + * + * @return An error message string. + */ + static std::string + getErrorMessageFromErrorCode(int errorCode = getLastErrorCode()); + /** + * Get the user's home directory + * + * The home directory is determined by common environment variables + * on different platforms. + * + * @retval Returns character string containing the user's home directory. + */ + static std::string getHomeDir(); - /** - * Get the user name - * - * A user name is disovered on unix by checking a few environment variables - * (USER, LOGNAME) and if not found defaulting to the user id. On Windows the - * USERNAME environment variable is set by the OS. - * - * @retval Returns character string containing the user name. - */ - static std::string getUserName(); + /** + * Get the user name + * + * A user name is disovered on unix by checking a few environment variables + * (USER, LOGNAME) and if not found defaulting to the user id. On Windows the + * USERNAME environment variable is set by the OS. + * + * @retval Returns character string containing the user name. + */ + static std::string getUserName(); - /** - * Get process memory usage - * - * Real and Virtual memory usage are returned in bytes - */ - static void getProcessMemoryUsage(size_t& realMem, size_t& virtualMem); + /** + * Get process memory usage + * + * Real and Virtual memory usage are returned in bytes + */ + static void getProcessMemoryUsage(size_t &realMem, size_t &virtualMem); - /** - * Execute a command and and return its output. - * - * @param command - * The command to execute - * @returns - * The output of the command. - */ - static std::string executeCommand(std::string command); - }; -} + /** + * Execute a command and and return its output. + * + * @param command + * The command to execute + * @returns + * The output of the command. + */ + static std::string executeCommand(std::string command); +}; +} // namespace nupic #endif // NTA_OS_HPP - diff --git a/src/nupic/os/OSUnix.cpp b/src/nupic/os/OSUnix.cpp old mode 100755 new mode 100644 index 1bd07a8986..88c995a9aa --- a/src/nupic/os/OSUnix.cpp +++ b/src/nupic/os/OSUnix.cpp @@ -20,41 +20,35 @@ * --------------------------------------------------------------------- */ - -/** @file +/** @file * Unix Implementations for the OS class */ #if !defined(NTA_OS_WINDOWS) -#include -#include +#include +#include +#include +#include +#include #include #include +#include +#include #include -#include -#include -#include // getuid() #include -#include -#include -#include - +#include // getuid() using namespace nupic; -std::string OS::getErrorMessage() -{ +std::string OS::getErrorMessage() { char buff[1024]; apr_status_t st = apr_get_os_error(); - ::apr_strerror(st , buff, 1024); + ::apr_strerror(st, buff, 1024); return std::string(buff); } - - -std::string OS::getHomeDir() -{ +std::string OS::getHomeDir() { std::string home; bool found = Env::get("HOME", home); if (!found) @@ -62,8 +56,7 @@ std::string OS::getHomeDir() return home; } -std::string OS::getUserName() -{ +std::string OS::getUserName() { std::string username; bool found = Env::get("USER", username); @@ -71,40 +64,37 @@ std::string OS::getUserName() if (!found) found = Env::get("LOGNAME", username); - if (!found) - { - NTA_WARN << "OS::getUserName -- USER and LOGNAME environment variables are not set. Using userid = " << getuid(); + if (!found) { + NTA_WARN << "OS::getUserName -- USER and LOGNAME environment variables are " + "not set. Using userid = " + << getuid(); std::stringstream ss(""); ss << getuid(); - username = ss.str(); - } + username = ss.str(); + } return username; } - - - - int OS::getLastErrorCode() { return errno; } -std::string OS::getErrorMessageFromErrorCode(int errorCode) -{ +std::string OS::getErrorMessageFromErrorCode(int errorCode) { std::stringstream errorMessage; char errorBuffer[1024]; errorBuffer[0] = '\0'; - + #if defined(__APPLE__) || (defined(NTA_ARCH_64) && defined(NTA_OS_SPARC)) int result = ::strerror_r(errorCode, errorBuffer, 1024); - if(result == 0) errorMessage << errorBuffer; + if (result == 0) + errorMessage << errorBuffer; #else char *result = ::strerror_r(errorCode, errorBuffer, 1024); - if(result != nullptr) errorMessage << errorBuffer; -#endif - else errorMessage << "Error code " << errorCode; + if (result != nullptr) + errorMessage << errorBuffer; +#endif + else + errorMessage << "Error code " << errorCode; return errorMessage.str(); } #endif - - diff --git a/src/nupic/os/OSWin.cpp b/src/nupic/os/OSWin.cpp old mode 100755 new mode 100644 index a1e7140d3e..6c8ed2ca08 --- a/src/nupic/os/OSWin.cpp +++ b/src/nupic/os/OSWin.cpp @@ -20,29 +20,25 @@ * --------------------------------------------------------------------- */ - -/** @file +/** @file * Win32 Implementations for the OS class */ #if defined(NTA_OS_WINDOWS) -#include #include +#include -#include -#include -#include +#include #include +#include #include +#include +#include #include -#include -#include - using namespace nupic; -std::string OS::getHomeDir() -{ +std::string OS::getHomeDir() { std::string homeDrive; std::string homePath; bool found = Env::get("HOMEDRIVE", homeDrive); @@ -52,8 +48,7 @@ std::string OS::getHomeDir() return homeDrive + homePath; } -std::string OS::getUserName() -{ +std::string OS::getUserName() { std::string username; bool found = Env::get("USERNAME", username); NTA_CHECK(found) << "Environment variable USERNAME is not defined"; @@ -61,32 +56,22 @@ std::string OS::getUserName() return username; } -int OS::getLastErrorCode() -{ - return ::GetLastError(); -} +int OS::getLastErrorCode() { return ::GetLastError(); } -std::string OS::getErrorMessageFromErrorCode(int errorCode) -{ +std::string OS::getErrorMessageFromErrorCode(int errorCode) { // Retrieve the system error message for the last-error code LPVOID lpMsgBuf; DWORD msgLen = ::FormatMessageA( - FORMAT_MESSAGE_ALLOCATE_BUFFER | - FORMAT_MESSAGE_FROM_SYSTEM | - FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, - errorCode, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPSTR) &lpMsgBuf, - 0, NULL - ); + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, errorCode, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPSTR)&lpMsgBuf, 0, NULL); std::ostringstream errMessage; - if(msgLen > 0) { - errMessage.write((LPSTR) lpMsgBuf, msgLen); - } - else { + if (msgLen > 0) { + errMessage.write((LPSTR)lpMsgBuf, msgLen); + } else { errMessage << "code: " << errorCode; } @@ -95,11 +80,8 @@ std::string OS::getErrorMessageFromErrorCode(int errorCode) return errMessage.str(); } -std::string OS::getErrorMessage() -{ - return getErrorMessageFromErrorCode (getLastErrorCode()); +std::string OS::getErrorMessage() { + return getErrorMessageFromErrorCode(getLastErrorCode()); } - #endif //#if defined(NTA_OS_WINDOWS) - diff --git a/src/nupic/os/Path.cpp b/src/nupic/os/Path.cpp index c701f0cef8..ab45cf8ec5 100644 --- a/src/nupic/os/Path.cpp +++ b/src/nupic/os/Path.cpp @@ -20,847 +20,726 @@ * --------------------------------------------------------------------- */ -/** @file -*/ +/** @file + */ -#include +#include +#include #include -#include #include +#include +#include #include -#include -#include -#include -#include #include - +#include +#include #if defined(NTA_OS_WINDOWS) - extern "C" { - #include - } - #include +extern "C" { +#include +} +#include #else - #include - #include - #include +#include +#include +#include #if defined(NTA_OS_DARWIN) - #include // _NSGetExecutablePath - #else - // linux - #include // readlink - #endif +#include // _NSGetExecutablePath +#else +// linux +#include // readlink +#endif #endif -namespace nupic -{ +namespace nupic { #if defined(NTA_OS_WINDOWS) - const char * Path::sep = "\\"; - const char * Path::pathSep = ";"; +const char *Path::sep = "\\"; +const char *Path::pathSep = ";"; #else - const char * Path::sep = "/"; - const char * Path::pathSep = ":"; +const char *Path::sep = "/"; +const char *Path::pathSep = ":"; #endif - const char * Path::parDir = ".."; +const char *Path::parDir = ".."; - Path::Path(std::string path) : path_(std::move(path)) - { - } +Path::Path(std::string path) : path_(std::move(path)) {} - static apr_status_t getInfo(const std::string & path, apr_int32_t wanted, apr_finfo_t & info) - { - NTA_CHECK(!path.empty()) << "Can't get the info of an empty path"; - - apr_status_t res; - apr_pool_t * pool = nullptr; - - #if defined(NTA_OS_WINDOWS) - res = ::apr_pool_create(&pool, NULL); - if (res != APR_SUCCESS) - { - NTA_WARN << "Internal error: unable to create APR pool when getting info on path '" << path << "'"; - } - #endif - - res = ::apr_stat(&info, path.c_str(), wanted, pool); - - #if defined(NTA_OS_WINDOWS) - ::apr_pool_destroy(pool); - #endif - - return res; - } - - bool Path::exists(const std::string & path) - { - if (path.empty()) - return false; - - apr_finfo_t st; - apr_status_t res = getInfo(path, APR_FINFO_TYPE, st); - return res == APR_SUCCESS; - } +static apr_status_t getInfo(const std::string &path, apr_int32_t wanted, + apr_finfo_t &info) { + NTA_CHECK(!path.empty()) << "Can't get the info of an empty path"; - static apr_filetype_e getType(const std::string & path, bool check = true) - { - apr_finfo_t st; - apr_status_t res = getInfo(path, APR_FINFO_TYPE, st); - if (check) - { - NTA_CHECK(res == APR_SUCCESS) - << "Can't get info for '" << path << "', " << OS::getErrorMessage(); - } - - return st.filetype; - } + apr_status_t res; + apr_pool_t *pool = nullptr; - bool Path::isFile(const std::string & path) - { - return getType(path, false) == APR_REG; +#if defined(NTA_OS_WINDOWS) + res = ::apr_pool_create(&pool, NULL); + if (res != APR_SUCCESS) { + NTA_WARN << "Internal error: unable to create APR pool when getting info " + "on path '" + << path << "'"; } +#endif - bool Path::isDirectory(const std::string & path) - { - return getType(path) == APR_DIR; - } + res = ::apr_stat(&info, path.c_str(), wanted, pool); - bool Path::isSymbolicLink(const std::string & path) - { - return getType(path) == APR_LNK; - } - - bool Path::isAbsolute(const std::string & path) - { - NTA_CHECK(!path.empty()) << "Empty path is invalid"; - #if defined(NTA_OS_WINDOWS) - if (path.size() < 2) - return false; - else - { - bool local = ::isalpha(path[0]) && path[1] == ':'; - bool unc = path.size() > 2 && path[0] == '\\' && path[1] == '\\'; - return local || unc; - } +#if defined(NTA_OS_WINDOWS) + ::apr_pool_destroy(pool); +#endif - #else - return path[0] == '/'; - #endif - } - - bool Path::areEquivalent(const std::string & path1, const std::string & path2) - { - apr_finfo_t st1; - apr_finfo_t st2; - apr_int32_t wanted = APR_FINFO_IDENT; - - - apr_status_t s; - s = getInfo(path1.c_str(), wanted, st1); - // If either of the paths does not exist, then we say they are not equivalent - if (s != APR_SUCCESS) - return false; - - s = getInfo(path2.c_str(), wanted, st2); - if (s != APR_SUCCESS) - return false; - - bool res = true; - res &= st1.device == st2.device; - res &= st1.inode == st2.inode; - // We do not require the names to match. Could be a hard link. - // res &= std::string(st1.fname) == std::string(st2.fname); - - return res; - } - - std::string Path::getParent(const std::string & path) - { - if (path == "") - return ""; + return res; +} - std::string np = Path::normalize(path); - Path::StringVec sv = Path::split(np); - sv.push_back(".."); +bool Path::exists(const std::string &path) { + if (path.empty()) + return false; - return Path::normalize(Path::join(sv.begin(), sv.end())); - } - - std::string Path::getBasename(const std::string & path) - { - std::string::size_type index = path.find_last_of(Path::sep); - - if (index == std::string::npos) - return path; - - return path.substr(index+1); - } - - std::string Path::getExtension(const std::string & path) - { - std::string basename = Path::getBasename(path); - std::string::size_type index = basename.find_last_of('.'); - - // If its a regular or hidden filenames with no extension - // return an empty string - if (index == std::string::npos || // regular filename with no ext - index == 0 || // hidden file (starts with a '.') - index == basename.length() -1) // filename ends with a dot - return ""; - - // Don't include the dot, just the extension itself (unlike Python) - return std::string(basename.c_str() + index + 1, basename.length() - index - 1); + apr_finfo_t st; + apr_status_t res = getInfo(path, APR_FINFO_TYPE, st); + return res == APR_SUCCESS; +} + +static apr_filetype_e getType(const std::string &path, bool check = true) { + apr_finfo_t st; + apr_status_t res = getInfo(path, APR_FINFO_TYPE, st); + if (check) { + NTA_CHECK(res == APR_SUCCESS) + << "Can't get info for '" << path << "', " << OS::getErrorMessage(); } - - - Size Path::getFileSize(const std::string & path) - { - apr_finfo_t st; - apr_int32_t wanted = APR_FINFO_TYPE | APR_FINFO_SIZE; - apr_status_t res = getInfo(path.c_str(), wanted, st); - NTA_CHECK(res == APR_SUCCESS); - NTA_CHECK(st.filetype == APR_REG) << "Can't get the size of a non-file object"; - - return (Size)st.size; + + return st.filetype; +} + +bool Path::isFile(const std::string &path) { + return getType(path, false) == APR_REG; +} + +bool Path::isDirectory(const std::string &path) { + return getType(path) == APR_DIR; +} + +bool Path::isSymbolicLink(const std::string &path) { + return getType(path) == APR_LNK; +} + +bool Path::isAbsolute(const std::string &path) { + NTA_CHECK(!path.empty()) << "Empty path is invalid"; +#if defined(NTA_OS_WINDOWS) + if (path.size() < 2) + return false; + else { + bool local = ::isalpha(path[0]) && path[1] == ':'; + bool unc = path.size() > 2 && path[0] == '\\' && path[1] == '\\'; + return local || unc; } - std::string Path::normalize(const std::string & path) - { - // Easiest way is: split, then remove "." and remove a/.. (but not ../.. !) - // This does not get a/b/../.. so if we remove a/.., go through the string again - // Also need to treat rootdir/.. specially - // Also normalize(foo/..) -> "." but normalize(foo/bar/..) -> "foo" - StringVec v = Path::split(path); - if (v.size() == 0) - return ""; - - StringVec outv; - bool doAgain = true; - while (doAgain) - { - doAgain = false; - for (unsigned int i = 0; i < v.size(); i++) - { - if (v[i] == "") continue; // remove empty fields - if (v[i] == "." && v.size() > 1) continue; // skip "." unless it is by itself - if (i == 0 && isRootdir(v[i]) && i+1 < v.size() && v[i+1] == "..") - { - // /.. -> - outv.push_back(v[i]); - i++; // skipped following ".." - doAgain = true; - continue; - } - // remove "foo/.." - if (i+1 < v.size() && v[i] != ".." && v[i+1] == "..") - { - // but as a special case, if the full path is "foo/.." return "." - if (v.size() == 2) return "."; - i++; - doAgain = true; - continue; - } +#else + return path[0] == '/'; +#endif +} + +bool Path::areEquivalent(const std::string &path1, const std::string &path2) { + apr_finfo_t st1; + apr_finfo_t st2; + apr_int32_t wanted = APR_FINFO_IDENT; + + apr_status_t s; + s = getInfo(path1.c_str(), wanted, st1); + // If either of the paths does not exist, then we say they are not equivalent + if (s != APR_SUCCESS) + return false; + + s = getInfo(path2.c_str(), wanted, st2); + if (s != APR_SUCCESS) + return false; + + bool res = true; + res &= st1.device == st2.device; + res &= st1.inode == st2.inode; + // We do not require the names to match. Could be a hard link. + // res &= std::string(st1.fname) == std::string(st2.fname); + + return res; +} + +std::string Path::getParent(const std::string &path) { + if (path == "") + return ""; + + std::string np = Path::normalize(path); + Path::StringVec sv = Path::split(np); + sv.push_back(".."); + + return Path::normalize(Path::join(sv.begin(), sv.end())); +} + +std::string Path::getBasename(const std::string &path) { + std::string::size_type index = path.find_last_of(Path::sep); + + if (index == std::string::npos) + return path; + + return path.substr(index + 1); +} + +std::string Path::getExtension(const std::string &path) { + std::string basename = Path::getBasename(path); + std::string::size_type index = basename.find_last_of('.'); + + // If its a regular or hidden filenames with no extension + // return an empty string + if (index == std::string::npos || // regular filename with no ext + index == 0 || // hidden file (starts with a '.') + index == basename.length() - 1) // filename ends with a dot + return ""; + + // Don't include the dot, just the extension itself (unlike Python) + return std::string(basename.c_str() + index + 1, + basename.length() - index - 1); +} + +Size Path::getFileSize(const std::string &path) { + apr_finfo_t st; + apr_int32_t wanted = APR_FINFO_TYPE | APR_FINFO_SIZE; + apr_status_t res = getInfo(path.c_str(), wanted, st); + NTA_CHECK(res == APR_SUCCESS); + NTA_CHECK(st.filetype == APR_REG) + << "Can't get the size of a non-file object"; + + return (Size)st.size; +} + +std::string Path::normalize(const std::string &path) { + // Easiest way is: split, then remove "." and remove a/.. (but not ../.. !) + // This does not get a/b/../.. so if we remove a/.., go through the string + // again Also need to treat rootdir/.. specially Also normalize(foo/..) -> "." + // but normalize(foo/bar/..) -> "foo" + StringVec v = Path::split(path); + if (v.size() == 0) + return ""; + + StringVec outv; + bool doAgain = true; + while (doAgain) { + doAgain = false; + for (unsigned int i = 0; i < v.size(); i++) { + if (v[i] == "") + continue; // remove empty fields + if (v[i] == "." && v.size() > 1) + continue; // skip "." unless it is by itself + if (i == 0 && isRootdir(v[i]) && i + 1 < v.size() && v[i + 1] == "..") { + // /.. -> outv.push_back(v[i]); + i++; // skipped following ".." + doAgain = true; + continue; } - if (doAgain) - { - v = outv; - outv.clear(); + // remove "foo/.." + if (i + 1 < v.size() && v[i] != ".." && v[i + 1] == "..") { + // but as a special case, if the full path is "foo/.." return "." + if (v.size() == 2) + return "."; + i++; + doAgain = true; + continue; } + outv.push_back(v[i]); + } + if (doAgain) { + v = outv; + outv.clear(); } - return Path::join(outv.begin(), outv.end()); + } + return Path::join(outv.begin(), outv.end()); +} + +std::string Path::makeAbsolute(const std::string &path) { + if (Path::isAbsolute(path)) + return path; + std::string cwd = Directory::getCWD(); + // If its already absolute just return the original path + if (::strncmp(cwd.c_str(), path.c_str(), cwd.length()) == 0) + return path; + + // Get rid of trailing separators if any + if (path.find_last_of(Path::sep) == path.length() - 1) { + cwd = std::string(cwd.c_str(), cwd.length() - 1); } - - std::string Path::makeAbsolute(const std::string & path) - { - if (Path::isAbsolute(path)) - return path; - - std::string cwd = Directory::getCWD(); - // If its already absolute just return the original path - if (::strncmp(cwd.c_str(), path.c_str(), cwd.length()) == 0) - return path; - - // Get rid of trailing separators if any - if (path.find_last_of(Path::sep) == path.length() - 1) - { - cwd = std::string(cwd.c_str(), cwd.length()-1); - } - // join the cwd to the path and return it (handle duplicate separators) - std::string result = cwd; - if (path.find_first_of(Path::sep) == 0) - { - return cwd + path; - } - else - { - return cwd + Path::sep + path; - } - - return ""; + // join the cwd to the path and return it (handle duplicate separators) + std::string result = cwd; + if (path.find_first_of(Path::sep) == 0) { + return cwd + path; + } else { + return cwd + Path::sep + path; } - - #if defined(NTA_OS_WINDOWS) - std::string Path::unicodeToUtf8(const std::wstring& path) - { - // Assume the worst we can do is have 6 UTF-8 bytes per unicode - // character. - apr_size_t tmpNameSize = path.size() * 6 + 1; - - // Store buffer in a boost::scoped_array so it gets cleaned up for us. - boost::scoped_array tmpNameBuf; - char* tmpNameP = new char[tmpNameSize]; - tmpNameBuf.reset(tmpNameP); - - apr_size_t inWords = path.size()+1; - apr_size_t outChars = tmpNameSize; - - apr_status_t result = ::apr_conv_ucs2_to_utf8((apr_wchar_t *)path.c_str(), - &inWords, tmpNameP, &outChars); - if (result != 0 || inWords != 0) - { - std::stringstream ss; - ss << "Path::unicodeToUtf8() - error converting path to UTF-8:" - << std::endl << "error code: " << result; - NTA_THROW << ss.str(); - } - - return std::string(tmpNameP); + + return ""; +} + +#if defined(NTA_OS_WINDOWS) +std::string Path::unicodeToUtf8(const std::wstring &path) { + // Assume the worst we can do is have 6 UTF-8 bytes per unicode + // character. + apr_size_t tmpNameSize = path.size() * 6 + 1; + + // Store buffer in a boost::scoped_array so it gets cleaned up for us. + boost::scoped_array tmpNameBuf; + char *tmpNameP = new char[tmpNameSize]; + tmpNameBuf.reset(tmpNameP); + + apr_size_t inWords = path.size() + 1; + apr_size_t outChars = tmpNameSize; + + apr_status_t result = ::apr_conv_ucs2_to_utf8((apr_wchar_t *)path.c_str(), + &inWords, tmpNameP, &outChars); + if (result != 0 || inWords != 0) { + std::stringstream ss; + ss << "Path::unicodeToUtf8() - error converting path to UTF-8:" << std::endl + << "error code: " << result; + NTA_THROW << ss.str(); } - std::wstring Path::utf8ToUnicode(const std::string& path) - { - // Assume the number of unicode characters is <= number of UTF-8 bytes. - apr_size_t tmpNameSize = path.size() + 1; - - // Store buffer in a boost::scoped_array so it gets cleaned up for us. - boost::scoped_array tmpNameBuf; - wchar_t* tmpNameP = new wchar_t[tmpNameSize]; - tmpNameBuf.reset(tmpNameP); - - apr_size_t inBytes = path.size()+1; - apr_size_t outWords = tmpNameSize; - - apr_status_t result = ::apr_conv_utf8_to_ucs2(path.c_str(), - &inBytes, (apr_wchar_t*)tmpNameP, &outWords); - if (result != 0 || inBytes != 0) - { - char errBuffer[1024]; - std::stringstream ss; - ss << "Path::utf8ToUnicode() - error converting path to Unicode" - << std::endl - << ::apr_strerror(result, errBuffer, 1024); - - NTA_THROW << ss.str(); - } - - return std::wstring(tmpNameP); + return std::string(tmpNameP); +} + +std::wstring Path::utf8ToUnicode(const std::string &path) { + // Assume the number of unicode characters is <= number of UTF-8 bytes. + apr_size_t tmpNameSize = path.size() + 1; + + // Store buffer in a boost::scoped_array so it gets cleaned up for us. + boost::scoped_array tmpNameBuf; + wchar_t *tmpNameP = new wchar_t[tmpNameSize]; + tmpNameBuf.reset(tmpNameP); + + apr_size_t inBytes = path.size() + 1; + apr_size_t outWords = tmpNameSize; + + apr_status_t result = ::apr_conv_utf8_to_ucs2( + path.c_str(), &inBytes, (apr_wchar_t *)tmpNameP, &outWords); + if (result != 0 || inBytes != 0) { + char errBuffer[1024]; + std::stringstream ss; + ss << "Path::utf8ToUnicode() - error converting path to Unicode" + << std::endl + << ::apr_strerror(result, errBuffer, 1024); + + NTA_THROW << ss.str(); } - #endif - - - Path::StringVec Path::split(const std::string & path) - { - /** - * Don't use boost::tokenizer because we need to handle the prefix specially. - * Handling the prefix is messy on windows, but this is the only place we have - * to do it - */ - StringVec parts; - std::string::size_type curpos = 0; - if (path.size() == 0) - return parts; + + return std::wstring(tmpNameP); +} +#endif + +Path::StringVec Path::split(const std::string &path) { + /** + * Don't use boost::tokenizer because we need to handle the prefix specially. + * Handling the prefix is messy on windows, but this is the only place we have + * to do it + */ + StringVec parts; + std::string::size_type curpos = 0; + if (path.size() == 0) + return parts; #if defined(NTA_OS_WINDOWS) - // prefix may be 1) "\", 2) "\\", 3) "[a-z]:", 4) "[a-z]:\" - if (path.size() == 1) - { - // captures both "\" and "a" - parts.push_back(path); - return parts; + // prefix may be 1) "\", 2) "\\", 3) "[a-z]:", 4) "[a-z]:\" + if (path.size() == 1) { + // captures both "\" and "a" + parts.push_back(path); + return parts; + } + if (path[0] == '\\') { + if (path[1] == '\\') { + // case 2 + parts.push_back("\\\\"); + curpos = 2; + } else { + // case 1 + parts.push_back("\\"); + curpos = 1; } - if (path[0] == '\\') - { - if (path[1] == '\\') - { - // case 2 - parts.push_back("\\\\"); + } else { + if (path[1] == ':') { + if (path.size() > 2 && path[2] == '\\') { + // case 4 + parts.push_back(path.substr(0, 3)); + curpos = 3; + } else { + parts.push_back(path.substr(0, 2)); curpos = 2; - } - else - { - // case 1 - parts.push_back("\\"); - curpos = 1; - } - } - else - { - if (path[1] == ':') - { - if (path.size() > 2 && path[2] == '\\') - { - // case 4 - parts.push_back(path.substr(0, 3)); - curpos = 3; - } - else - { - parts.push_back(path.substr(0, 2)); - curpos = 2; - } } } + } #else - // only possible prefix is "/" - if (path[0] == '/') - { - parts.push_back("/"); - curpos++; - } + // only possible prefix is "/" + if (path[0] == '/') { + parts.push_back("/"); + curpos++; + } #endif - // simple tokenization based on separator. Note that "foo//bar" -> "foo", "", "bar" - std::string::size_type newpos; - while (curpos < path.size() && curpos != std::string::npos) - { - // Be able to split on either separator including mixed separators on Windows - #if defined(NTA_OS_WINDOWS) - std::string::size_type p1 = path.find("\\", curpos); - std::string::size_type p2 = path.find("/", curpos); - newpos = p1 < p2 ? p1 : p2; - #else - newpos = path.find(Path::sep, curpos); - #endif - - if (newpos == std::string::npos) - { - parts.push_back(path.substr(curpos)); - curpos = newpos; - } - else - { - // note: if we have a "//" then newpos == curpos and this string is empty - if (newpos != curpos) - { - parts.push_back(path.substr(curpos, newpos - curpos)); - } - curpos = newpos + 1; + // simple tokenization based on separator. Note that "foo//bar" -> "foo", "", + // "bar" + std::string::size_type newpos; + while (curpos < path.size() && curpos != std::string::npos) { + // Be able to split on either separator including mixed separators on + // Windows +#if defined(NTA_OS_WINDOWS) + std::string::size_type p1 = path.find("\\", curpos); + std::string::size_type p2 = path.find("/", curpos); + newpos = p1 < p2 ? p1 : p2; +#else + newpos = path.find(Path::sep, curpos); +#endif + + if (newpos == std::string::npos) { + parts.push_back(path.substr(curpos)); + curpos = newpos; + } else { + // note: if we have a "//" then newpos == curpos and this string is empty + if (newpos != curpos) { + parts.push_back(path.substr(curpos, newpos - curpos)); } + curpos = newpos + 1; } - - return parts; - } - bool Path::isPrefix(const std::string & s) - { + return parts; +} + +bool Path::isPrefix(const std::string &s) { #if defined(NTA_OS_WINDOWS) - size_t len = s.length(); - if (len < 2) - return false; - if (len == 2) - return ::isalpha(s[0]) && s[1] == ':'; - else if (len == 3) - { - bool localPrefix = ::isalpha(s[0]) && s[1] == ':' && s[2] == '\\'; - bool uncPrefix = s[0] == '\\' && s[1] == '\\' && ::isalpha(s[2]); - - return localPrefix || uncPrefix; - } - else // len > 3 - return s[0] == '\\' && s[1] == '\\' && ::isalpha(s[2]); + size_t len = s.length(); + if (len < 2) + return false; + if (len == 2) + return ::isalpha(s[0]) && s[1] == ':'; + else if (len == 3) { + bool localPrefix = ::isalpha(s[0]) && s[1] == ':' && s[2] == '\\'; + bool uncPrefix = s[0] == '\\' && s[1] == '\\' && ::isalpha(s[2]); + + return localPrefix || uncPrefix; + } else // len > 3 + return s[0] == '\\' && s[1] == '\\' && ::isalpha(s[2]); #else return s == "/"; #endif } - bool Path::isRootdir(const std::string& s) - { - // redundant test on unix, but second test covers windows - return isPrefix(s); - } +bool Path::isRootdir(const std::string &s) { + // redundant test on unix, but second test covers windows + return isPrefix(s); +} - std::string Path::join(StringVec::const_iterator begin, StringVec::const_iterator end) - { - if (begin == end) - return ""; - - if (begin + 1 == end) - return std::string(*begin); - - std::string path(*begin); - #if defined(NTA_OS_WINDOWS) - if (path[path.length()-1] != Path::sep[0]) - path += Path::sep; - #else - // Treat first element specially (on Unix) - // it may be a prefix, which is not followed by "/" - if (!Path::isPrefix(*begin)) - path += Path::sep; - #endif - begin++; +std::string Path::join(StringVec::const_iterator begin, + StringVec::const_iterator end) { + if (begin == end) + return ""; - while (begin != end) - { - path += *begin; - begin++; - if (begin != end) - { - path += Path::sep; - } + if (begin + 1 == end) + return std::string(*begin); + + std::string path(*begin); +#if defined(NTA_OS_WINDOWS) + if (path[path.length() - 1] != Path::sep[0]) + path += Path::sep; +#else + // Treat first element specially (on Unix) + // it may be a prefix, which is not followed by "/" + if (!Path::isPrefix(*begin)) + path += Path::sep; +#endif + begin++; + + while (begin != end) { + path += *begin; + begin++; + if (begin != end) { + path += Path::sep; } - - return path; } + return path; +} - void Path::copy(const std::string & source, const std::string & destination) - { - NTA_CHECK(!source.empty()) - << "Can't copy from an empty source"; +void Path::copy(const std::string &source, const std::string &destination) { + NTA_CHECK(!source.empty()) << "Can't copy from an empty source"; - NTA_CHECK(!destination.empty()) - << "Can't copy to an empty destination"; + NTA_CHECK(!destination.empty()) << "Can't copy to an empty destination"; - NTA_CHECK(source != destination) + NTA_CHECK(source != destination) << "Source and destination must be different"; - - if (isDirectory(source)) - { - Directory::copyTree(source, destination); - return; - } - - // The target is always a filename. The input destination - // Can be either a directory or a filename. If the destination - // doesn't exist it is treated as a filename. - std::string target(destination); - if (Path::exists(destination) && isDirectory(destination)) - target = Path::normalize(Path::join(destination, Path::getBasename(source))); - - bool success = true; - #if defined(NTA_OS_WINDOWS) - - // Must remove read-only or hidden files before copy - // because they cannot be overwritten. For simplicity - // I just always remove if it exists. - if (Path::exists(target)) - Path::remove(target); - - // This will quietly overwrite the destination file if it exists - std::wstring wsource(utf8ToUnicode(source)); - std::wstring wtarget(utf8ToUnicode(target)); - BOOL res = ::CopyFileW(wsource.c_str(), - wtarget.c_str(), - FALSE); - - success = res != FALSE; - #else - - try - { - OFStream out(target.c_str()); - out.exceptions(std::ofstream::failbit | std::ofstream::badbit); - UInt64 size = Path::getFileSize(source); - if(size) { - IFStream in(source.c_str()); - if(out.fail()) { - std::cout << OS::getErrorMessage() << std::endl; - } - in.exceptions(std::ifstream::failbit | std::ifstream::badbit); - out << in.rdbuf(); + + if (isDirectory(source)) { + Directory::copyTree(source, destination); + return; + } + + // The target is always a filename. The input destination + // Can be either a directory or a filename. If the destination + // doesn't exist it is treated as a filename. + std::string target(destination); + if (Path::exists(destination) && isDirectory(destination)) + target = + Path::normalize(Path::join(destination, Path::getBasename(source))); + + bool success = true; +#if defined(NTA_OS_WINDOWS) + + // Must remove read-only or hidden files before copy + // because they cannot be overwritten. For simplicity + // I just always remove if it exists. + if (Path::exists(target)) + Path::remove(target); + + // This will quietly overwrite the destination file if it exists + std::wstring wsource(utf8ToUnicode(source)); + std::wstring wtarget(utf8ToUnicode(target)); + BOOL res = ::CopyFileW(wsource.c_str(), wtarget.c_str(), FALSE); + + success = res != FALSE; +#else + + try { + OFStream out(target.c_str()); + out.exceptions(std::ofstream::failbit | std::ofstream::badbit); + UInt64 size = Path::getFileSize(source); + if (size) { + IFStream in(source.c_str()); + if (out.fail()) { + std::cout << OS::getErrorMessage() << std::endl; } + in.exceptions(std::ifstream::failbit | std::ifstream::badbit); + out << in.rdbuf(); } - catch(std::exception &e) { - std::cerr << "Path::copy('" << source << "', '" << target << "'): " - << e.what() << std::endl; - } - catch (...) - { - success = false; - } - #endif - if (!success) - NTA_THROW << "Path::copy() - failed copying file " - << source << " to " << destination << " os error: " - << OS::getErrorMessage(); + } catch (std::exception &e) { + std::cerr << "Path::copy('" << source << "', '" << target + << "'): " << e.what() << std::endl; + } catch (...) { + success = false; } +#endif + if (!success) + NTA_THROW << "Path::copy() - failed copying file " << source << " to " + << destination << " os error: " << OS::getErrorMessage(); +} - void Path::setPermissions(const std::string &path, - bool userRead, bool userWrite, - bool groupRead, bool groupWrite, - bool otherRead, bool otherWrite - ) - { - - if(Path::isDirectory(path)) { - Directory::Iterator iter(path); - Directory::Entry e; - while(iter.next(e)) { - std::string sub = Path::join(path, e.path); - setPermissions(sub, - userRead, userWrite, - groupRead, groupWrite, - otherRead, otherWrite); - } +void Path::setPermissions(const std::string &path, bool userRead, + bool userWrite, bool groupRead, bool groupWrite, + bool otherRead, bool otherWrite) { + + if (Path::isDirectory(path)) { + Directory::Iterator iter(path); + Directory::Entry e; + while (iter.next(e)) { + std::string sub = Path::join(path, e.path); + setPermissions(sub, userRead, userWrite, groupRead, groupWrite, otherRead, + otherWrite); } + } #if defined(NTA_OS_WINDOWS) - int countFailure = 0; - std::wstring wpath(utf8ToUnicode(path)); - DWORD attr = GetFileAttributesW(wpath.c_str()); - if(attr != INVALID_FILE_ATTRIBUTES) { - if(userWrite) attr &= ~FILE_ATTRIBUTE_READONLY; - BOOL res = SetFileAttributesW(wpath.c_str(), attr); - if(!res) { - NTA_WARN << "Path::setPermissions: Failed to set attributes for " << path; - ++countFailure; - } - } - else { - NTA_WARN << "Path::setPermissions: Failed to get attributes for " << path; + int countFailure = 0; + std::wstring wpath(utf8ToUnicode(path)); + DWORD attr = GetFileAttributesW(wpath.c_str()); + if (attr != INVALID_FILE_ATTRIBUTES) { + if (userWrite) + attr &= ~FILE_ATTRIBUTE_READONLY; + BOOL res = SetFileAttributesW(wpath.c_str(), attr); + if (!res) { + NTA_WARN << "Path::setPermissions: Failed to set attributes for " << path; ++countFailure; } + } else { + NTA_WARN << "Path::setPermissions: Failed to get attributes for " << path; + ++countFailure; + } + + if (countFailure > 0) { + NTA_THROW << "Path::setPermissions failed for " << path; + } - if(countFailure > 0) { - NTA_THROW << "Path::setPermissions failed for " << path; - } - #else - mode_t mode = 0; - if (userRead) mode |= S_IRUSR; - if (userWrite) mode |= S_IRUSR; - if (groupRead) mode |= S_IRGRP; - if (groupWrite) mode |= S_IWGRP; - if (otherRead) mode |= S_IROTH; - if (otherWrite) mode |= S_IWOTH; - chmod(path.c_str(), mode); + mode_t mode = 0; + if (userRead) + mode |= S_IRUSR; + if (userWrite) + mode |= S_IRUSR; + if (groupRead) + mode |= S_IRGRP; + if (groupWrite) + mode |= S_IWGRP; + if (otherRead) + mode |= S_IROTH; + if (otherWrite) + mode |= S_IWOTH; + chmod(path.c_str(), mode); #endif +} + +void Path::remove(const std::string &path) { + NTA_CHECK(!path.empty()) << "Can't remove an empty path"; + + // Just return if it doesn't exist already + if (!Path::exists(path)) + return; + + if (isDirectory(path)) { + Directory::removeTree(path); + return; } - - void Path::remove(const std::string & path) - { - NTA_CHECK(!path.empty()) - << "Can't remove an empty path"; - - // Just return if it doesn't exist already - if (!Path::exists(path)) - return; - - if (isDirectory(path)) - { - Directory::removeTree(path); - return; - } - - #if defined(NTA_OS_WINDOWS) - std::wstring wpath(utf8ToUnicode(path)); - BOOL res = ::DeleteFileW(wpath.c_str()); - if (res == FALSE) - NTA_THROW << "Path::remove() -- unable to delete '" << path - << "' error message: " << OS::getErrorMessage(); - #else - int res = ::remove(path.c_str()); - if (res != 0) - NTA_THROW << "Path::remove() -- unable to delete '" << path - << "' error message: " << OS::getErrorMessage(); - #endif - } - - void Path::rename(const std::string & oldPath, const std::string & newPath) - { - NTA_CHECK(!oldPath.empty() && !newPath.empty()) + +#if defined(NTA_OS_WINDOWS) + std::wstring wpath(utf8ToUnicode(path)); + BOOL res = ::DeleteFileW(wpath.c_str()); + if (res == FALSE) + NTA_THROW << "Path::remove() -- unable to delete '" << path + << "' error message: " << OS::getErrorMessage(); +#else + int res = ::remove(path.c_str()); + if (res != 0) + NTA_THROW << "Path::remove() -- unable to delete '" << path + << "' error message: " << OS::getErrorMessage(); +#endif +} + +void Path::rename(const std::string &oldPath, const std::string &newPath) { + NTA_CHECK(!oldPath.empty() && !newPath.empty()) << "Can't rename to/from empty path"; - #if defined(NTA_OS_WINDOWS) - std::wstring wOldPath(utf8ToUnicode(oldPath)); - std::wstring wNewPath(utf8ToUnicode(newPath)); - BOOL res = ::MoveFileW(wOldPath.c_str(), wNewPath.c_str()); - if (res == FALSE) - NTA_THROW << "Path::rename() -- unable to rename '" - << oldPath << "' to '" << newPath - << "' error message: " << OS::getErrorMessage(); - #else - int res = ::rename(oldPath.c_str(), newPath.c_str()); - if (res == -1) - NTA_THROW << "Path::rename() -- unable to rename '" - << oldPath << "' to '" << newPath - << "' error message: " << OS::getErrorMessage(); - #endif - } - - Path::operator const char *() const - { - return path_.c_str(); - } - - Path & Path::operator+=(const Path & path) - { - Path::StringVec sv; - sv.push_back(std::string(path_)); - sv.push_back(std::string(path.path_)); - path_ = Path::join(sv.begin(), sv.end()); - return *this; - } +#if defined(NTA_OS_WINDOWS) + std::wstring wOldPath(utf8ToUnicode(oldPath)); + std::wstring wNewPath(utf8ToUnicode(newPath)); + BOOL res = ::MoveFileW(wOldPath.c_str(), wNewPath.c_str()); + if (res == FALSE) + NTA_THROW << "Path::rename() -- unable to rename '" << oldPath << "' to '" + << newPath << "' error message: " << OS::getErrorMessage(); +#else + int res = ::rename(oldPath.c_str(), newPath.c_str()); + if (res == -1) + NTA_THROW << "Path::rename() -- unable to rename '" << oldPath << "' to '" + << newPath << "' error message: " << OS::getErrorMessage(); +#endif +} - bool Path::operator==(const Path & other) - { - return Path::normalize(path_) == Path::normalize(other.path_); - } +Path::operator const char *() const { return path_.c_str(); } - Path Path::getParent() const - { - return Path::getParent(path_); - } - - Path Path::getBasename() const - { - return Path::getBasename(path_); - } - - Path Path::getExtension() const - { - return Path::getExtension(path_); - } - - Size Path::getFileSize() const - { - return Path::getFileSize(path_); - } - - Path & Path::normalize() - { - path_ = Path::normalize(path_); - return *this; - } - - Path & Path::makeAbsolute() - { - if (!isAbsolute()) - path_ = Path::makeAbsolute(path_); - return *this; - } - - Path::StringVec Path::split() const - { - return Path::split(path_); - } - - void Path::remove() const - { - Path::remove(path_); - } - - void Path::rename(const std::string & newPath) - { - Path::rename(path_, newPath); - path_ = newPath; - } +Path &Path::operator+=(const Path &path) { + Path::StringVec sv; + sv.push_back(std::string(path_)); + sv.push_back(std::string(path.path_)); + path_ = Path::join(sv.begin(), sv.end()); + return *this; +} - bool Path::isDirectory() const - { - return Path::isDirectory(path_); - } - - bool Path::isFile() const - { - return Path::isFile(path_); - } +bool Path::operator==(const Path &other) { + return Path::normalize(path_) == Path::normalize(other.path_); +} - bool Path::isSymbolicLink() const - { - return Path::isSymbolicLink(path_); - } - bool Path::isAbsolute() const - { - return Path::isAbsolute(path_); - } +Path Path::getParent() const { return Path::getParent(path_); } - bool Path::isRootdir() const - { - return Path::isRootdir(path_); - } +Path Path::getBasename() const { return Path::getBasename(path_); } - bool Path::exists() const - { - return Path::exists(path_); - } +Path Path::getExtension() const { return Path::getExtension(path_); } - bool Path::isEmpty() const - { - return path_.empty(); - } - - Path operator+(const Path & p1, const Path & p2) - { - Path::StringVec sv; - sv.push_back(std::string(p1)); - sv.push_back(std::string(p2)); - return Path::join(sv.begin(), sv.end()); - } +Size Path::getFileSize() const { return Path::getFileSize(path_); } - std::string Path::join(const std::string & path1, const std::string & path2) - { - return path1 + Path::sep + path2; - } +Path &Path::normalize() { + path_ = Path::normalize(path_); + return *this; +} - std::string Path::join(const std::string & path1, const std::string & path2, - const std::string & path3) - { - return path1 + Path::sep + path2 + Path::sep + path3; - } +Path &Path::makeAbsolute() { + if (!isAbsolute()) + path_ = Path::makeAbsolute(path_); + return *this; +} - std::string Path::join(const std::string & path1, const std::string & path2, - const std::string & path3, const std::string & path4) - { - return path1 + Path::sep + path2 + Path::sep + path3 + Path::sep + path4; - } - - std::string Path::getExecutablePath() - { +Path::StringVec Path::split() const { return Path::split(path_); } + +void Path::remove() const { Path::remove(path_); } + +void Path::rename(const std::string &newPath) { + Path::rename(path_, newPath); + path_ = newPath; +} - std::string epath = "UnknownExecutablePath"; +bool Path::isDirectory() const { return Path::isDirectory(path_); } + +bool Path::isFile() const { return Path::isFile(path_); } + +bool Path::isSymbolicLink() const { return Path::isSymbolicLink(path_); } +bool Path::isAbsolute() const { return Path::isAbsolute(path_); } + +bool Path::isRootdir() const { return Path::isRootdir(path_); } + +bool Path::exists() const { return Path::exists(path_); } + +bool Path::isEmpty() const { return path_.empty(); } + +Path operator+(const Path &p1, const Path &p2) { + Path::StringVec sv; + sv.push_back(std::string(p1)); + sv.push_back(std::string(p2)); + return Path::join(sv.begin(), sv.end()); +} + +std::string Path::join(const std::string &path1, const std::string &path2) { + return path1 + Path::sep + path2; +} + +std::string Path::join(const std::string &path1, const std::string &path2, + const std::string &path3) { + return path1 + Path::sep + path2 + Path::sep + path3; +} + +std::string Path::join(const std::string &path1, const std::string &path2, + const std::string &path3, const std::string &path4) { + return path1 + Path::sep + path2 + Path::sep + path3 + Path::sep + path4; +} + +std::string Path::getExecutablePath() { + + std::string epath = "UnknownExecutablePath"; #if !defined(NTA_OS_WINDOWS) - auto buf = new char[1000]; - UInt32 bufsize = 1000; - // sets bufsize to actual length. - #if defined(NTA_OS_DARWIN) - _NSGetExecutablePath(buf, &bufsize); - if (bufsize < 1000) - buf[bufsize] = '\0'; - #elif defined(NTA_OS_LINUX) - int count = readlink("/proc/self/exe", buf, bufsize); - if (count < 0) - NTA_THROW << "Unable to read /proc/self/exe to get executable name"; - if (count < 1000) - buf[count] = '\0'; - #elif defined(NTA_ARCH_64) && defined(NTA_OS_SPARC) - const char *tmp = getexecname(); - if (!tmp) - NTA_THROW << "Unable to determine executable name"; - strncpy(buf, tmp, bufsize); - #endif - - // make sure it's null-terminated - buf[999] = '\0'; - epath = buf; - delete[] buf; + auto buf = new char[1000]; + UInt32 bufsize = 1000; + // sets bufsize to actual length. +#if defined(NTA_OS_DARWIN) + _NSGetExecutablePath(buf, &bufsize); + if (bufsize < 1000) + buf[bufsize] = '\0'; +#elif defined(NTA_OS_LINUX) + int count = readlink("/proc/self/exe", buf, bufsize); + if (count < 0) + NTA_THROW << "Unable to read /proc/self/exe to get executable name"; + if (count < 1000) + buf[count] = '\0'; +#elif defined(NTA_ARCH_64) && defined(NTA_OS_SPARC) + const char *tmp = getexecname(); + if (!tmp) + NTA_THROW << "Unable to determine executable name"; + strncpy(buf, tmp, bufsize); +#endif + + // make sure it's null-terminated + buf[999] = '\0'; + epath = buf; + delete[] buf; #else - // windows - auto buf = new wchar_t[1000]; - GetModuleFileNameW(NULL, buf, 1000); - // null-terminated string guaranteed unless length > 999 - buf[999] = '\0'; - std::wstring wpath(buf); - delete[] buf; - epath = unicodeToUtf8(wpath); + // windows + auto buf = new wchar_t[1000]; + GetModuleFileNameW(NULL, buf, 1000); + // null-terminated string guaranteed unless length > 999 + buf[999] = '\0'; + std::wstring wpath(buf); + delete[] buf; + epath = unicodeToUtf8(wpath); #endif - return epath; - } + return epath; +} } // namespace nupic diff --git a/src/nupic/os/Path.hpp b/src/nupic/os/Path.hpp old mode 100755 new mode 100644 index e37ed4331e..87411789b7 --- a/src/nupic/os/Path.hpp +++ b/src/nupic/os/Path.hpp @@ -33,277 +33,272 @@ //---------------------------------------------------------------------- -namespace nupic -{ - /** - * @b Responsibility: - * 1. Represent a cross-platform path to a filesystem object - * (file, directory, symlink) - * - * 2. Provide a slew of of path manipulation operations - * - * @b Rationale: - * File system paths are used a lot. It makes sense to have - * a cross-platform class with a nice interface tailored to our needs. - * In particular operations throw NTA::Exception on failure and - * don't return error codes, which is alligned nicely with the - * way we handle errors. - * - * Operations are both static and instance methods (use single implementation). - * - * @b Resource/Ownerships: - * 1. A path string for the instance. - * - * @b Notes: - * The Path() constructors don't try to validate the path string - * for efficiency reasons (it's complicated too). If you pass - * an invalid path string it will fail when you actually try to use - * the resulting path. - * - * The error handling strategy is to return error NULLs and not to throw exceptions. - * The reason is that it is a very generic low-level class that should not be aware - * and depend on the runtime's error handling policy. It may be used in different - * contexts like tools and utilities that may utilize a different error handling - * strategy. It is also a common idiom to return NULL from a failed factory method. +namespace nupic { +/** + * @b Responsibility: + * 1. Represent a cross-platform path to a filesystem object + * (file, directory, symlink) + * + * 2. Provide a slew of of path manipulation operations + * + * @b Rationale: + * File system paths are used a lot. It makes sense to have + * a cross-platform class with a nice interface tailored to our needs. + * In particular operations throw NTA::Exception on failure and + * don't return error codes, which is alligned nicely with the + * way we handle errors. + * + * Operations are both static and instance methods (use single implementation). + * + * @b Resource/Ownerships: + * 1. A path string for the instance. + * + * @b Notes: + * The Path() constructors don't try to validate the path string + * for efficiency reasons (it's complicated too). If you pass + * an invalid path string it will fail when you actually try to use + * the resulting path. + * + * The error handling strategy is to return error NULLs and not to throw + * exceptions. The reason is that it is a very generic low-level class that + * should not be aware and depend on the runtime's error handling policy. It may + * be used in different contexts like tools and utilities that may utilize a + * different error handling strategy. It is also a common idiom to return NULL + * from a failed factory method. + * + * @b Performance: + * The emphasis is on code readability and ease of use. Performance takes + * second place, because the critical path of our codebase doesn't involve a lot + * of path manipulation. In particular, simple ANSI C or POSIX cross-platform + * implementation is often preffered to calling specific platform APIs. Whenever + * possible APR is used under the covers. + * + * Note, that constructing a Path object (or calling the Path instance methods) + * involve an extra copy of the path string into the new Path instance. Again, + * this is not prohibitive in most cases. If you are concerned use plain strings + * and the static methods. + * + * @b Details, details + * Portable filesystem interfaces are tricky to get right. We are targeting a + * simple and intuitive interface like Python rather than the + * difficult-to-understand boost interface. The current implementation does not + * cover every corner case, but it gets many of them. For more insight into the + * details, see the python os.path documentation, java.io.file documentation and + * the Wikipedia entry on Path_(computing) + * + * @todo We do not support unicode filenames (yet) + */ +class Path { +public: + typedef std::vector StringVec; + + static const char *sep; + static const char *pathSep; + static const char *parDir; + + /** + * This first set of methods act symbolically on strings + * and don't look at an actual filesystem. + */ + + /** + * getParent(path) -> normalize(path/..) + * Examples: + * getParent("/foo/bar") -> "/foo" + * getParent("foo") -> "." + * getParent(foo/bar.txt) -> "foo" + * getParent(rootdir) -> rootdir + * getParent("../../a") -> "../.." + + * @discussion + * Can't we do better? + * What if: getParent(path) -> normalize(path) - lastElement + * The problems with this are: + * - getParent("../..") can't be done + * - Also we have to normalize first, because we don't want + * getParent("foo/bar/..") -> foo/bar * - * @b Performance: - * The emphasis is on code readability and ease of use. Performance takes second - * place, because the critical path of our codebase doesn't involve a lot of - * path manipulation. In particular, simple ANSI C or POSIX cross-platform implementation - * is often preffered to calling specific platform APIs. Whenever possible APR is - * used under the covers. + * The main issue with adding ".." are + * - ".." doesn't exist if you're getting the parent of a file + * This is ok because we normalize, which is a symbolic manipulation * - * Note, that constructing a Path object (or calling the Path instance methods) - * involve an extra copy of the path string into the new Path instance. Again, this - * is not prohibitive in most cases. If you are concerned use plain strings and - * the static methods. + * For both solutions, we have to know when we reach the "top" + * if we want to iterate "up" the stack of directories. + * With an absolute path, we can check using isRootdir(), but + * with a relative path, we keep adding ".." forever. + * The application has to be aware of this and do the right thing. + */ + static std::string getParent(const std::string &path); + + /** + * getBasename(foo/bar.baz) -> bar.baz + */ + static std::string getBasename(const std::string &path); + + /** + * getExtension(foo/bar.baz) -> .baz + */ + static std::string getExtension(const std::string &path); + + /** + * Normalize: + * - remove "../" and "./" (unless leading) c + * - convert "//" to "/" + * - remove trailing "/" + * - normalize(rootdir/..) -> rootdir + * - normalize(foo/..) -> "." * - * @b Details, details - * Portable filesystem interfaces are tricky to get right. We are targeting a simple - * and intuitive interface like Python rather than the difficult-to-understand boost interface. - * The current implementation does not cover every corner case, but it gets many of them. - * For more insight into the details, see the python os.path documentation, java.io.file - * documentation and the Wikipedia entry on Path_(computing) - * - * @todo We do not support unicode filenames (yet) + * Note that because we are operating symbolically, the results might + * be unexpected if there are symbolic links in the path. + * For example if /foo/bar is a link to /quux/glorp then + * normalize("/foo/bar/..")-> "/foo", not "/quux" + * Also, "path/file/.." is converted to "path" even if "path/file" is a + * regular file (which doesn't have a ".." entry). On windows, a path starting + * with "\\" is a UNC path and the prefix is not converted. */ - class Path - { - public: - typedef std::vector StringVec; - - static const char * sep; - static const char * pathSep; - static const char * parDir; - - - /** - * This first set of methods act symbolically on strings - * and don't look at an actual filesystem. - */ - - /** - * getParent(path) -> normalize(path/..) - * Examples: - * getParent("/foo/bar") -> "/foo" - * getParent("foo") -> "." - * getParent(foo/bar.txt) -> "foo" - * getParent(rootdir) -> rootdir - * getParent("../../a") -> "../.." - - * @discussion - * Can't we do better? - * What if: getParent(path) -> normalize(path) - lastElement - * The problems with this are: - * - getParent("../..") can't be done - * - Also we have to normalize first, because we don't want - * getParent("foo/bar/..") -> foo/bar - * - * The main issue with adding ".." are - * - ".." doesn't exist if you're getting the parent of a file - * This is ok because we normalize, which is a symbolic manipulation - * - * For both solutions, we have to know when we reach the "top" - * if we want to iterate "up" the stack of directories. - * With an absolute path, we can check using isRootdir(), but - * with a relative path, we keep adding ".." forever. - * The application has to be aware of this and do the right thing. - */ - static std::string getParent(const std::string & path); - - /** - * getBasename(foo/bar.baz) -> bar.baz - */ - static std::string getBasename(const std::string & path); - - /** - * getExtension(foo/bar.baz) -> .baz - */ - static std::string getExtension(const std::string & path); - - /** - * Normalize: - * - remove "../" and "./" (unless leading) c - * - convert "//" to "/" - * - remove trailing "/" - * - normalize(rootdir/..) -> rootdir - * - normalize(foo/..) -> "." - * - * Note that because we are operating symbolically, the results might - * be unexpected if there are symbolic links in the path. - * For example if /foo/bar is a link to /quux/glorp then - * normalize("/foo/bar/..")-> "/foo", not "/quux" - * Also, "path/file/.." is converted to "path" even if "path/file" is a regular - * file (which doesn't have a ".." entry). - * On windows, a path starting with "\\" is a UNC path and the prefix is not converted. - */ - static std::string normalize(const std::string & path); - - /** - * makeAbsolute(path) -> - * if isAbsolute(path) -> path - * unix: -> join(cwd, path) - * windows: makeAbsolute("c:foo") -> join("c:", cwd, "foo") - * windows: makeAbsolute("/foo") -> join(cwd.split()[0], "foo") - */ - static std::string makeAbsolute(const std::string & path); - - /** - * Convert a unicode string to UTF-8 - */ - static std::string unicodeToUtf8(const std::wstring& path); - - /** - * Convert a UTF-8 path to a unicode string - */ - static std::wstring utf8ToUnicode(const std::string& path); - - /** - * When splitting a path into components, the "prefix" has to be - * treated specially. We do not store it in a separate data - * structure -- the prefix is just the first element of the split. - * No normalization is performed. We always have path == join(split(path)) - * except when there are empty components, e.g. foo//bar. Empty components - * are omitted. - * See the java.io.file module documentation for some good background - * split("foo/bar/../quux") -> ("foo", "bar", "..", "quux") - * split("/foo/bar/quux") -> ("/", "foo", "bar", "quux") - * split("a:\foo\bar") -> ("a:\", "foo", "bar") - * split("\\host\drive\file") -> ("\\", "host", "drive", "file") - * split("/foo//bar/") -> ("/", "foo", "bar") - * Note: this behavior is different from the original behavior. - */ - static StringVec split(const std::string & path); - - /** - * Construct a path from components. path == join(split(path)) - */ - static std::string join(StringVec::const_iterator begin, - StringVec::const_iterator end); - - /** - * path == "/" on unix - * path == "/" or path == "a:/" on windows - */ - static bool isRootdir(const std::string & path); - - /** - * isAbsolute("/foo/bar") -> true isAbsolute("foo")->false on Unix - * is Absolute("a:\foo\bar") -> true isAbsolute("\foo\bar") -> false on windows - */ - static bool isAbsolute(const std::string & path); - - /** - * varargs through overloading - */ - static std::string join(const std::string & path1, const std::string & path2); - static std::string join(const std::string & path1, const std::string & path2, - const std::string & path3); - static std::string join(const std::string & path1, const std::string & path2, - const std::string & path3, const std::string & path4); - - - /** - * This second set of methods must interact with the filesystem - * to do their work. - */ - - /** - * true if path exists. false is for broken links - * @todo lexists() - */ - static bool exists(const std::string & path); - - /** - * getFileSize throws an exception if does not exist or is a directory - */ - static Size getFileSize(const std::string & path); - - /** - * @todo What if source is directory? What id source is file and dest is directory? - * @todo What if one of source or dest is a symbolic link? - */ - static void copy(const std::string & source, const std::string & destination); - static void remove(const std::string & path); - static void rename(const std::string & oldPath, const std::string & newPath); - static bool isDirectory(const std::string & path); - static bool isFile(const std::string & path); - static bool isSymbolicLink(const std::string & path); - static bool areEquivalent(const std::string & path1, const std::string & path2); - // Get a path to the currently running executable - static std::string getExecutablePath(); - - static void setPermissions(const std::string &path, - bool userRead, bool userWrite, - bool groupRead, bool groupWrite, - bool otherRead, bool otherWrite - ); - - - Path(std::string path); - operator const char*() const; - - /** - * Test for symbolic equivalence, i.e. normalize(a) == normalize(b) - * To test if they refer to the same file/directory, use areEquivalent - */ - bool operator==(const Path & other); - - Path & operator +=(const Path & path); - bool exists() const; - Path getParent() const; - Path getBasename() const; - Path getExtension() const; - Size getFileSize() const; - - Path & normalize(); - Path & makeAbsolute(); - StringVec split() const; - - void remove() const; - void copy(const std::string & destination) const; - void rename(const std::string & newPath); - - bool isDirectory() const; - bool isFile() const; - bool isRootdir() const; - bool isAbsolute() const; - bool isSymbolicLink() const; - bool isEmpty() const; - - private: - // on unix: == "/"; on windows: == "/" || == "C:" || == "C:/" - static bool isPrefix(const std::string&); - Path(); - - private: - std::string path_; - }; - - // Global operator - Path operator+(const Path & p1, const Path & p2); -} + static std::string normalize(const std::string &path); -#endif // NTA_PATH_HPP + /** + * makeAbsolute(path) -> + * if isAbsolute(path) -> path + * unix: -> join(cwd, path) + * windows: makeAbsolute("c:foo") -> join("c:", cwd, "foo") + * windows: makeAbsolute("/foo") -> join(cwd.split()[0], "foo") + */ + static std::string makeAbsolute(const std::string &path); + /** + * Convert a unicode string to UTF-8 + */ + static std::string unicodeToUtf8(const std::wstring &path); + + /** + * Convert a UTF-8 path to a unicode string + */ + static std::wstring utf8ToUnicode(const std::string &path); + /** + * When splitting a path into components, the "prefix" has to be + * treated specially. We do not store it in a separate data + * structure -- the prefix is just the first element of the split. + * No normalization is performed. We always have path == join(split(path)) + * except when there are empty components, e.g. foo//bar. Empty components + * are omitted. + * See the java.io.file module documentation for some good background + * split("foo/bar/../quux") -> ("foo", "bar", "..", "quux") + * split("/foo/bar/quux") -> ("/", "foo", "bar", "quux") + * split("a:\foo\bar") -> ("a:\", "foo", "bar") + * split("\\host\drive\file") -> ("\\", "host", "drive", "file") + * split("/foo//bar/") -> ("/", "foo", "bar") + * Note: this behavior is different from the original behavior. + */ + static StringVec split(const std::string &path); + + /** + * Construct a path from components. path == join(split(path)) + */ + static std::string join(StringVec::const_iterator begin, + StringVec::const_iterator end); + + /** + * path == "/" on unix + * path == "/" or path == "a:/" on windows + */ + static bool isRootdir(const std::string &path); + + /** + * isAbsolute("/foo/bar") -> true isAbsolute("foo")->false on Unix + * is Absolute("a:\foo\bar") -> true isAbsolute("\foo\bar") -> false on + * windows + */ + static bool isAbsolute(const std::string &path); + + /** + * varargs through overloading + */ + static std::string join(const std::string &path1, const std::string &path2); + static std::string join(const std::string &path1, const std::string &path2, + const std::string &path3); + static std::string join(const std::string &path1, const std::string &path2, + const std::string &path3, const std::string &path4); + + /** + * This second set of methods must interact with the filesystem + * to do their work. + */ + + /** + * true if path exists. false is for broken links + * @todo lexists() + */ + static bool exists(const std::string &path); + + /** + * getFileSize throws an exception if does not exist or is a directory + */ + static Size getFileSize(const std::string &path); + + /** + * @todo What if source is directory? What id source is file and dest is + * directory? + * @todo What if one of source or dest is a symbolic link? + */ + static void copy(const std::string &source, const std::string &destination); + static void remove(const std::string &path); + static void rename(const std::string &oldPath, const std::string &newPath); + static bool isDirectory(const std::string &path); + static bool isFile(const std::string &path); + static bool isSymbolicLink(const std::string &path); + static bool areEquivalent(const std::string &path1, const std::string &path2); + // Get a path to the currently running executable + static std::string getExecutablePath(); + + static void setPermissions(const std::string &path, bool userRead, + bool userWrite, bool groupRead, bool groupWrite, + bool otherRead, bool otherWrite); + + Path(std::string path); + operator const char *() const; + + /** + * Test for symbolic equivalence, i.e. normalize(a) == normalize(b) + * To test if they refer to the same file/directory, use areEquivalent + */ + bool operator==(const Path &other); + + Path &operator+=(const Path &path); + bool exists() const; + Path getParent() const; + Path getBasename() const; + Path getExtension() const; + Size getFileSize() const; + + Path &normalize(); + Path &makeAbsolute(); + StringVec split() const; + + void remove() const; + void copy(const std::string &destination) const; + void rename(const std::string &newPath); + + bool isDirectory() const; + bool isFile() const; + bool isRootdir() const; + bool isAbsolute() const; + bool isSymbolicLink() const; + bool isEmpty() const; + +private: + // on unix: == "/"; on windows: == "/" || == "C:" || == "C:/" + static bool isPrefix(const std::string &); + Path(); + +private: + std::string path_; +}; + +// Global operator +Path operator+(const Path &p1, const Path &p2); +} // namespace nupic + +#endif // NTA_PATH_HPP diff --git a/src/nupic/os/Regex.cpp b/src/nupic/os/Regex.cpp index 521b5845f4..91fbd36f4b 100644 --- a/src/nupic/os/Regex.cpp +++ b/src/nupic/os/Regex.cpp @@ -20,53 +20,49 @@ * --------------------------------------------------------------------- */ -/** @file -*/ +/** @file + */ #include #include #if defined(NTA_OS_WINDOWS) - // TODO: See https://github.com/numenta/nupic.core/issues/128 - #include +// TODO: See https://github.com/numenta/nupic.core/issues/128 +#include #else - //https://gcc.gnu.org/bugzilla/show_bug.cgi?id=53631 - #include +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=53631 +#include #endif -namespace nupic -{ - namespace regex - { - bool match(const std::string & re, const std::string & text) - { - NTA_CHECK(!re.empty()) << "Empty regular expressions is invalid"; - - // Make sure the regex will perform an exact match - std::string exactRegExp; - if (re[0] != '^') - exactRegExp += '^'; - exactRegExp += re; - if (re[re.length()-1] != '$') - exactRegExp += '$'; +namespace nupic { +namespace regex { +bool match(const std::string &re, const std::string &text) { + NTA_CHECK(!re.empty()) << "Empty regular expressions is invalid"; + + // Make sure the regex will perform an exact match + std::string exactRegExp; + if (re[0] != '^') + exactRegExp += '^'; + exactRegExp += re; + if (re[re.length() - 1] != '$') + exactRegExp += '$'; #if defined(NTA_OS_WINDOWS) - std::regex r(exactRegExp, std::regex::extended | std::regex::nosubs); - if (std::regex_match(text, r)) - return true; + std::regex r(exactRegExp, std::regex::extended | std::regex::nosubs); + if (std::regex_match(text, r)) + return true; - return false; + return false; #else - regex_t r; - int res = ::regcomp(&r, exactRegExp.c_str(), REG_EXTENDED|REG_NOSUB); - NTA_CHECK(res == 0) - << "regcomp() failed to compile the regular expression: " - << re << " . The error code is: " << res; - - res = regexec(&r, text.c_str(), (size_t) 0, nullptr, 0); - ::regfree(&r); + regex_t r; + int res = ::regcomp(&r, exactRegExp.c_str(), REG_EXTENDED | REG_NOSUB); + NTA_CHECK(res == 0) << "regcomp() failed to compile the regular expression: " + << re << " . The error code is: " << res; + + res = regexec(&r, text.c_str(), (size_t)0, nullptr, 0); + ::regfree(&r); - return res == 0; + return res == 0; #endif - } - } } +} // namespace regex +} // namespace nupic diff --git a/src/nupic/os/Regex.hpp b/src/nupic/os/Regex.hpp old mode 100755 new mode 100644 index 840f754fe1..eaf869fcd7 --- a/src/nupic/os/Regex.hpp +++ b/src/nupic/os/Regex.hpp @@ -22,7 +22,6 @@ /** @file */ - #ifndef NTA_REGEX_HPP #define NTA_REGEX_HPP @@ -32,14 +31,10 @@ //---------------------------------------------------------------------- -namespace nupic -{ - namespace regex - { - bool match(const std::string & re, const std::string & text); - } +namespace nupic { +namespace regex { +bool match(const std::string &re, const std::string &text); } +} // namespace nupic #endif // NTA_REGEX_HPP - - diff --git a/src/nupic/os/Timer.cpp b/src/nupic/os/Timer.cpp index 1c01bdb9a9..755053589c 100644 --- a/src/nupic/os/Timer.cpp +++ b/src/nupic/os/Timer.cpp @@ -20,8 +20,7 @@ * --------------------------------------------------------------------- */ - -/** @file +/** @file * Generic OS Implementations for the OS class */ @@ -38,10 +37,8 @@ static nupic::UInt64 initialTicks_ = 0; // initTime is called by the constructor, so it will always // have been called by the time we call getTicksPerSec or getCurrentTime -static inline void initTime() -{ - if (initialTicks_ == 0) - { +static inline void initTime() { + if (initialTicks_ == 0) { LARGE_INTEGER f; QueryPerformanceCounter(&f); initialTicks_ = (nupic::UInt64)(f.QuadPart); @@ -51,14 +48,9 @@ static inline void initTime() } } -static inline nupic::UInt64 getTicksPerSec() -{ - return ticksPerSec_; -} - +static inline nupic::UInt64 getTicksPerSec() { return ticksPerSec_; } -static nupic::UInt64 getCurrentTime() -{ +static nupic::UInt64 getCurrentTime() { LARGE_INTEGER v; QueryPerformanceCounter(&v); return (nupic::UInt64)(v.QuadPart) - initialTicks_; @@ -66,8 +58,8 @@ static nupic::UInt64 getCurrentTime() #elif defined(NTA_OS_DARWIN) -// This include defines a UInt64 type that conflicts with the nupic::UInt64 type. -// Because of this, all UInt64 is explicitly qualified in the interface. +// This include defines a UInt64 type that conflicts with the nupic::UInt64 +// type. Because of this, all UInt64 is explicitly qualified in the interface. #include #include #include @@ -78,29 +70,22 @@ static nupic::UInt64 getCurrentTime() static uint64_t initialTicks_ = 0; static nupic::UInt64 ticksPerSec_ = 0; -static inline void initTime() -{ +static inline void initTime() { if (initialTicks_ == 0) initialTicks_ = mach_absolute_time(); - if (ticksPerSec_ == 0) - { + if (ticksPerSec_ == 0) { mach_timebase_info_data_t sTimebaseInfo; mach_timebase_info(&sTimebaseInfo); ticksPerSec_ = (nupic::UInt64)(1e9 * (uint64_t)sTimebaseInfo.denom / - (uint64_t)sTimebaseInfo.numer); + (uint64_t)sTimebaseInfo.numer); } } -static inline nupic::UInt64 getCurrentTime() -{ +static inline nupic::UInt64 getCurrentTime() { return (nupic::UInt64)(mach_absolute_time() - initialTicks_); } -static inline nupic::UInt64 getTicksPerSec() -{ - return ticksPerSec_; -} - +static inline nupic::UInt64 getTicksPerSec() { return ticksPerSec_; } #else // linux @@ -108,107 +93,81 @@ static inline nupic::UInt64 getTicksPerSec() static nupic::UInt64 initialTicks_ = 0; -static inline void initTime() -{ - if (initialTicks_ == 0) - { +static inline void initTime() { + if (initialTicks_ == 0) { struct timeval t; ::gettimeofday(&t, nullptr); initialTicks_ = nupic::UInt64((t.tv_sec * 1e6) + t.tv_usec); } } -static inline nupic::UInt64 getCurrentTime() -{ +static inline nupic::UInt64 getCurrentTime() { struct timeval t; ::gettimeofday(&t, nullptr); nupic::UInt64 ticks = nupic::UInt64((t.tv_sec * 1e6) + t.tv_usec); return ticks - initialTicks_; } +static inline nupic::UInt64 getTicksPerSec() { return (nupic::UInt64)(1e6); } +#endif -static inline nupic::UInt64 getTicksPerSec() -{ - return (nupic::UInt64)(1e6); +namespace nupic { + +Timer::Timer(bool startme) { + initTime(); + reset(); + if (startme) + start(); } -#endif +void Timer::start() { + if (started_ == false) { + start_ = getCurrentTime(); + nstarts_++; + started_ = true; + } +} -namespace nupic -{ +/** + * Stop the stopwatch. When restarted, time will accumulate + */ - Timer::Timer(bool startme) - { - initTime(); - reset(); - if (startme) - start(); - } - - - void Timer::start() - { - if (started_ == false) - { - start_ = getCurrentTime(); - nstarts_++; - started_ = true; - } - } - - /** - * Stop the stopwatch. When restarted, time will accumulate - */ - - void Timer::stop() - { // stop the stopwatch - if (started_ == true) - { - prevElapsed_ += (getCurrentTime() - start_); - start_ = 0; - started_ = false; - } - } - - Real64 Timer::getElapsed() const - { - nupic::UInt64 elapsed = prevElapsed_; - if (started_) - { - elapsed += (getCurrentTime() - start_); - } - - return (Real64)(elapsed) / (Real64)getTicksPerSec(); - } - - void Timer::reset() - { - prevElapsed_ = 0; +void Timer::stop() { // stop the stopwatch + if (started_ == true) { + prevElapsed_ += (getCurrentTime() - start_); start_ = 0; - nstarts_ = 0; started_ = false; } - - UInt64 Timer::getStartCount() const - { - return nstarts_; - } - - bool Timer::isStarted() const - { - return started_; - } - - - std::string Timer::toString() const - { - std::stringstream ss; - ss << "[Elapsed: " << getElapsed() << " Starts: " << getStartCount(); - if (isStarted()) - ss << " (running)"; - ss << "]"; - return ss.str(); +} + +Real64 Timer::getElapsed() const { + nupic::UInt64 elapsed = prevElapsed_; + if (started_) { + elapsed += (getCurrentTime() - start_); } -} // namespace nupic + return (Real64)(elapsed) / (Real64)getTicksPerSec(); +} + +void Timer::reset() { + prevElapsed_ = 0; + start_ = 0; + nstarts_ = 0; + started_ = false; +} + +UInt64 Timer::getStartCount() const { return nstarts_; } + +bool Timer::isStarted() const { return started_; } + +std::string Timer::toString() const { + std::stringstream ss; + ss << "[Elapsed: " << getElapsed() << " Starts: " << getStartCount(); + if (isStarted()) + ss << " (running)"; + ss << "]"; + return ss.str(); +} + +} // namespace nupic diff --git a/src/nupic/os/Timer.hpp b/src/nupic/os/Timer.hpp index 1b584ee2c7..30375fa6a7 100644 --- a/src/nupic/os/Timer.hpp +++ b/src/nupic/os/Timer.hpp @@ -27,89 +27,75 @@ #ifndef NTA_TIMER2_HPP #define NTA_TIMER2_HPP -#include #include +#include -namespace nupic -{ +namespace nupic { +/** + * @Responsibility + * Simple stopwatch services + * + * @Description + * A timer object is a stopwatch. You can start it, stop it, read the + * elapsed time, and reset it. It is very convenient for performance + * measurements. + * + * Uses the most precise and lowest overhead timer available on a given system. + * + */ +class Timer { +public: /** - * @Responsibility - * Simple stopwatch services - * - * @Description - * A timer object is a stopwatch. You can start it, stop it, read the - * elapsed time, and reset it. It is very convenient for performance - * measurements. - * - * Uses the most precise and lowest overhead timer available on a given system. + * Create a stopwatch * + * @param startme If true, the timer is started when created + */ + Timer(bool startme = false); + + /** + * Start the stopwatch + */ + void start(); + + /** + * Stop the stopwatch. When restarted, time will accumulate + */ + void stop(); + + /** + * If stopped, return total elapsed time. + * If started, return current elapsed time but don't stop the clock + * return the value in seconds; + */ + Real64 getElapsed() const; + + /** + * Reset the stopwatch, setting accumulated time to zero. + */ + void reset(); + + /**Train + * Return the number of time the stopwatch has been started. + */ + UInt64 getStartCount() const; + + /** + * Returns true is the stopwatch is currently running */ - class Timer - { - public: - - /** - * Create a stopwatch - * - * @param startme If true, the timer is started when created - */ - Timer(bool startme = false); - - - /** - * Start the stopwatch - */ - void - start(); - - - /** - * Stop the stopwatch. When restarted, time will accumulate - */ - void - stop(); - - - /** - * If stopped, return total elapsed time. - * If started, return current elapsed time but don't stop the clock - * return the value in seconds; - */ - Real64 - getElapsed() const; - - /** - * Reset the stopwatch, setting accumulated time to zero. - */ - void - reset(); - - /**Train - * Return the number of time the stopwatch has been started. - */ - UInt64 - getStartCount() const; - - /** - * Returns true is the stopwatch is currently running - */ - bool - isStarted() const; - - std::string - toString() const; - - private: - // internally times are stored as ticks - UInt64 prevElapsed_; // total time as of last stop() (in ticks) - UInt64 start_; // time that start() was called (in ticks) - UInt64 nstarts_; // number of times start() was called - bool started_; // true if was started - - }; // class Timer - + bool isStarted() const; + + std::string toString() const; + +private: + // internally times are stored as ticks + UInt64 prevElapsed_; // total time as of last stop() (in ticks) + UInt64 start_; // time that start() was called (in ticks) + UInt64 nstarts_; // number of times start() was called + bool started_; // true if was started + +}; // class Timer + } // namespace nupic #endif // NTA_TIMER2_HPP - diff --git a/src/nupic/py_support/NumpyArrayObject.cpp b/src/nupic/py_support/NumpyArrayObject.cpp index 108ca2c3ea..412cf4fe23 100644 --- a/src/nupic/py_support/NumpyArrayObject.cpp +++ b/src/nupic/py_support/NumpyArrayObject.cpp @@ -30,14 +30,12 @@ #include namespace nupic { - void initializeNumpy() - { - // Use _import_array() because import_array() is a macro that contains a - // return statement. - if (_import_array() != 0) - { - throw std::runtime_error( +void initializeNumpy() { + // Use _import_array() because import_array() is a macro that contains a + // return statement. + if (_import_array() != 0) { + throw std::runtime_error( "initializeNumpy: numpy.core.multiarray failed to import."); - } } } +} // namespace nupic diff --git a/src/nupic/py_support/NumpyArrayObject.hpp b/src/nupic/py_support/NumpyArrayObject.hpp index e3a7ae1806..449a22d52e 100644 --- a/src/nupic/py_support/NumpyArrayObject.hpp +++ b/src/nupic/py_support/NumpyArrayObject.hpp @@ -43,14 +43,14 @@ #undef NO_IMPORT_ARRAY namespace nupic { - /** - * This method needs to be called some time in the process lifetime before - * calling the numpy C APIs. - * - * It initializes the global "NTA_NumpyArray_API" to an array of function - * pointers -- i.e. it dynamically links with the numpy library. - */ - void initializeNumpy(); -}; +/** + * This method needs to be called some time in the process lifetime before + * calling the numpy C APIs. + * + * It initializes the global "NTA_NumpyArray_API" to an array of function + * pointers -- i.e. it dynamically links with the numpy library. + */ +void initializeNumpy(); +}; // namespace nupic #endif // NTA_NUMPY_ARRAY_OBJECT diff --git a/src/nupic/py_support/NumpyVector.cpp b/src/nupic/py_support/NumpyVector.cpp index 6dcf122f0a..ee82bbdf21 100644 --- a/src/nupic/py_support/NumpyVector.cpp +++ b/src/nupic/py_support/NumpyVector.cpp @@ -23,8 +23,6 @@ /** @file */ - - #include // workaround for change in numpy config.h for python2.5 on windows @@ -39,8 +37,8 @@ #include -#include #include +#include using namespace std; using namespace nupic; @@ -49,22 +47,25 @@ using namespace nupic; // Auto-convert a compile-time type to a Numpy dtype. // -------------------------------------------------------------- -template -class NumpyDTypeTraits {}; +template class NumpyDTypeTraits {}; -template -int LookupNumpyDTypeT(const T *) - { return NumpyDTypeTraits::numpyDType; } +template int LookupNumpyDTypeT(const T *) { + return NumpyDTypeTraits::numpyDType; +} -#define NTA_DEF_NUMPY_DTYPE_TRAIT(a, b) \ -template<> class NumpyDTypeTraits { public: enum { numpyDType=b }; }; \ -int nupic::LookupNumpyDType(const a *p) { return LookupNumpyDTypeT(p); } +#define NTA_DEF_NUMPY_DTYPE_TRAIT(a, b) \ + template <> class NumpyDTypeTraits { \ + public: \ + enum { numpyDType = b }; \ + }; \ + int nupic::LookupNumpyDType(const a *p) { return LookupNumpyDTypeT(p); } NTA_DEF_NUMPY_DTYPE_TRAIT(nupic::Byte, NPY_BYTE); NTA_DEF_NUMPY_DTYPE_TRAIT(nupic::Int16, NPY_INT16); NTA_DEF_NUMPY_DTYPE_TRAIT(nupic::UInt16, NPY_UINT16); -#if defined(NTA_ARCH_64) && (defined(NTA_OS_LINUX) || defined(NTA_OS_DARWIN) || defined(NTA_OS_SPARC)) +#if defined(NTA_ARCH_64) && \ + (defined(NTA_OS_LINUX) || defined(NTA_OS_DARWIN) || defined(NTA_OS_SPARC)) NTA_DEF_NUMPY_DTYPE_TRAIT(size_t, NPY_UINT64); #else NTA_DEF_NUMPY_DTYPE_TRAIT(size_t, NPY_UINT32); @@ -79,23 +80,24 @@ NTA_DEF_NUMPY_DTYPE_TRAIT(nupic::UInt32, NPY_UINT32); NTA_DEF_NUMPY_DTYPE_TRAIT(nupic::Int64, NPY_INT64); -#if (!(defined(NTA_ARCH_64) && (defined(NTA_OS_LINUX) || defined(NTA_OS_DARWIN) || defined(NTA_OS_SPARC))) && !defined(NTA_OS_WINDOWS)) +#if (!(defined(NTA_ARCH_64) && \ + (defined(NTA_OS_LINUX) || defined(NTA_OS_DARWIN) || \ + defined(NTA_OS_SPARC))) && \ + !defined(NTA_OS_WINDOWS)) NTA_DEF_NUMPY_DTYPE_TRAIT(nupic::UInt64, NPY_UINT64); #endif - NTA_DEF_NUMPY_DTYPE_TRAIT(nupic::Real32, NPY_FLOAT32); NTA_DEF_NUMPY_DTYPE_TRAIT(nupic::Real64, NPY_FLOAT64); // -------------------------------------------------------------- NumpyArray::NumpyArray(int nd, const int *ndims, int dtype) - : p_(0), dtype_(dtype) -{ + : p_(0), dtype_(dtype) { // declare static to avoid new/delete with every call static npy_intp ndims_intp[NPY_MAXDIMS]; - if(nd < 0) + if (nd < 0) throw runtime_error("Negative dimensioned arrays not supported."); if (nd > NPY_MAXDIMS) @@ -105,86 +107,88 @@ NumpyArray::NumpyArray(int nd, const int *ndims, int dtype) * npy_intp is an integer that can hold a pointer. On 64-bit * systems this is not the same as an int. */ - for (int i = 0; i < nd; i++) - { + for (int i = 0; i < nd; i++) { ndims_intp[i] = (npy_intp)ndims[i]; } - p_ = (PyArrayObject *) PyArray_SimpleNew(nd, ndims_intp, dtype); - + p_ = (PyArrayObject *)PyArray_SimpleNew(nd, ndims_intp, dtype); } NumpyArray::NumpyArray(PyObject *p, int dtype, int requiredDimension) - : p_(0), dtype_(dtype) -{ + : p_(0), dtype_(dtype) { PyObject *contiguous = PyArray_ContiguousFromObject(p, NPY_NOTYPE, 0, 0); - if(!contiguous) + if (!contiguous) throw std::runtime_error("Array could not be made contiguous."); - if(!PyArray_Check(contiguous)) + if (!PyArray_Check(contiguous)) throw std::logic_error("Failed to convert to array."); - PyObject *casted = PyArray_Cast((PyArrayObject *) contiguous, dtype); + PyObject *casted = PyArray_Cast((PyArrayObject *)contiguous, dtype); Py_CLEAR(contiguous); - if(!casted) throw std::runtime_error("Array could not be cast to requested type."); - if(!PyArray_Check(casted)) throw std::logic_error("Array is not contiguous."); - PyArrayObject *final = (PyArrayObject *) casted; - if((requiredDimension != 0) && (PyArray_NDIM(final) != requiredDimension)) + if (!casted) + throw std::runtime_error("Array could not be cast to requested type."); + if (!PyArray_Check(casted)) + throw std::logic_error("Array is not contiguous."); + PyArrayObject *final = (PyArrayObject *)casted; + if ((requiredDimension != 0) && (PyArray_NDIM(final) != requiredDimension)) throw std::runtime_error("Array is not of the required dimension."); p_ = final; } -NumpyArray::~NumpyArray() -{ - PyObject *generic = (PyObject *) p_; +NumpyArray::~NumpyArray() { + PyObject *generic = (PyObject *)p_; p_ = 0; Py_CLEAR(generic); } -int NumpyArray::getRank() const -{ - if(!p_) throw runtime_error("Null NumpyArray."); +int NumpyArray::getRank() const { + if (!p_) + throw runtime_error("Null NumpyArray."); return PyArray_NDIM(p_); } -int NumpyArray::dimension(int i) const -{ - if(!p_) throw runtime_error("Null NumpyArray."); - if(i < 0) throw runtime_error("Negative dimension requested."); - if(i >= PyArray_NDIM(p_)) throw out_of_range("Dimension exceeds number available."); +int NumpyArray::dimension(int i) const { + if (!p_) + throw runtime_error("Null NumpyArray."); + if (i < 0) + throw runtime_error("Negative dimension requested."); + if (i >= PyArray_NDIM(p_)) + throw out_of_range("Dimension exceeds number available."); return int(PyArray_DIMS(p_)[i]); } -void NumpyArray::getDims(int *out) const -{ - if(!p_) throw runtime_error("Null NumpyArray."); +void NumpyArray::getDims(int *out) const { + if (!p_) + throw runtime_error("Null NumpyArray."); int n = PyArray_NDIM(p_); - for(int i=0; i // For std::copy. +#include // for 'type_id' #include #include // For nupic::Real. -#include // For NTA_ASSERT -#include // For std::copy. -#include // for 'type_id' +#include // For NTA_ASSERT namespace nupic { - extern int LookupNumpyDType(const size_t *); - extern int LookupNumpyDType(const nupic::Byte *); - extern int LookupNumpyDType(const nupic::Int16 *); - extern int LookupNumpyDType(const nupic::UInt16 *); - extern int LookupNumpyDType(const nupic::Int32 *); - extern int LookupNumpyDType(const nupic::UInt32 *); - extern int LookupNumpyDType(const nupic::Int64 *); - extern int LookupNumpyDType(const nupic::UInt64 *); - extern int LookupNumpyDType(const nupic::Real32 *); - extern int LookupNumpyDType(const nupic::Real64 *); - /** - * Concrete Numpy multi-d array wrapper whose implementation cannot be visible - * due to the specifics of dynamically loading the Numpy C function API. - */ - class NumpyArray - { - - NumpyArray(const NumpyArray &); // Verboten. - NumpyArray &operator=(const NumpyArray &); // Verboten. - - protected: - PyArrayObject *p_; - int dtype_; - - const char *addressOf0() const; - char *addressOf0(); - int stride(int i) const; - - NumpyArray(int nd, const int *dims, int dtype); - NumpyArray(PyObject *p, int dtype, int requiredDimension=0); - - public: - - /////////////////////////////////////////////////////////// - /// Destructor. - /// - /// Releases the reference to the internal numpy array. - /////////////////////////////////////////////////////////// - virtual ~NumpyArray(); - - /////////////////////////////////////////////////////////// - /// The number of dimensions of the internal numpy array. - /// - /// Will always be 1, as enforced by the constructors. - /////////////////////////////////////////////////////////// - int numDimensions() const { return getRank(); } - - int getRank() const; - - /////////////////////////////////////////////////////////// - /// Gets the size of the array along dimension i. - /// - /// Does not check the validity of the passed-in dimension. - /////////////////////////////////////////////////////////// - int dimension(int i) const; - - - void getDims(int *) const; - - /////////////////////////////////////////////////////////// - /// Gets the size of the array (along dimension 0). - /////////////////////////////////////////////////////////// - int size() const { return dimension(0); } - - /////////////////////////////////////////////////////////// - /// Returns a PyObject that can be returned from C code to Python. - /// - /// The PyObject returned is a new reference, and the caller must - /// dereference the object when done. - /// The PyObject is produced by PyArray_Return (whatever that does). - /////////////////////////////////////////////////////////// - PyObject *forPython(); - - }; +extern int LookupNumpyDType(const size_t *); +extern int LookupNumpyDType(const nupic::Byte *); +extern int LookupNumpyDType(const nupic::Int16 *); +extern int LookupNumpyDType(const nupic::UInt16 *); +extern int LookupNumpyDType(const nupic::Int32 *); +extern int LookupNumpyDType(const nupic::UInt32 *); +extern int LookupNumpyDType(const nupic::Int64 *); +extern int LookupNumpyDType(const nupic::UInt64 *); +extern int LookupNumpyDType(const nupic::Real32 *); +extern int LookupNumpyDType(const nupic::Real64 *); +/** + * Concrete Numpy multi-d array wrapper whose implementation cannot be visible + * due to the specifics of dynamically loading the Numpy C function API. + */ +class NumpyArray { + + NumpyArray(const NumpyArray &); // Verboten. + NumpyArray &operator=(const NumpyArray &); // Verboten. + +protected: + PyArrayObject *p_; + int dtype_; + + const char *addressOf0() const; + char *addressOf0(); + int stride(int i) const; + + NumpyArray(int nd, const int *dims, int dtype); + NumpyArray(PyObject *p, int dtype, int requiredDimension = 0); + +public: + /////////////////////////////////////////////////////////// + /// Destructor. + /// + /// Releases the reference to the internal numpy array. + /////////////////////////////////////////////////////////// + virtual ~NumpyArray(); /////////////////////////////////////////////////////////// - /// A wrapper for 1D numpy arrays of data type equaivalent to nupic::Real. + /// The number of dimensions of the internal numpy array. /// - /// Numpy is a Python extension written in C. - /// Accessing numpy's C API directly is tricky but possible. - /// Such access can be performed with SWIG typemaps, - /// using a slow and feature-poor set of SWIG typemap definitions - /// provided as an example with the numpy documentation. - /// This class bypasses that method of access, in favor of - /// a faster interface. + /// Will always be 1, as enforced by the constructors. + /////////////////////////////////////////////////////////// + int numDimensions() const { return getRank(); } + + int getRank() const; + + /////////////////////////////////////////////////////////// + /// Gets the size of the array along dimension i. /// - /// This wrapper should only be used within Python bindings, - /// as numpy data structures will only be passed in from Python code. - /// For an example of its use, see the nupic::SparseMatrix Python bindings - /// in nupic/python/bindings/math/SparseMatrix.i + /// Does not check the validity of the passed-in dimension. /////////////////////////////////////////////////////////// - template - class NumpyVectorT : public NumpyArray - { - - NumpyVectorT(const NumpyVectorT &); // Verboten. - NumpyVectorT &operator=(const NumpyVectorT &); // Verboten. - - public: - - /////////////////////////////////////////////////////////// - /// Create a new 1D numpy array of size n. - /////////////////////////////////////////////////////////// - NumpyVectorT(int n, const T& val=0) - : NumpyArray(1, &n, LookupNumpyDType((const T *) 0)) - { - std::fill(begin(), end(), val); - } + int dimension(int i) const; - NumpyVectorT(int n, const T *val) - : NumpyArray(1, &n, LookupNumpyDType((const T *) 0)) - { - if(val) std::copy(val, val+n, begin()); - } + void getDims(int *) const; - /////////////////////////////////////////////////////////// - /// Reference an existing 1D numpy array, or copy it if - /// it differs in type. - /// - /// Produces a really annoying warning if this will do a slow copy. - /// Do not use in this case. Make sure the data coming in is in - /// the appropriate format (1D contiguous numpy array of type - /// equivalent to nupic::Real). If nupic::Real is float, - /// the incoming array should have been created with dtype=numpy.float32 - /////////////////////////////////////////////////////////// - NumpyVectorT(PyObject *p) - : NumpyArray(p, LookupNumpyDType((const T *) 0), 1) - {} - - virtual ~NumpyVectorT() {} - - T* begin() { return addressOf(0); } - T* end() { return begin() + size(); } - const T* begin() const { return addressOf(0); } - const T* end() const { return begin() + size(); } - - /////////////////////////////////////////////////////////// - /// Get a pointer to element i. - /// - /// Does not check the validity of the index. - /////////////////////////////////////////////////////////// - const T *addressOf(int i) const - { return (const T *) (addressOf0() + i*stride(0)); } - - /////////////////////////////////////////////////////////// - /// Get a non-const pointer to element i. - /// - /// Does not check the validity of the index. - /////////////////////////////////////////////////////////// - T *addressOf(int i) - { return (T *) (addressOf0() + i*stride(0)); } - - /////////////////////////////////////////////////////////// - /// Get the increment (in number of Reals) from one element - /// to the next. - /////////////////////////////////////////////////////////// - int incr() const { return int(addressOf(1) - addressOf(0)); } - - inline T& get(int i) { return *addressOf(i); } - inline T get(int i) const { return *addressOf(i); } - inline void set(int i, const T& val) { *addressOf(i) = val; } - }; - - //-------------------------------------------------------------------------------- - template - class NumpyMatrixT : public NumpyArray - { - NumpyMatrixT(const NumpyMatrixT &); // Verboten. - NumpyMatrixT &operator=(const NumpyMatrixT &); // Verboten. - - public: - - typedef int size_type; - - /////////////////////////////////////////////////////////// - /// Create a new 2D numpy array of size n. - /////////////////////////////////////////////////////////// - NumpyMatrixT(const int nRowsCols[2]) - : NumpyArray(2, nRowsCols, LookupNumpyDType((const T *) 0)) - {} - - NumpyMatrixT(PyObject *p) - : NumpyArray(p, LookupNumpyDType((const T *) 0), 2) - {} - - /////////////////////////////////////////////////////////// - /// Destructor. - /// - /// Releases the reference to the internal numpy array. - /////////////////////////////////////////////////////////// - virtual ~NumpyMatrixT() {} - - int rows() const { return dimension(0); } - int columns() const { return dimension(1); } - int nRows() const { return dimension(0); } - int nCols() const { return dimension(1); } - - inline const T *addressOf(int row, int col) const - { return (const T *) (addressOf0() + row*stride(0) + col*stride(1)); } - - inline T *addressOf(int row, int col) - { return (T *) (addressOf0() + row*stride(0) + col*stride(1)); } - - inline const T* begin(int row) const - { return (const T*)(addressOf0() + row*stride(0)); } - - inline const T* end(int row) const - { return (const T*)(addressOf0() + row*stride(0) + nCols()*stride(1)); } - - inline T* begin(int row) - { return (T*)(addressOf0() + row*stride(0)); } - - inline T* end(int row) - { return (T*)(addressOf0() + row*stride(0) + nCols()*stride(1)); } - - inline T& get(int i, int j) { return *addressOf(i,j); } - inline T get(int i, int j) const { return *addressOf(i,j); } - inline void set(int i, int j, const T& val) { *addressOf(i,j) = val; } - }; - - template - class NumpyNDArrayT : public NumpyArray - { - NumpyNDArrayT(const NumpyNDArrayT &); // Verboten. - NumpyNDArrayT &operator=(const NumpyNDArrayT &); // Verboten. - - public: - NumpyNDArrayT(PyObject *p) - : NumpyArray(p, LookupNumpyDType((const T *) 0)) - {} - NumpyNDArrayT(int rank, const int *dims) - : NumpyArray(rank, dims, LookupNumpyDType((const T *) 0)) - {} - virtual ~NumpyNDArrayT() {} - - const T *getData() const { return (const T *) addressOf0(); } - T *getData() { return (T *) addressOf0(); } - }; - - //-------------------------------------------------------------------------------- - typedef NumpyVectorT<> NumpyVector; - typedef NumpyMatrixT<> NumpyMatrix; - typedef NumpyNDArrayT<> NumpyNDArray; - - //-------------------------------------------------------------------------------- - template - inline T convertToValueType(PyObject *val) - { - return * nupic::NumpyNDArrayT(val).getData(); + /////////////////////////////////////////////////////////// + /// Gets the size of the array (along dimension 0). + /////////////////////////////////////////////////////////// + int size() const { return dimension(0); } + + /////////////////////////////////////////////////////////// + /// Returns a PyObject that can be returned from C code to Python. + /// + /// The PyObject returned is a new reference, and the caller must + /// dereference the object when done. + /// The PyObject is produced by PyArray_Return (whatever that does). + /////////////////////////////////////////////////////////// + PyObject *forPython(); +}; + +/////////////////////////////////////////////////////////// +/// A wrapper for 1D numpy arrays of data type equaivalent to nupic::Real. +/// +/// Numpy is a Python extension written in C. +/// Accessing numpy's C API directly is tricky but possible. +/// Such access can be performed with SWIG typemaps, +/// using a slow and feature-poor set of SWIG typemap definitions +/// provided as an example with the numpy documentation. +/// This class bypasses that method of access, in favor of +/// a faster interface. +/// +/// This wrapper should only be used within Python bindings, +/// as numpy data structures will only be passed in from Python code. +/// For an example of its use, see the nupic::SparseMatrix Python bindings +/// in nupic/python/bindings/math/SparseMatrix.i +/////////////////////////////////////////////////////////// +template class NumpyVectorT : public NumpyArray { + + NumpyVectorT(const NumpyVectorT &); // Verboten. + NumpyVectorT &operator=(const NumpyVectorT &); // Verboten. + +public: + /////////////////////////////////////////////////////////// + /// Create a new 1D numpy array of size n. + /////////////////////////////////////////////////////////// + NumpyVectorT(int n, const T &val = 0) + : NumpyArray(1, &n, LookupNumpyDType((const T *)0)) { + std::fill(begin(), end(), val); } - //-------------------------------------------------------------------------------- - template - inline PyObject* convertFromValueType(const T& value) { - nupic::NumpyNDArrayT ret(0, NULL); - *ret.getData() = value; - return ret.forPython(); + NumpyVectorT(int n, const T *val) + : NumpyArray(1, &n, LookupNumpyDType((const T *)0)) { + if (val) + std::copy(val, val + n, begin()); } - //-------------------------------------------------------------------------------- - template - inline PyObject* convertToPairOfLists(I i_begin, I i_end, T val) - { - const size_t n = (size_t) (i_end - i_begin); - - PyObject *indOut = PyTuple_New(n); - // Steals the new references. - for (size_t i = 0; i != n; ++i, ++i_begin) - PyTuple_SET_ITEM(indOut, i, PyInt_FromLong(*i_begin)); - - PyObject *valOut = PyTuple_New(n); - // Steals the new references. - for (size_t i = 0; i != n; ++i, ++val) - PyTuple_SET_ITEM(valOut, i, PyFloat_FromDouble(*val)); - - PyObject *toReturn = PyTuple_New(2); - // Steals the index tuple reference. - PyTuple_SET_ITEM(toReturn, 0, indOut); - // Steals the index tuple reference. - PyTuple_SET_ITEM(toReturn, 1, valOut); - - // Returns a single new reference. - return toReturn; + /////////////////////////////////////////////////////////// + /// Reference an existing 1D numpy array, or copy it if + /// it differs in type. + /// + /// Produces a really annoying warning if this will do a slow copy. + /// Do not use in this case. Make sure the data coming in is in + /// the appropriate format (1D contiguous numpy array of type + /// equivalent to nupic::Real). If nupic::Real is float, + /// the incoming array should have been created with dtype=numpy.float32 + /////////////////////////////////////////////////////////// + NumpyVectorT(PyObject *p) + : NumpyArray(p, LookupNumpyDType((const T *)0), 1) {} + + virtual ~NumpyVectorT() {} + + T *begin() { return addressOf(0); } + T *end() { return begin() + size(); } + const T *begin() const { return addressOf(0); } + const T *end() const { return begin() + size(); } + + /////////////////////////////////////////////////////////// + /// Get a pointer to element i. + /// + /// Does not check the validity of the index. + /////////////////////////////////////////////////////////// + const T *addressOf(int i) const { + return (const T *)(addressOf0() + i * stride(0)); } - //-------------------------------------------------------------------------------- - template - inline PyObject* createPair32(I i, T v) - { - PyObject *result = PyTuple_New(2); - PyTuple_SET_ITEM(result, 0, PyInt_FromLong(i)); - PyTuple_SET_ITEM(result, 1, PyFloat_FromDouble(v)); - return result; + /////////////////////////////////////////////////////////// + /// Get a non-const pointer to element i. + /// + /// Does not check the validity of the index. + /////////////////////////////////////////////////////////// + T *addressOf(int i) { return (T *)(addressOf0() + i * stride(0)); } + + /////////////////////////////////////////////////////////// + /// Get the increment (in number of Reals) from one element + /// to the next. + /////////////////////////////////////////////////////////// + int incr() const { return int(addressOf(1) - addressOf(0)); } + + inline T &get(int i) { return *addressOf(i); } + inline T get(int i) const { return *addressOf(i); } + inline void set(int i, const T &val) { *addressOf(i) = val; } +}; + +//-------------------------------------------------------------------------------- +template class NumpyMatrixT : public NumpyArray { + NumpyMatrixT(const NumpyMatrixT &); // Verboten. + NumpyMatrixT &operator=(const NumpyMatrixT &); // Verboten. + +public: + typedef int size_type; + + /////////////////////////////////////////////////////////// + /// Create a new 2D numpy array of size n. + /////////////////////////////////////////////////////////// + NumpyMatrixT(const int nRowsCols[2]) + : NumpyArray(2, nRowsCols, LookupNumpyDType((const T *)0)) {} + + NumpyMatrixT(PyObject *p) + : NumpyArray(p, LookupNumpyDType((const T *)0), 2) {} + + /////////////////////////////////////////////////////////// + /// Destructor. + /// + /// Releases the reference to the internal numpy array. + /////////////////////////////////////////////////////////// + virtual ~NumpyMatrixT() {} + + int rows() const { return dimension(0); } + int columns() const { return dimension(1); } + int nRows() const { return dimension(0); } + int nCols() const { return dimension(1); } + + inline const T *addressOf(int row, int col) const { + return (const T *)(addressOf0() + row * stride(0) + col * stride(1)); } - //-------------------------------------------------------------------------------- - template - inline PyObject* createPair64(I i, T v) - { - PyObject *result = PyTuple_New(2); - PyTuple_SET_ITEM(result, 0, PyLong_FromLongLong(i)); - PyTuple_SET_ITEM(result, 1, PyFloat_FromDouble(v)); - return result; + inline T *addressOf(int row, int col) { + return (T *)(addressOf0() + row * stride(0) + col * stride(1)); } - //-------------------------------------------------------------------------------- - template - inline PyObject* createTriplet32(I i1, I i2, T v1) - { - PyObject *result = PyTuple_New(3); - PyTuple_SET_ITEM(result, 0, PyInt_FromLong(i1)); - PyTuple_SET_ITEM(result, 1, PyInt_FromLong(i2)); - PyTuple_SET_ITEM(result, 2, PyFloat_FromDouble(v1)); - return result; + inline const T *begin(int row) const { + return (const T *)(addressOf0() + row * stride(0)); } - //-------------------------------------------------------------------------------- - template - inline PyObject* createTriplet64(I i1, I i2, T v1) - { - PyObject *result = PyTuple_New(3); - PyTuple_SET_ITEM(result, 0, PyLong_FromLongLong(i1)); - PyTuple_SET_ITEM(result, 1, PyLong_FromLongLong(i2)); - PyTuple_SET_ITEM(result, 2, PyFloat_FromDouble(v1)); - return result; + inline const T *end(int row) const { + return (const T *)(addressOf0() + row * stride(0) + nCols() * stride(1)); } - //-------------------------------------------------------------------------------- - template - PyObject* PyInt32Vector(TIter begin, TIter end) - { - Py_ssize_t n = end - begin; - PyObject *p = PyTuple_New(n); - Py_ssize_t i = 0; - for (TIter cur=begin; cur!=end; ++cur, ++i) { - PyTuple_SET_ITEM(p, i, PyInt_FromLong(*cur)); - } + inline T *begin(int row) { return (T *)(addressOf0() + row * stride(0)); } - return p; + inline T *end(int row) { + return (T *)(addressOf0() + row * stride(0) + nCols() * stride(1)); } - //-------------------------------------------------------------------------------- - template - PyObject* PyInt64Vector(TIter begin, TIter end) - { - Py_ssize_t n = end - begin; - PyObject *p = PyTuple_New(n); - Py_ssize_t i = 0; - for (TIter cur=begin; cur!=end; ++cur, ++i) { - PyTuple_SET_ITEM(p, i, PyLong_FromLongLong(*cur)); - } + inline T &get(int i, int j) { return *addressOf(i, j); } + inline T get(int i, int j) const { return *addressOf(i, j); } + inline void set(int i, int j, const T &val) { *addressOf(i, j) = val; } +}; + +template class NumpyNDArrayT : public NumpyArray { + NumpyNDArrayT(const NumpyNDArrayT &); // Verboten. + NumpyNDArrayT &operator=(const NumpyNDArrayT &); // Verboten. + +public: + NumpyNDArrayT(PyObject *p) : NumpyArray(p, LookupNumpyDType((const T *)0)) {} + NumpyNDArrayT(int rank, const int *dims) + : NumpyArray(rank, dims, LookupNumpyDType((const T *)0)) {} + virtual ~NumpyNDArrayT() {} + + const T *getData() const { return (const T *)addressOf0(); } + T *getData() { return (T *)addressOf0(); } +}; + +//-------------------------------------------------------------------------------- +typedef NumpyVectorT<> NumpyVector; +typedef NumpyMatrixT<> NumpyMatrix; +typedef NumpyNDArrayT<> NumpyNDArray; + +//-------------------------------------------------------------------------------- +template inline T convertToValueType(PyObject *val) { + return *nupic::NumpyNDArrayT(val).getData(); +} + +//-------------------------------------------------------------------------------- +template inline PyObject *convertFromValueType(const T &value) { + nupic::NumpyNDArrayT ret(0, NULL); + *ret.getData() = value; + return ret.forPython(); +} + +//-------------------------------------------------------------------------------- +template +inline PyObject *convertToPairOfLists(I i_begin, I i_end, T val) { + const size_t n = (size_t)(i_end - i_begin); + + PyObject *indOut = PyTuple_New(n); + // Steals the new references. + for (size_t i = 0; i != n; ++i, ++i_begin) + PyTuple_SET_ITEM(indOut, i, PyInt_FromLong(*i_begin)); + + PyObject *valOut = PyTuple_New(n); + // Steals the new references. + for (size_t i = 0; i != n; ++i, ++val) + PyTuple_SET_ITEM(valOut, i, PyFloat_FromDouble(*val)); + + PyObject *toReturn = PyTuple_New(2); + // Steals the index tuple reference. + PyTuple_SET_ITEM(toReturn, 0, indOut); + // Steals the index tuple reference. + PyTuple_SET_ITEM(toReturn, 1, valOut); + + // Returns a single new reference. + return toReturn; +} + +//-------------------------------------------------------------------------------- +template inline PyObject *createPair32(I i, T v) { + PyObject *result = PyTuple_New(2); + PyTuple_SET_ITEM(result, 0, PyInt_FromLong(i)); + PyTuple_SET_ITEM(result, 1, PyFloat_FromDouble(v)); + return result; +} + +//-------------------------------------------------------------------------------- +template inline PyObject *createPair64(I i, T v) { + PyObject *result = PyTuple_New(2); + PyTuple_SET_ITEM(result, 0, PyLong_FromLongLong(i)); + PyTuple_SET_ITEM(result, 1, PyFloat_FromDouble(v)); + return result; +} + +//-------------------------------------------------------------------------------- +template +inline PyObject *createTriplet32(I i1, I i2, T v1) { + PyObject *result = PyTuple_New(3); + PyTuple_SET_ITEM(result, 0, PyInt_FromLong(i1)); + PyTuple_SET_ITEM(result, 1, PyInt_FromLong(i2)); + PyTuple_SET_ITEM(result, 2, PyFloat_FromDouble(v1)); + return result; +} + +//-------------------------------------------------------------------------------- +template +inline PyObject *createTriplet64(I i1, I i2, T v1) { + PyObject *result = PyTuple_New(3); + PyTuple_SET_ITEM(result, 0, PyLong_FromLongLong(i1)); + PyTuple_SET_ITEM(result, 1, PyLong_FromLongLong(i2)); + PyTuple_SET_ITEM(result, 2, PyFloat_FromDouble(v1)); + return result; +} + +//-------------------------------------------------------------------------------- +template PyObject *PyInt32Vector(TIter begin, TIter end) { + Py_ssize_t n = end - begin; + PyObject *p = PyTuple_New(n); + Py_ssize_t i = 0; + for (TIter cur = begin; cur != end; ++cur, ++i) { + PyTuple_SET_ITEM(p, i, PyInt_FromLong(*cur)); + } + + return p; +} - return p; +//-------------------------------------------------------------------------------- +template PyObject *PyInt64Vector(TIter begin, TIter end) { + Py_ssize_t n = end - begin; + PyObject *p = PyTuple_New(n); + Py_ssize_t i = 0; + for (TIter cur = begin; cur != end; ++cur, ++i) { + PyTuple_SET_ITEM(p, i, PyLong_FromLongLong(*cur)); } - //-------------------------------------------------------------------------------- - template - PyObject* PyFloatVector(TIter begin, TIter end) - { - Py_ssize_t n = end - begin; - PyObject *p = PyTuple_New(n); - Py_ssize_t i = 0; - for (TIter cur=begin; cur!=end; ++cur, ++i) { - PyTuple_SET_ITEM(p, i, PyFloat_FromDouble(*cur)); - } + return p; +} - return p; +//-------------------------------------------------------------------------------- +template PyObject *PyFloatVector(TIter begin, TIter end) { + Py_ssize_t n = end - begin; + PyObject *p = PyTuple_New(n); + Py_ssize_t i = 0; + for (TIter cur = begin; cur != end; ++cur, ++i) { + PyTuple_SET_ITEM(p, i, PyFloat_FromDouble(*cur)); } + return p; +} - /** - * Extract a 1D Numpy array's buffer. - */ - template - class NumpyVectorWeakRefT - { - public: - NumpyVectorWeakRefT(PyObject* pyArray) - : pyArray_((PyArrayObject*)pyArray) - { - NTA_ASSERT(PyArray_NDIM(pyArray_) == 1); - NTA_ASSERT(PyArray_EquivTypenums( - PyArray_TYPE(pyArray_), LookupNumpyDType((const T *) 0))); - } +/** + * Extract a 1D Numpy array's buffer. + */ +template class NumpyVectorWeakRefT { +public: + NumpyVectorWeakRefT(PyObject *pyArray) : pyArray_((PyArrayObject *)pyArray) { + NTA_ASSERT(PyArray_NDIM(pyArray_) == 1); + NTA_ASSERT(PyArray_EquivTypenums(PyArray_TYPE(pyArray_), + LookupNumpyDType((const T *)0))); + } - T* begin() const - { - return (T*) PyArray_DATA(pyArray_); - } + T *begin() const { return (T *)PyArray_DATA(pyArray_); } - T* end() const - { - return (T*) PyArray_DATA(pyArray_) + size(); - } + T *end() const { return (T *)PyArray_DATA(pyArray_) + size(); } - size_t size() const - { - return PyArray_DIMS(pyArray_)[0]; - } + size_t size() const { return PyArray_DIMS(pyArray_)[0]; } + +protected: + PyArrayObject *pyArray_; +}; - protected: - PyArrayObject* pyArray_; - }; - - /** - * Similar to NumpyVectorWeakRefT but also provides extra type checking - */ - template - class CheckedNumpyVectorWeakRefT : public NumpyVectorWeakRefT - { - public: - CheckedNumpyVectorWeakRefT(PyObject* pyArray) - : NumpyVectorWeakRefT(pyArray) - { - if (PyArray_NDIM(this->pyArray_) != 1) - { - NTA_THROW << "Expecting 1D array " - << "but got " << PyArray_NDIM(this->pyArray_) << "D array"; - } - if (!PyArray_EquivTypenums( - PyArray_TYPE(this->pyArray_), LookupNumpyDType((const T *) 0))) - { - boost::typeindex::stl_type_index expectedType = - boost::typeindex::stl_type_index::type_id(); - NTA_THROW << "Expecting '" << expectedType.pretty_name() << "' " - << "but got '" << PyArray_DTYPE(this->pyArray_)->type << "'"; - } - } - }; +/** + * Similar to NumpyVectorWeakRefT but also provides extra type checking + */ +template +class CheckedNumpyVectorWeakRefT : public NumpyVectorWeakRefT { +public: + CheckedNumpyVectorWeakRefT(PyObject *pyArray) + : NumpyVectorWeakRefT(pyArray) { + if (PyArray_NDIM(this->pyArray_) != 1) { + NTA_THROW << "Expecting 1D array " + << "but got " << PyArray_NDIM(this->pyArray_) << "D array"; + } + if (!PyArray_EquivTypenums(PyArray_TYPE(this->pyArray_), + LookupNumpyDType((const T *)0))) { + boost::typeindex::stl_type_index expectedType = + boost::typeindex::stl_type_index::type_id(); + NTA_THROW << "Expecting '" << expectedType.pretty_name() << "' " + << "but got '" << PyArray_DTYPE(this->pyArray_)->type << "'"; + } + } +}; } // End namespace nupic. #endif diff --git a/src/nupic/py_support/PyArray.cpp b/src/nupic/py_support/PyArray.cpp index e2ad454761..a961cf0d42 100644 --- a/src/nupic/py_support/PyArray.cpp +++ b/src/nupic/py_support/PyArray.cpp @@ -22,7 +22,7 @@ // The Python.h #include MUST always be #included first in every // compilation unit (.c or .cpp file). That means that PyHelpers.hpp -// must be #included first and transitively every .hpp file that +// must be #included first and transitively every .hpp file that // #includes directly or indirectly PyHelpers.hpp must be #included // first. #include @@ -32,88 +32,82 @@ #include #include -#include #include +#include -namespace nupic -{ - // ------------------------------------- - // - // G E T B A S I C T Y P E - // - // ------------------------------------- +namespace nupic { +// ------------------------------------- +// +// G E T B A S I C T Y P E +// +// ------------------------------------- - NTA_BasicType getBasicType(NTA_Byte) { return NTA_BasicType_Byte; } - NTA_BasicType getBasicType(NTA_Int16) { return NTA_BasicType_Int16; } - NTA_BasicType getBasicType(NTA_UInt16) { return NTA_BasicType_UInt16; } - NTA_BasicType getBasicType(NTA_Int32) { return NTA_BasicType_Int32; } - NTA_BasicType getBasicType(NTA_UInt32) { return NTA_BasicType_UInt32; } - NTA_BasicType getBasicType(NTA_Int64) { return NTA_BasicType_Int64; } - NTA_BasicType getBasicType(NTA_UInt64) { return NTA_BasicType_UInt64; } - NTA_BasicType getBasicType(NTA_Real32) { return NTA_BasicType_Real32; } - NTA_BasicType getBasicType(NTA_Real64) { return NTA_BasicType_Real64; } - NTA_BasicType getBasicType(bool) { return NTA_BasicType_Bool; } +NTA_BasicType getBasicType(NTA_Byte) { return NTA_BasicType_Byte; } +NTA_BasicType getBasicType(NTA_Int16) { return NTA_BasicType_Int16; } +NTA_BasicType getBasicType(NTA_UInt16) { return NTA_BasicType_UInt16; } +NTA_BasicType getBasicType(NTA_Int32) { return NTA_BasicType_Int32; } +NTA_BasicType getBasicType(NTA_UInt32) { return NTA_BasicType_UInt32; } +NTA_BasicType getBasicType(NTA_Int64) { return NTA_BasicType_Int64; } +NTA_BasicType getBasicType(NTA_UInt64) { return NTA_BasicType_UInt64; } +NTA_BasicType getBasicType(NTA_Real32) { return NTA_BasicType_Real32; } +NTA_BasicType getBasicType(NTA_Real64) { return NTA_BasicType_Real64; } +NTA_BasicType getBasicType(bool) { return NTA_BasicType_Bool; } - // ------------------------------------- - // - // A R R A Y 2 N U M P Y - // - // ------------------------------------- - // Wrap an Array object with a numpy array PyObject - PyObject * array2numpy(const ArrayBase & a) - { - npy_intp dims[1]; - dims[0] = npy_intp(a.getCount()); +// ------------------------------------- +// +// A R R A Y 2 N U M P Y +// +// ------------------------------------- +// Wrap an Array object with a numpy array PyObject +PyObject *array2numpy(const ArrayBase &a) { + npy_intp dims[1]; + dims[0] = npy_intp(a.getCount()); - NTA_BasicType t = a.getType(); - int dtype; - switch (t) - { - case NTA_BasicType_Byte: - dtype = NPY_INT8; - break; - case NTA_BasicType_Int16: - dtype = NPY_INT16; - break; - case NTA_BasicType_UInt16: - dtype = NPY_UINT16; - break; - case NTA_BasicType_Int32: - dtype = NPY_INT32; - break; - case NTA_BasicType_UInt32: - dtype = NPY_UINT32; - break; - case NTA_BasicType_Int64: - dtype = NPY_INT64; - break; - case NTA_BasicType_UInt64: - dtype = NPY_UINT64; - break; - case NTA_BasicType_Real32: - dtype = NPY_FLOAT32; - break; - case NTA_BasicType_Real64: - dtype = NPY_FLOAT64; - break; - case NTA_BasicType_Bool: - dtype = NPY_BOOL; - break; - default: - NTA_THROW << "Unknown basic type: " << t; - }; + NTA_BasicType t = a.getType(); + int dtype; + switch (t) { + case NTA_BasicType_Byte: + dtype = NPY_INT8; + break; + case NTA_BasicType_Int16: + dtype = NPY_INT16; + break; + case NTA_BasicType_UInt16: + dtype = NPY_UINT16; + break; + case NTA_BasicType_Int32: + dtype = NPY_INT32; + break; + case NTA_BasicType_UInt32: + dtype = NPY_UINT32; + break; + case NTA_BasicType_Int64: + dtype = NPY_INT64; + break; + case NTA_BasicType_UInt64: + dtype = NPY_UINT64; + break; + case NTA_BasicType_Real32: + dtype = NPY_FLOAT32; + break; + case NTA_BasicType_Real64: + dtype = NPY_FLOAT64; + break; + case NTA_BasicType_Bool: + dtype = NPY_BOOL; + break; + default: + NTA_THROW << "Unknown basic type: " << t; + }; - return (PyObject *)PyArray_SimpleNewFromData(1, - dims, - dtype, - a.getBuffer()); - } + return (PyObject *)PyArray_SimpleNewFromData(1, dims, dtype, a.getBuffer()); +} //// ------------------------------------- //// //// P Y A R R A Y B A S E //// -//// ------------------------------------- +//// ------------------------------------- // template // PyArrayBase::PyArrayBase() : A(getType()) // { @@ -128,23 +122,23 @@ namespace nupic // //PyArrayBase::PyArrayBase(A * a) : A(*a) // //{ // //} -// +// // template // NTA_BasicType PyArrayBase::getType() // { // T t = 0; // return getBasicType(t); // } -// +// // template // T PyArrayBase::__getitem__(int i) const -// { +// { // return ((T *)(A::getBuffer()))[i]; // } // // template // void PyArrayBase::__setitem__(int i, T x) -// { +// { // ((T *)(A::getBuffer()))[i] = x; // } // @@ -175,172 +169,136 @@ namespace nupic // PyObject * PyArrayBase::asNumpyArray() const // { // return array2numpy(*this); -// } +// } // ------------------------------------- // // P Y A R R A Y // // ------------------------------------- - template - PyArray::PyArray() : Array(getType()) - //PyArray::PyArray() : PyArrayBase() - { - } - - template - //PyArray::PyArray(size_t count) : PyArrayBase() - PyArray::PyArray(size_t count) : Array(getType()) - { - allocateBuffer(count); - } +template +PyArray::PyArray() + : Array(getType()) +// PyArray::PyArray() : PyArrayBase() +{} - template - NTA_BasicType PyArray::getType() - { - T t = 0; - return getBasicType(t); - } +template +// PyArray::PyArray(size_t count) : PyArrayBase() +PyArray::PyArray(size_t count) : Array(getType()) { + allocateBuffer(count); +} - template - T PyArray::__getitem__(int i) const - { - return ((T *)(getBuffer()))[i]; - //return PyArrayBase::__getitem__(i); - } +template NTA_BasicType PyArray::getType() { + T t = 0; + return getBasicType(t); +} - template - void PyArray::__setitem__(int i, T x) - { - ((T *)(getBuffer()))[i] = x; - //PyArrayBase::__setitem__(i, x); - } +template T PyArray::__getitem__(int i) const { + return ((T *)(getBuffer()))[i]; + // return PyArrayBase::__getitem__(i); +} - template - size_t PyArray::__len__() const - { - return getCount(); - //return PyArrayBase::__len__(); - } +template void PyArray::__setitem__(int i, T x) { + ((T *)(getBuffer()))[i] = x; + // PyArrayBase::__setitem__(i, x); +} - template - std::string PyArray::__repr__() const - { - std::stringstream ss; - ss << "[ "; - for (size_t i = 0; i < __len__(); ++i) - ss << __getitem__(i) << " "; - ss << "]"; - return ss.str(); +template size_t PyArray::__len__() const { + return getCount(); + // return PyArrayBase::__len__(); +} - //return PyArrayBase::__repr__(); - } +template std::string PyArray::__repr__() const { + std::stringstream ss; + ss << "[ "; + for (size_t i = 0; i < __len__(); ++i) + ss << __getitem__(i) << " "; + ss << "]"; + return ss.str(); - template - std::string PyArray::__str__() const - { - return __repr__(); - //return PyArrayBase::__str__(); - } + // return PyArrayBase::__repr__(); +} - template - PyObject * PyArray::asNumpyArray() const - { - return array2numpy(*this); - //return PyArrayBase::asNumpyArray(); - } +template std::string PyArray::__str__() const { + return __repr__(); + // return PyArrayBase::__str__(); +} + +template PyObject *PyArray::asNumpyArray() const { + return array2numpy(*this); + // return PyArrayBase::asNumpyArray(); +} // ------------------------------------- // // P Y A R R A Y R E F // // ------------------------------------- - template - PyArrayRef::PyArrayRef() : ArrayRef(getType()) - { - } - - template - PyArrayRef::PyArrayRef(const ArrayRef & a) : ArrayRef(a) - { - } +template PyArrayRef::PyArrayRef() : ArrayRef(getType()) {} - template - NTA_BasicType PyArrayRef::getType() - { - T t = 0; - return getBasicType(t); - } +template +PyArrayRef::PyArrayRef(const ArrayRef &a) : ArrayRef(a) {} - template - T PyArrayRef::__getitem__(int i) const - { - return ((T *)(getBuffer()))[i]; - //return PyArrayBase::__getitem__(i); - } - - template - void PyArrayRef::__setitem__(int i, T x) - { - ((T *)(getBuffer()))[i] = x; - //PyArrayBase::__setitem__(i, x); - } +template NTA_BasicType PyArrayRef::getType() { + T t = 0; + return getBasicType(t); +} - template - size_t PyArrayRef::__len__() const - { - return getCount(); - //return PyArrayBase::__len__(); - } +template T PyArrayRef::__getitem__(int i) const { + return ((T *)(getBuffer()))[i]; + // return PyArrayBase::__getitem__(i); +} - template - std::string PyArrayRef::__repr__() const - { - std::stringstream ss; - ss << "[ "; - for (size_t i = 0; i < __len__(); ++i) - ss << __getitem__(i) << " "; - ss << "]"; - return ss.str(); +template void PyArrayRef::__setitem__(int i, T x) { + ((T *)(getBuffer()))[i] = x; + // PyArrayBase::__setitem__(i, x); +} - //return PyArrayBase::__repr__(); - } +template size_t PyArrayRef::__len__() const { + return getCount(); + // return PyArrayBase::__len__(); +} - template - std::string PyArrayRef::__str__() const - { - return __repr__(); - //return PyArrayBase::__str__(); - } +template std::string PyArrayRef::__repr__() const { + std::stringstream ss; + ss << "[ "; + for (size_t i = 0; i < __len__(); ++i) + ss << __getitem__(i) << " "; + ss << "]"; + return ss.str(); - template - PyObject * PyArrayRef::asNumpyArray() const - { - return array2numpy(*this); - //return PyArrayBase::asNumpyArray(); - } + // return PyArrayBase::__repr__(); +} +template std::string PyArrayRef::__str__() const { + return __repr__(); + // return PyArrayBase::__str__(); +} - template class PyArray; - template class PyArray; - template class PyArray; - template class PyArray; - template class PyArray; - template class PyArray; - template class PyArray; - template class PyArray; - template class PyArray; - template class PyArray; - - template class PyArrayRef; - template class PyArrayRef; - template class PyArrayRef; - template class PyArrayRef; - template class PyArrayRef; - template class PyArrayRef; - template class PyArrayRef; - template class PyArrayRef; - template class PyArrayRef; - template class PyArrayRef; +template PyObject *PyArrayRef::asNumpyArray() const { + return array2numpy(*this); + // return PyArrayBase::asNumpyArray(); } +template class PyArray; +template class PyArray; +template class PyArray; +template class PyArray; +template class PyArray; +template class PyArray; +template class PyArray; +template class PyArray; +template class PyArray; +template class PyArray; + +template class PyArrayRef; +template class PyArrayRef; +template class PyArrayRef; +template class PyArrayRef; +template class PyArrayRef; +template class PyArrayRef; +template class PyArrayRef; +template class PyArrayRef; +template class PyArrayRef; +template class PyArrayRef; +} // namespace nupic diff --git a/src/nupic/py_support/PyArray.hpp b/src/nupic/py_support/PyArray.hpp index f96e58713d..fa7c0dd8d5 100644 --- a/src/nupic/py_support/PyArray.hpp +++ b/src/nupic/py_support/PyArray.hpp @@ -20,29 +20,29 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Definitions for the PyArrayBase class - * - * A PyArrayBase object is a Python compatible wrapper around an nupic::Array object - * - * It delegates everything to its Array and exposes a Python facade that - * includes: indexed access with operator[], len() support, An Array contains: - * - a pointer to a buffer - * - a length - * - a type - * - a flag indicating whether or not the object owns the buffer. - */ + * + * A PyArrayBase object is a Python compatible wrapper around an nupic::Array + * object + * + * It delegates everything to its Array and exposes a Python facade that + * includes: indexed access with operator[], len() support, An Array contains: + * - a pointer to a buffer + * - a length + * - a type + * - a flag indicating whether or not the object owns the buffer. + */ #ifndef NTA_PY_ARRAY_HPP #define NTA_PY_ARRAY_HPP -#include -#include #include #include +#include +#include -namespace nupic -{ +namespace nupic { // ------------------------------------- // // G E T B A S I C T Y P E @@ -66,21 +66,21 @@ NTA_BasicType getBasicType(bool); // // ------------------------------------- // Wrap an Array object with a numpy array PyObject -PyObject * array2numpy(const ArrayBase & a); +PyObject *array2numpy(const ArrayBase &a); // ------------------------------------- // // P Y A R R A Y B A S E // // ------------------------------------- -//template -//class PyArrayBase : public A +// template +// class PyArrayBase : public A //{ -//public: +// public: // PyArrayBase(); // //PyArrayBase(A * a); // //PyArrayBase(ArrayBase * a); -// +// // NTA_BasicType getType(); // T __getitem__(int i) const; // void __setitem__(int i, T x); @@ -95,14 +95,14 @@ PyObject * array2numpy(const ArrayBase & a); // P Y A R R A Y // // ------------------------------------- -template -class PyArray : public Array //public PyArrayBase +template +class PyArray : public Array // public PyArrayBase { public: PyArray(); PyArray(size_t count); - //PyArray(Array * a); - //PyArray(ArrayBase * a); + // PyArray(Array * a); + // PyArray(ArrayBase * a); NTA_BasicType getType(); T __getitem__(int i) const; @@ -110,33 +110,28 @@ class PyArray : public Array //public PyArrayBase size_t __len__() const; std::string __repr__() const; std::string __str__() const; - PyObject * asNumpyArray() const; + PyObject *asNumpyArray() const; }; - // ------------------------------------- // // P Y A R R A Y R E F // // ------------------------------------- -template -class PyArrayRef : public ArrayRef -{ +template class PyArrayRef : public ArrayRef { public: PyArrayRef(); - PyArrayRef(const ArrayRef & a); - + PyArrayRef(const ArrayRef &a); + NTA_BasicType getType(); T __getitem__(int i) const; void __setitem__(int i, T x); size_t __len__() const; std::string __repr__() const; std::string __str__() const; - PyObject * asNumpyArray() const; + PyObject *asNumpyArray() const; }; - -} +} // namespace nupic #endif - diff --git a/src/nupic/py_support/PyCapnp.hpp b/src/nupic/py_support/PyCapnp.hpp index d46f2b6181..b29897e161 100644 --- a/src/nupic/py_support/PyCapnp.hpp +++ b/src/nupic/py_support/PyCapnp.hpp @@ -27,113 +27,103 @@ #ifndef NTA_PY_CAPNP_HPP #define NTA_PY_CAPNP_HPP -#include // for std::logic_error +#include // for std::logic_error #include #if !CAPNP_LITE - #include - #include - #include - #include +#include +#include +#include +#include #endif // !CAPNP_LITE #include - - -namespace nupic -{ - - class PyCapnpHelper - { - public: - /** - * Serialize the given nupic::Serializable-based instance, returning a capnp - * byte buffer as python byte string. - * - * :param obj: The Serializable object - * - * :returns: capnp byte buffer encoded as python byte string. - * - * :example: PyObject* pyBytes = PyCapnpHelper::writeAsPyBytes(*netPtr); - */ - template - static PyObject* writeAsPyBytes(const nupic::Serializable& obj) - { - #if !CAPNP_LITE - capnp::MallocMessageBuilder message; - typename MessageType::Builder proto = message.initRoot(); - - obj.write(proto); - - // Extract message data and convert to Python byte object - kj::Array array = capnp::messageToFlatArray(message); // copy - kj::ArrayPtr byteArray = array.asBytes(); - PyObject* result = PyString_FromStringAndSize( - (const char*)byteArray.begin(), - byteArray.size()); // copy - return result; - #else - throw std::logic_error( +namespace nupic { + +class PyCapnpHelper { +public: + /** + * Serialize the given nupic::Serializable-based instance, returning a capnp + * byte buffer as python byte string. + * + * :param obj: The Serializable object + * + * :returns: capnp byte buffer encoded as python byte string. + * + * :example: PyObject* pyBytes = PyCapnpHelper::writeAsPyBytes(*netPtr); + */ + template + static PyObject *writeAsPyBytes(const nupic::Serializable &obj) { +#if !CAPNP_LITE + capnp::MallocMessageBuilder message; + typename MessageType::Builder proto = message.initRoot(); + + obj.write(proto); + + // Extract message data and convert to Python byte object + kj::Array array = capnp::messageToFlatArray(message); // copy + kj::ArrayPtr byteArray = array.asBytes(); + PyObject *result = + PyString_FromStringAndSize((const char *)byteArray.begin(), + byteArray.size()); // copy + return result; +#else + throw std::logic_error( "PyCapnpHelper::writeAsPyBytes is not implemented when " "compiled with CAPNP_LITE=1."); - #endif - } +#endif + } + + /** + * Initialize the given nupic::Serializable-based instance from the given + * Capnp message reader. + * + * :param pyBytes: The Serializable object + * + * :returns: capnp byte buffer encoded as python byte string. + * + * :example: PyCapnpHelper::initFromPyBytes(network, pyBytes); + */ + template + static void initFromPyBytes(nupic::Serializable &obj, + const PyObject *pyBytes) { +#if !CAPNP_LITE + const char *srcBytes = nullptr; + Py_ssize_t srcNumBytes = 0; - /** - * Initialize the given nupic::Serializable-based instance from the given - * Capnp message reader. - * - * :param pyBytes: The Serializable object - * - * :returns: capnp byte buffer encoded as python byte string. - * - * :example: PyCapnpHelper::initFromPyBytes(network, pyBytes); - */ - template - static void initFromPyBytes( - nupic::Serializable& obj, - const PyObject* pyBytes) - { - #if !CAPNP_LITE - const char* srcBytes = nullptr; - Py_ssize_t srcNumBytes = 0; - - // NOTE: srcBytes will be set to point to the internal buffer inside - // pyRegionProtoBytes' - PyString_AsStringAndSize(const_cast(pyBytes), - const_cast(&srcBytes), - &srcNumBytes); - - if (srcNumBytes % sizeof(capnp::word) != 0) - { - throw std::logic_error( + // NOTE: srcBytes will be set to point to the internal buffer inside + // pyRegionProtoBytes' + PyString_AsStringAndSize(const_cast(pyBytes), + const_cast(&srcBytes), &srcNumBytes); + + if (srcNumBytes % sizeof(capnp::word) != 0) { + throw std::logic_error( "PyCapnpHelper.initFromPyBytes input length must be a multiple of " "capnp::word."); - } - const int srcNumWords = srcNumBytes / sizeof(capnp::word); - - // Ensure alignment on capnp::word boundary; the buffer inside PyObject appears - // to be unaligned on capnp::word boundary. - kj::Array array = kj::heapArray(srcNumWords); - memcpy(array.asBytes().begin(), srcBytes, srcNumBytes); // copy - - capnp::FlatArrayMessageReader reader(array.asPtr()); // copy ? - typename MessageType::Reader proto = reader.getRoot(); - obj.read(proto); - #else - throw std::logic_error( + } + const int srcNumWords = srcNumBytes / sizeof(capnp::word); + + // Ensure alignment on capnp::word boundary; the buffer inside PyObject + // appears to be unaligned on capnp::word boundary. + kj::Array array = kj::heapArray(srcNumWords); + memcpy(array.asBytes().begin(), srcBytes, srcNumBytes); // copy + + capnp::FlatArrayMessageReader reader(array.asPtr()); // copy ? + typename MessageType::Reader proto = reader.getRoot(); + obj.read(proto); +#else + throw std::logic_error( "PyCapnpHelper::initFromPyBytes is not implemented when " "compiled with CAPNP_LITE=1."); - #endif - } - - }; // class PyCapnpHelper +#endif + } -} // namespace nupic +}; // class PyCapnpHelper +} // namespace nupic -#endif // NTA_PY_CAPNP_HPP +#endif // NTA_PY_CAPNP_HPP diff --git a/src/nupic/py_support/PyHelpers.cpp b/src/nupic/py_support/PyHelpers.cpp index ff213dc9eb..34b62c315e 100644 --- a/src/nupic/py_support/PyHelpers.cpp +++ b/src/nupic/py_support/PyHelpers.cpp @@ -20,781 +20,605 @@ * --------------------------------------------------------------------- */ - #include "PyHelpers.hpp" // Nested namespace nupic::py -namespace nupic { namespace py -{ - static bool runningUnderPython = false; - - void setRunningUnderPython() - { - runningUnderPython = true; - } - - // --- - // Get the stack trace from a Python tracebak object - // --- - static std::string getTraceback(PyObject * p) - { - if (!p) - return ""; - - std::stringstream ss; - - PyTracebackObject * tb = (PyTracebackObject *)p; - NTA_CHECK(PyTraceBack_Check(tb)); - - while (tb) - { - PyCodeObject * c = tb->tb_frame->f_code; - std::string filename(PyString_AsString(c->co_filename)); - std::string function(PyString_AsString(c->co_name)); - int lineno = tb->tb_lineno; - // Read source line from the file (assumes line is shorter than 256) - char line[256]; - std::ifstream f(filename.c_str()); - for (int i = 0; i < lineno; ++i) - { - f.getline(line, 256); - } - - ss << " File \" " - << filename - << ", line " << lineno << ", in " - << function - << std::endl - << std::string(line) - << std::endl; - - tb = tb->tb_next; - } - - return ss.str(); - } - - void checkPyError(int lineno) - { - if (!PyErr_Occurred()) - return; - - PyObject * exceptionClass = NULL; - PyObject * exceptionValue = NULL; - PyObject * exceptionTraceback = NULL; - - // Get the Python exception info - PyErr_Fetch(&exceptionClass, - &exceptionValue, - &exceptionTraceback); - - if (!exceptionValue) - { - NTA_THROW << "Python exception raised. Unable to extract info"; - } - - // Normalize the exception value to make sure - // it is an instance of the exception class - // (Python often goes crazy with exception values) - PyErr_NormalizeException(&exceptionClass, - &exceptionValue, - &exceptionTraceback); - - std::string exception; - std::string traceback; - if (exceptionValue) - { - // Extract the exception message as a string - PyObject * sExcValue = PyObject_Str(exceptionValue); - exception = std::string(PyString_AsString(sExcValue)); - traceback = getTraceback(exceptionTraceback); - - Py_XDECREF(sExcValue); - } - - // If running under Python restore the fetched exception so it can - // be handled at the Python interpreter level - - if (runningUnderPython) - { - PyErr_Restore(exceptionClass, exceptionValue, exceptionTraceback); - } - else - { - //PyErr_Clear(); - Py_XDECREF(exceptionClass); - Py_XDECREF(exceptionTraceback); - Py_XDECREF(exceptionValue); - } - - // Throw a correponding C++ exception - throw nupic::Exception(__FILE__, lineno, exception, traceback); - } - - // --- - // Implementation of Ptr class - // --- - Ptr::Ptr(PyObject * p, bool allowNULL) : p_(p), allowNULL_(allowNULL) - { - if (!p && !allowNULL) - NTA_THROW << "The PyObject * is NULL"; - } - - Ptr::~Ptr() - { - Py_XDECREF(p_); - } - - PyObject * Ptr::release() - { - PyObject * result = p_; - p_ = NULL; - return result; - } - - std::string Ptr::getTypeName() - { - if (!p_) - return "(NULL)"; - - // Do not wrap t in a Ptr object because it will break - // the tracing macros (with recursive calls) - PyObject * t = PyObject_Type(p_); - std::string result(((PyTypeObject *)t)->tp_name); - Py_DECREF(t); - - if (PyString_Check(p_)) - { - result += "\"" + std::string(PyString_AsString(p_)) + "\""; - } - return result; - } - - void Ptr::assign(PyObject * p) - { - // Identity check - if (p == p_) - return; - - // Check for NULL and allow NULL - NTA_CHECK(p || allowNULL_); - - // Verify that the new object type matches the current type - // unless one of them is NULL - if (p && p_) - { - NTA_CHECK(PyObject_Type(p_) == PyObject_Type(p)); - } - - // decrease ref count of the existing pointer if not NULL - Py_XDECREF(p_); - - // Assign the new pointer - p_ = p; - - // increment the ref count of the new pointer if not NULL - Py_XINCREF(p_); - } - - Ptr::operator PyObject *() - { - return p_; - } - - Ptr::operator const PyObject *() const - { - return p_; - } - - bool Ptr::isNULL() - { - return p_ == NULL; - } - - - // --- - // Implementation of String class - // --- - String::String(const std::string & s, bool allowNULL) : - Ptr(createString_(s.c_str(), s.size()), allowNULL) - { - } - - String::String(const char * s, size_t size, bool allowNULL) : - Ptr(createString_(s, size), allowNULL) - { - } - - String::String(const char * s, bool allowNULL) : - Ptr(createString_(s), allowNULL) - { - } - - String::String(PyObject * p) : Ptr(p) - { - NTA_CHECK(PyString_Check(p)); - } - - String::operator const char * () - { - if (!p_) - return NULL; - - return PyString_AsString(p_); - } - - PyObject * String::createString_(const char * s, size_t size) - { - if (size == 0) - { - NTA_CHECK(s) << "The input string must not be NULL when size == 0"; - size = ::strlen(s); - } - return PyString_FromStringAndSize(s, size); - } - - // --- - // Implementation of Int class - // --- - Int::Int(long n) : Ptr(PyInt_FromLong(n)) - { - } - - Int::Int(PyObject * p) : Ptr(p) - { - NTA_CHECK(PyInt_Check(p)); - } - - Int::operator long() - { - NTA_CHECK(p_); - return PyInt_AsLong(p_); - } - - - // --- - // Implementation of Long class - // --- - Long::Long(long n) : Ptr(PyInt_FromLong(n)) - { - } - - Long::Long(PyObject * p) : Ptr(p) - { - NTA_CHECK(PyLong_Check(p) || PyInt_Check(p)); - } - - Long::operator long() - { - NTA_CHECK(p_); - return PyInt_AsLong(p_); - } - - // --- - // Implementation of UnsignedLong class - // --- - UnsignedLong::UnsignedLong(unsigned long n) : Ptr(PyInt_FromLong((long)n)) - { - } - - UnsignedLong::UnsignedLong(PyObject * p) : Ptr(p) - { - NTA_CHECK(PyLong_Check(p) || PyInt_Check(p)); - } - - UnsignedLong::operator unsigned long() - { - NTA_CHECK(p_); - return (unsigned long)PyInt_AsLong(p_); - } - - // --- - // Implementation of LongLong class - // --- - LongLong::LongLong(long long n) : Ptr(PyLong_FromLongLong(n)) - { - } - - LongLong::LongLong(PyObject * p) : Ptr(p) - { - NTA_CHECK(PyLong_Check(p) || PyInt_Check(p)); - } - - LongLong::operator long long() - { - NTA_CHECK(p_); - return PyLong_AsLongLong(p_); - } - - - // --- - // Implementation of UnsignedLongLong class - // --- - UnsignedLongLong::UnsignedLongLong(unsigned long long n) : - Ptr(PyLong_FromUnsignedLongLong(n)) - { - } - - UnsignedLongLong::UnsignedLongLong(PyObject * p) : Ptr(p) - { - NTA_CHECK(PyLong_Check(p) || PyInt_Check(p)); - } - - UnsignedLongLong::operator unsigned long long() - { - NTA_CHECK(p_); - return PyLong_AsUnsignedLongLong(p_); - } - - // --- - // Implementation of Float class - // --- - Float::Float(const char * n) : Ptr(PyFloat_FromString(String(n), NULL)) - { - } - - Float::Float(double n) : Ptr(PyFloat_FromDouble(n)) - { - } - - Float::Float(PyObject * p) : Ptr(p) - { - NTA_CHECK(PyFloat_Check(p)); - } - - Float::operator double() - { - NTA_CHECK(p_); - - return PyFloat_AsDouble(p_); - } - - double Float::getMax() - { - return PyFloat_GetMax(); - } - - double Float::getMin() - { - return PyFloat_GetMin(); - } - - // --- - // Implementation of Bool class - // --- - - Bool::Bool(bool b) : Ptr(b ? Py_True : Py_False) - { - Py_XINCREF(p_); - } - - Bool::Bool(PyObject * p) : Ptr(p) - { - NTA_CHECK(PyBool_Check(p_)); - } - - Bool::operator bool() - { - NTA_CHECK(p_); - - if (p_ == Py_True) - { - return true; - } - else if (p_ == Py_False) - { - return false; +namespace nupic { +namespace py { +static bool runningUnderPython = false; + +void setRunningUnderPython() { runningUnderPython = true; } + +// --- +// Get the stack trace from a Python tracebak object +// --- +static std::string getTraceback(PyObject *p) { + if (!p) + return ""; + + std::stringstream ss; + + PyTracebackObject *tb = (PyTracebackObject *)p; + NTA_CHECK(PyTraceBack_Check(tb)); + + while (tb) { + PyCodeObject *c = tb->tb_frame->f_code; + std::string filename(PyString_AsString(c->co_filename)); + std::string function(PyString_AsString(c->co_name)); + int lineno = tb->tb_lineno; + // Read source line from the file (assumes line is shorter than 256) + char line[256]; + std::ifstream f(filename.c_str()); + for (int i = 0; i < lineno; ++i) { + f.getline(line, 256); } - else - { - NTA_THROW << "Invalid ptr"; - } - } - // --- - // Implementation of Tuple class - // --- - Tuple::Tuple(PyObject * p) : - Ptr(p), - size_(PyTuple_Size(p)) - { - } - - Tuple::Tuple(Py_ssize_t size) : - Ptr(PyTuple_New(size)), - size_(size) - { - } + ss << " File \" " << filename << ", line " << lineno << ", in " << function + << std::endl + << std::string(line) << std::endl; - void Tuple::assign(PyObject * p) - { - Ptr::assign(p); - size_ = PyTuple_Size(p); + tb = tb->tb_next; } - PyObject * Tuple::getItem(Py_ssize_t index) - { - NTA_CHECK(index < size_); - PyObject * p = PyTuple_GetItem(p_, index); - NTA_CHECK(p); - // Increment refcount of borrowed item (caller must DECREF) - Py_INCREF(p); - return p; - } + return ss.str(); +} - PyObject * Tuple::fastGetItem(Py_ssize_t index) - { - NTA_ASSERT(index < getCount()); - PyObject * p = PyTuple_GET_ITEM(p_, index); - NTA_ASSERT(p); - return p; - } - - void Tuple::setItem(Py_ssize_t index, PyObject * item) - { - NTA_CHECK(item); - NTA_CHECK(index < getCount()); +void checkPyError(int lineno) { + if (!PyErr_Occurred()) + return; - // Increment refcount, so caller still needs to DECREF the item - // eventhough PyTuple_SetItem steals the reference - Py_INCREF(item); + PyObject *exceptionClass = NULL; + PyObject *exceptionValue = NULL; + PyObject *exceptionTraceback = NULL; - //! item reference is stolen here - int res = PyTuple_SetItem(p_, index, item); - NTA_CHECK(res == 0); - } + // Get the Python exception info + PyErr_Fetch(&exceptionClass, &exceptionValue, &exceptionTraceback); - Py_ssize_t Tuple::getCount() - { - return PyTuple_Size(p_); + if (!exceptionValue) { + NTA_THROW << "Python exception raised. Unable to extract info"; } - // --- - // Implementation of List class - // --- - List::List(PyObject * p) : Ptr(p) - { - } + // Normalize the exception value to make sure + // it is an instance of the exception class + // (Python often goes crazy with exception values) + PyErr_NormalizeException(&exceptionClass, &exceptionValue, + &exceptionTraceback); + std::string exception; + std::string traceback; + if (exceptionValue) { + // Extract the exception message as a string + PyObject *sExcValue = PyObject_Str(exceptionValue); + exception = std::string(PyString_AsString(sExcValue)); + traceback = getTraceback(exceptionTraceback); - List::List() : Ptr(PyList_New(0)) - { + Py_XDECREF(sExcValue); } - PyObject * List::getItem(Py_ssize_t index) - { - NTA_CHECK(index < getCount()); - PyObject * p = PyList_GetItem(p_, index); - NTA_CHECK(p); - // Increment refcount of borrowed item (caller must DECREF) - Py_INCREF(p); - return p; - } + // If running under Python restore the fetched exception so it can + // be handled at the Python interpreter level - PyObject * List::fastGetItem(Py_ssize_t index) - { - NTA_ASSERT(index < getCount()); - PyObject * p = PyList_GET_ITEM(p_, index); - NTA_ASSERT(p); - return p; + if (runningUnderPython) { + PyErr_Restore(exceptionClass, exceptionValue, exceptionTraceback); + } else { + // PyErr_Clear(); + Py_XDECREF(exceptionClass); + Py_XDECREF(exceptionTraceback); + Py_XDECREF(exceptionValue); } - void List::setItem(Py_ssize_t index, PyObject * item) - { - NTA_CHECK(item); - NTA_CHECK(index < getCount()); - - // Increment refcount, so caller still needs to DECREF the item - // eventhough PyList_SetItem steals the reference - Py_INCREF(item); + // Throw a correponding C++ exception + throw nupic::Exception(__FILE__, lineno, exception, traceback); +} - //! item reference is stolen here - int res = PyList_SetItem(p_, index, item); - NTA_CHECK(res == 0); - } +// --- +// Implementation of Ptr class +// --- +Ptr::Ptr(PyObject *p, bool allowNULL) : p_(p), allowNULL_(allowNULL) { + if (!p && !allowNULL) + NTA_THROW << "The PyObject * is NULL"; +} - void List::append(PyObject * item) - { - NTA_CHECK(item); +Ptr::~Ptr() { Py_XDECREF(p_); } - int res = PyList_Append(p_, item); - NTA_CHECK(res == 0); - } +PyObject *Ptr::release() { + PyObject *result = p_; + p_ = NULL; + return result; +} - Py_ssize_t List::getCount() - { - return PyList_Size(p_); - } +std::string Ptr::getTypeName() { + if (!p_) + return "(NULL)"; - // --- - // Implementation of Dict class - // --- - Dict::Dict() : Ptr(PyDict_New()) - { - } + // Do not wrap t in a Ptr object because it will break + // the tracing macros (with recursive calls) + PyObject *t = PyObject_Type(p_); + std::string result(((PyTypeObject *)t)->tp_name); + Py_DECREF(t); - Dict::Dict(PyObject * dict) : Ptr(dict) - { - NTA_CHECK(PyDict_Check(dict)); + if (PyString_Check(p_)) { + result += "\"" + std::string(PyString_AsString(p_)) + "\""; } + return result; +} - PyObject * Dict::getItem(const std::string & name, PyObject * defaultItem) - { - // PyDict_GetItem() returns a borrowed reference - PyObject * pItem = PyDict_GetItem(p_, String(name)); - if (!pItem) - return defaultItem; +void Ptr::assign(PyObject *p) { + // Identity check + if (p == p_) + return; - // Increment ref count, so the caller has to call DECREF - Py_INCREF(pItem); + // Check for NULL and allow NULL + NTA_CHECK(p || allowNULL_); - return pItem; + // Verify that the new object type matches the current type + // unless one of them is NULL + if (p && p_) { + NTA_CHECK(PyObject_Type(p_) == PyObject_Type(p)); } - void Dict::setItem(const std::string & name, PyObject * pItem) - { - int res = PyDict_SetItem(p_, String(name), pItem); - NTA_CHECK(res == 0); - } + // decrease ref count of the existing pointer if not NULL + Py_XDECREF(p_); + // Assign the new pointer + p_ = p; - // --- - // Implementation of Module class - // --- - Module::Module(const std::string & moduleName) : - Ptr(createModule_(moduleName)) - { - } - - // Invoke a module method. Equivalent to: - // return module.method(*args, **kwargs) - // This code is identical to Instance::invoke - PyObject * Module::invoke(std::string method, - PyObject * args, - PyObject * kwargs) const - { - NTA_CHECK(p_); - PyObject * pMethod = getAttr(method); - NTA_CHECK(PyCallable_Check(pMethod)); - - Ptr m(pMethod); - PyObject * result = PyObject_Call(m, args, kwargs); - checkPyError(__LINE__); - NTA_CHECK(result); - return result; - } + // increment the ref count of the new pointer if not NULL + Py_XINCREF(p_); +} - // Get an module attribute. Equivalent to: - // - // return module.name - // This code is identical to Instance::getAttr - PyObject * Module::getAttr(std::string name) const - { - NTA_CHECK(p_); - PyObject * attr = PyObject_GetAttrString(p_, name.c_str()); - checkPyError(__LINE__); - NTA_CHECK(attr); - return attr; - } +Ptr::operator PyObject *() { return p_; } - PyObject * Module::createModule_(const std::string & moduleName) - { - String name(moduleName); - // Import the module - PyObject * pModule = PyImport_Import(name); - checkPyError(__LINE__); +Ptr::operator const PyObject *() const { return p_; } - if (pModule == NULL || !(PyModule_Check(pModule))) - { - NTA_THROW << "Unable to import module: " << moduleName; - } - return pModule; - } +bool Ptr::isNULL() { return p_ == NULL; } - // --- - // Implementation of Class class - // --- - // Wraps a Python class object and allows invoking class methods - // The constructor is equivalent the following Python code: - // - // from moduleName import className - Class::Class(const std::string & moduleName, const std::string & className) : - Ptr(createClass_(Module(moduleName), className)) - { - } - - Class::Class(PyObject * pModule, const std::string & className) : - Ptr(createClass_(pModule, className)) - { - } +// --- +// Implementation of String class +// --- +String::String(const std::string &s, bool allowNULL) + : Ptr(createString_(s.c_str(), s.size()), allowNULL) {} - // The invoke() method is equivalent the following Python code - // which invokes a class method: - // - // return className.method(*args, **kwargs) - PyObject * Class::invoke(std::string method, - PyObject * args, - PyObject * kwargs) - { - NTA_CHECK(p_); - PyObject * pMethod = PyObject_GetAttrString(p_, method.c_str()); - //printPythonError(); - NTA_CHECK(pMethod); - NTA_CHECK(PyCallable_Check(pMethod)); - - Ptr m(pMethod); - PyObject * result = PyObject_Call(m, args, kwargs); - checkPyError(__LINE__); - - return result; - } +String::String(const char *s, size_t size, bool allowNULL) + : Ptr(createString_(s, size), allowNULL) {} - PyObject * Class::createClass_(PyObject * pModule, const std::string & className) - { - // Get the node class from the module (as a new reference) - PyObject * pClass = PyObject_GetAttrString(pModule, className.c_str()); - NTA_CHECK(pClass && PyType_Check(pClass)); +String::String(const char *s, bool allowNULL) + : Ptr(createString_(s), allowNULL) {} - return pClass; - } +String::String(PyObject *p) : Ptr(p) { NTA_CHECK(PyString_Check(p)); } - // --- - // Implementation of Instance class - // --- - // Wraps an instance of a Python object - - // A constructor that takes an existing PyObject * (can be NULL too) - Instance::Instance(PyObject * p) : Ptr(p, p == NULL) - { - } +String::operator const char *() { + if (!p_) + return NULL; - // A constructor that instantiates an instance of the requested - // class from the requested module. Equivalent to the follwoing: - // - // from moduleName import className - // instance = className(*args, **kwargs) - Instance::Instance(const std::string & moduleName, - const std::string & className, - PyObject * args, - PyObject * kwargs) : - Ptr(createInstance_(Class(moduleName, className), - args, - kwargs)) - { - } - - Instance::Instance(PyObject * pClass, - PyObject * args, - PyObject * kwargs) : - Ptr(createInstance_(pClass, args, kwargs)) - { - } + return PyString_AsString(p_); +} - // Return true if the instance has attribute 'name' and false otherwise - bool Instance::hasAttr(std::string name) - { - checkPyError(__LINE__); - NTA_CHECK(p_); - return PyObject_HasAttrString(p_, name.c_str()) != 0; +PyObject *String::createString_(const char *s, size_t size) { + if (size == 0) { + NTA_CHECK(s) << "The input string must not be NULL when size == 0"; + size = ::strlen(s); } + return PyString_FromStringAndSize(s, size); +} - // Get an instance attribute. Equivalent to: - // - // return instance.name - PyObject * Instance::getAttr(std::string name) const - { - NTA_CHECK(p_); - PyObject * attr = PyObject_GetAttrString(p_, name.c_str()); - checkPyError(__LINE__); - NTA_CHECK(attr); - - return attr; - } +// --- +// Implementation of Int class +// --- +Int::Int(long n) : Ptr(PyInt_FromLong(n)) {} - // Set an instance attribute. Equivalent to: - // - // return instance.name = value - void Instance::setAttr(std::string name, PyObject * value) - { - NTA_CHECK(p_); - int rc = PyObject_SetAttrString(p_, name.c_str(), value); - checkPyError(__LINE__); - NTA_CHECK(rc != -1); - } +Int::Int(PyObject *p) : Ptr(p) { NTA_CHECK(PyInt_Check(p)); } +Int::operator long() { + NTA_CHECK(p_); + return PyInt_AsLong(p_); +} - // Return a string representation of an instance. Equivalent to: - // - // return str(instance) - PyObject * Instance::toString() - { - NTA_CHECK(p_); - PyObject * s = PyObject_Str(p_); - checkPyError(__LINE__); - NTA_CHECK(s); +// --- +// Implementation of Long class +// --- +Long::Long(long n) : Ptr(PyInt_FromLong(n)) {} - return s; - } +Long::Long(PyObject *p) : Ptr(p) { + NTA_CHECK(PyLong_Check(p) || PyInt_Check(p)); +} - // Invoke an instance method. Equivalent to: - // - // return instance.method(*args, **kwargs) - PyObject * Instance::invoke(std::string method, - PyObject * args, - PyObject * kwargs) const - { - NTA_CHECK(p_); - PyObject * pMethod = getAttr(method); - NTA_CHECK(PyCallable_Check(pMethod)); - - Ptr m(pMethod); - PyObject * result = PyObject_Call(m, args, kwargs); - checkPyError(__LINE__); - NTA_CHECK(result); - return result; - } +Long::operator long() { + NTA_CHECK(p_); + return PyInt_AsLong(p_); +} - PyObject * Instance::createInstance_(PyObject * pClass, - PyObject * args, - PyObject * kwargs) - { - NTA_CHECK(pClass && PyCallable_Check(pClass)); - NTA_CHECK(args && PyTuple_Check(args)); - NTA_CHECK(!kwargs || PyDict_Check(kwargs)); - - PyObject * pInstance = PyObject_Call(pClass, args, kwargs); - checkPyError(__LINE__); - NTA_CHECK(pInstance); - - return pInstance; - } +// --- +// Implementation of UnsignedLong class +// --- +UnsignedLong::UnsignedLong(unsigned long n) : Ptr(PyInt_FromLong((long)n)) {} - //// --- - //// Raise a Python RuntimeError exception from C++ with - //// an error message and an optional stack trace. The - //// stack trace will be added as a custom attribute of the - //// exception object. - //// --- - //void setPyError(const char * message, const char * stackTrace = NULL) - //{ - // // Create a RuntimeError Python exception object - // Tuple args(1); - // args.setItem(0, py::String(message)); - // Instance e(PyExc_RuntimeError, args); +UnsignedLong::UnsignedLong(PyObject *p) : Ptr(p) { + NTA_CHECK(PyLong_Check(p) || PyInt_Check(p)); +} - // // Add a new attribute "stackTrace" if available - // if (stackTrace) - // { - // e.setAttr("stackTrace", py::String(stackTrace)); - // } +UnsignedLong::operator unsigned long() { + NTA_CHECK(p_); + return (unsigned long)PyInt_AsLong(p_); +} - // // Set the Python error. Equivalent to: "raise e" - // PyErr_SetObject(PyExc_RuntimeError, e); - //} +// --- +// Implementation of LongLong class +// --- +LongLong::LongLong(long long n) : Ptr(PyLong_FromLongLong(n)) {} +LongLong::LongLong(PyObject *p) : Ptr(p) { + NTA_CHECK(PyLong_Check(p) || PyInt_Check(p)); +} +LongLong::operator long long() { + NTA_CHECK(p_); + return PyLong_AsLongLong(p_); +} -} } // end of nupic::py namespace +// --- +// Implementation of UnsignedLongLong class +// --- +UnsignedLongLong::UnsignedLongLong(unsigned long long n) + : Ptr(PyLong_FromUnsignedLongLong(n)) {} +UnsignedLongLong::UnsignedLongLong(PyObject *p) : Ptr(p) { + NTA_CHECK(PyLong_Check(p) || PyInt_Check(p)); +} +UnsignedLongLong::operator unsigned long long() { + NTA_CHECK(p_); + return PyLong_AsUnsignedLongLong(p_); +} +// --- +// Implementation of Float class +// --- +Float::Float(const char *n) : Ptr(PyFloat_FromString(String(n), NULL)) {} + +Float::Float(double n) : Ptr(PyFloat_FromDouble(n)) {} + +Float::Float(PyObject *p) : Ptr(p) { NTA_CHECK(PyFloat_Check(p)); } + +Float::operator double() { + NTA_CHECK(p_); + + return PyFloat_AsDouble(p_); +} + +double Float::getMax() { return PyFloat_GetMax(); } + +double Float::getMin() { return PyFloat_GetMin(); } + +// --- +// Implementation of Bool class +// --- + +Bool::Bool(bool b) : Ptr(b ? Py_True : Py_False) { Py_XINCREF(p_); } + +Bool::Bool(PyObject *p) : Ptr(p) { NTA_CHECK(PyBool_Check(p_)); } + +Bool::operator bool() { + NTA_CHECK(p_); + + if (p_ == Py_True) { + return true; + } else if (p_ == Py_False) { + return false; + } else { + NTA_THROW << "Invalid ptr"; + } +} + +// --- +// Implementation of Tuple class +// --- +Tuple::Tuple(PyObject *p) : Ptr(p), size_(PyTuple_Size(p)) {} + +Tuple::Tuple(Py_ssize_t size) : Ptr(PyTuple_New(size)), size_(size) {} + +void Tuple::assign(PyObject *p) { + Ptr::assign(p); + size_ = PyTuple_Size(p); +} + +PyObject *Tuple::getItem(Py_ssize_t index) { + NTA_CHECK(index < size_); + PyObject *p = PyTuple_GetItem(p_, index); + NTA_CHECK(p); + // Increment refcount of borrowed item (caller must DECREF) + Py_INCREF(p); + return p; +} + +PyObject *Tuple::fastGetItem(Py_ssize_t index) { + NTA_ASSERT(index < getCount()); + PyObject *p = PyTuple_GET_ITEM(p_, index); + NTA_ASSERT(p); + return p; +} + +void Tuple::setItem(Py_ssize_t index, PyObject *item) { + NTA_CHECK(item); + NTA_CHECK(index < getCount()); + + // Increment refcount, so caller still needs to DECREF the item + // eventhough PyTuple_SetItem steals the reference + Py_INCREF(item); + + //! item reference is stolen here + int res = PyTuple_SetItem(p_, index, item); + NTA_CHECK(res == 0); +} + +Py_ssize_t Tuple::getCount() { return PyTuple_Size(p_); } + +// --- +// Implementation of List class +// --- +List::List(PyObject *p) : Ptr(p) {} + +List::List() : Ptr(PyList_New(0)) {} + +PyObject *List::getItem(Py_ssize_t index) { + NTA_CHECK(index < getCount()); + PyObject *p = PyList_GetItem(p_, index); + NTA_CHECK(p); + // Increment refcount of borrowed item (caller must DECREF) + Py_INCREF(p); + return p; +} + +PyObject *List::fastGetItem(Py_ssize_t index) { + NTA_ASSERT(index < getCount()); + PyObject *p = PyList_GET_ITEM(p_, index); + NTA_ASSERT(p); + return p; +} + +void List::setItem(Py_ssize_t index, PyObject *item) { + NTA_CHECK(item); + NTA_CHECK(index < getCount()); + + // Increment refcount, so caller still needs to DECREF the item + // eventhough PyList_SetItem steals the reference + Py_INCREF(item); + + //! item reference is stolen here + int res = PyList_SetItem(p_, index, item); + NTA_CHECK(res == 0); +} + +void List::append(PyObject *item) { + NTA_CHECK(item); + + int res = PyList_Append(p_, item); + NTA_CHECK(res == 0); +} + +Py_ssize_t List::getCount() { return PyList_Size(p_); } + +// --- +// Implementation of Dict class +// --- +Dict::Dict() : Ptr(PyDict_New()) {} + +Dict::Dict(PyObject *dict) : Ptr(dict) { NTA_CHECK(PyDict_Check(dict)); } + +PyObject *Dict::getItem(const std::string &name, PyObject *defaultItem) { + // PyDict_GetItem() returns a borrowed reference + PyObject *pItem = PyDict_GetItem(p_, String(name)); + if (!pItem) + return defaultItem; + + // Increment ref count, so the caller has to call DECREF + Py_INCREF(pItem); + + return pItem; +} + +void Dict::setItem(const std::string &name, PyObject *pItem) { + int res = PyDict_SetItem(p_, String(name), pItem); + NTA_CHECK(res == 0); +} + +// --- +// Implementation of Module class +// --- +Module::Module(const std::string &moduleName) + : Ptr(createModule_(moduleName)) {} + +// Invoke a module method. Equivalent to: +// return module.method(*args, **kwargs) +// This code is identical to Instance::invoke +PyObject *Module::invoke(std::string method, PyObject *args, + PyObject *kwargs) const { + NTA_CHECK(p_); + PyObject *pMethod = getAttr(method); + NTA_CHECK(PyCallable_Check(pMethod)); + + Ptr m(pMethod); + PyObject *result = PyObject_Call(m, args, kwargs); + checkPyError(__LINE__); + NTA_CHECK(result); + return result; +} + +// Get an module attribute. Equivalent to: +// +// return module.name +// This code is identical to Instance::getAttr +PyObject *Module::getAttr(std::string name) const { + NTA_CHECK(p_); + PyObject *attr = PyObject_GetAttrString(p_, name.c_str()); + checkPyError(__LINE__); + NTA_CHECK(attr); + return attr; +} + +PyObject *Module::createModule_(const std::string &moduleName) { + String name(moduleName); + // Import the module + PyObject *pModule = PyImport_Import(name); + checkPyError(__LINE__); + + if (pModule == NULL || !(PyModule_Check(pModule))) { + NTA_THROW << "Unable to import module: " << moduleName; + } + return pModule; +} + +// --- +// Implementation of Class class +// --- +// Wraps a Python class object and allows invoking class methods +// The constructor is equivalent the following Python code: +// +// from moduleName import className +Class::Class(const std::string &moduleName, const std::string &className) + : Ptr(createClass_(Module(moduleName), className)) {} + +Class::Class(PyObject *pModule, const std::string &className) + : Ptr(createClass_(pModule, className)) {} + +// The invoke() method is equivalent the following Python code +// which invokes a class method: +// +// return className.method(*args, **kwargs) +PyObject *Class::invoke(std::string method, PyObject *args, PyObject *kwargs) { + NTA_CHECK(p_); + PyObject *pMethod = PyObject_GetAttrString(p_, method.c_str()); + // printPythonError(); + NTA_CHECK(pMethod); + NTA_CHECK(PyCallable_Check(pMethod)); + + Ptr m(pMethod); + PyObject *result = PyObject_Call(m, args, kwargs); + checkPyError(__LINE__); + + return result; +} + +PyObject *Class::createClass_(PyObject *pModule, const std::string &className) { + // Get the node class from the module (as a new reference) + PyObject *pClass = PyObject_GetAttrString(pModule, className.c_str()); + NTA_CHECK(pClass && PyType_Check(pClass)); + + return pClass; +} + +// --- +// Implementation of Instance class +// --- +// Wraps an instance of a Python object + +// A constructor that takes an existing PyObject * (can be NULL too) +Instance::Instance(PyObject *p) : Ptr(p, p == NULL) {} + +// A constructor that instantiates an instance of the requested +// class from the requested module. Equivalent to the follwoing: +// +// from moduleName import className +// instance = className(*args, **kwargs) +Instance::Instance(const std::string &moduleName, const std::string &className, + PyObject *args, PyObject *kwargs) + : Ptr(createInstance_(Class(moduleName, className), args, kwargs)) {} + +Instance::Instance(PyObject *pClass, PyObject *args, PyObject *kwargs) + : Ptr(createInstance_(pClass, args, kwargs)) {} + +// Return true if the instance has attribute 'name' and false otherwise +bool Instance::hasAttr(std::string name) { + checkPyError(__LINE__); + NTA_CHECK(p_); + return PyObject_HasAttrString(p_, name.c_str()) != 0; +} + +// Get an instance attribute. Equivalent to: +// +// return instance.name +PyObject *Instance::getAttr(std::string name) const { + NTA_CHECK(p_); + PyObject *attr = PyObject_GetAttrString(p_, name.c_str()); + checkPyError(__LINE__); + NTA_CHECK(attr); + + return attr; +} + +// Set an instance attribute. Equivalent to: +// +// return instance.name = value +void Instance::setAttr(std::string name, PyObject *value) { + NTA_CHECK(p_); + int rc = PyObject_SetAttrString(p_, name.c_str(), value); + checkPyError(__LINE__); + NTA_CHECK(rc != -1); +} + +// Return a string representation of an instance. Equivalent to: +// +// return str(instance) +PyObject *Instance::toString() { + NTA_CHECK(p_); + PyObject *s = PyObject_Str(p_); + checkPyError(__LINE__); + NTA_CHECK(s); + + return s; +} + +// Invoke an instance method. Equivalent to: +// +// return instance.method(*args, **kwargs) +PyObject *Instance::invoke(std::string method, PyObject *args, + PyObject *kwargs) const { + NTA_CHECK(p_); + PyObject *pMethod = getAttr(method); + NTA_CHECK(PyCallable_Check(pMethod)); + + Ptr m(pMethod); + PyObject *result = PyObject_Call(m, args, kwargs); + checkPyError(__LINE__); + NTA_CHECK(result); + return result; +} + +PyObject *Instance::createInstance_(PyObject *pClass, PyObject *args, + PyObject *kwargs) { + NTA_CHECK(pClass && PyCallable_Check(pClass)); + NTA_CHECK(args && PyTuple_Check(args)); + NTA_CHECK(!kwargs || PyDict_Check(kwargs)); + + PyObject *pInstance = PyObject_Call(pClass, args, kwargs); + checkPyError(__LINE__); + NTA_CHECK(pInstance); + + return pInstance; +} + +//// --- +//// Raise a Python RuntimeError exception from C++ with +//// an error message and an optional stack trace. The +//// stack trace will be added as a custom attribute of the +//// exception object. +//// --- +// void setPyError(const char * message, const char * stackTrace = NULL) +//{ +// // Create a RuntimeError Python exception object +// Tuple args(1); +// args.setItem(0, py::String(message)); +// Instance e(PyExc_RuntimeError, args); + +// // Add a new attribute "stackTrace" if available +// if (stackTrace) +// { +// e.setAttr("stackTrace", py::String(stackTrace)); +// } + +// // Set the Python error. Equivalent to: "raise e" +// PyErr_SetObject(PyExc_RuntimeError, e); +//} + +} // namespace py +} // namespace nupic diff --git a/src/nupic/py_support/PyHelpers.hpp b/src/nupic/py_support/PyHelpers.hpp index 6a69bb586f..fe849ebbee 100644 --- a/src/nupic/py_support/PyHelpers.hpp +++ b/src/nupic/py_support/PyHelpers.hpp @@ -25,15 +25,15 @@ // The Python.h #include MUST always be #included first in every // compilation unit (.c or .cpp file). That means that PyHelpers.hpp -// must be #included first and transitively every .hpp file that +// must be #included first and transitively every .hpp file that // #includes directly or indirectly PyHelpers.hpp must be #included // first. #include #include -#include -#include #include +#include +#include // === // --------------------------------- @@ -43,22 +43,22 @@ // --------------------------------- // // The PyHelpers module contain C++ classes that wrap Python C-API -// objects and allows working with Python C-API in a safe and user +// objects and allows working with Python C-API in a safe and user // friendly manner. Please refer to the Python C-API docs for general // background: http://docs.python.org/c-api/ // // The purpose of this module is to support internally the C++ PyNode -// class and it implements exactly the classes needed by PyNode. +// class and it implements exactly the classes needed by PyNode. // It is not a comprehensive wrapper around the Python C-API. -// -//The follwoing classes are implemented in the +// +// The follwoing classes are implemented in the // namespac nupic::py // // Ptr: // A class that manages a PyObject pointer and serves as base class // for all other helper objects -// -// Int, Long, LongLong, UnsignedLong, UnsignedLongLong: +// +// Int, Long, LongLong, UnsignedLong, UnsignedLongLong: // Integral types that match the corresponding C integral types and // map to either the Python int or Python long types. They provide // constructors and conversion operators that were chosen to reflect @@ -68,13 +68,13 @@ // Float: // A floating point number class that maps a double precision C double // to a Python float. It provides a constructor and conversion operator. -// Python has just one floating point type (not including numpy and +// Python has just one floating point type (not including numpy and // the complex type). // // Bool: // A boolean class that maps a C bool to a Python bool. -// -// String: +// +// String: // A string type that maps to the Python string and provides // a constractor and conversion operator for the C char * type. // @@ -84,334 +84,308 @@ // // Module, Class, Instance: // Types for working with the Python object system. Module is for importing -// modules. Class is for invoking class methods and Instance is for +// modules. Class is for invoking class methods and Instance is for // instantiating objects and invoking their methods. // === // Nested namespace nupic::py -namespace nupic { namespace py -{ - void setRunningUnderPython(); - - - void checkPyError(int lineno); - - // A RAII class to hold a PyObject * - // It decrements the refCount when it - // is destroyed. +namespace nupic { +namespace py { +void setRunningUnderPython(); + +void checkPyError(int lineno); + +// A RAII class to hold a PyObject * +// It decrements the refCount when it +// is destroyed. +// +// Ptr objects can be passed directly to Python C API calls. +// Most functions just use the object and when the function +// returns the refCount stays the same. Some functions take +// ownership of the PyObject pointer. In this case you should +// call release(). +// +// The Ptr class also serves as the base class for all the other helper +// classes that rely on it to manage the underlying PyObject * and +// just add type checks and type-specific constructors, conversion +// operator and special methods. +// +// In general, the specific sub-classes should be used to store Python +// objects (e.g use PyDict to store a dict object). You should use PyPtr +// directly only when storing Python objects that don't have a +// corresponding PyHelpers sub-class. +class Ptr { +public: + // Constructor of the Ptr class // - // Ptr objects can be passed directly to Python C API calls. - // Most functions just use the object and when the function - // returns the refCount stays the same. Some functions take - // ownership of the PyObject pointer. In this case you should - // call release(). + // PyObject * p - the managed pointer (ref count not incremented) + // bool allowNULL - If false (default) p must not be NULL // - // The Ptr class also serves as the base class for all the other helper - // classes that rely on it to manage the underlying PyObject * and - // just add type checks and type-specific constructors, conversion - // operator and special methods. + Ptr(PyObject *p, bool allowNULL = false); + + virtual ~Ptr(); + + // Relinquish ownership of this object. Call release() + // if you pass the Ptr to a function that takes ownership + // like PyTuple_SetItem() + PyObject *release(); + std::string getTypeName(); + void assign(PyObject *p); + operator PyObject *(); + operator const PyObject *() const; + bool isNULL(); + +private: + // This constructors MUST never be implemented to maintain + // the integrity of the Ptr class. As a RAII class you don't + // want copies sharing (and later releasing) the same pointer. + Ptr(); + Ptr(const Ptr &); + +protected: + PyObject *p_; + bool allowNULL_; +}; + +// String +class String : public Ptr { +public: + explicit String(const std::string &s, bool allowNULL = false); + explicit String(const char *s, size_t size, bool allowNULL = false); + explicit String(const char *s, bool allowNULL = false); + explicit String(PyObject *p); + + operator const char *(); + +private: + PyObject *createString_(const char *s, size_t size = 0); +}; + +// Int +class Int : public Ptr { +public: + Int(long n); + Int(PyObject *p); + operator long(); +}; + +// Long +class Long : public Ptr { +public: + Long(long n); + Long(PyObject *p); + operator long(); +}; + +// UnsignedLong +class UnsignedLong : public Ptr { +public: + UnsignedLong(unsigned long n); + UnsignedLong(PyObject *p); + operator unsigned long(); +}; + +// LongLong +class LongLong : public Ptr { +public: + LongLong(long long n); + LongLong(PyObject *p); + operator long long(); +}; + +// UnsignedLongLong +class UnsignedLongLong : public Ptr { +public: + UnsignedLongLong(unsigned long long n); + UnsignedLongLong(PyObject *p); + operator unsigned long long(); +}; + +// Float +class Float : public Ptr { +public: + Float(const char *n); + Float(double n); + Float(PyObject *p); + operator double(); + static double getMax(); + static double getMin(); +}; + +// Bool +class Bool : public Ptr { +public: + Bool(bool b); + Bool(PyObject *p); + operator bool(); +}; + +// Tuple +class Tuple : public Ptr { +public: + Tuple(Py_ssize_t size = 0); + Tuple(PyObject *p); + void assign(PyObject *); + + PyObject *getItem(Py_ssize_t index); + + // The fast version of getItem() doesn't do bounds checking + // and doesn't increment the ref count of returned item. The original + // tuple still owns the item. If you try to access out of bounds item + // you will crash (or worse). If you assign the returned item to a py::Ptr + // or a sub-class you MUST call release() to prevent the py::Ptr from + // decrementing the ref count + PyObject *fastGetItem(Py_ssize_t index); + void setItem(Py_ssize_t index, PyObject *item); + Py_ssize_t getCount(); + +private: + Py_ssize_t size_; +}; + +// List +class List : public Ptr { +public: + List(); + List(PyObject *); + PyObject *getItem(Py_ssize_t index); + + // The fast version of getItem() doesn't do bounds checking + // and doesn't increment the ref count of returned item. The original + // list still owns the item. If you try to access out of bounds item + // you will crash (or worse). If you assign the returned item to a py::Ptr + // or a sub-class you MUST call release() to prevent the py::Ptr from + // decrementing the ref count + PyObject *fastGetItem(Py_ssize_t index); + void setItem(Py_ssize_t index, PyObject *item); + void append(PyObject *item); + Py_ssize_t getCount(); +}; + +// Dict +class Dict : public Ptr { +public: + Dict(); + Dict(PyObject *dict); + PyObject *getItem(const std::string &name, PyObject *defaultItem = NULL); + void setItem(const std::string &name, PyObject *pItem); +}; + +// Module +// +// Wraps a Python module object and can import by name +// The Python interpreter sys.path must contain the +// requested module. +class Module : public Ptr { +public: + Module(const std::string &moduleName); + + // Invoke a module method. Equivalent to: + // return module.method(*args, **kwargs) + // This code is identical to Instance::invoke + PyObject *invoke(std::string method, PyObject *args, + PyObject *kwargs = NULL) const; + // Get an module attribute. Equivalent to: // - // In general, the specific sub-classes should be used to store Python - // objects (e.g use PyDict to store a dict object). You should use PyPtr - // directly only when storing Python objects that don't have a - // corresponding PyHelpers sub-class. - class Ptr - { - public: - // Constructor of the Ptr class - // - // PyObject * p - the managed pointer (ref count not incremented) - // bool allowNULL - If false (default) p must not be NULL - // - Ptr(PyObject * p, bool allowNULL = false); - - virtual ~Ptr(); - - // Relinquish ownership of this object. Call release() - // if you pass the Ptr to a function that takes ownership - // like PyTuple_SetItem() - PyObject * release(); - std::string getTypeName(); - void assign(PyObject * p); - operator PyObject *(); - operator const PyObject *() const; - bool isNULL(); - private: - // This constructors MUST never be implemented to maintain - // the integrity of the Ptr class. As a RAII class you don't - // want copies sharing (and later releasing) the same pointer. - Ptr(); - Ptr(const Ptr &); - - protected: - PyObject * p_; - bool allowNULL_; - }; - - - // String - class String : public Ptr - { - public: - explicit String(const std::string & s, bool allowNULL = false); - explicit String(const char * s, size_t size, bool allowNULL = false); - explicit String(const char * s, bool allowNULL = false); - explicit String(PyObject * p); - - operator const char * (); - - private: - PyObject * createString_(const char * s, size_t size = 0); - }; - - // Int - class Int : public Ptr - { - public: - Int(long n); - Int(PyObject * p); - operator long(); - }; - - // Long - class Long : public Ptr - { - public: - Long(long n); - Long(PyObject * p); - operator long(); - }; - - // UnsignedLong - class UnsignedLong : public Ptr - { - public: - UnsignedLong(unsigned long n); - UnsignedLong(PyObject * p); - operator unsigned long(); - }; - - // LongLong - class LongLong : public Ptr - { - public: - LongLong(long long n); - LongLong(PyObject * p); - operator long long(); - }; - - // UnsignedLongLong - class UnsignedLongLong : public Ptr - { - public: - UnsignedLongLong(unsigned long long n); - UnsignedLongLong(PyObject * p); - operator unsigned long long(); - }; - - - // Float - class Float : public Ptr - { - public: - Float(const char * n); - Float(double n); - Float(PyObject * p); - operator double(); - static double getMax(); - static double getMin(); - }; - - // Bool - class Bool : public Ptr - { - public: - Bool(bool b); - Bool(PyObject * p); - operator bool(); - }; - - // Tuple - class Tuple : public Ptr - { - public: - Tuple(Py_ssize_t size = 0); - Tuple(PyObject * p); - void assign(PyObject *); - - PyObject * getItem(Py_ssize_t index); - - // The fast version of getItem() doesn't do bounds checking - // and doesn't increment the ref count of returned item. The original - // tuple still owns the item. If you try to access out of bounds item - // you will crash (or worse). If you assign the returned item to a py::Ptr - // or a sub-class you MUST call release() to prevent the py::Ptr from - // decrementing the ref count - PyObject * fastGetItem(Py_ssize_t index); - void setItem(Py_ssize_t index, PyObject * item); - Py_ssize_t getCount(); - private: - Py_ssize_t size_; - }; - - // List - class List : public Ptr - { - public: - List(); - List(PyObject *); - PyObject * getItem(Py_ssize_t index); - - // The fast version of getItem() doesn't do bounds checking - // and doesn't increment the ref count of returned item. The original - // list still owns the item. If you try to access out of bounds item - // you will crash (or worse). If you assign the returned item to a py::Ptr - // or a sub-class you MUST call release() to prevent the py::Ptr from - // decrementing the ref count - PyObject * fastGetItem(Py_ssize_t index); - void setItem(Py_ssize_t index, PyObject * item); - void append(PyObject * item); - Py_ssize_t getCount(); - }; - - // Dict - class Dict : public Ptr - { - public: - Dict(); - Dict(PyObject * dict); - PyObject * getItem(const std::string & name, PyObject * defaultItem = NULL); - void setItem(const std::string & name, PyObject * pItem); - }; - - // Module + // return module.name + // This code is identical to Instance::getAttr + PyObject *getAttr(std::string name) const; + +private: + PyObject *createModule_(const std::string &moduleName); +}; + +// Class +// +// Wraps a Python class object and allows invoking class methods +class Class : public Ptr { +public: + // The constructor is equivalent the following Python code: // - // Wraps a Python module object and can import by name - // The Python interpreter sys.path must contain the - // requested module. - class Module : public Ptr - { - public: - Module(const std::string & moduleName); - - // Invoke a module method. Equivalent to: - // return module.method(*args, **kwargs) - // This code is identical to Instance::invoke - PyObject * invoke(std::string method, - PyObject * args, - PyObject * kwargs = NULL) const; - // Get an module attribute. Equivalent to: - // - // return module.name - // This code is identical to Instance::getAttr - PyObject * getAttr(std::string name) const; - private: - PyObject * createModule_(const std::string & moduleName); - }; - - // Class + // from moduleName import className + Class(const std::string &moduleName, const std::string &className); + Class(PyObject *pModule, const std::string &className); + + // The invoke() method is equivalent the following Python code + // which invokes a class method: // - // Wraps a Python class object and allows invoking class methods - class Class : public Ptr - { - public: - // The constructor is equivalent the following Python code: - // - // from moduleName import className - Class(const std::string & moduleName, const std::string & className); - Class(PyObject * pModule, const std::string & className); - - // The invoke() method is equivalent the following Python code - // which invokes a class method: - // - // return className.method(*args, **kwargs) - PyObject * invoke(std::string method, - PyObject * args, - PyObject * kwargs = NULL); - private: - PyObject * createClass_(PyObject * pModule, const std::string & className); - }; - - - // Instance + // return className.method(*args, **kwargs) + PyObject *invoke(std::string method, PyObject *args, PyObject *kwargs = NULL); + +private: + PyObject *createClass_(PyObject *pModule, const std::string &className); +}; + +// Instance +// +// Wraps an instance of a Python object +class Instance : public Ptr { +public: + // A constructor that takes an existing PyObject * (can be NULL too) + Instance(PyObject *p = NULL); + + // A constructor that instantiates an instance of the requested + // class from the requested module. Equivalent to the follwoing: // - // Wraps an instance of a Python object - class Instance : public Ptr - { - public: - // A constructor that takes an existing PyObject * (can be NULL too) - Instance(PyObject * p = NULL); - - // A constructor that instantiates an instance of the requested - // class from the requested module. Equivalent to the follwoing: - // - // from moduleName import className - // instance = className(*args, **kwargs) - Instance(const std::string & moduleName, - const std::string & className, - PyObject * args, - PyObject * kwargs = NULL); - - Instance(PyObject * pClass, - PyObject * args, - PyObject * kwargs = NULL); - - // Return true if the instance has attribute 'name' and false otherwise - bool hasAttr(std::string name); - - // Get an instance attribute. Equivalent to: - // - // return instance.name - PyObject * getAttr(std::string name) const; - - // Set an instance attribute. Equivalent to: - // - // return instance.name = value - void setAttr(std::string name, PyObject * value); - - // Return a string representation of an instance. Equivalent to: - // - // return str(instance) - PyObject * toString(); - - // Invoke an instance method. Equivalent to: - // - // return instance.method(*args, **kwargs) - PyObject * invoke(std::string method, - PyObject * args, - PyObject * kwargs = NULL) const; - private: - PyObject * createInstance_(PyObject * pClass, - PyObject * args, - PyObject * kwargs); - }; - - //// --- - //// Raise a Python RuntimeError exception from C++ with - //// an error message and an optional stack trace. The - //// stack trace will be added as a custom attribute of the - //// exception object. - //// --- - //void setPyError(const char * message, const char * stackTrace = NULL) - //{ - // // Create a RuntimeError Python exception object - // Tuple args(1); - // args.setItem(0, py::String(message)); - // Instance e(PyExc_RuntimeError, args); - - // // Add a new attribute "stackTrace" if available - // if (stackTrace) - // { - // e.setAttr("stackTrace", py::String(stackTrace)); - // } - - // // Set the Python error. Equivalent to: "raise e" - // PyErr_SetObject(PyExc_RuntimeError, e); - //} - - - -} } // end of nupic::py namespace + // from moduleName import className + // instance = className(*args, **kwargs) + Instance(const std::string &moduleName, const std::string &className, + PyObject *args, PyObject *kwargs = NULL); -#endif // NTA_PY_HELPERS_HPP + Instance(PyObject *pClass, PyObject *args, PyObject *kwargs = NULL); + + // Return true if the instance has attribute 'name' and false otherwise + bool hasAttr(std::string name); + + // Get an instance attribute. Equivalent to: + // + // return instance.name + PyObject *getAttr(std::string name) const; + // Set an instance attribute. Equivalent to: + // + // return instance.name = value + void setAttr(std::string name, PyObject *value); + + // Return a string representation of an instance. Equivalent to: + // + // return str(instance) + PyObject *toString(); + + // Invoke an instance method. Equivalent to: + // + // return instance.method(*args, **kwargs) + PyObject *invoke(std::string method, PyObject *args, + PyObject *kwargs = NULL) const; + +private: + PyObject *createInstance_(PyObject *pClass, PyObject *args, PyObject *kwargs); +}; + +//// --- +//// Raise a Python RuntimeError exception from C++ with +//// an error message and an optional stack trace. The +//// stack trace will be added as a custom attribute of the +//// exception object. +//// --- +// void setPyError(const char * message, const char * stackTrace = NULL) +//{ +// // Create a RuntimeError Python exception object +// Tuple args(1); +// args.setItem(0, py::String(message)); +// Instance e(PyExc_RuntimeError, args); + +// // Add a new attribute "stackTrace" if available +// if (stackTrace) +// { +// e.setAttr("stackTrace", py::String(stackTrace)); +// } + +// // Set the Python error. Equivalent to: "raise e" +// PyErr_SetObject(PyExc_RuntimeError, e); +//} + +} // namespace py +} // namespace nupic + +#endif // NTA_PY_HELPERS_HPP diff --git a/src/nupic/py_support/PythonStream.cpp b/src/nupic/py_support/PythonStream.cpp index d10a49d4a5..00c952a686 100644 --- a/src/nupic/py_support/PythonStream.cpp +++ b/src/nupic/py_support/PythonStream.cpp @@ -27,35 +27,25 @@ * Bumps up size to a nicely aligned larger size. * Taken for NuPIC2 from PythonUtils.hpp */ -static size_t NextPythonSize(size_t n) -{ +static size_t NextPythonSize(size_t n) { n += 1; n += 8 - (n % 8); return n; } // ------------------------------------------------------------- -SharedPythonOStream::SharedPythonOStream(size_t maxSize) : - target_size_(NextPythonSize(maxSize)), - ss_(std::ios_base::out) -{ -} +SharedPythonOStream::SharedPythonOStream(size_t maxSize) + : target_size_(NextPythonSize(maxSize)), ss_(std::ios_base::out) {} // ------------------------------------------------------------- -std::ostream &SharedPythonOStream::getStream() -{ - return ss_; -} +std::ostream &SharedPythonOStream::getStream() { return ss_; } // ------------------------------------------------------------- -PyObject * SharedPythonOStream::close() -{ - ss_.flush(); +PyObject *SharedPythonOStream::close() { + ss_.flush(); - if (ss_.str().length() > target_size_) + if (ss_.str().length() > target_size_) throw std::runtime_error("Stream output larger than allocated buffer."); return PyString_FromStringAndSize(ss_.str().c_str(), ss_.str().length()); } - - diff --git a/src/nupic/py_support/PythonStream.hpp b/src/nupic/py_support/PythonStream.hpp index 5528054e0e..f4742a9a1b 100644 --- a/src/nupic/py_support/PythonStream.hpp +++ b/src/nupic/py_support/PythonStream.hpp @@ -23,8 +23,8 @@ * --------------------------------------------------------------------- */ -#include #include +#include #include /////////////////////////////////////////////////////////////////// @@ -37,30 +37,24 @@ /// @b Description /// After instantiation, a call to getStream() returns an ostream /// that collects the characters fed to it. Any subsequent call -/// to close() will return a PyObject * to a PyString that +/// to close() will return a PyObject * to a PyString that /// contains the current contents of the ostream. -/// +/// /// @note /// A close() before a getStream() will return an empty PyString. -/// +/// /////////////////////////////////////////////////////////////////// -class SharedPythonOStream -{ +class SharedPythonOStream { public: SharedPythonOStream(size_t maxSize); std::ostream &getStream(); PyObject *close(); private: - size_t target_size_; - std::stringstream ss_; + size_t target_size_; + std::stringstream ss_; }; //------------------------------------------------------------------ #endif // NTA_PYTHON_STREAM_HPP - - - - - diff --git a/src/nupic/py_support/WrappedVector.hpp b/src/nupic/py_support/WrappedVector.hpp index 348e3fc58a..c8ebf7627d 100644 --- a/src/nupic/py_support/WrappedVector.hpp +++ b/src/nupic/py_support/WrappedVector.hpp @@ -23,29 +23,28 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file */ - + #include #include -#include +#include +#include #include #include -#include -#include +#include namespace nupic { -template -inline std::string tts(const T x) { +template inline std::string tts(const T x) { std::ostringstream s; s << x; return s.str(); } -template -class WrappedVectorIter //: public random_access_iterator +template +class WrappedVectorIter //: public random_access_iterator { public: int n_; @@ -53,59 +52,94 @@ class WrappedVectorIter //: public random_access_iterator T *p_; typedef T value_type; -// typedef int size_type; -// typedef ptrdiff_t difference_type; + // typedef int size_type; + // typedef ptrdiff_t difference_type; typedef std::random_access_iterator_tag iterator_category; typedef int difference_type; typedef T *pointer; -// typedef const T *const_pointer; + // typedef const T *const_pointer; typedef T &reference; -// typedef const T &const_reference; + // typedef const T &const_reference; WrappedVectorIter(int n, int incr, T *p) : n_(n), incr_(incr), p_(p) {} int size() const { return n_; } - T operator[](int i) const { return *(p_ + (i*incr_)); } - T &operator[](int i) { return *(p_ + (i*incr_)); } + T operator[](int i) const { return *(p_ + (i * incr_)); } + T &operator[](int i) { return *(p_ + (i * incr_)); } /// Slices without bounds-checking. WrappedVectorIter slice(int i, int j) const { - if(j >= i) return WrappedVectorIter(j - i, incr_, p_ + i*incr_); - else return WrappedVectorIter(i - j, -incr_, p_ + i*incr_); + if (j >= i) + return WrappedVectorIter(j - i, incr_, p_ + i * incr_); + else + return WrappedVectorIter(i - j, -incr_, p_ + i * incr_); } - WrappedVectorIter slice(int start, int /*stop*/, int step, int length) const { - return WrappedVectorIter(length, incr_*step, p_ + start*incr_); + WrappedVectorIter slice(int start, int /*stop*/, int step, + int length) const { + return WrappedVectorIter(length, incr_ * step, p_ + start * incr_); } WrappedVectorIter end() const { return slice(n_, n_); } WrappedVectorIter reversed() const { - return WrappedVectorIter(n_, -incr_, p_ + (n_-1)*incr_); + return WrappedVectorIter(n_, -incr_, p_ + (n_ - 1) * incr_); } /// Advance - WrappedVectorIter operator+(int n) const { return slice(n, n_-n); } - WrappedVectorIter operator+(size_t n) const { return slice(int(n), n_-int(n)); } + WrappedVectorIter operator+(int n) const { return slice(n, n_ - n); } + WrappedVectorIter operator+(size_t n) const { + return slice(int(n), n_ - int(n)); + } /// Advance in-place - WrappedVectorIter &operator+=(int n) { *this = slice(n, n_-n); return *this; } - WrappedVectorIter &operator+=(size_t n) { *this = slice(int(n), n_-int(n)); return *this; } + WrappedVectorIter &operator+=(int n) { + *this = slice(n, n_ - n); + return *this; + } + WrappedVectorIter &operator+=(size_t n) { + *this = slice(int(n), n_ - int(n)); + return *this; + } /// Prefix increment - WrappedVectorIter &operator++() { *this = slice(1, n_-1); return *this; } + WrappedVectorIter &operator++() { + *this = slice(1, n_ - 1); + return *this; + } /// Postfix increment - WrappedVectorIter operator++(int) { WrappedVectorIter a = *this; *this = slice(1, n_-1); return a; } + WrappedVectorIter operator++(int) { + WrappedVectorIter a = *this; + *this = slice(1, n_ - 1); + return a; + } /// Move back - WrappedVectorIter operator-(int n) const { return slice(-n, n_+n); } - WrappedVectorIter operator-(size_t n) const { return slice(-int(n), n_+int(n)); } + WrappedVectorIter operator-(int n) const { return slice(-n, n_ + n); } + WrappedVectorIter operator-(size_t n) const { + return slice(-int(n), n_ + int(n)); + } /// Move back in-place - WrappedVectorIter &operator-=(int n) { *this = slice(-n, n_+n); return *this; } - WrappedVectorIter &operator-=(size_t n) { *this = slice(-int(n), n_+int(n)); return *this; } + WrappedVectorIter &operator-=(int n) { + *this = slice(-n, n_ + n); + return *this; + } + WrappedVectorIter &operator-=(size_t n) { + *this = slice(-int(n), n_ + int(n)); + return *this; + } /// Prefix decrement - WrappedVectorIter &operator--() { *this = slice(-1, n_+1); return *this; } + WrappedVectorIter &operator--() { + *this = slice(-1, n_ + 1); + return *this; + } /// Postfix decrement - WrappedVectorIter operator--(int) { WrappedVectorIter a = slice(-1, n_+1); *this = slice(-1, n_+1); return a; } + WrappedVectorIter operator--(int) { + WrappedVectorIter a = slice(-1, n_ + 1); + *this = slice(-1, n_ + 1); + return a; + } /// Difference between two iterators - int operator-(const WrappedVectorIter &i) const { return int((p_ - i.p_) / incr_); } + int operator-(const WrappedVectorIter &i) const { + return int((p_ - i.p_) / incr_); + } /// Dereference T operator*() const { return *p_; } @@ -119,7 +153,7 @@ class WrappedVectorIter //: public random_access_iterator const T *operator->() const { return p_; } /// Access as pointer, non-const T *operator->() { return p_; } - + /// Not-equal comparison bool neq(const T *p) const { return p_ != p; } /// Equal comparison @@ -140,34 +174,45 @@ class WrappedVectorIter //: public random_access_iterator bool operator>(const WrappedVectorIter &x) const { return ge(x.p_); } bool operator>=(const WrappedVectorIter &x) const { return geq(x.p_); } - template - void copyFrom(int dim, int incr, const T2 *in) { + template void copyFrom(int dim, int incr, const T2 *in) { const int n = n_, inc1 = incr_, inc2 = incr; T *p1 = p_; const T2 *p2 = in; - for(int i=0; i - void into(int dim, int incr, T2 *out) const { + template void into(int dim, int incr, T2 *out) const { const int n = n_, inc1 = incr_, inc2 = incr; const T *p1 = p_; T2 *p2 = out; - for(int i=0; i p_; nupic::Real *own_; - void free() { if(own_) { delete own_; own_ = 0; } } + void free() { + if (own_) { + delete own_; + own_ = 0; + } + } public: /// Extremely dangerous constructor! Only use for unit testing! @@ -175,25 +220,23 @@ class WrappedVector { public: void checkIndex(int i) const { - if(!((i >= 0) && (i < p_.n_))) + if (!((i >= 0) && (i < p_.n_))) throw std::invalid_argument("Index " + tts(i) + " out of bounds."); } void checkBeginEnd(int begin, int end) const { - if(end > begin) { - if(!(begin >= 0)) + if (end > begin) { + if (!(begin >= 0)) throw std::invalid_argument("Begin " + tts(begin) + " out of bounds."); - if(!(end <= p_.n_)) + if (!(end <= p_.n_)) throw std::invalid_argument("End " + tts(end) + " out of bounds."); - } - else if(end == begin) { - if(!((begin >= 0) && (end <= p_.n_))) + } else if (end == begin) { + if (!((begin >= 0) && (end <= p_.n_))) throw std::invalid_argument("Out of bounds."); - } - else if(end < begin) { - if(!(end >= (-1))) + } else if (end < begin) { + if (!(end >= (-1))) throw std::invalid_argument("End " + tts(end) + " out of bounds."); - if(!(begin < p_.n_)) + if (!(begin < p_.n_)) throw std::invalid_argument("Begin " + tts(begin) + " out of bounds."); } } @@ -204,19 +247,41 @@ class WrappedVector { WrappedVector() : p_(0, 1, 0), own_(0) {} WrappedVector(const WrappedVectorIter &p) : p_(p), own_(0) {} WrappedVector(int size, nupic::Real *p) : p_(size, 1, p), own_(0) {} - -// WrappedVector(const nupic::Belief &b) : p_(b.size(), 1, const_cast(b.ptr())), own_(0) {} - WrappedVector(const std::vector &v) : p_(int(v.size()), 1, const_cast(&(v[0]))), own_(0) {} + + // WrappedVector(const nupic::Belief &b) : p_(b.size(), 1, + // const_cast(b.ptr())), own_(0) {} + WrappedVector(const std::vector &v) + : p_(int(v.size()), 1, const_cast(&(v[0]))), own_(0) {} WrappedVector(const WrappedVector &v) : p_(v.p_), own_(0) {} - WrappedVector &operator=(const WrappedVector &v) { this->free(); p_ = v.p_; own_ = 0; return *this; } - + WrappedVector &operator=(const WrappedVector &v) { + this->free(); + p_ = v.p_; + own_ = 0; + return *this; + } + ~WrappedVector() { this->free(); } - - WrappedVector wvector(size_t lag=0) const { return *this; } - void clear() { this->free(); p_.n_ = 0; p_.incr_ = 1; p_.p_ = 0; } - void setPointer(int n, int incr, nupic::Real *p) { this->free(); p_.n_ = n; p_.incr_ = incr; p_.p_ = p; } - void setPointer(int n, nupic::Real *p) { this->free(); p_.n_ = n; p_.incr_ = 1; p_.p_ = p; } + WrappedVector wvector(size_t lag = 0) const { return *this; } + + void clear() { + this->free(); + p_.n_ = 0; + p_.incr_ = 1; + p_.p_ = 0; + } + void setPointer(int n, int incr, nupic::Real *p) { + this->free(); + p_.n_ = n; + p_.incr_ = incr; + p_.p_ = p; + } + void setPointer(int n, nupic::Real *p) { + this->free(); + p_.n_ = n; + p_.incr_ = 1; + p_.p_ = p; + } // Returns the beginning address of the underlying // data buffer as an integer. @@ -227,44 +292,44 @@ class WrappedVector { nupic::Size __len__() const { return p_.size(); } - template - void adjust(T &endPoint) const { - T n = (T) p_.size(); - if(endPoint < 0) endPoint += n; - else if(endPoint > n) endPoint = n; + template void adjust(T &endPoint) const { + T n = (T)p_.size(); + if (endPoint < 0) + endPoint += n; + else if (endPoint > n) + endPoint = n; } nupic::Real __getitem__(int i) const { - adjust(i); - checkIndex(i); - return p_[i]; + adjust(i); + checkIndex(i); + return p_[i]; } void __setitem__(int i, nupic::Real x) { adjust(i); - checkIndex(i); - p_[i] = x; + checkIndex(i); + p_[i] = x; } std::string __repr__() const { std::ostringstream s; s << "["; const nupic::Real *p = p_.p_; - int nm1 = p_.n_-1; - for(int i=0; i - void copyFromT(int n, int incr, const T2 *p) { - if(!(p_.n_ == n)) - throw std::invalid_argument("Sizes must match: " + - tts(p_.n_) + " " + tts(n)); + template void copyFromT(int n, int incr, const T2 *p) { + if (!(p_.n_ == n)) + throw std::invalid_argument("Sizes must match: " + tts(p_.n_) + " " + + tts(n)); p_.copyFrom(n, incr, p); } - template - void copyIntoT(int n, int incr, T2 *p) const { + template void copyIntoT(int n, int incr, T2 *p) const { p_.into(n, incr, p); } void setSlice(int i, int j, const WrappedVector &v) { checkBeginEnd(i, j); int n = v.p_.n_; - if(!(n == abs(j-i))) throw std::invalid_argument("Sizes must match."); + if (!(n == abs(j - i))) + throw std::invalid_argument("Sizes must match."); p_.slice(i, j).copyFrom(n, v.p_.incr_, v.p_.p_); } @@ -346,19 +419,27 @@ class WrappedVector { void fill(nupic::Real x) { const int n = p_.n_, inc1 = p_.incr_; nupic::Real *p1 = p_.p_; - for(int i=0; i= 1)) + if (!(n >= 1)) throw std::runtime_error("Cannot call argmax on a 0-length vector."); const nupic::Real *p1 = p_.p_; int mi = 0; - nupic::Real mv = *p1; p1 += inc1; - for(int i=1; i mv) { mi = i; mv = x; } + nupic::Real mv = *p1; + p1 += inc1; + for (int i = 1; i < n; ++i) { + nupic::Real x = *p1; + p1 += inc1; + if (x > mv) { + mi = i; + mv = x; + } } return mi; } @@ -367,8 +448,9 @@ class WrappedVector { const int n = p_.n_, inc1 = p_.incr_; const nupic::Real *p1 = p_.p_; nupic::Real sum = 0; - for(int i=0; i #include +#include // std::memcpy #include -#include #include -#include // std::memcpy +#include #include -#include -#include #include #include -#include -#include // IWrite/ReadBuffer -#include +#include +#include #include #include -#include #include -#include +#include // IWrite/ReadBuffer +#include #include +#include #include #include #include +#include +#include using namespace nupic; @@ -53,181 +53,147 @@ using namespace nupic; static char lastError[LAST_ERROR_LENGTH]; static bool finalizePython; -extern "C" -{ - // NTA_initPython() must be called by the MultinodeFactory before any call to - // NTA_createPyNode() - void PyRegion::NTA_initPython() - { - if(Py_IsInitialized()) - { - // Set the PyHelpers flag so it knows its running under Python. - // This is necessary for PyHelpers to determine if it should - // clear or restore Python exceptions (see NPC-113) - py::setRunningUnderPython(); - finalizePython = true; - } - } - - // NTA_finalizePython() must be called before unloading the pynode dynamic library - // to ensure proper cleanup. - void PyRegion::NTA_finalizePython() - { - if (finalizePython) - { - //NTA_DEBUG << "Called Py_Finalize()"; - Py_Finalize(); - } +extern "C" { +// NTA_initPython() must be called by the MultinodeFactory before any call to +// NTA_createPyNode() +void PyRegion::NTA_initPython() { + if (Py_IsInitialized()) { + // Set the PyHelpers flag so it knows its running under Python. + // This is necessary for PyHelpers to determine if it should + // clear or restore Python exceptions (see NPC-113) + py::setRunningUnderPython(); + finalizePython = true; } +} - // createPyNode is used by the MultinodeFactory to create a C++ PyNode instance - // That references a Python instance. The function tries to create a NuPIC 2.0 - // Py node first and if it fails it tries to create a NuPIC 1.x Py node - void * PyRegion::NTA_createPyNode(const char * module, void * nodeParams, - void * region, void ** exception, const char* className) - { - try - { - NTA_CHECK(nodeParams != NULL); - NTA_CHECK(region != NULL); - - ValueMap * valueMap = static_cast(nodeParams); - Region * r = static_cast(region); - RegionImpl * p = NULL; - p = new nupic::PyRegion(module, *valueMap, r, className); - - return p; - } - catch (nupic::Exception & e) - { - *exception = new nupic::Exception(e); - return NULL; - } - catch (...) - { - return NULL; - } +// NTA_finalizePython() must be called before unloading the pynode dynamic +// library to ensure proper cleanup. +void PyRegion::NTA_finalizePython() { + if (finalizePython) { + // NTA_DEBUG << "Called Py_Finalize()"; + Py_Finalize(); } +} - // deserializePyNode is used by the MultinodeFactory to create a C++ PyNode instance - // that references a Python instance which has been deserialized from saved state - void * PyRegion::NTA_deserializePyNode(const char * module, void * bundle, - void * region, void ** exception, const char* className) - { - try - { - NTA_CHECK(region != NULL); - - Region * r = static_cast(region); - BundleIO *b = static_cast(bundle); - RegionImpl * p = NULL; - p = new PyRegion(module, *b, r, className); - return p; - } - catch (nupic::Exception & e) - { - *exception = new nupic::Exception(e); - return NULL; - } - catch (...) - { - return NULL; - } +// createPyNode is used by the MultinodeFactory to create a C++ PyNode instance +// That references a Python instance. The function tries to create a NuPIC 2.0 +// Py node first and if it fails it tries to create a NuPIC 1.x Py node +void *PyRegion::NTA_createPyNode(const char *module, void *nodeParams, + void *region, void **exception, + const char *className) { + try { + NTA_CHECK(nodeParams != NULL); + NTA_CHECK(region != NULL); + + ValueMap *valueMap = static_cast(nodeParams); + Region *r = static_cast(region); + RegionImpl *p = NULL; + p = new nupic::PyRegion(module, *valueMap, r, className); + + return p; + } catch (nupic::Exception &e) { + *exception = new nupic::Exception(e); + return NULL; + } catch (...) { + return NULL; } +} - void * PyRegion::NTA_deserializePyNodeProto(const char * module, void * proto, - void * region, void ** exception, const char* className) - { - try - { - NTA_CHECK(region != NULL); - - Region * r = static_cast(region); - capnp::AnyPointer::Reader *c = static_cast(proto); - RegionImpl * p = NULL; - p = new PyRegion(module, *c, r, className); - return p; - } - catch (nupic::Exception & e) - { - *exception = new nupic::Exception(e); - return NULL; - } - catch (...) - { - return NULL; - } +// deserializePyNode is used by the MultinodeFactory to create a C++ PyNode +// instance that references a Python instance which has been deserialized from +// saved state +void *PyRegion::NTA_deserializePyNode(const char *module, void *bundle, + void *region, void **exception, + const char *className) { + try { + NTA_CHECK(region != NULL); + + Region *r = static_cast(region); + BundleIO *b = static_cast(bundle); + RegionImpl *p = NULL; + p = new PyRegion(module, *b, r, className); + return p; + } catch (nupic::Exception &e) { + *exception = new nupic::Exception(e); + return NULL; + } catch (...) { + return NULL; } +} - // getLastError() returns the last error message - const char * PyRegion::NTA_getLastError() - { - return lastError; +void *PyRegion::NTA_deserializePyNodeProto(const char *module, void *proto, + void *region, void **exception, + const char *className) { + try { + NTA_CHECK(region != NULL); + + Region *r = static_cast(region); + capnp::AnyPointer::Reader *c = + static_cast(proto); + RegionImpl *p = NULL; + p = new PyRegion(module, *c, r, className); + return p; + } catch (nupic::Exception &e) { + *exception = new nupic::Exception(e); + return NULL; + } catch (...) { + return NULL; } +} - // createSpec is used by the RegionImplFactory to get the node spec - // and cache it. It is a static function so there is no need to instantiate - // a dummy node, just to get its node spec. - void * PyRegion::NTA_createSpec(const char * nodeType, void ** exception, const char* className) - { - try - { - return PyRegion::createSpec(nodeType, className); - } - catch (nupic::Exception & e) - { - NTA_WARN << "PyRegion::createSpec failed: " << exception; - - *exception = new nupic::Exception(e); - return NULL; - } - catch (...) - { - return NULL; - } +// getLastError() returns the last error message +const char *PyRegion::NTA_getLastError() { return lastError; } + +// createSpec is used by the RegionImplFactory to get the node spec +// and cache it. It is a static function so there is no need to instantiate +// a dummy node, just to get its node spec. +void *PyRegion::NTA_createSpec(const char *nodeType, void **exception, + const char *className) { + try { + return PyRegion::createSpec(nodeType, className); + } catch (nupic::Exception &e) { + NTA_WARN << "PyRegion::createSpec failed: " << exception; + + *exception = new nupic::Exception(e); + return NULL; + } catch (...) { + return NULL; } +} - // destroySpec is used by the RegionImplFactory to destroy - // a cached node spec. - int PyRegion::NTA_destroySpec(const char * nodeType, const char* className) - { - try - { - PyRegion::destroySpec(nodeType, className); - return 0; - } - catch (...) - { - return -1; - } +// destroySpec is used by the RegionImplFactory to destroy +// a cached node spec. +int PyRegion::NTA_destroySpec(const char *nodeType, const char *className) { + try { + PyRegion::destroySpec(nodeType, className); + return 0; + } catch (...) { + return -1; } } +} // This map stores the node specs for all the Python nodes std::map PyRegion::specs_; - // // Get the node spec from the underlying Python node // and populate a dynamically C++ node spec object. // Return the node spec pointer (that will be owned // by RegionImplFactory. // -Spec * PyRegion::createSpec(const char * nodeType, const char* className) -{ +Spec *PyRegion::createSpec(const char *nodeType, const char *className) { // If the node spec for a node type is requested more than once // return the exisiting one from the map. std::string name(nodeType); std::string realClassName(className); name = name + "."; - if (!realClassName.empty()) - { + if (!realClassName.empty()) { name = name + realClassName; } - if (specs_.find(name) != specs_.end()) - { - Spec & ns = specs_[name]; + if (specs_.find(name) != specs_.end()) { + Spec &ns = specs_[name]; return &ns; } @@ -235,101 +201,87 @@ Spec * PyRegion::createSpec(const char * nodeType, const char* className) createSpec(nodeType, ns, className); specs_[name] = ns; - //NTA_DEBUG << "node type: " << nodeType << std::endl; - //NTA_DEBUG << specs_[name].toString() << std::endl; + // NTA_DEBUG << "node type: " << nodeType << std::endl; + // NTA_DEBUG << specs_[name].toString() << std::endl; return &specs_[name]; } -void PyRegion::destroySpec(const char * nodeType, const char* className) -{ +void PyRegion::destroySpec(const char *nodeType, const char *className) { std::string name(nodeType); std::string realClassName(className); name = name + "."; - if (!realClassName.empty()) - { + if (!realClassName.empty()) { name = name + realClassName; } specs_.erase(name); } -namespace nupic -{ +namespace nupic { class RegionImpl; -static PyObject * makePyValue(const Value & v) -{ +static PyObject *makePyValue(const Value &v) { if (v.isArray()) return array2numpy(*(v.getArray().get())); - if (v.isString()) - { + if (v.isString()) { return py::String(*(v.getString().get())).release(); } - switch (v.getType()) - { - case NTA_BasicType_Byte: - NTA_THROW << "Scalar parameters of type Byte are not supported"; - break; - case NTA_BasicType_Int16: - return py::Long(v.getScalarT()).release(); - case NTA_BasicType_Int32: - return py::Long(v.getScalarT()).release(); - case NTA_BasicType_Int64: - return py::LongLong(v.getScalarT()).release(); - case NTA_BasicType_UInt16: - return py::UnsignedLong(v.getScalarT()).release(); - case NTA_BasicType_UInt32: - return py::UnsignedLong(v.getScalarT()).release(); - case NTA_BasicType_UInt64: - return py::UnsignedLongLong(v.getScalarT()).release(); - case NTA_BasicType_Real32: - { - std::stringstream ss; - ss << v.getScalarT(); - return py::Float(ss.str().c_str()).release(); - } - case NTA_BasicType_Real64: - return py::Float(v.getScalarT()).release(); - case NTA_BasicType_Bool: - return py::Bool(v.getScalarT()).release(); - case NTA_BasicType_Handle: - return (PyObject *)(v.getScalarT()); - default: - NTA_THROW << "Invalid type: " << v.getType(); + switch (v.getType()) { + case NTA_BasicType_Byte: + NTA_THROW << "Scalar parameters of type Byte are not supported"; + break; + case NTA_BasicType_Int16: + return py::Long(v.getScalarT()).release(); + case NTA_BasicType_Int32: + return py::Long(v.getScalarT()).release(); + case NTA_BasicType_Int64: + return py::LongLong(v.getScalarT()).release(); + case NTA_BasicType_UInt16: + return py::UnsignedLong(v.getScalarT()).release(); + case NTA_BasicType_UInt32: + return py::UnsignedLong(v.getScalarT()).release(); + case NTA_BasicType_UInt64: + return py::UnsignedLongLong(v.getScalarT()).release(); + case NTA_BasicType_Real32: { + std::stringstream ss; + ss << v.getScalarT(); + return py::Float(ss.str().c_str()).release(); + } + case NTA_BasicType_Real64: + return py::Float(v.getScalarT()).release(); + case NTA_BasicType_Bool: + return py::Bool(v.getScalarT()).release(); + case NTA_BasicType_Handle: + return (PyObject *)(v.getScalarT()); + default: + NTA_THROW << "Invalid type: " << v.getType(); } } -static void prepareCreationParams(const ValueMap & vm, py::Dict & d) -{ +static void prepareCreationParams(const ValueMap &vm, py::Dict &d) { ValueMap::const_iterator it; - for (it = vm.begin(); it != vm.end(); ++it) - { - try - { + for (it = vm.begin(); it != vm.end(); ++it) { + try { py::Ptr v(makePyValue(*(it->second))); d.setItem(it->first, v); - } catch (Exception& e) { + } catch (Exception &e) { NTA_THROW << "Unable to create a Python object for parameter '" << it->first << ": " << e.what(); } } }; -PyRegion::PyRegion(const char * module, const ValueMap & nodeParams, Region * - region, const char* className) : - RegionImpl(region), - module_(module), - className_(className) -{ +PyRegion::PyRegion(const char *module, const ValueMap &nodeParams, + Region *region, const char *className) + : RegionImpl(region), module_(module), className_(className) { NTA_CHECK(region != NULL); std::string realClassName(className); - if (realClassName.empty()) - { + if (realClassName.empty()) { realClassName = Path::getExtension(module_); } @@ -343,44 +295,32 @@ PyRegion::PyRegion(const char * module, const ValueMap & nodeParams, Region * NTA_CHECK(node_); } -PyRegion::PyRegion(const char* module, BundleIO& bundle, Region * region, const - char* className) : - RegionImpl(region), - module_(module), - className_(className) +PyRegion::PyRegion(const char *module, BundleIO &bundle, Region *region, + const char *className) + : RegionImpl(region), module_(module), className_(className) { deserialize(bundle); // XXX ADD CHECK TO MAKE SURE THE TYPE MATCHES! } -PyRegion::PyRegion(const char * module, - capnp::AnyPointer::Reader& proto, - Region * region, - const char* className): - RegionImpl(region), - module_(module), - className_(className) -{ +PyRegion::PyRegion(const char *module, capnp::AnyPointer::Reader &proto, + Region *region, const char *className) + : RegionImpl(region), module_(module), className_(className) { NTA_CHECK(region != NULL); read(proto); } -PyRegion::~PyRegion() -{ - for (std::map::iterator i = inputArrays_.begin(); - i != inputArrays_.end(); - i++) - { +PyRegion::~PyRegion() { + for (std::map::iterator i = inputArrays_.begin(); + i != inputArrays_.end(); i++) { delete i->second; i->second = NULL; } - } -void PyRegion::serialize(BundleIO& bundle) -{ +void PyRegion::serialize(BundleIO &bundle) { // 1. serialize main state using pickle // 2. call class method to serialize external state @@ -415,16 +355,12 @@ void PyRegion::serialize(BundleIO& bundle) // Need to put the None result in py::Ptr to decrement the ref count py::Ptr none1(node_.invoke("serializeExtraData", args1)); - - } -void PyRegion::deserialize(BundleIO& bundle) -{ +void PyRegion::deserialize(BundleIO &bundle) { // 1. deserialize main state using pickle // 2. call class method to deserialize external state - // 1. de-serialize main state using pickle // f = open(path, "rb") # binary mode needed on windows py::Tuple args(2); @@ -453,21 +389,17 @@ void PyRegion::deserialize(BundleIO& bundle) // Need to put the None result in py::Ptr to decrement the ref count py::Ptr none1(node_.invoke("deSerializeExtraData", args1)); - } -void PyRegion::write(capnp::AnyPointer::Builder& proto) const -{ +void PyRegion::write(capnp::AnyPointer::Builder &proto) const { #if !CAPNP_LITE - class Helper - { + class Helper { public: /** * NOTE: We wrap several operations in this method to reduce the number of * region data copies (could be huge memory size) that coexist on the heap. */ - static kj::Array serialize(const py::Instance& node) - { + static kj::Array serialize(const py::Instance &node) { // Request python object to write itself out and return PyRegionProto // serialized as a python byte array py::Class pyCapnpHelperCls("nupic.bindings.engine_internal", @@ -476,15 +408,15 @@ void PyRegion::write(capnp::AnyPointer::Builder& proto) const py::Dict kwargs; // NOTE py::Dict::setItem doesn't accept a const PyObject*, however we // know that we won't modify it, so casting trickery is okay here - kwargs.setItem("region", - const_cast(static_cast(node))); + kwargs.setItem("region", const_cast( + static_cast(node))); kwargs.setItem("methodName", py::String("write")); // Wrap result in py::Ptr to force dereferencing when going out of scope py::Ptr pyRegionProtoBytes( - pyCapnpHelperCls.invoke("writePyRegion", args, kwargs)); + pyCapnpHelperCls.invoke("writePyRegion", args, kwargs)); - char * srcBytes = nullptr; + char *srcBytes = nullptr; Py_ssize_t srcNumBytes = 0; // NOTE: srcBytes will be set to point to the internal buffer inside // pyRegionProtoBytes @@ -515,25 +447,21 @@ void PyRegion::write(capnp::AnyPointer::Builder& proto) const #endif } -void PyRegion::read(capnp::AnyPointer::Reader& proto) -{ +void PyRegion::read(capnp::AnyPointer::Reader &proto) { #if !CAPNP_LITE - class Helper - { + class Helper { public: - static kj::Array flatArrayFromReader( - PyRegionProto::Reader& reader) - { + static kj::Array + flatArrayFromReader(PyRegionProto::Reader &reader) { // NOTE: this requires conversion to builder, which incurs an additional // copy, because readers aren't supported by capnp::messageToFlatArray capnp::MallocMessageBuilder builder; - builder.setRoot(reader); // copy + builder.setRoot(reader); // copy return capnp::messageToFlatArray(builder); // copy } - static PyObject* pyBytesFromFlatArray(kj::Array flatArray) - { + static PyObject *pyBytesFromFlatArray(kj::Array flatArray) { kj::ArrayPtr byteArray = flatArray.asBytes(); // Copy from array to PyObject so that we can pass it to the Python layer py::String pyRegionBytes((const char *)byteArray.begin(), @@ -551,14 +479,12 @@ void PyRegion::read(capnp::AnyPointer::Reader& proto) // NOTE: the intention of this nested call is to reduce the number of // sumulatneously present copies of data to no more than two at any given // moment. - py::String pyRegionBytes( - Helper::pyBytesFromFlatArray( + py::String pyRegionBytes(Helper::pyBytesFromFlatArray( Helper::flatArrayFromReader(pyRegionReader))); // Construct the python region instance by thunking into python std::string realClassName(className_); - if (realClassName.empty()) - { + if (realClassName.empty()) { realClassName = Path::getExtension(module_); } @@ -573,7 +499,8 @@ void PyRegion::read(capnp::AnyPointer::Reader& proto) kwargs.setItem("methodName", py::String("read")); // Deserialize the data into a new python region and assign it to node_ - py::Class pyCapnpHelperCls("nupic.bindings.engine_internal", "_PyCapnpHelper"); + py::Class pyCapnpHelperCls("nupic.bindings.engine_internal", + "_PyCapnpHelper"); // NOTE: wrap result in py::Ptr so that unnecessary refcount will be // decremented upon it going out of scope and after node_.assign increments it py::Ptr pyRegionImpl(pyCapnpHelperCls.invoke("readPyRegion", args, kwargs)); @@ -586,8 +513,7 @@ void PyRegion::read(capnp::AnyPointer::Reader& proto) #endif } -const Spec & PyRegion::getSpec() -{ +const Spec &PyRegion::getSpec() { return *(PyRegion::createSpec(module_.c_str(), className_.c_str())); } @@ -597,7 +523,7 @@ const Spec & PyRegion::getSpec() //// Return the node spec pointer (that will be owned //// by RegionImplFactory. //// -//void PyRegion::createSpec(const char * nodeType, Spec & ns) +// void PyRegion::createSpec(const char * nodeType, Spec & ns) //{ // // Get the Python class object // std::string className = Path::getExtension(nodeType); @@ -642,8 +568,8 @@ const Spec & PyRegion::getSpec() // bool required = py::Int(input.getItem("required")) != 0; // bool regionLevel = py::Int(input.getItem("regionLevel")) != 0; // bool isDefaultInput = py::Int(input.getItem("isDefaultInput")) != 0; -// bool requireSplitterMap = py::Int(input.getItem("requireSplitterMap")) != 0; -// ns.inputs.add( +// bool requireSplitterMap = py::Int(input.getItem("requireSplitterMap")) != +// 0; ns.inputs.add( // name, // InputSpec( // description, @@ -732,8 +658,7 @@ const Spec & PyRegion::getSpec() //} template -T PyRegion::getParameterT(const std::string & name, Int64 index) -{ +T PyRegion::getParameterT(const std::string &name, Int64 index) { py::Tuple args(2); args.setItem(0, py::String(name)); args.setItem(1, py::LongLong(index)); @@ -743,8 +668,7 @@ T PyRegion::getParameterT(const std::string & name, Int64 index) } template -void PyRegion::setParameterT(const std::string & name, Int64 index, T value) -{ +void PyRegion::setParameterT(const std::string &name, Int64 index, T value) { py::Tuple args(3); args.setItem(0, py::String(name)); args.setItem(1, py::LongLong(index)); @@ -753,47 +677,38 @@ void PyRegion::setParameterT(const std::string & name, Int64 index, T value) py::Ptr none(node_.invoke("setParameter", args)); } -Byte PyRegion::getParameterByte(const std::string& name, Int64 index) -{ +Byte PyRegion::getParameterByte(const std::string &name, Int64 index) { return getParameterT(name, index); } -Int32 PyRegion::getParameterInt32(const std::string& name, Int64 index) -{ - //return getParameterT(name, index); +Int32 PyRegion::getParameterInt32(const std::string &name, Int64 index) { + // return getParameterT(name, index); return getParameterT(name, index); } -UInt32 PyRegion::getParameterUInt32(const std::string& name, Int64 index) -{ +UInt32 PyRegion::getParameterUInt32(const std::string &name, Int64 index) { return getParameterT(name, index); } -Int64 PyRegion::getParameterInt64(const std::string& name, Int64 index) -{ +Int64 PyRegion::getParameterInt64(const std::string &name, Int64 index) { return getParameterT(name, index); } -UInt64 PyRegion::getParameterUInt64(const std::string& name, Int64 index) -{ +UInt64 PyRegion::getParameterUInt64(const std::string &name, Int64 index) { return getParameterT(name, index); } -Real32 PyRegion::getParameterReal32(const std::string& name, Int64 index) -{ +Real32 PyRegion::getParameterReal32(const std::string &name, Int64 index) { return getParameterT(name, index); } -Real64 PyRegion::getParameterReal64(const std::string& name, Int64 index) -{ +Real64 PyRegion::getParameterReal64(const std::string &name, Int64 index) { return getParameterT(name, index); } -Handle PyRegion::getParameterHandle(const std::string& name, Int64 index) -{ - if (name == std::string("self")) - { - PyObject * o = (PyObject *)node_; +Handle PyRegion::getParameterHandle(const std::string &name, Int64 index) { + if (name == std::string("self")) { + PyObject *o = (PyObject *)node_; Py_INCREF(o); return o; } @@ -801,58 +716,57 @@ Handle PyRegion::getParameterHandle(const std::string& name, Int64 index) return getParameterT(name, index); } -bool PyRegion::getParameterBool(const std::string& name, Int64 index) -{ +bool PyRegion::getParameterBool(const std::string &name, Int64 index) { return getParameterT(name, index); } -void PyRegion::setParameterByte(const std::string& name, Int64 index, Byte value) -{ +void PyRegion::setParameterByte(const std::string &name, Int64 index, + Byte value) { setParameterT(name, index, value); } -void PyRegion::setParameterInt32(const std::string& name, Int64 index, Int32 value) -{ +void PyRegion::setParameterInt32(const std::string &name, Int64 index, + Int32 value) { setParameterT(name, index, value); } -void PyRegion::setParameterUInt32(const std::string& name, Int64 index, UInt32 value) -{ +void PyRegion::setParameterUInt32(const std::string &name, Int64 index, + UInt32 value) { setParameterT(name, index, value); } -void PyRegion::setParameterInt64(const std::string& name, Int64 index, Int64 value) -{ +void PyRegion::setParameterInt64(const std::string &name, Int64 index, + Int64 value) { setParameterT(name, index, value); } -void PyRegion::setParameterUInt64(const std::string& name, Int64 index, UInt64 value) -{ +void PyRegion::setParameterUInt64(const std::string &name, Int64 index, + UInt64 value) { setParameterT(name, index, value); } -void PyRegion::setParameterReal32(const std::string& name, Int64 index, Real32 value) -{ +void PyRegion::setParameterReal32(const std::string &name, Int64 index, + Real32 value) { setParameterT(name, index, value); } -void PyRegion::setParameterReal64(const std::string& name, Int64 index, Real64 value) -{ +void PyRegion::setParameterReal64(const std::string &name, Int64 index, + Real64 value) { setParameterT(name, index, value); } -void PyRegion::setParameterHandle(const std::string& name, Int64 index, Handle value) -{ +void PyRegion::setParameterHandle(const std::string &name, Int64 index, + Handle value) { setParameterT(name, index, (PyObject *)value); } -void PyRegion::setParameterBool(const std::string& name, Int64 index, bool value) -{ +void PyRegion::setParameterBool(const std::string &name, Int64 index, + bool value) { setParameterT(name, index, value); } -void PyRegion::getParameterArray(const std::string& name, Int64 index, Array & a) -{ +void PyRegion::getParameterArray(const std::string &name, Int64 index, + Array &a) { py::Tuple args(3); args.setItem(0, py::String(name)); args.setItem(1, py::LongLong(index)); @@ -862,8 +776,8 @@ void PyRegion::getParameterArray(const std::string& name, Int64 index, Array & a py::Ptr none(node_.invoke("getParameterArray", args)); } -void PyRegion::setParameterArray(const std::string& name, Int64 index, const Array & a) -{ +void PyRegion::setParameterArray(const std::string &name, Int64 index, + const Array &a) { py::Tuple args(3); args.setItem(0, py::String(name)); args.setItem(1, py::LongLong(index)); @@ -873,18 +787,17 @@ void PyRegion::setParameterArray(const std::string& name, Int64 index, const Arr py::Ptr none(node_.invoke("setParameterArray", args)); } -std::string PyRegion::getParameterString(const std::string& name, Int64 index) -{ +std::string PyRegion::getParameterString(const std::string &name, Int64 index) { py::Tuple args(2); args.setItem(0, py::String(name)); args.setItem(1, py::LongLong(index)); py::String result(node_.invoke("getParameter", args)); - return (const char*)result; + return (const char *)result; } -void PyRegion::setParameterString(const std::string& name, Int64 index, const std::string& value) -{ +void PyRegion::setParameterString(const std::string &name, Int64 index, + const std::string &value) { py::Tuple args(3); args.setItem(0, py::String(name)); args.setItem(1, py::LongLong(index)); @@ -892,27 +805,21 @@ void PyRegion::setParameterString(const std::string& name, Int64 index, const st py::Ptr none(node_.invoke("setParameter", args)); } -void PyRegion::getParameterFromBuffer(const std::string& name, - Int64 index, - IWriteBuffer& value) -{ +void PyRegion::getParameterFromBuffer(const std::string &name, Int64 index, + IWriteBuffer &value) { // we override getParameterX for every type, so this should never // be called NTA_THROW << "::getParameterFromBuffer should not have been called"; } - -void PyRegion::setParameterFromBuffer(const std::string& name, - Int64 index, - IReadBuffer& value) -{ +void PyRegion::setParameterFromBuffer(const std::string &name, Int64 index, + IReadBuffer &value) { // we override getParameterX for every type, so this should never // be called NTA_THROW << "::setParameterFromBuffer should not have been called"; } -size_t PyRegion::getParameterArrayCount(const std::string& name, Int64 index) -{ +size_t PyRegion::getParameterArrayCount(const std::string &name, Int64 index) { py::Tuple args(2); args.setItem(0, py::String(name)); args.setItem(1, py::LongLong(index)); @@ -922,8 +829,7 @@ size_t PyRegion::getParameterArrayCount(const std::string& name, Int64 index) return size_t(result); } -size_t PyRegion::getNodeOutputElementCount(const std::string& outputName) -{ +size_t PyRegion::getNodeOutputElementCount(const std::string &outputName) { py::Tuple args(1); args.setItem(0, py::String(outputName)); @@ -932,14 +838,13 @@ size_t PyRegion::getNodeOutputElementCount(const std::string& outputName) return size_t(result); } -std::string PyRegion::executeCommand(const std::vector& args, Int64 index) -{ +std::string PyRegion::executeCommand(const std::vector &args, + Int64 index) { py::String cmd(args[0]); py::Tuple t(args.size() - 1); - for (size_t i = 1; i < args.size(); ++i) - { + for (size_t i = 1; i < args.size(); ++i) { py::String s(args[i]); - t.setItem(i-1, s); + t.setItem(i - 1, s); } py::Tuple commandArgs(2); @@ -949,29 +854,26 @@ std::string PyRegion::executeCommand(const std::vector& args, Int64 py::Instance res(node_.invoke("executeMethod", commandArgs)); py::String s(res.invoke("__str__", py::Tuple())); - const char * ss = (const char *)s; + const char *ss = (const char *)s; std::string result(ss); return ss; } -void PyRegion::compute() -{ - const Spec & ns = getSpec(); +void PyRegion::compute() { + const Spec &ns = getSpec(); // Prepare the inputs dict py::Dict inputs; - for (size_t i = 0; i < ns.inputs.getCount(); ++i) - { + for (size_t i = 0; i < ns.inputs.getCount(); ++i) { // Get the current InputSpec object - const std::pair & p = - ns.inputs.getByIndex(i); + const std::pair &p = ns.inputs.getByIndex(i); // Get the corresponding input buffer - Input * inp = region_->getInput(p.first); + Input *inp = region_->getInput(p.first); NTA_CHECK(inp); // Set pa to point to the original input array - const Array * pa = &(inp->getData()); + const Array *pa = &(inp->getData()); // Skip unlinked inputs of size 0 if (pa->getCount() == 0) @@ -979,24 +881,25 @@ void PyRegion::compute() // If the input requires a splitter map then // Copy the original input array to the stored input array, which is larger - // by one element and put 0 in the extra element. This is needed for splitter map - // access. - if (p.second.requireSplitterMap) - { + // by one element and put 0 in the extra element. This is needed for + // splitter map access. + if (p.second.requireSplitterMap) { // Verify that this input has a stored input array NTA_ASSERT(inputArrays_.find(p.first) != inputArrays_.end()); - Array & a = *(inputArrays_[p.first]); + Array &a = *(inputArrays_[p.first]); - // Verify that the stored input array is larger by 1 then the original input + // Verify that the stored input array is larger by 1 then the original + // input NTA_ASSERT(a.getCount() == pa->getCount() + 1); // Work at the char * level because there is no good way - // to work with the actual data type of the input (since the buffer is void *) + // to work with the actual data type of the input (since the buffer is + // void *) size_t itemSize = BasicType::getSize(p.second.dataType); - char * begin1 = (char *)pa->getBuffer(); - char * end1 = begin1 + pa->getCount() * itemSize; - char * begin2 = (char *)a.getBuffer(); - char * end2 = begin2 + a.getCount() * itemSize; + char *begin1 = (char *)pa->getBuffer(); + char *end1 = begin1 + pa->getCount() * itemSize; + char *begin2 = (char *)a.getBuffer(); + char *end2 = begin2 + a.getCount() * itemSize; // Copy the original input array to the stored array std::copy(begin1, end1, begin2); @@ -1017,19 +920,17 @@ void PyRegion::compute() // Prepare the outputs dict py::Dict outputs; - for (size_t i = 0; i < ns.outputs.getCount(); ++i) - { + for (size_t i = 0; i < ns.outputs.getCount(); ++i) { // Get the current OutputSpec object - const std::pair & p = - ns.outputs.getByIndex(i); + const std::pair &p = ns.outputs.getByIndex(i); // Get the corresponding output buffer - Output * out = region_->getOutput(p.first); + Output *out = region_->getOutput(p.first); // Skip optional outputs if (!out) continue; - const Array & data = out->getData(); + const Array &data = out->getData(); py::Ptr numpyArray(array2numpy(data)); @@ -1046,19 +947,17 @@ void PyRegion::compute() py::Ptr none(node_.invoke("guardedCompute", args)); } - // // Get the node spec from the underlying Python node // and populate a dynamically C++ node spec object. // Return the node spec pointer (that will be owned // by RegionImplFactory. // -void PyRegion::createSpec(const char * nodeType, Spec & ns, const char* className) -{ +void PyRegion::createSpec(const char *nodeType, Spec &ns, + const char *className) { // Get the Python class object std::string realClassName(className); - if (realClassName.empty()) - { + if (realClassName.empty()) { realClassName = Path::getExtension(nodeType); } @@ -1077,7 +976,7 @@ void PyRegion::createSpec(const char * nodeType, Spec & ns, const char* classNam // Extract the 4 dicts from the node spec py::Dict inputs(nodeSpec.getItem("inputs", py::Dict())); - //NTA_DEBUG << "'inputs' type: " << inputs.getTypeName(); + // NTA_DEBUG << "'inputs' type: " << inputs.getTypeName(); py::Dict outputs(nodeSpec.getItem("outputs", py::Dict())); py::Dict parameters(nodeSpec.getItem("parameters", py::Dict())); @@ -1086,14 +985,13 @@ void PyRegion::createSpec(const char * nodeType, Spec & ns, const char* classNam // key, value and pos are used to iterate over the // inputs, outputs, parameters and commands dicts // of the Python node spec - PyObject * key; - PyObject * value; + PyObject *key; + PyObject *value; Py_ssize_t pos; // Add inputs pos = 0; - while (PyDict_Next(inputs, &pos, &key, &value)) - { + while (PyDict_Next(inputs, &pos, &key, &value)) { // key and value are borrowed from the dict. Their ref count // must be incremented so they can be used with // the Py helpers safely @@ -1106,7 +1004,8 @@ void PyRegion::createSpec(const char * nodeType, Spec & ns, const char* classNam // Add an InputSpec object for each input spec dict std::ostringstream inputMessagePrefix; inputMessagePrefix << "Region " << realClassName - << " spec has missing key for input section " << name << ": "; + << " spec has missing key for input section " << name + << ": "; NTA_ASSERT(input.getItem("description") != nullptr) << inputMessagePrefix.str() << "description"; @@ -1135,9 +1034,8 @@ void PyRegion::createSpec(const char * nodeType, Spec & ns, const char* classNam // make regionLevel optional and default to true. bool regionLevel = true; - if (input.getItem("regionLevel") != nullptr) - { - regionLevel = py::Int(input.getItem("regionLevel")) != 0; + if (input.getItem("regionLevel") != nullptr) { + regionLevel = py::Int(input.getItem("regionLevel")) != 0; } NTA_ASSERT(input.getItem("isDefaultInput") != nullptr) @@ -1146,27 +1044,18 @@ void PyRegion::createSpec(const char * nodeType, Spec & ns, const char* classNam // make requireSplitterMap optional and default to false. bool requireSplitterMap = false; - if (input.getItem("requireSplitterMap") != nullptr) - { - requireSplitterMap = py::Int(input.getItem("requireSplitterMap")) != 0; + if (input.getItem("requireSplitterMap") != nullptr) { + requireSplitterMap = py::Int(input.getItem("requireSplitterMap")) != 0; } - ns.inputs.add( - name, - InputSpec( - description, - dataType, - count, - required, - regionLevel, - isDefaultInput, - requireSplitterMap)); + ns.inputs.add(name, + InputSpec(description, dataType, count, required, regionLevel, + isDefaultInput, requireSplitterMap)); } // Add outputs pos = 0; - while (PyDict_Next(outputs, &pos, &key, &value)) - { + while (PyDict_Next(outputs, &pos, &key, &value)) { // key and value are borrowed from the dict. Their ref count // must be incremented so they can be used with // the Py helpers safely @@ -1179,7 +1068,8 @@ void PyRegion::createSpec(const char * nodeType, Spec & ns, const char* classNam // Add an OutputSpec object for each output spec dict std::ostringstream outputMessagePrefix; outputMessagePrefix << "Region " << realClassName - << " spec has missing key for output section " << name << ": "; + << " spec has missing key for output section " << name + << ": "; NTA_ASSERT(output.getItem("description") != nullptr) << outputMessagePrefix.str() << "description"; @@ -1204,29 +1094,21 @@ void PyRegion::createSpec(const char * nodeType, Spec & ns, const char* classNam // make regionLevel optional and default to true. bool regionLevel = true; - if (output.getItem("regionLevel") != nullptr) - { - regionLevel = py::Int(output.getItem("regionLevel")) != 0; + if (output.getItem("regionLevel") != nullptr) { + regionLevel = py::Int(output.getItem("regionLevel")) != 0; } NTA_ASSERT(output.getItem("isDefaultOutput") != nullptr) << outputMessagePrefix.str() << "isDefaultOutput"; bool isDefaultOutput = py::Int(output.getItem("isDefaultOutput")) != 0; - ns.outputs.add( - name, - OutputSpec( - description, - dataType, - count, - regionLevel, - isDefaultOutput)); + ns.outputs.add(name, OutputSpec(description, dataType, count, regionLevel, + isDefaultOutput)); } // Add parameters pos = 0; - while (PyDict_Next(parameters, &pos, &key, &value)) - { + while (PyDict_Next(parameters, &pos, &key, &value)) { // key and value are borrowed from the dict. Their ref count // must be incremented so they can be used with // the Py helpers safely @@ -1239,7 +1121,8 @@ void PyRegion::createSpec(const char * nodeType, Spec & ns, const char* classNam // Add an ParameterSpec object for each output spec dict std::ostringstream parameterMessagePrefix; parameterMessagePrefix << "Region " << realClassName - << " spec has missing key for parameter section " << name << ": "; + << " spec has missing key for parameter section " + << name << ": "; NTA_ASSERT(parameter.getItem("description") != nullptr) << parameterMessagePrefix.str() << "description"; @@ -1262,10 +1145,9 @@ void PyRegion::createSpec(const char * nodeType, Spec & ns, const char* classNam << parameterMessagePrefix.str() << "count"; UInt32 count = py::Int(parameter.getItem("count")); - std::string constraints; // This parameter is optional - if (parameter.getItem("constraints") != nullptr){ + if (parameter.getItem("constraints") != nullptr) { constraints = py::String(parameter.getItem("constraints")); } else { constraints = py::String(""); @@ -1286,8 +1168,7 @@ void PyRegion::createSpec(const char * nodeType, Spec & ns, const char* classNam // Get default value as a string if it's a create parameter std::string defaultValue; - if (am == "Create") - { + if (am == "Create") { NTA_ASSERT(parameter.getItem("defaultValue") != nullptr) << parameterMessagePrefix.str() << "defaultValue"; py::Instance dv(parameter.getItem("defaultValue")); @@ -1297,32 +1178,20 @@ void PyRegion::createSpec(const char * nodeType, Spec & ns, const char* classNam if (defaultValue == "None") defaultValue = ""; - ns.parameters.add( - name, - ParameterSpec( - description, - dataType, - count, - constraints, - defaultValue, - accessMode)); + ns.parameters.add(name, + ParameterSpec(description, dataType, count, constraints, + defaultValue, accessMode)); } // Add the automatic "self" parameter ns.parameters.add( - "self", - ParameterSpec( - "The PyObject * of the region's Python classd", - NTA_BasicType_Handle, - 1, - "", - "", - ParameterSpec::ReadOnlyAccess)); + "self", ParameterSpec("The PyObject * of the region's Python classd", + NTA_BasicType_Handle, 1, "", "", + ParameterSpec::ReadOnlyAccess)); // Add commands pos = 0; - while (PyDict_Next(commands, &pos, &key, &value)) - { + while (PyDict_Next(commands, &pos, &key, &value)) { // key and value are borrowed from the dict. Their ref count // must be incremented so they can be used with // the Py helpers safely @@ -1335,24 +1204,21 @@ void PyRegion::createSpec(const char * nodeType, Spec & ns, const char* classNam // Add a CommandSpec object for each output spec dict std::ostringstream commandsMessagePrefix; commandsMessagePrefix << "Region " << realClassName - << " spec has missing key for commands section " << name << ": "; + << " spec has missing key for commands section " + << name << ": "; NTA_ASSERT(command.getItem("description") != nullptr) << commandsMessagePrefix.str() << "description"; std::string description(py::String(command.getItem("description"))); - ns.commands.add( - name, - CommandSpec(description)); + ns.commands.add(name, CommandSpec(description)); } } -void PyRegion::initialize() -{ +void PyRegion::initialize() { // Call the Python initialize() method // Need to put the None result in py::Ptr, so decrement the ref count py::Ptr none(node_.invoke("initialize", py::Tuple())); } - } // end namespace nupic diff --git a/src/nupic/regions/PyRegion.hpp b/src/nupic/regions/PyRegion.hpp index 196dd588f0..7120a794d8 100644 --- a/src/nupic/regions/PyRegion.hpp +++ b/src/nupic/regions/PyRegion.hpp @@ -20,15 +20,14 @@ * --------------------------------------------------------------------- */ - #ifndef NTA_PY_REGION_HPP #define NTA_PY_REGION_HPP #include +#include #include #include -#include #include @@ -36,147 +35,153 @@ #include #include -namespace nupic -{ - struct Spec; - - class PyRegion : public RegionImpl - { - typedef std::map SpecMap; - public: - // Used by RegionImplFactory to create and cache a nodespec - static Spec * createSpec(const char * nodeType, const char* className=""); - - // Used by RegionImplFactory to destroy a node spec when clearing its cache - static void destroySpec(const char * nodeType, const char* className=""); - - PyRegion(const char * module, const ValueMap & nodeParams, Region * region, const char* className=""); - PyRegion(const char * module, BundleIO& bundle, Region * region, const char* className=""); - PyRegion(const char * module, capnp::AnyPointer::Reader& proto, Region * region, const char* className=""); - virtual ~PyRegion(); - - // DynamicPythonLibrary functions. Originally used NTA_EXPORT - static void NTA_initPython(); - static void NTA_finalizePython(); - static void * NTA_createPyNode(const char * module, void * nodeParams, - void * region, void ** exception, const char* className=""); - static void * NTA_deserializePyNode(const char * module, void * bundle, - void * region, void ** exception, const char* className=""); - static void * NTA_deserializePyNodeProto(const char * module, void * proto, - void * region, void ** exception, const char* className=""); - static const char * NTA_getLastError(); - static void * NTA_createSpec(const char * nodeType, void ** exception, const char* className=""); - static int NTA_destroySpec(const char * nodeType, const char* className=""); - - // Manual serialization methods. Current recommended method. - void serialize(BundleIO& bundle) override; - void deserialize(BundleIO& bundle) override; - - // Capnp serialization methods - not yet implemented for PyRegions. This - // method will replace serialize/deserialize once fully implemented - // throughout codebase. - using RegionImpl::write; - /** - * Serialize instance to the given message builder - * - * :param proto: PyRegionProto builder masquerading as AnyPointer builder - */ - void write(capnp::AnyPointer::Builder& proto) const override; - - using RegionImpl::read; - /** - * Initialize instance from the given message reader - * - * :param proto: PyRegionProto reader masquerading as AnyPointer reader - */ - void read(capnp::AnyPointer::Reader& proto) override; - - const Spec & getSpec(); - - static void createSpec(const char * nodeType, Spec & ns, const char* className=""); - - // RegionImpl interface - - size_t getNodeOutputElementCount(const std::string& outputName) override; - void getParameterFromBuffer(const std::string& name, Int64 index, IWriteBuffer& value) override; - void setParameterFromBuffer(const std::string& name, Int64 index, IReadBuffer& value) override; - - void initialize() override; - void compute() override; - std::string executeCommand( - const std::vector& args, Int64 index) override; - - size_t getParameterArrayCount(const std::string& name, Int64 index) - override; - - virtual Byte getParameterByte(const std::string& name, Int64 index); - virtual Int32 getParameterInt32(const std::string& name, Int64 index) - override; - virtual UInt32 getParameterUInt32(const std::string& name, Int64 index) - override; - virtual Int64 getParameterInt64(const std::string& name, Int64 index) - override; - virtual UInt64 getParameterUInt64(const std::string& name, Int64 index) - override; - virtual Real32 getParameterReal32(const std::string& name, Int64 index) - override; - virtual Real64 getParameterReal64(const std::string& name, Int64 index) - override; - virtual Handle getParameterHandle(const std::string& name, Int64 index) - override; - virtual bool getParameterBool(const std::string& name, Int64 index) - override; - virtual std::string getParameterString( - const std::string& name, Int64 index) override; - - virtual void setParameterByte( - const std::string& name, Int64 index, Byte value); - virtual void setParameterInt32( - const std::string& name, Int64 index, Int32 value) override; - virtual void setParameterUInt32( - const std::string& name, Int64 index, UInt32 value) override; - virtual void setParameterInt64( - const std::string& name, Int64 index, Int64 value) override; - virtual void setParameterUInt64( - const std::string& name, Int64 index, UInt64 value) override; - virtual void setParameterReal32( - const std::string& name, Int64 index, Real32 value) override; - virtual void setParameterReal64( - const std::string& name, Int64 index, Real64 value) override; - virtual void setParameterHandle( - const std::string& name, Int64 index, Handle value) override; - virtual void setParameterBool( - const std::string& name, Int64 index, bool value) override; - virtual void setParameterString( - const std::string& name, Int64 index, const std::string& value) - override; - - virtual void getParameterArray( - const std::string& name, Int64 index, Array & array) override; - virtual void setParameterArray( - const std::string& name, Int64 index, const Array & array) override; - - // Helper methods - template - T getParameterT(const std::string & name, Int64 index); - - template - void setParameterT(const std::string & name, Int64 index, T value); - - private: - PyRegion(); - PyRegion(const Region &); - - private: - static SpecMap specs_; - std::string module_; - std::string className_; - py::Instance node_; - std::set > > splitterMaps_; - // pointers rather than objects because Array doesnt - // have a default constructor - std::map inputArrays_; - }; -} +namespace nupic { +struct Spec; + +class PyRegion : public RegionImpl { + typedef std::map SpecMap; + +public: + // Used by RegionImplFactory to create and cache a nodespec + static Spec *createSpec(const char *nodeType, const char *className = ""); + + // Used by RegionImplFactory to destroy a node spec when clearing its cache + static void destroySpec(const char *nodeType, const char *className = ""); + + PyRegion(const char *module, const ValueMap &nodeParams, Region *region, + const char *className = ""); + PyRegion(const char *module, BundleIO &bundle, Region *region, + const char *className = ""); + PyRegion(const char *module, capnp::AnyPointer::Reader &proto, Region *region, + const char *className = ""); + virtual ~PyRegion(); + + // DynamicPythonLibrary functions. Originally used NTA_EXPORT + static void NTA_initPython(); + static void NTA_finalizePython(); + static void *NTA_createPyNode(const char *module, void *nodeParams, + void *region, void **exception, + const char *className = ""); + static void *NTA_deserializePyNode(const char *module, void *bundle, + void *region, void **exception, + const char *className = ""); + static void *NTA_deserializePyNodeProto(const char *module, void *proto, + void *region, void **exception, + const char *className = ""); + static const char *NTA_getLastError(); + static void *NTA_createSpec(const char *nodeType, void **exception, + const char *className = ""); + static int NTA_destroySpec(const char *nodeType, const char *className = ""); + + // Manual serialization methods. Current recommended method. + void serialize(BundleIO &bundle) override; + void deserialize(BundleIO &bundle) override; + + // Capnp serialization methods - not yet implemented for PyRegions. This + // method will replace serialize/deserialize once fully implemented + // throughout codebase. + using RegionImpl::write; + /** + * Serialize instance to the given message builder + * + * :param proto: PyRegionProto builder masquerading as AnyPointer builder + */ + void write(capnp::AnyPointer::Builder &proto) const override; + + using RegionImpl::read; + /** + * Initialize instance from the given message reader + * + * :param proto: PyRegionProto reader masquerading as AnyPointer reader + */ + void read(capnp::AnyPointer::Reader &proto) override; + + const Spec &getSpec(); + + static void createSpec(const char *nodeType, Spec &ns, + const char *className = ""); + + // RegionImpl interface + + size_t getNodeOutputElementCount(const std::string &outputName) override; + void getParameterFromBuffer(const std::string &name, Int64 index, + IWriteBuffer &value) override; + void setParameterFromBuffer(const std::string &name, Int64 index, + IReadBuffer &value) override; + + void initialize() override; + void compute() override; + std::string executeCommand(const std::vector &args, + Int64 index) override; + + size_t getParameterArrayCount(const std::string &name, Int64 index) override; + + virtual Byte getParameterByte(const std::string &name, Int64 index); + virtual Int32 getParameterInt32(const std::string &name, + Int64 index) override; + virtual UInt32 getParameterUInt32(const std::string &name, + Int64 index) override; + virtual Int64 getParameterInt64(const std::string &name, + Int64 index) override; + virtual UInt64 getParameterUInt64(const std::string &name, + Int64 index) override; + virtual Real32 getParameterReal32(const std::string &name, + Int64 index) override; + virtual Real64 getParameterReal64(const std::string &name, + Int64 index) override; + virtual Handle getParameterHandle(const std::string &name, + Int64 index) override; + virtual bool getParameterBool(const std::string &name, Int64 index) override; + virtual std::string getParameterString(const std::string &name, + Int64 index) override; + + virtual void setParameterByte(const std::string &name, Int64 index, + Byte value); + virtual void setParameterInt32(const std::string &name, Int64 index, + Int32 value) override; + virtual void setParameterUInt32(const std::string &name, Int64 index, + UInt32 value) override; + virtual void setParameterInt64(const std::string &name, Int64 index, + Int64 value) override; + virtual void setParameterUInt64(const std::string &name, Int64 index, + UInt64 value) override; + virtual void setParameterReal32(const std::string &name, Int64 index, + Real32 value) override; + virtual void setParameterReal64(const std::string &name, Int64 index, + Real64 value) override; + virtual void setParameterHandle(const std::string &name, Int64 index, + Handle value) override; + virtual void setParameterBool(const std::string &name, Int64 index, + bool value) override; + virtual void setParameterString(const std::string &name, Int64 index, + const std::string &value) override; + + virtual void getParameterArray(const std::string &name, Int64 index, + Array &array) override; + virtual void setParameterArray(const std::string &name, Int64 index, + const Array &array) override; + + // Helper methods + template + T getParameterT(const std::string &name, Int64 index); + + template + void setParameterT(const std::string &name, Int64 index, T value); + +private: + PyRegion(); + PyRegion(const Region &); + +private: + static SpecMap specs_; + std::string module_; + std::string className_; + py::Instance node_; + std::set>> splitterMaps_; + // pointers rather than objects because Array doesnt + // have a default constructor + std::map inputArrays_; +}; +} // namespace nupic #endif // NTA_PY_REGION_HPP diff --git a/src/nupic/regions/VectorFile.cpp b/src/nupic/regions/VectorFile.cpp index 7451e2208a..18cb9ca35e 100644 --- a/src/nupic/regions/VectorFile.cpp +++ b/src/nupic/regions/VectorFile.cpp @@ -21,47 +21,41 @@ */ /** @file -* Implementation for VectorFile class -*/ + * Implementation for VectorFile class + */ #include // memset -#include -#include #include -#include #include -#include -#include #include // For isSystemLittleEndian and utils::swapBytesInPlace. #include #include +#include +#include +#include #include +#include #include using namespace std; using namespace nupic; //---------------------------------------------------------------------------- -VectorFile::VectorFile() -{ } - +VectorFile::VectorFile() {} //---------------------------------------------------------------------------- -VectorFile::~VectorFile() -{ - clear(); -} +VectorFile::~VectorFile() { clear(); } //---------------------------------------------------------------------------- -void VectorFile::clear(bool clearScaling) -{ +void VectorFile::clear(bool clearScaling) { Size n = fileVectors_.size(); - if(own_.size() != fileVectors_.size()) { + if (own_.size() != fileVectors_.size()) { throw logic_error("Invalid ownership flags."); } - - for(Size i=0; i 4 ) { + + if (fileFormat > 4) { NTA_THROW << "VectorFile::appendFile - incorrect file format: " - << fileFormat; + << fileFormat; } - + try { - if (fileFormat==3) - { + if (fileFormat == 3) { appendCSVFile(inFile, expectedElementCount); - } - else - { + } else { // Read in space separated text file string sLine; NTA_Size elementCount = expectedElementCount; if (fileFormat != 2) { inFile >> elementCount; getline(inFile, sLine); - + if (elementCount != expectedElementCount) { NTA_THROW << "VectorFile::appendFile - number of elements" - << " in file (" << elementCount << ") does not match" - << " output element count (" << expectedElementCount << ")"; + << " in file (" << elementCount << ") does not match" + << " output element count (" << expectedElementCount + << ")"; } } - - // If format is 'labeled', read in the next line, which is a label per elmnt + + // If format is 'labeled', read in the next line, which is a label per + // elmnt if (fileFormat == 1) { getline(inFile, sLine); - + // Pull out all the words from the first line istringstream aLine(sLine.c_str()); - while(1) { + while (1) { string aWord; aLine >> aWord; - if(aLine.fail()) break; + if (aLine.fail()) + break; elementLabels_.push_back(aWord); } - + // Ensure we have the right number of words if (elementLabels_.size() != elementCount) { - NTA_THROW << "VectorFile::appendFile - wrong number of element labels (" - << elementLabels_.size() << ") in file " << fileName; + NTA_THROW + << "VectorFile::appendFile - wrong number of element labels (" + << elementLabels_.size() << ") in file " << fileName; } } - + // Read each vector in, including labels if so indicated - while (!inFile.eof()) - { + while (!inFile.eof()) { string vectorLabel; if (fileFormat == 1) { inFile >> vectorLabel; } - + auto b = new NTA_Real[elementCount]; - for (Size i= 0; i < elementCount; ++i) { + for (Size i = 0; i < elementCount; ++i) { inFile >> b[i]; } - + if (!inFile.eof()) { fileVectors_.push_back(b); own_.push_back(true); vectorLabels_.push_back(vectorLabel); - } - else delete [] b; + } else + delete[] b; } } - } catch(ios_base::failure&) { - if (!inFile.eof()) - NTA_THROW << "VectorFile::appendFile" - << "Error reading from sensor input file: " << fileName - << " : improperly formatted data"; + } catch (ios_base::failure &) { + if (!inFile.eof()) + NTA_THROW << "VectorFile::appendFile" + << "Error reading from sensor input file: " << fileName + << " : improperly formatted data"; } } NTA_CHECK(fileVectors_.size() > 0) - << "VectorFile::appendFile - no vectors were read in."; + << "VectorFile::appendFile - no vectors were read in."; // Reset scaling only if the vector lengths changed - if (scaleVector_.size() != expectedElementCount) - { + if (scaleVector_.size() != expectedElementCount) { NTA_INFO << "appendFile - need to reset scale and offset vectors."; resetScaling((UInt)expectedElementCount); } } -// Determine if the file is a DOS file (ASCII 13, or CTRL-M line ending) +// Determine if the file is a DOS file (ASCII 13, or CTRL-M line ending) // Searches for ASCII 13 or 10. Return true if ASCII 13 is found before ASCII 10 -// False otherwise. -// This is a bit of a hack - if a string contains 13 or 10 it should not count, but -// we don't support strings in files anyway. -static bool dosEndings(IFStream &inFile) -{ +// False otherwise. +// This is a bit of a hack - if a string contains 13 or 10 it should not count, +// but we don't support strings in files anyway. +static bool dosEndings(IFStream &inFile) { bool unixLines = true; int pos = inFile.tellg(); - while (!inFile.eof()) - { + while (!inFile.eof()) { int c = inFile.get(); - if (c==10) - { + if (c == 10) { unixLines = false; break; - } - else if (c==13) - { + } else if (c == 13) { unixLines = true; break; } } - inFile.seekg(pos); // Reset back to where we were + inFile.seekg(pos); // Reset back to where we were return unixLines; } -void VectorFile::saveVectors(ostream &out, Size nColumns, UInt32 fileFormat, - Int64 begin, const char *lineEndings) -{ - saveVectors(out, nColumns, fileFormat, begin, fileVectors_.size(), lineEndings); +void VectorFile::saveVectors(ostream &out, Size nColumns, UInt32 fileFormat, + Int64 begin, const char *lineEndings) { + saveVectors(out, nColumns, fileFormat, begin, fileVectors_.size(), + lineEndings); } -void VectorFile::saveVectors(ostream &out, Size nColumns, UInt32 fileFormat, - Int64 begin, Int64 end, const char *lineEndings) -{ +void VectorFile::saveVectors(ostream &out, Size nColumns, UInt32 fileFormat, + Int64 begin, Int64 end, const char *lineEndings) { out.exceptions(ios_base::failbit | ios_base::badbit); Size n = fileVectors_.size(); - if(begin < 0) begin += n; - if(end < 0) end += n; - if(begin > Int64(n)) { + if (begin < 0) + begin += n; + if (end < 0) + end += n; + if (begin > Int64(n)) { stringstream msg; msg << "Begin (" << begin << ") out of bounds."; throw runtime_error(msg.str()); } - if(end > Int64(n)) { + if (end > Int64(n)) { stringstream msg; msg << "End (" << begin << ") out of bounds."; throw runtime_error(msg.str()); } - if(end < begin) end = begin; + if (end < begin) + end = begin; // Setup iterators for the rows. - vector::const_iterator i=(fileVectors_.begin() + size_t(begin)); + vector::const_iterator i = (fileVectors_.begin() + size_t(begin)); auto iend = i + size_t(end - begin); - switch(fileFormat) { + switch (fileFormat) { + case 0: + case 1: + case 2: + case 3: { + // Could be single chars (faster), but might prevent future extension. + const char *lineSep = lineEndings ? lineEndings : "\n"; + const char *sep = (fileFormat == 3) ? "," : " "; + + // Output the number of columns + switch (fileFormat) { case 0: case 1: - case 2: - case 3: - { - // Could be single chars (faster), but might prevent future extension. - const char *lineSep = lineEndings ? lineEndings : "\n"; - const char *sep = (fileFormat == 3) ? ",": " "; - - // Output the number of columns - switch(fileFormat) { - case 0: case 1: out << nColumns << lineSep; break; - default: break; - } - - // Decide if each row should be labelled in the output. - bool hasRowLabels = false; - vector::const_iterator iRowLabel; - switch(fileFormat) { - case 1: - // case 3: // Could be supported, but is not. - { - hasRowLabels = !vectorLabels_.empty(); - if(Int64(vectorLabels_.size()) < end) { - stringstream msg; - msg << "Too few vector labels (" << vectorLabels_.size() << ") " - "to write to file (writing to row " << end << ")."; - throw runtime_error(msg.str()); - } - break; - } - default: break; - } + out << nColumns << lineSep; + break; + default: + break; + } - // Output the column labels. - switch(fileFormat) { - case 1: - { - if(nColumns && !elementLabels_.size()) - throw runtime_error("Format '1' requires column labels."); - vector::const_iterator iLabel=elementLabels_.begin(), - labelEnd=elementLabels_.end(); - if(hasRowLabels) out << sep; // No row label for header row. - out << *(iLabel++); - for(; iLabel!=labelEnd; ++iLabel) out << sep << (*iLabel); - out << lineSep; - break; - } - case 3: // Identical to 1, but different error conditions. - { - if(elementLabels_.size()) { - vector::const_iterator iLabel=elementLabels_.begin(), - labelEnd=elementLabels_.end(); - if(hasRowLabels) out << sep; // No row label for header row. - out << *(iLabel++); - for(; iLabel!=labelEnd; ++iLabel) out << sep << (*iLabel); - out << lineSep; - } - break; + // Decide if each row should be labelled in the output. + bool hasRowLabels = false; + vector::const_iterator iRowLabel; + switch (fileFormat) { + case 1: + // case 3: // Could be supported, but is not. + { + hasRowLabels = !vectorLabels_.empty(); + if (Int64(vectorLabels_.size()) < end) { + stringstream msg; + msg << "Too few vector labels (" << vectorLabels_.size() + << ") " + "to write to file (writing to row " + << end << ")."; + throw runtime_error(msg.str()); } - default: break; + break; } + default: + break; + } - // Output the rows. - for(; i!=iend; ++i) { - if(hasRowLabels) { - out << *(iRowLabel++); - if(nColumns) out << sep; - } - const Real *p = *i; - if(nColumns) { - const Real *pEnd = p + nColumns; - out << *(p++); - for(; p::const_iterator iLabel = elementLabels_.begin(), + labelEnd = elementLabels_.end(); + if (hasRowLabels) + out << sep; // No row label for header row. + out << *(iLabel++); + for (; iLabel != labelEnd; ++iLabel) + out << sep << (*iLabel); + out << lineSep; + break; + } + case 3: // Identical to 1, but different error conditions. + { + if (elementLabels_.size()) { + vector::const_iterator iLabel = elementLabels_.begin(), + labelEnd = elementLabels_.end(); + if (hasRowLabels) + out << sep; // No row label for header row. + out << *(iLabel++); + for (; iLabel != labelEnd; ++iLabel) + out << sep << (*iLabel); out << lineSep; } - break; } - case 4: - case 5: - { - if(end <= begin) return; - const Size rowBytes = nColumns * sizeof(Real32); - const bool bigEndian = (fileFormat == 5); - const bool needSwap = (nupic::isSystemLittleEndian() == bigEndian); - const bool needConversion = (sizeof(Real32) == sizeof(Real)); - - if(needSwap || needConversion) { - auto buffer = new Real32[nColumns]; - try { - for(; i!=iend; ++i) { - if(needConversion) { - const Real *p = *i; - for(Size j=0; j(block) + (totalElements - 1); Real *pWrite = block + (totalElements - 1); - for(; pWrite>=block;) { + for (; pWrite >= block;) { *(pWrite--) = *(pRead--); } } - } - catch(...) { + } catch (...) { delete[] block; fileVectors_.resize(offset); own_.resize(offset); - if(nRowLabels) vectorLabels_.resize(offset); + if (nRowLabels) + vectorLabels_.resize(offset); throw; } } -// Append a CSV file to the list of stored vectors. There are some strict -// assumptions here. We assume that each row has at least expectedElements -// numbers separated by commas. It is ok to have more, we keep the first +// Append a CSV file to the list of stored vectors. There are some strict +// assumptions here. We assume that each row has at least expectedElements +// numbers separated by commas. It is ok to have more, we keep the first // expectedElements numbers. In addition, the first expectedElements values // must be numbers. We do not handle having a bunch of strings or empty values -// interspersed in the middle. If a row does have any of the above errors, +// interspersed in the middle. If a row does have any of the above errors, // the routine will silently skip it. The code handles error conditions like: // 23,,43 // 23,hello,42 @@ -474,113 +477,123 @@ void VectorFile::appendFloat32File(const string &filename, // 23,24, // 23,"42,d",55 -void VectorFile::appendCSVFile(IFStream &inFile, Size expectedElements) -{ - // Read in csv file one line at a time. If that line contains any errors, +void VectorFile::appendCSVFile(IFStream &inFile, Size expectedElements) { + // Read in csv file one line at a time. If that line contains any errors, // skip it and move onto the next one. - try - { + try { bool dosLines = dosEndings(inFile); - while (!inFile.eof()) - { + while (!inFile.eof()) { Size elementsFound = 0; auto b = new Real[expectedElements]; // Read and parse a single line - string sLine; // We'll use string for robust line parsing - stringstream converted; // We'll use stringstream for robust ascii text to Real conversion + string sLine; // We'll use string for robust line parsing + stringstream converted; // We'll use stringstream for robust ascii text to + // Real conversion size_t beg = 0, pos = 0; - + // Read the next line, using the appropriate delimiter - try - { - if (dosLines) getline(inFile,sLine,'\r'); - else getline(inFile, sLine); - } catch(...) { + try { + if (dosLines) + getline(inFile, sLine, '\r'); + else + getline(inFile, sLine); + } catch (...) { // An exception here is ok break; } - - while (pos != string::npos) - { + + while (pos != string::npos) { pos = sLine.find(',', beg); - converted << sLine.substr(beg, pos-beg) << " "; - beg = pos+1; + converted << sLine.substr(beg, pos - beg) << " "; + beg = pos + 1; } - for(Size i = 0; i < expectedElements; i++) { + for (Size i = 0; i < expectedElements; i++) { converted >> b[i]; - if(converted.fail()) break; + if (converted.fail()) + break; elementsFound++; } // Validate the parsed line and store it. - // At this point b will have some numbers. If elementsFound == expectedElements - // then everything went well and we can insert the array into fileVectors_ - // If not, then we deallocate the memory for b - if (elementsFound == expectedElements) - { + // At this point b will have some numbers. If elementsFound == + // expectedElements then everything went well and we can insert the array + // into fileVectors_ If not, then we deallocate the memory for b + if (elementsFound == expectedElements) { fileVectors_.push_back(b); own_.push_back(true); vectorLabels_.push_back(string()); b = nullptr; - } - else - { - delete [] b; + } else { + delete[] b; b = nullptr; } } - // any bizarre errors and cleanup - } catch(...) { + // any bizarre errors and cleanup + } catch (...) { NTA_THROW << "VectorFile - Error reading CSV file"; } } -template -void convert(T2 *pOut, const T1 *pIn, TSize n, TSize fill) -{ - for(TSize i=0; i +void convert(T2 *pOut, const T1 *pIn, TSize n, TSize fill) { + for (TSize i = 0; i < n; ++i) *(pOut++) = T2(*(pIn++)); - if(fill) ::memset(pOut, 0, fill * sizeof(T2)); + if (fill) + ::memset(pOut, 0, fill * sizeof(T2)); } -void VectorFile::appendIDXFile(const string &filename, - int expectedElements, bool bigEndian) -{ +void VectorFile::appendIDXFile(const string &filename, int expectedElements, + bool bigEndian) { const bool needSwap = (nupic::isSystemLittleEndian() == bigEndian); AutoReleaseFile file(filename); char header[4]; file.read(header, 4); - + int nDims = header[3]; - if(nDims < 1) throw runtime_error("Invalid number of dimensions."); + if (nDims < 1) + throw runtime_error("Invalid number of dimensions."); int dims[256]; file.read(dims, nDims * sizeof(int)); - if(needSwap) nupic::swapBytesInPlace(dims, nDims); + if (needSwap) + nupic::swapBytesInPlace(dims, nDims); int vectorSize = 1; - for(int i=1; i(readBuffer); - for(int row=0; row(readBuffer); + for (int row = 0; row < nRows; ++row) { + file.read(pRead, readRow); + // No need for byte swapping. + convert(pBlock, pRead, copy, fill); + pBlock += expectedElements; } - case 0x09: // signed byte. - { - signed char *pRead = reinterpret_cast(readBuffer); - for(int row=0; row(readBuffer); + for (int row = 0; row < nRows; ++row) { + file.read(pRead, readRow); + // No need for byte swapping. + convert(pBlock, pRead, copy, fill); + pBlock += expectedElements; } - case 0x0B: // signed short. - { - short *pRead = reinterpret_cast(readBuffer); - for(int row=0; row(readBuffer); + for (int row = 0; row < nRows; ++row) { + file.read(pRead, readRow); + if (needSwap) + nupic::swapBytesInPlace(pRead, copy); + convert(pBlock, pRead, copy, fill); + pBlock += expectedElements; } - case 0x0C: // signed int. - { - int *pRead = reinterpret_cast(readBuffer); - for(int row=0; row(readBuffer); + for (int row = 0; row < nRows; ++row) { + file.read(pRead, readRow); + if (needSwap) + nupic::swapBytesInPlace(pRead, copy); + convert(pBlock, pRead, copy, fill); + pBlock += expectedElements; } - case 0x0D: // 32-bit float. - { - float *pRead = reinterpret_cast(readBuffer); - for(int row=0; row(readBuffer); + for (int row = 0; row < nRows; ++row) { + file.read(pRead, readRow); + if (needSwap) + nupic::swapBytesInPlace(pRead, copy); + convert(pBlock, pRead, copy, fill); + pBlock += expectedElements; } - case 0x0E: // 64-bit float. - { - double *pRead = reinterpret_cast(readBuffer); - for(int row=0; row(readBuffer); + for (int row = 0; row < nRows; ++row) { + file.read(pRead, readRow); + if (needSwap) + nupic::swapBytesInPlace(pRead, copy); + convert(pBlock, pRead, copy, fill); + pBlock += expectedElements; } - default: throw logic_error("Unsupported type."); break; + break; + } + default: + throw logic_error("Unsupported type."); + break; } // Set up the ownership. own_.resize(offset + nRows, false); own_[offset] = true; // The first vector pointer points to the whole block. - if(nRowLabels) vectorLabels_.resize(offset + nRows); + if (nRowLabels) + vectorLabels_.resize(offset + nRows); // Set all the row pointers. fileVectors_.resize(offset + nRows); auto cur = fileVectors_.begin() + offset; Real *pEnd = block + (nRows * expectedElements); - for(Real *pCur=block; pCur!=pEnd; pCur+=expectedElements) + for (Real *pCur = block; pCur != pEnd; pCur += expectedElements) *(cur++) = pCur; - } - catch(...) { + } catch (...) { delete[] block; delete[] readBuffer; fileVectors_.resize(offset); own_.resize(offset); - if(nRowLabels) vectorLabels_.resize(offset); + if (nRowLabels) + vectorLabels_.resize(offset); throw; } - delete[] readBuffer; readBuffer = nullptr; + delete[] readBuffer; + readBuffer = nullptr; // Don't delete block, as it is owned by fileVectors_ now. } -/// Reset scaling to have no effect (unitary scaling vector and zero offset vector) -void VectorFile::resetScaling(UInt nElements) -{ - if (nElements != 0) - { +/// Reset scaling to have no effect (unitary scaling vector and zero offset +/// vector) +void VectorFile::resetScaling(UInt nElements) { + if (nElements != 0) { scaleVector_.resize(nElements); offsetVector_.resize(nElements); } - for (unsigned int i= 0; i < scaleVector_.size(); i++) - { + for (unsigned int i = 0; i < scaleVector_.size(); i++) { scaleVector_[i] = 1.0; offsetVector_[i] = 0.0; } } -size_t VectorFile::getElementCount() const -{ - return scaleVector_.size(); -} - +size_t VectorFile::getElementCount() const { return scaleVector_.size(); } /// Retrieve the i'th vector and copy into output without scaling /// output must have size at least elementCount -void VectorFile::getRawVector(const UInt v, Real *out, UInt offset, Size count) -{ +void VectorFile::getRawVector(const UInt v, Real *out, UInt offset, + Size count) { if (v >= vectorCount()) NTA_THROW << "Requested non-existent vector: " << v; - if ( !out || (count==0) ) + if (!out || (count == 0)) NTA_THROW << "Invalid arguments out is null and/or count is zero"; if (getElementCount() < offset + count) - NTA_THROW - << "Wrong offset/count: the sum " << offset << "+" << count << " = " << offset + count - << ", must be smaller than element count: " << getElementCount(); + NTA_THROW << "Wrong offset/count: the sum " << offset << "+" << count + << " = " << offset + count + << ", must be smaller than element count: " << getElementCount(); // Get the pointers and copy over the vector - Real *vec = fileVectors_[v]; - for (Size i= 0; i < count; i++) + Real *vec = fileVectors_[v]; + for (Size i = 0; i < count; i++) out[i] = vec[offset + i]; } /// Retrieve i'th vector, apply scaling and copy result into output /// output must have size at least 'count' elements -void VectorFile::getScaledVector(const UInt v, Real *out, UInt offset, Size count) -{ +void VectorFile::getScaledVector(const UInt v, Real *out, UInt offset, + Size count) { if (v >= vectorCount()) NTA_THROW << "Requested non-existent vector: " << v; NTA_CHECK(getElementCount() <= offset + count); // Get the pointers and copy over the vector - Real *vec = fileVectors_[v]; - for (Size i = 0; i < count; i++) - { - out[i] = scaleVector_[i]*(vec[i + offset] + offsetVector_[i]); + Real *vec = fileVectors_[v]; + for (Size i = 0; i < count; i++) { + out[i] = scaleVector_[i] * (vec[i + offset] + offsetVector_[i]); } } /// Get the scaling and offset values for element e -void VectorFile::getScaling(const UInt e, Real &scale, Real &offset) -{ +void VectorFile::getScaling(const UInt e, Real &scale, Real &offset) { if (e >= getElementCount()) NTA_THROW << "Requested non-existent element: " << e; scale = scaleVector_[e]; @@ -762,16 +775,14 @@ void VectorFile::getScaling(const UInt e, Real &scale, Real &offset) } /// Set the scale value for element e -void VectorFile::setScale(const UInt e, const Real scale) -{ +void VectorFile::setScale(const UInt e, const Real scale) { if (e >= getElementCount()) NTA_THROW << "Requested non-existent element: " << e; scaleVector_[e] = scale; } /// Set the offset value for element e -void VectorFile::setOffset(const UInt e, const Real offset) -{ +void VectorFile::setOffset(const UInt e, const Real offset) { if (e >= getElementCount()) NTA_THROW << "Requested non-existent element: " << e; offsetVector_[e] = offset; @@ -780,47 +791,45 @@ void VectorFile::setOffset(const UInt e, const Real offset) /// Set the scale and offset vectors to correspond to standard form /// Sets the offset component of each element to be -mean /// Sets the scale component of each element to be 1/stddev -void VectorFile::setStandardScaling() -{ - if ( (getElementCount() == 0) || (vectorCount() <= 1) ) - NTA_THROW << "Error in setting standard scaling: insufficient vectors loaded in memory."; +void VectorFile::setStandardScaling() { + if ((getElementCount() == 0) || (vectorCount() <= 1)) + NTA_THROW << "Error in setting standard scaling: insufficient vectors " + "loaded in memory."; Size nv = vectorCount(); - for (UInt e= 0; e < getElementCount(); e++) - { - double sum = 0, sum2 = 0; // Accumulate sums as doubles + for (UInt e = 0; e < getElementCount(); e++) { + double sum = 0, sum2 = 0; // Accumulate sums as doubles // First compute the mean and offset - for (Size i= 0; i < nv; i++) sum += fileVectors_[i][e]; + for (Size i = 0; i < nv; i++) + sum += fileVectors_[i][e]; double mean = sum / nv; - offsetVector_[e] = (Real) (-mean); + offsetVector_[e] = (Real)(-mean); // Now compute the squared term for stdev - for (Size i= 0; i < nv; i++) - { - double s = (fileVectors_[i][e]-mean); - sum2 += s*s; + for (Size i = 0; i < nv; i++) { + double s = (fileVectors_[i][e] - mean); + sum2 += s * s; } // Now compute the "unbiased" or "n-1" form of standard deviation - double stdev = sqrt( sum2/(nv-1) ); + double stdev = sqrt(sum2 / (nv - 1)); if (fabs(stdev) < 0.00000001) - NTA_THROW << "Error setting standard form, stdeviation is almost zero for some component."; - scaleVector_[e] = (Real) (1.0 / stdev); + NTA_THROW << "Error setting standard form, stdeviation is almost zero " + "for some component."; + scaleVector_[e] = (Real)(1.0 / stdev); } } /// Save the scale and offset vectors to this stream -void VectorFile::saveState(ostream &str) -{ +void VectorFile::saveState(ostream &str) { if (!str.good()) NTA_THROW << "saveState(): Internal error - Bad stream"; - // Save the number of elements, followed by scaling + // Save the number of elements, followed by scaling // and offset numbers for each component str << getElementCount() << " "; - for (UInt i= 0; i < getElementCount(); i++) - { + for (UInt i = 0; i < getElementCount(); i++) { str << scaleVector_[i] << " " << offsetVector_[i] << " "; } if (!str.good()) @@ -828,24 +837,21 @@ void VectorFile::saveState(ostream &str) } /// Save the scale and offset vectors to this stream -void VectorFile::readState(istream& str) -{ +void VectorFile::readState(istream &str) { if (!str.good()) NTA_THROW << "readState(): Internal error - Bad stream or network file"; UInt numElts; str >> numElts; - if ( (vectorCount()>0) && (numElts != getElementCount()) ) + if ((vectorCount() > 0) && (numElts != getElementCount())) NTA_THROW << "readState(): Number of elements in stream does not match " << "stored vectors"; resetScaling(numElts); - for (UInt i= 0; i < numElts; i++) - { + for (UInt i = 0; i < numElts; i++) { str >> scaleVector_[i]; str >> offsetVector_[i]; } if (!str.good()) NTA_THROW << "readState(): Internal error - Bad stream or network file"; } - diff --git a/src/nupic/regions/VectorFile.hpp b/src/nupic/regions/VectorFile.hpp index 80aa19b2ab..59f1f97fe3 100644 --- a/src/nupic/regions/VectorFile.hpp +++ b/src/nupic/regions/VectorFile.hpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Simple class for reading and processing data files */ @@ -31,126 +31,122 @@ //---------------------------------------------------------------------- -#include -#include #include +#include +#include -namespace nupic -{ - /** - * VectorFile is a simple container class for lists of numerical vectors. Its only - * purpose is to support the needs of the VectorFileSensor. Key features of - * interest are its ability to read in different text file formats and its - * ability to dynamically scale its outputs. - */ - class VectorFile - { - public: - - VectorFile(); - virtual ~VectorFile(); - - static Int32 maxFormat() { return 6; } - - /// Read in vectors from the given filename. All vectors are expected to - /// have the same size (i.e. same number of elements). - /// If a list already exists, new vectors are expected to have the same size - /// and will be appended to the end of the list - /// appendFile will NOT change the scaling vectors as long as the expectedElementCount - /// is the same as previously stored vectors. - /// The fileFormat number corresponds to the file formats in VectorFileSensor: - /// 0 # Reads in unlabeled file with first number = element count - /// 1 # Reads in a labeled file with first number = element count - /// 2 # Reads in unlabeled file without element count - /// 3 # Reads in a csv file - /// 4 # Reads in a little-endian float32 binary file - /// 5 # Reads in a big-endian float32 binary file - /// 6 # Reads in a big-endian IDX binary file - void appendFile(const std::string &fileName, - NTA_Size expectedElementCount, - UInt32 fileFormat); - - /// Retrieve i'th vector, apply scaling and copy result into output - /// output must have size of at least 'count' elements - void getScaledVector(const UInt i, Real *out, UInt offset, Size count); - - /// Retrieve the i'th vector and copy into output without scaling - /// output must have size at least 'count' elements - void getRawVector(const UInt i, Real *out, UInt offset, Size count); - - /// Return the number of stored vectors - size_t vectorCount() const { return fileVectors_.size(); } - - /// Return the size of each vevtor (number of elements per vector) - size_t getElementCount() const; - - /// Set the scale and offset vectors to correspond to standard form - /// Sets the offset component of each element to be -mean - /// Sets the scale component of each element to be 1/stddev - void setStandardScaling(); - - /// Reset scaling to have no effect (unitary scaling vector and zero offset vector) - /// If nElements > 0, also resize the scaling vector to have that many elements, - /// otherwise leave it as-is - void resetScaling(UInt nElements = 0); - - /// Get the scaling and offset values for element e - void getScaling(const UInt e, Real &scale, Real &offset); - - /// Set the scale value for element e - void setScale(const UInt e, const Real scale); - - /// Set the offset value for element e - void setOffset(const UInt e, const Real offset); - - /// Clear the set of vectors and labels, including scale and offset vectors, - /// release all memory, and set numElements back to zero. - void clear(bool clearScaling = true); - - // Return true iff a labeled file was read in - inline bool isLabeled() const - { return (! (elementLabels_.empty() || vectorLabels_.empty()) ); } - - /// Save the scale and offset vectors to this stream - void saveState(std::ostream &str); - - /// Initialize the scaling and offset vectors from this stream - /// If vectorCount() > 0, it is an error if numElements() - /// does not match the data in the stream - void readState(std::istream& state); - - /// Save vectors, unscaled, to a file with the specified format. - void saveVectors(std::ostream &out, Size nColumns, UInt32 fileFormat, - Int64 begin=0, const char *lineEndings=nullptr); - void saveVectors(std::ostream &out, Size nColumns, UInt32 fileFormat, - Int64 begin, Int64 end, const char *lineEndings=nullptr); - - private: - std::vector fileVectors_; // list of vectors - std::vector own_; // memory ownership flags - std::vector scaleVector_; // the scaling vector - std::vector offsetVector_; // the offset vector - - std::vector elementLabels_; // string denoting the meaning of each element - std::vector vectorLabels_; // a string label for each vector - - //------------------- Utility routines - void appendCSVFile(IFStream &inFile, Size expectedElementCount); - - /// Read vectors from a binary file. - void appendFloat32File(const std::string &filename, Size expectedElements, - bool bigEndian); - - /// Read vectors from a binary IDX file. - void appendIDXFile(const std::string &filename, int expectedElements, - bool bigEndian); - - }; // end class VectorFile - - //---------------------------------------------------------------------- - -} +namespace nupic { +/** + * VectorFile is a simple container class for lists of numerical vectors. Its + * only purpose is to support the needs of the VectorFileSensor. Key features of + * interest are its ability to read in different text file formats and its + * ability to dynamically scale its outputs. + */ +class VectorFile { +public: + VectorFile(); + virtual ~VectorFile(); + + static Int32 maxFormat() { return 6; } + + /// Read in vectors from the given filename. All vectors are expected to + /// have the same size (i.e. same number of elements). + /// If a list already exists, new vectors are expected to have the same size + /// and will be appended to the end of the list + /// appendFile will NOT change the scaling vectors as long as the + /// expectedElementCount is the same as previously stored vectors. The + /// fileFormat number corresponds to the file formats in VectorFileSensor: + /// 0 # Reads in unlabeled file with first number = element + /// count 1 # Reads in a labeled file with first number = + /// element count 2 # Reads in unlabeled file without element + /// count 3 # Reads in a csv file 4 # Reads in a + /// little-endian float32 binary file 5 # Reads in a + /// big-endian float32 binary file 6 # Reads in a big-endian + /// IDX binary file + void appendFile(const std::string &fileName, NTA_Size expectedElementCount, + UInt32 fileFormat); + + /// Retrieve i'th vector, apply scaling and copy result into output + /// output must have size of at least 'count' elements + void getScaledVector(const UInt i, Real *out, UInt offset, Size count); + + /// Retrieve the i'th vector and copy into output without scaling + /// output must have size at least 'count' elements + void getRawVector(const UInt i, Real *out, UInt offset, Size count); + + /// Return the number of stored vectors + size_t vectorCount() const { return fileVectors_.size(); } + + /// Return the size of each vevtor (number of elements per vector) + size_t getElementCount() const; + + /// Set the scale and offset vectors to correspond to standard form + /// Sets the offset component of each element to be -mean + /// Sets the scale component of each element to be 1/stddev + void setStandardScaling(); + + /// Reset scaling to have no effect (unitary scaling vector and zero offset + /// vector) If nElements > 0, also resize the scaling vector to have that many + /// elements, otherwise leave it as-is + void resetScaling(UInt nElements = 0); + + /// Get the scaling and offset values for element e + void getScaling(const UInt e, Real &scale, Real &offset); + + /// Set the scale value for element e + void setScale(const UInt e, const Real scale); + + /// Set the offset value for element e + void setOffset(const UInt e, const Real offset); + + /// Clear the set of vectors and labels, including scale and offset vectors, + /// release all memory, and set numElements back to zero. + void clear(bool clearScaling = true); + + // Return true iff a labeled file was read in + inline bool isLabeled() const { + return (!(elementLabels_.empty() || vectorLabels_.empty())); + } + + /// Save the scale and offset vectors to this stream + void saveState(std::ostream &str); + + /// Initialize the scaling and offset vectors from this stream + /// If vectorCount() > 0, it is an error if numElements() + /// does not match the data in the stream + void readState(std::istream &state); + + /// Save vectors, unscaled, to a file with the specified format. + void saveVectors(std::ostream &out, Size nColumns, UInt32 fileFormat, + Int64 begin = 0, const char *lineEndings = nullptr); + void saveVectors(std::ostream &out, Size nColumns, UInt32 fileFormat, + Int64 begin, Int64 end, const char *lineEndings = nullptr); + +private: + std::vector fileVectors_; // list of vectors + std::vector own_; // memory ownership flags + std::vector scaleVector_; // the scaling vector + std::vector offsetVector_; // the offset vector + + std::vector + elementLabels_; // string denoting the meaning of each element + std::vector vectorLabels_; // a string label for each vector + + //------------------- Utility routines + void appendCSVFile(IFStream &inFile, Size expectedElementCount); + + /// Read vectors from a binary file. + void appendFloat32File(const std::string &filename, Size expectedElements, + bool bigEndian); + + /// Read vectors from a binary IDX file. + void appendIDXFile(const std::string &filename, int expectedElements, + bool bigEndian); + +}; // end class VectorFile -#endif // NTA_VECTOR_FILE_HPP +//---------------------------------------------------------------------- +} // namespace nupic +#endif // NTA_VECTOR_FILE_HPP diff --git a/src/nupic/regions/VectorFileEffector.cpp b/src/nupic/regions/VectorFileEffector.cpp index 23ba07b583..b668a2a997 100644 --- a/src/nupic/regions/VectorFileEffector.cpp +++ b/src/nupic/regions/VectorFileEffector.cpp @@ -24,140 +24,113 @@ * Implementation for VectorFileEffector class */ -#include -#include #include -#include #include +#include +#include +#include #include #include #include +#include +#include #include #include -#include -#include -namespace nupic -{ +namespace nupic { -VectorFileEffector::VectorFileEffector(const ValueMap& params, Region* region) : - RegionImpl(region), - dataIn_(NTA_BasicType_Real32), - filename_(""), - outFile_(nullptr) -{ +VectorFileEffector::VectorFileEffector(const ValueMap ¶ms, Region *region) + : RegionImpl(region), dataIn_(NTA_BasicType_Real32), filename_(""), + outFile_(nullptr) { if (params.contains("outputFile")) filename_ = *params.getString("outputFile"); else filename_ = ""; - } -VectorFileEffector::VectorFileEffector(BundleIO& bundle, Region* region) : - RegionImpl(region), - dataIn_(NTA_BasicType_Real32), - filename_(""), - outFile_(nullptr) -{ -} +VectorFileEffector::VectorFileEffector(BundleIO &bundle, Region *region) + : RegionImpl(region), dataIn_(NTA_BasicType_Real32), filename_(""), + outFile_(nullptr) {} -VectorFileEffector::VectorFileEffector( - capnp::AnyPointer::Reader& proto, Region* region) : - RegionImpl(region), - dataIn_(NTA_BasicType_Real32), - filename_(""), - outFile_(nullptr) -{ +VectorFileEffector::VectorFileEffector(capnp::AnyPointer::Reader &proto, + Region *region) + : RegionImpl(region), dataIn_(NTA_BasicType_Real32), filename_(""), + outFile_(nullptr) { read(proto); } +VectorFileEffector::~VectorFileEffector() { closeFile(); } -VectorFileEffector::~VectorFileEffector() -{ - closeFile(); -} - - - -void VectorFileEffector::initialize() -{ +void VectorFileEffector::initialize() { NTA_CHECK(region_ != nullptr); // We have no outputs or parameters; just need our input. dataIn_ = region_->getInputData("dataIn"); - if (dataIn_.getCount() == 0) - { + if (dataIn_.getCount() == 0) { NTA_THROW << "VectorFileEffector::init - no input found\n"; } } +void VectorFileEffector::compute() { -void VectorFileEffector::compute() -{ - - // It's not necessarily an error to have no inputs. In this case we just return + // It's not necessarily an error to have no inputs. In this case we just + // return if (dataIn_.getCount() == 0) return; // Don't write if there is no open file. - if (outFile_ == nullptr) - { - NTA_WARN << "VectorFileEffector compute() called, but there is no open file"; + if (outFile_ == nullptr) { + NTA_WARN + << "VectorFileEffector compute() called, but there is no open file"; return; } // Ensure we can write to it - if ( outFile_->fail() ) - { + if (outFile_->fail()) { NTA_THROW << "VectorFileEffector: There was an error writing to the file " << filename_.c_str() << "\n"; } - Real *inputVec = (Real*)(dataIn_.getBuffer()); + Real *inputVec = (Real *)(dataIn_.getBuffer()); NTA_CHECK(inputVec != nullptr); OFStream &outFile = *outFile_; - for(Size offset = 0; offset < dataIn_.getCount(); ++offset) - { + for (Size offset = 0; offset < dataIn_.getCount(); ++offset) { // TBD -- could be very inefficient to do one at a time outFile << inputVec[offset] << " "; } outFile << "\n"; } - -void VectorFileEffector::closeFile() -{ - if (outFile_) - { +void VectorFileEffector::closeFile() { + if (outFile_) { outFile_->close(); outFile_ = nullptr; filename_ = ""; } } -void VectorFileEffector::openFile(const std::string& filename) -{ +void VectorFileEffector::openFile(const std::string &filename) { if (outFile_ && !outFile_->fail()) - closeFile(); + closeFile(); if (filename == "") return; outFile_ = new OFStream(filename.c_str(), std::ios::app); - if (outFile_->fail()) - { + if (outFile_->fail()) { delete outFile_; outFile_ = nullptr; - NTA_THROW << "VectorFileEffector::openFile -- unable to create or open file: " << filename.c_str(); + NTA_THROW + << "VectorFileEffector::openFile -- unable to create or open file: " + << filename.c_str(); } filename_ = filename; } - -void VectorFileEffector::setParameterString(const std::string& paramName, Int64 index, const std::string& s) -{ +void VectorFileEffector::setParameterString(const std::string ¶mName, + Int64 index, const std::string &s) { if (paramName == "outputFile") { if (s == filename_ && outFile_) @@ -168,11 +141,10 @@ void VectorFileEffector::setParameterString(const std::string& paramName, Int64 } else { NTA_THROW << "VectorFileEffector -- Unknown string parameter " << paramName; } - } -std::string VectorFileEffector::getParameterString(const std::string& paramName, Int64 index) -{ +std::string VectorFileEffector::getParameterString(const std::string ¶mName, + Int64 index) { if (paramName == "outputFile") { return filename_; } else { @@ -180,133 +152,103 @@ std::string VectorFileEffector::getParameterString(const std::string& paramName, } } -std::string VectorFileEffector::executeCommand(const std::vector& args, Int64 index) -{ +std::string +VectorFileEffector::executeCommand(const std::vector &args, + Int64 index) { NTA_CHECK(args.size() > 0); // Process the flushFile command - if (args[0] == "flushFile") - { + if (args[0] == "flushFile") { // Ensure we have a valid file before flushing, otherwise fail silently. - if (!((outFile_ == nullptr) || (outFile_->fail()))) - { + if (!((outFile_ == nullptr) || (outFile_->fail()))) { outFile_->flush(); } - } else if (args[0] == "closeFile") - { + } else if (args[0] == "closeFile") { closeFile(); - } else if (args[0] == "echo") - { + } else if (args[0] == "echo") { // Ensure we have a valid file before flushing, otherwise fail silently. - if ((outFile_ == nullptr) || (outFile_->fail())) - { - NTA_THROW << "VectorFileEffector: echo command failed because there is no file open"; + if ((outFile_ == nullptr) || (outFile_->fail())) { + NTA_THROW << "VectorFileEffector: echo command failed because there is " + "no file open"; } - for (size_t i = 1; i < args.size(); i++) - { + for (size_t i = 1; i < args.size(); i++) { *outFile_ << args[i]; } *outFile_ << "\n"; - } else - { + } else { NTA_THROW << "VectorFileEffector: Unknown execute '" << args[0] << "'"; } return ""; - } - -Spec* VectorFileEffector::createSpec() -{ +Spec *VectorFileEffector::createSpec() { auto ns = new Spec; ns->description = - "VectorFileEffector is a node that simply writes its\n" - "input vectors to a text file. The target filename is specified\n" - "using the 'outputFile' parameter at run time. On each\n" - "compute, the current input vector is written (but not flushed)\n" - "to the file.\n"; - - ns->inputs.add( - "dataIn", - InputSpec("Data to be written to file", - NTA_BasicType_Real32, - 0, // count - false, // required? - false, // isRegionLevel - true // isDefaultInput - )); - - - - ns->parameters.add( - "outputFile", - ParameterSpec( - "Writes output vectors to this file on each compute. Will append to any\n" - "existing data in the file. This parameter must be set at runtime before\n" - "the first compute is called. Throws an exception if it is not set or\n" - "the file cannot be written to.\n", - NTA_BasicType_Byte, - 0, // elementCount - "", // constraints - "", // defaultValue - ParameterSpec::ReadWriteAccess - )); - - ns->commands.add( - "flushFile", - CommandSpec("Flush file data to disk")); - - ns->commands.add( - "closeFile", - CommandSpec("Close the current file, if open.")); + "VectorFileEffector is a node that simply writes its\n" + "input vectors to a text file. The target filename is specified\n" + "using the 'outputFile' parameter at run time. On each\n" + "compute, the current input vector is written (but not flushed)\n" + "to the file.\n"; + + ns->inputs.add("dataIn", + InputSpec("Data to be written to file", NTA_BasicType_Real32, + 0, // count + false, // required? + false, // isRegionLevel + true // isDefaultInput + )); + + ns->parameters.add("outputFile", + ParameterSpec("Writes output vectors to this file on each " + "compute. Will append to any\n" + "existing data in the file. This parameter " + "must be set at runtime before\n" + "the first compute is called. Throws an " + "exception if it is not set or\n" + "the file cannot be written to.\n", + NTA_BasicType_Byte, + 0, // elementCount + "", // constraints + "", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->commands.add("flushFile", CommandSpec("Flush file data to disk")); + + ns->commands.add("closeFile", + CommandSpec("Close the current file, if open.")); return ns; } - -size_t VectorFileEffector::getNodeOutputElementCount(const std::string& outputName) -{ - NTA_THROW << "VectorFileEffector::getNodeOutputElementCount -- unknown output '" - << outputName << "'"; +size_t +VectorFileEffector::getNodeOutputElementCount(const std::string &outputName) { + NTA_THROW + << "VectorFileEffector::getNodeOutputElementCount -- unknown output '" + << outputName << "'"; } - -void VectorFileEffector::getParameterFromBuffer(const std::string& name, Int64 index, IWriteBuffer& value) -{ +void VectorFileEffector::getParameterFromBuffer(const std::string &name, + Int64 index, + IWriteBuffer &value) { NTA_THROW << "VectorFileEffector -- unknown parameter '" << name << "'"; } - -void VectorFileEffector::setParameterFromBuffer(const std::string& name, Int64 index, IReadBuffer& value) -{ +void VectorFileEffector::setParameterFromBuffer(const std::string &name, + Int64 index, + IReadBuffer &value) { NTA_THROW << "VectorFileEffector -- unknown parameter '" << name << "'"; } +void VectorFileEffector::serialize(BundleIO &bundle) { return; } -void VectorFileEffector::serialize(BundleIO& bundle) -{ - return; -} - - -void VectorFileEffector::deserialize(BundleIO& bundle) -{ - return; -} - - -void VectorFileEffector::write(capnp::AnyPointer::Builder& anyProto) const -{ - return; -} +void VectorFileEffector::deserialize(BundleIO &bundle) { return; } - -void VectorFileEffector::read(capnp::AnyPointer::Reader& anyProto) -{ +void VectorFileEffector::write(capnp::AnyPointer::Builder &anyProto) const { return; } +void VectorFileEffector::read(capnp::AnyPointer::Reader &anyProto) { return; } -} +} // namespace nupic diff --git a/src/nupic/regions/VectorFileEffector.hpp b/src/nupic/regions/VectorFileEffector.hpp index 39d42bac41..d15f5dabf3 100644 --- a/src/nupic/regions/VectorFileEffector.hpp +++ b/src/nupic/regions/VectorFileEffector.hpp @@ -33,101 +33,96 @@ #include -#include -#include #include #include #include +#include +#include -namespace nupic -{ - - class ValueMap; - - /** - * VectorFileEffector is a node that takes its input vectors and - * writes them sequentially to a file. - * - * The current input vector is written (but not flushed) to the file - * each time the effector's compute() method is called. - * - * The file format for the file is a space-separated list of numbers, with - * one vector per line: - * - * e11 e12 e13 ... e1N - * e21 e22 e23 ... e2N - * : - * eM1 eM2 eM3 ... eMN - * - * VectorFileEffector implements the execute() commands as defined in the - * nodeSpec. - * - */ - class VectorFileEffector : public RegionImpl - { - public: - - static Spec* createSpec(); - size_t getNodeOutputElementCount(const std::string& outputName) override; - void getParameterFromBuffer(const std::string& name, Int64 index, IWriteBuffer& value) override; - - void setParameterFromBuffer(const std::string& name, Int64 index, IReadBuffer& value) override; - - void setParameterString(const std::string& name, Int64 index, const std::string& s) override; - std::string getParameterString(const std::string& name, Int64 index) override; - - void initialize() override; +namespace nupic { - VectorFileEffector(const ValueMap& params, Region *region); +class ValueMap; - VectorFileEffector(BundleIO& bundle, Region* region); +/** + * VectorFileEffector is a node that takes its input vectors and + * writes them sequentially to a file. + * + * The current input vector is written (but not flushed) to the file + * each time the effector's compute() method is called. + * + * The file format for the file is a space-separated list of numbers, with + * one vector per line: + * + * e11 e12 e13 ... e1N + * e21 e22 e23 ... e2N + * : + * eM1 eM2 eM3 ... eMN + * + * VectorFileEffector implements the execute() commands as defined in the + * nodeSpec. + * + */ +class VectorFileEffector : public RegionImpl { +public: + static Spec *createSpec(); + size_t getNodeOutputElementCount(const std::string &outputName) override; + void getParameterFromBuffer(const std::string &name, Int64 index, + IWriteBuffer &value) override; - VectorFileEffector(capnp::AnyPointer::Reader& proto, Region* region); + void setParameterFromBuffer(const std::string &name, Int64 index, + IReadBuffer &value) override; - virtual ~VectorFileEffector(); + void setParameterString(const std::string &name, Int64 index, + const std::string &s) override; + std::string getParameterString(const std::string &name, Int64 index) override; + void initialize() override; - // --- - /// Serialize state to bundle - // --- - virtual void serialize(BundleIO& bundle) override; + VectorFileEffector(const ValueMap ¶ms, Region *region); - // --- - /// De-serialize state from bundle - // --- - virtual void deserialize(BundleIO& bundle) override; + VectorFileEffector(BundleIO &bundle, Region *region); - using RegionImpl::write; - virtual void write(capnp::AnyPointer::Builder& anyProto) const override; + VectorFileEffector(capnp::AnyPointer::Reader &proto, Region *region); - using RegionImpl::read; - virtual void read(capnp::AnyPointer::Reader& anyProto) override; + virtual ~VectorFileEffector(); - void compute() override; + // --- + /// Serialize state to bundle + // --- + virtual void serialize(BundleIO &bundle) override; - virtual std::string executeCommand(const std::vector& args, Int64 index) override; + // --- + /// De-serialize state from bundle + // --- + virtual void deserialize(BundleIO &bundle) override; + using RegionImpl::write; + virtual void write(capnp::AnyPointer::Builder &anyProto) const override; + using RegionImpl::read; + virtual void read(capnp::AnyPointer::Reader &anyProto) override; - private: + void compute() override; - void closeFile(); - void openFile(const std::string& filename); + virtual std::string executeCommand(const std::vector &args, + Int64 index) override; - ArrayRef dataIn_; - std::string filename_; // Name of the output file - nupic::OFStream *outFile_; // Handle to current file +private: + void closeFile(); + void openFile(const std::string &filename); - /// Disable unsupported default constructors - VectorFileEffector(const VectorFileEffector&); - VectorFileEffector& operator=(const VectorFileEffector&); + ArrayRef dataIn_; + std::string filename_; // Name of the output file + nupic::OFStream *outFile_; // Handle to current file - }; // end class VectorFileEffector + /// Disable unsupported default constructors + VectorFileEffector(const VectorFileEffector &); + VectorFileEffector &operator=(const VectorFileEffector &); - //---------------------------------------------------------------------- +}; // end class VectorFileEffector +//---------------------------------------------------------------------- } // namespace nupic - #endif // NTA_VECTOR_FILE_EFFECTOR_HPP diff --git a/src/nupic/regions/VectorFileSensor.cpp b/src/nupic/regions/VectorFileSensor.cpp index 9e50dd292e..9d0335a1fd 100644 --- a/src/nupic/regions/VectorFileSensor.cpp +++ b/src/nupic/regions/VectorFileSensor.cpp @@ -21,15 +21,15 @@ */ /** @file -* Implementation for VectorFileSensor class -*/ + * Implementation for VectorFileSensor class + */ -#include -#include +#include // strlen #include -#include #include -#include // strlen +#include +#include +#include #include @@ -40,119 +40,86 @@ #include #include //#include -#include #include +#include using namespace std; -namespace nupic -{ +namespace nupic { //---------------------------------------------------------------------------- -VectorFileSensor::VectorFileSensor(const ValueMap& params, Region* region) : - RegionImpl(region), - - repeatCount_(1), - iterations_(0), - curVector_(0), - activeOutputCount_(0), - hasCategoryOut_(false), - hasResetOut_(false), - dataOut_(NTA_BasicType_Real32), - categoryOut_(NTA_BasicType_Real32), - resetOut_(NTA_BasicType_Real32), - filename_(""), - scalingMode_("none"), - recentFile_("") -{ - activeOutputCount_ = params.getScalar("activeOutputCount")->getValue(); +VectorFileSensor::VectorFileSensor(const ValueMap ¶ms, Region *region) + : RegionImpl(region), + + repeatCount_(1), iterations_(0), curVector_(0), activeOutputCount_(0), + hasCategoryOut_(false), hasResetOut_(false), + dataOut_(NTA_BasicType_Real32), categoryOut_(NTA_BasicType_Real32), + resetOut_(NTA_BasicType_Real32), filename_(""), scalingMode_("none"), + recentFile_("") { + activeOutputCount_ = + params.getScalar("activeOutputCount")->getValue(); if (params.contains("hasCategoryOut")) hasCategoryOut_ = - params.getScalar("hasCategoryOut")->getValue() == 1; + params.getScalar("hasCategoryOut")->getValue() == 1; if (params.contains("hasResetOut")) - hasResetOut_ = - params.getScalar("hasResetOut")->getValue() == 1; + hasResetOut_ = params.getScalar("hasResetOut")->getValue() == 1; if (params.contains("inputFile")) filename_ = *params.getString("inputFile"); if (params.contains("repeatCount")) repeatCount_ = params.getScalar("repeatCount")->getValue(); } -VectorFileSensor::VectorFileSensor(BundleIO& bundle, Region* region) : - RegionImpl(region), - repeatCount_(1), - iterations_(0), - curVector_(0), - activeOutputCount_(0), - hasCategoryOut_(false), - hasResetOut_(false), - dataOut_(NTA_BasicType_Real32), - categoryOut_(NTA_BasicType_Real32), - resetOut_(NTA_BasicType_Real32), - filename_(""), - scalingMode_("none"), - recentFile_("") -{ +VectorFileSensor::VectorFileSensor(BundleIO &bundle, Region *region) + : RegionImpl(region), repeatCount_(1), iterations_(0), curVector_(0), + activeOutputCount_(0), hasCategoryOut_(false), hasResetOut_(false), + dataOut_(NTA_BasicType_Real32), categoryOut_(NTA_BasicType_Real32), + resetOut_(NTA_BasicType_Real32), filename_(""), scalingMode_("none"), + recentFile_("") { deserialize(bundle); } -VectorFileSensor::VectorFileSensor( - capnp::AnyPointer::Reader& proto, Region* region) : - RegionImpl(region), - repeatCount_(1), - iterations_(0), - curVector_(0), - activeOutputCount_(0), - hasCategoryOut_(false), - hasResetOut_(false), - dataOut_(NTA_BasicType_Real32), - categoryOut_(NTA_BasicType_Real32), - resetOut_(NTA_BasicType_Real32), - filename_(""), - scalingMode_("none"), - recentFile_("") -{ +VectorFileSensor::VectorFileSensor(capnp::AnyPointer::Reader &proto, + Region *region) + : RegionImpl(region), repeatCount_(1), iterations_(0), curVector_(0), + activeOutputCount_(0), hasCategoryOut_(false), hasResetOut_(false), + dataOut_(NTA_BasicType_Real32), categoryOut_(NTA_BasicType_Real32), + resetOut_(NTA_BasicType_Real32), filename_(""), scalingMode_("none"), + recentFile_("") { read(proto); } -void VectorFileSensor::initialize() -{ +void VectorFileSensor::initialize() { NTA_CHECK(region_ != nullptr); dataOut_ = region_->getOutputData("dataOut"); categoryOut_ = region_->getOutputData("categoryOut"); resetOut_ = region_->getOutputData("resetOut"); - if (dataOut_.getCount() != activeOutputCount_) - { - NTA_THROW - << "VectorFileSensor::init - wrong output size: " << dataOut_.getCount() - << " should be: " << activeOutputCount_; + if (dataOut_.getCount() != activeOutputCount_) { + NTA_THROW << "VectorFileSensor::init - wrong output size: " + << dataOut_.getCount() << " should be: " << activeOutputCount_; } } -VectorFileSensor::~VectorFileSensor() -{ -} +VectorFileSensor::~VectorFileSensor() {} //---------------------------------------------------------------------------- -void VectorFileSensor::compute() -{ - // It's not necessarily an error to have no outputs. In this case we just return +void VectorFileSensor::compute() { + // It's not necessarily an error to have no outputs. In this case we just + // return if (dataOut_.getCount() == 0) return; // Don't write if there is no open file. - if (recentFile_ == "") - { + if (recentFile_ == "") { NTA_WARN << "VectorFileSesnsor compute() called, but there is no open file"; return; } NTA_CHECK(vectorFile_.vectorCount() > 0) - << "VectorFileSensor::compute - no data vectors in memory." - << "Perhaps no data file has been loaded using the 'loadFile'" - << " execute command."; + << "VectorFileSensor::compute - no data vectors in memory." + << "Perhaps no data file has been loaded using the 'loadFile'" + << " execute command."; if (iterations_ % repeatCount_ == 0) { // Get index to next vector and copy scaled vector to our output @@ -160,21 +127,19 @@ void VectorFileSensor::compute() curVector_ %= vectorFile_.vectorCount(); } - Real *out = (Real *) dataOut_.getBuffer(); + Real *out = (Real *)dataOut_.getBuffer(); Size count = dataOut_.getCount(); UInt offset = 0; - if (hasCategoryOut_) - { - Real * categoryOut = reinterpret_cast(categoryOut_.getBuffer()); + if (hasCategoryOut_) { + Real *categoryOut = reinterpret_cast(categoryOut_.getBuffer()); vectorFile_.getRawVector((nupic::UInt)curVector_, categoryOut, offset, 1); offset++; } - if (hasResetOut_) - { - Real * resetOut = reinterpret_cast(resetOut_.getBuffer()); + if (hasResetOut_) { + Real *resetOut = reinterpret_cast(resetOut_.getBuffer()); vectorFile_.getRawVector((nupic::UInt)curVector_, resetOut, offset, 1); offset++; } @@ -185,11 +150,10 @@ void VectorFileSensor::compute() //-------------------------------------------------------------------------------- inline const char *checkExtensions(const std::string &filename, - const char *const *extensions) -{ - while(*extensions) { + const char *const *extensions) { + while (*extensions) { const char *ext = *extensions; - if(filename.rfind(ext) == (filename.size() - ::strlen(ext))) + if (filename.rfind(ext) == (filename.size() - ::strlen(ext))) return ext; ++extensions; @@ -199,7 +163,9 @@ inline const char *checkExtensions(const std::string &filename, //-------------------------------------------------------------------------------- /// Execute a VectorFilesensor specific command -std::string VectorFileSensor::executeCommand(const std::vector& args, Int64 index) +std::string +VectorFileSensor::executeCommand(const std::vector &args, + Int64 index) { UInt32 argCount = args.size(); @@ -208,26 +174,22 @@ std::string VectorFileSensor::executeCommand(const std::vector& arg string command = args[0]; // Process each command - if ((command == "loadFile") || (command == "appendFile")) - { - NTA_CHECK(argCount > 1) << "VectorFileSensor: no filename specified for " << command; + if ((command == "loadFile") || (command == "appendFile")) { + NTA_CHECK(argCount > 1) + << "VectorFileSensor: no filename specified for " << command; - UInt32 labeled = 2; // Default format is 2 + UInt32 labeled = 2; // Default format is 2 // string filename = ReadStringFromBuffer(*buf2); string filename(args[1]); cout << "In VectorFileSensor " << filename << endl; - if (argCount == 3) - { + if (argCount == 3) { labeled = StringUtils::toUInt32(args[2]); - } - else - { + } else { // Check for some common extensions. - const char *csvExtensions[] = { ".csv", ".CSV", nullptr }; - if(checkExtensions(filename, csvExtensions)) - { + const char *csvExtensions[] = {".csv", ".CSV", nullptr}; + if (checkExtensions(filename, csvExtensions)) { cout << "Reading CSV file" << endl; labeled = 3; // CSV format. } @@ -240,26 +202,26 @@ std::string VectorFileSensor::executeCommand(const std::vector& arg labeled = 4; } - if (labeled > (UInt32) VectorFile::maxFormat()) + if (labeled > (UInt32)VectorFile::maxFormat()) NTA_THROW << "VectorFileSensor: unknown file format '" << labeled << "'"; // Read in new set of vectors - // If the command is loadFile, we clear the list first and reset the position - // to the beginning + // If the command is loadFile, we clear the list first and reset the + // position to the beginning if (command == "loadFile") vectorFile_.clear(false); - //Timer t(true); + // Timer t(true); UInt32 elementCount = activeOutputCount_; if (hasCategoryOut_) - elementCount ++; + elementCount++; if (hasResetOut_) - elementCount ++; + elementCount++; vectorFile_.appendFile(filename, elementCount, labeled); cout << "Read " << vectorFile_.vectorCount() << " vectors" << endl; - //in " << t.getValue() << " seconds" << endl; + // in " << t.getValue() << " seconds" << endl; if (command == "loadFile") seek(0); @@ -267,47 +229,44 @@ std::string VectorFileSensor::executeCommand(const std::vector& arg recentFile_ = filename; } - else if (command == "dump") - { + else if (command == "dump") { nupic::Byte message[256]; Size n = ::sprintf(message, - "VectorFileSensor isLabeled = %d repeatCount = %d vectorCount = %d iterations = %d\n", - vectorFile_.isLabeled(), (int) repeatCount_, (int) vectorFile_.vectorCount(), (int) iterations_); - //out.write(message, n); + "VectorFileSensor isLabeled = %d repeatCount = %d " + "vectorCount = %d iterations = %d\n", + vectorFile_.isLabeled(), (int)repeatCount_, + (int)vectorFile_.vectorCount(), (int)iterations_); + // out.write(message, n); return string(message, n); } - else if (command == "saveFile") - { - NTA_CHECK(argCount > 1) << "VectorFileSensor: no filename specified for " << command; + else if (command == "saveFile") { + NTA_CHECK(argCount > 1) + << "VectorFileSensor: no filename specified for " << command; - Int32 format = 2; // Default format is 2 + Int32 format = 2; // Default format is 2 Int64 begin = 0, end = 0; bool hasEnd = false; string filename(args[1]); - if (argCount > 2) - { + if (argCount > 2) { format = StringUtils::toUInt32(args[2]); if ((format < 0) || (format > VectorFile::maxFormat())) NTA_THROW << "VectorFileSensor: unknown file format '" << format << "'"; } - if (argCount > 3) - { + if (argCount > 3) { begin = StringUtils::toUInt32(args[3]); } - if (argCount > 4) - { + if (argCount > 4) { end = StringUtils::toUInt32(args[4]); hasEnd = true; } NTA_CHECK(argCount <= 5) << "VectorFileSensor: too many arguments"; - OFStream f(filename.c_str()); if (hasEnd) vectorFile_.saveVectors(f, dataOut_.getCount(), format, begin, end); @@ -315,9 +274,9 @@ std::string VectorFileSensor::executeCommand(const std::vector& arg vectorFile_.saveVectors(f, dataOut_.getCount(), format, begin, end); } - else - { - NTA_THROW << "VectorFileSensor: Unknown execute command: '" << command << "' sent!"; + else { + NTA_THROW << "VectorFileSensor: Unknown execute command: '" << command + << "' sent!"; } return ""; @@ -325,48 +284,41 @@ std::string VectorFileSensor::executeCommand(const std::vector& arg //-------------------------------------------------------------------------------- -void VectorFileSensor::setParameterFromBuffer(const std::string& name, - Int64 index, - IReadBuffer& value) -{ - const char* where = "VectorFileSensor, while setting parameter: "; +void VectorFileSensor::setParameterFromBuffer(const std::string &name, + Int64 index, IReadBuffer &value) { + const char *where = "VectorFileSensor, while setting parameter: "; UInt32 int_param = 0; - if (name == "repeatCount") - { + if (name == "repeatCount") { NTA_CHECK(value.read(int_param) == 0) - << where << "Unable to read repeatCount: " - << int_param << " - Should be a positive integer"; + << where << "Unable to read repeatCount: " << int_param + << " - Should be a positive integer"; - if (int_param >= 1) - { + if (int_param >= 1) { repeatCount_ = int_param; } } - else if (name == "position") - { + else if (name == "position") { NTA_CHECK(value.read(int_param) == 0) - << where << "Unable to read position: " - << int_param << " - Should be a positive integer"; - if ( int_param < vectorFile_.vectorCount() ) - { + << where << "Unable to read position: " << int_param + << " - Should be a positive integer"; + if (int_param < vectorFile_.vectorCount()) { seek(int_param); - } - else - { + } else { NTA_THROW << "VectorFileSensor: invalid position " - << " to seek to: " << int_param; + << " to seek to: " << int_param; } } - else if (name == "scalingMode") - { + else if (name == "scalingMode") { // string mode = ReadStringFromvaluefer(value); string mode(value.getData(), value.getSize()); - if (mode == "none") vectorFile_.resetScaling(); - else if (mode == "standardForm") vectorFile_.setStandardScaling(); + if (mode == "none") + vectorFile_.resetScaling(); + else if (mode == "standardForm") + vectorFile_.setStandardScaling(); else if (mode != "custom") // Do nothing if set to custom NTA_THROW << where << " Unknown scaling mode: " << mode; scalingMode_ = mode; @@ -374,33 +326,30 @@ void VectorFileSensor::setParameterFromBuffer(const std::string& name, else if (name == "hasCategoryOut") { NTA_CHECK(value.read(int_param) == 0) - << where << "Unable to read hasCategoryOut: " - << int_param << " - Should be a positive integer"; + << where << "Unable to read hasCategoryOut: " << int_param + << " - Should be a positive integer"; hasCategoryOut_ = int_param == 1; } else if (name == "hasResetOut") { NTA_CHECK(value.read(int_param) == 0) - << where << "Unable to read hasResetOut: " - << int_param << " - Should be a positive integer"; + << where << "Unable to read hasResetOut: " << int_param + << " - Should be a positive integer"; hasResetOut_ = int_param == 1; } - else - { + else { NTA_THROW << where << "couldn't set '" << name << "'"; } - } //-------------------------------------------------------------------------------- -void VectorFileSensor::getParameterFromBuffer(const std::string& name, +void VectorFileSensor::getParameterFromBuffer(const std::string &name, Int64 index, - IWriteBuffer& value) -{ - const char* where = "VectorFileSensor, while getting parameter: "; + IWriteBuffer &value) { + const char *where = "VectorFileSensor, while getting parameter: "; Int32 res = 0; @@ -409,7 +358,7 @@ void VectorFileSensor::getParameterFromBuffer(const std::string& name, } else if (name == "position") { - res = value.write(UInt32(curVector_+1)); + res = value.write(UInt32(curVector_ + 1)); } else if (name == "repeatCount") { @@ -423,12 +372,9 @@ void VectorFileSensor::getParameterFromBuffer(const std::string& name, else if (name == "recentFile") { // res = value.writeString(recentFile_.data(), (Size)recentFile_.size()); - if (recentFile_.empty()) - { + if (recentFile_.empty()) { res = value.write("", 1); - } - else - { + } else { res = value.write(recentFile_.data(), (Size)recentFile_.size()); } } @@ -436,8 +382,7 @@ void VectorFileSensor::getParameterFromBuffer(const std::string& name, else if (name == "scaleVector") { stringstream buf; Real s = 0, o = 0; - for (UInt i = 0; i < vectorFile_.getElementCount(); i++) - { + for (UInt i = 0; i < vectorFile_.getElementCount(); i++) { vectorFile_.getScaling(i, s, o); buf << s << " "; } @@ -456,8 +401,7 @@ void VectorFileSensor::getParameterFromBuffer(const std::string& name, else if (name == "offsetVector") { stringstream buf; Real s = 0, o = 0; - for (UInt i = 0; i < vectorFile_.getElementCount(); i++) - { + for (UInt i = 0; i < vectorFile_.getElementCount(); i++) { vectorFile_.getScaling(i, s, o); buf << o << " "; } @@ -477,45 +421,37 @@ void VectorFileSensor::getParameterFromBuffer(const std::string& name, } //---------------------------------------------------------------------- -void VectorFileSensor::seek(int n) -{ - NTA_CHECK( (n >= 0) && ((unsigned int) n < vectorFile_.vectorCount()) ); +void VectorFileSensor::seek(int n) { + NTA_CHECK((n >= 0) && ((unsigned int)n < vectorFile_.vectorCount())); // Set curVector_ to be one before the vector we want and reset iterations iterations_ = 0; curVector_ = n - 1; - //circular-buffer, reached one end of vector/line, continue fro the other - if (n - 1 <= 0) curVector_ = (NTA_Size)vectorFile_.vectorCount() - 1; + // circular-buffer, reached one end of vector/line, continue fro the other + if (n - 1 <= 0) + curVector_ = (NTA_Size)vectorFile_.vectorCount() - 1; } -size_t VectorFileSensor::getNodeOutputElementCount(const std::string& outputName) -{ +size_t +VectorFileSensor::getNodeOutputElementCount(const std::string &outputName) { NTA_CHECK(outputName == "dataOut") << "Invalid output name: " << outputName; return activeOutputCount_; } -void VectorFileSensor::serialize(BundleIO& bundle) -{ - std::ofstream & f = bundle.getOutputStream("vfs"); - f << repeatCount_ << " " - << activeOutputCount_ << " " - << filename_ << " " +void VectorFileSensor::serialize(BundleIO &bundle) { + std::ofstream &f = bundle.getOutputStream("vfs"); + f << repeatCount_ << " " << activeOutputCount_ << " " << filename_ << " " << scalingMode_ << " "; f.close(); } -void VectorFileSensor::deserialize(BundleIO& bundle) -{ - std::ifstream& f = bundle.getInputStream("vfs"); - f >> repeatCount_ - >> activeOutputCount_ - >> filename_ - >> scalingMode_; +void VectorFileSensor::deserialize(BundleIO &bundle) { + std::ifstream &f = bundle.getInputStream("vfs"); + f >> repeatCount_ >> activeOutputCount_ >> filename_ >> scalingMode_; f.close(); } -void VectorFileSensor::write(capnp::AnyPointer::Builder& anyProto) const -{ +void VectorFileSensor::write(capnp::AnyPointer::Builder &anyProto) const { auto proto = anyProto.getAs(); proto.setRepeatCount(repeatCount_); proto.setActiveOutputCount(activeOutputCount_); @@ -523,8 +459,7 @@ void VectorFileSensor::write(capnp::AnyPointer::Builder& anyProto) const proto.setScalingMode(scalingMode_.c_str()); } -void VectorFileSensor::read(capnp::AnyPointer::Reader& anyProto) -{ +void VectorFileSensor::read(capnp::AnyPointer::Reader &anyProto) { auto proto = anyProto.getAs(); repeatCount_ = proto.getRepeatCount(); activeOutputCount_ = proto.getActiveOutputCount(); @@ -532,310 +467,288 @@ void VectorFileSensor::read(capnp::AnyPointer::Reader& anyProto) scalingMode_ = proto.getScalingMode().cStr(); } -Spec* VectorFileSensor::createSpec() -{ +Spec *VectorFileSensor::createSpec() { auto ns = new Spec; ns->description = - "VectorFileSensor is a basic sensor for reading files containing vectors.\n" - "\n" - "VectorFileSensor reads in a text file containing lists of numbers\n" - "and outputs these vectors in sequence. The output is updated\n" - "each time the sensor's compute() method is called. If\n" - "repeatCount is > 1, then each vector is repeated that many times\n" - "before moving to the next one. The sensor loops when the end of\n" - "the vector list is reached. The default file format\n" - "is as follows (assuming the sensor is configured with N outputs):\n" - "\n" - " e11 e12 e13 ... e1N\n" - " e21 e22 e23 ... e2N\n" - " : \n" - " eM1 eM2 eM3 ... eMN\n" - "\n" - "In this format the sensor ignores all whitespace in the file, including newlines\n" - "If the file contains an incorrect number of floats, the sensor has no way\n" - "of checking and will silently ignore the extra numbers at the end of the file.\n" - "\n" - "The sensor can also read in comma-separated (CSV) files following the format:\n" - "\n" - " e11, e12, e13, ... ,e1N\n" - " e21, e22, e23, ... ,e2N\n" - " : \n" - " eM1, eM2, eM3, ... ,eMN\n" - "\n" - "When reading CSV files the sensor expects that each line contains a new vector\n" - "Any line containing too few elements or any text will be ignored. If there are\n" - "more than N numbers on a line, the sensor retains only the first N.\n"; - - ns->outputs.add( - "dataOut", - OutputSpec("Data read from file", - NTA_BasicType_Real32, - 0, // count - true, // isRegionLevel - true // isDefaultOutput - )); - - - ns->outputs.add( - "categoryOut", - OutputSpec("The current category encoded as a float (represent a whole number)", - NTA_BasicType_Real32, - 1, // count - true, // isRegionLevel - false // isDefaultOutput - )); + "VectorFileSensor is a basic sensor for reading files containing " + "vectors.\n" + "\n" + "VectorFileSensor reads in a text file containing lists of numbers\n" + "and outputs these vectors in sequence. The output is updated\n" + "each time the sensor's compute() method is called. If\n" + "repeatCount is > 1, then each vector is repeated that many times\n" + "before moving to the next one. The sensor loops when the end of\n" + "the vector list is reached. The default file format\n" + "is as follows (assuming the sensor is configured with N outputs):\n" + "\n" + " e11 e12 e13 ... e1N\n" + " e21 e22 e23 ... e2N\n" + " : \n" + " eM1 eM2 eM3 ... eMN\n" + "\n" + "In this format the sensor ignores all whitespace in the file, including " + "newlines\n" + "If the file contains an incorrect number of floats, the sensor has no " + "way\n" + "of checking and will silently ignore the extra numbers at the end of " + "the file.\n" + "\n" + "The sensor can also read in comma-separated (CSV) files following the " + "format:\n" + "\n" + " e11, e12, e13, ... ,e1N\n" + " e21, e22, e23, ... ,e2N\n" + " : \n" + " eM1, eM2, eM3, ... ,eMN\n" + "\n" + "When reading CSV files the sensor expects that each line contains a new " + "vector\n" + "Any line containing too few elements or any text will be ignored. If " + "there are\n" + "more than N numbers on a line, the sensor retains only the first N.\n"; + + ns->outputs.add("dataOut", + OutputSpec("Data read from file", NTA_BasicType_Real32, + 0, // count + true, // isRegionLevel + true // isDefaultOutput + )); ns->outputs.add( - "resetOut", - OutputSpec("Sequence reset signal: 0 - do nothing, otherwise start a new sequence", - NTA_BasicType_Real32, - 1, // count - true, // isRegionLevel - false // isDefaultOutput - )); - - ns->parameters.add( - "vectorCount", - ParameterSpec( - "The number of vectors currently loaded in memory.", - NTA_BasicType_UInt32, - 1, // elementCount - "interval: [0, ...]", // constraints - "0", // defaultValue - ParameterSpec::ReadOnlyAccess - )); - - ns->parameters.add( - "position", - ParameterSpec( - "Set or get the current position within the list of vectors in memory.", - NTA_BasicType_UInt32, - 1, // elementCount - "interval: [0, ...]", // constraints - "0", // defaultValue - ParameterSpec::ReadWriteAccess - )); - - ns->parameters.add( - "repeatCount", - ParameterSpec( - "Set or get the current repeatCount. Each vector is repeated\n" - "repeatCount times before moving to the next one.", - NTA_BasicType_UInt32, - 1, // elementCount - "interval: [1, ...]", // constraints - "1", // defaultValue - ParameterSpec::ReadWriteAccess - )); - - ns->parameters.add( - "recentFile", - ParameterSpec( - "Writes output vectors to this file on each compute. Will append to any\n" - "existing data in the file. This parameter must be set at runtime before\n" - "the first compute is called. Throws an exception if it is not set or\n" - "the file cannot be written to.\n", - NTA_BasicType_Byte, - 0, // elementCount - "", // constraints - "", // defaultValue - ParameterSpec::ReadOnlyAccess - )); + "categoryOut", + OutputSpec( + "The current category encoded as a float (represent a whole number)", + NTA_BasicType_Real32, + 1, // count + true, // isRegionLevel + false // isDefaultOutput + )); + + ns->outputs.add("resetOut", + OutputSpec("Sequence reset signal: 0 - do nothing, otherwise " + "start a new sequence", + NTA_BasicType_Real32, + 1, // count + true, // isRegionLevel + false // isDefaultOutput + )); ns->parameters.add( - "scalingMode", - ParameterSpec( - "During compute, each vector is adjusted as follows. If X is the data vector,\n" - "S the scaling vector and O the offset vector, then the node's output\n" - " Y[i] = S[i]*(X[i] + O[i]).\n" - "\n" - "Scaling is applied according to scalingMode as follows:\n" - "\n" - " If 'none', the vectors are unchanged, i.e. S[i]=1 and O[i]=0.\n" - " If 'standardForm', S[i] is 1/standard deviation(i) and O[i] = - mean(i)\n" - " If 'custom', each component is adjusted according to the vectors specified by the\n" - "setScale and setOffset commands.\n", - NTA_BasicType_Byte, - 0, // elementCount - "", // constraints - "none", // defaultValue - ParameterSpec::ReadWriteAccess - )); + "vectorCount", + ParameterSpec("The number of vectors currently loaded in memory.", + NTA_BasicType_UInt32, + 1, // elementCount + "interval: [0, ...]", // constraints + "0", // defaultValue + ParameterSpec::ReadOnlyAccess)); + + ns->parameters.add("position", + ParameterSpec("Set or get the current position within the " + "list of vectors in memory.", + NTA_BasicType_UInt32, + 1, // elementCount + "interval: [0, ...]", // constraints + "0", // defaultValue + ParameterSpec::ReadWriteAccess)); ns->parameters.add( - "scaleVector", - ParameterSpec( - "Set or return the current scale vector S.\n", - NTA_BasicType_Real32, - 0, // elementCount - "", // constraints - "", // defaultValue - ParameterSpec::ReadWriteAccess - )); + "repeatCount", + ParameterSpec( + "Set or get the current repeatCount. Each vector is repeated\n" + "repeatCount times before moving to the next one.", + NTA_BasicType_UInt32, + 1, // elementCount + "interval: [1, ...]", // constraints + "1", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("recentFile", + ParameterSpec("Writes output vectors to this file on each " + "compute. Will append to any\n" + "existing data in the file. This parameter " + "must be set at runtime before\n" + "the first compute is called. Throws an " + "exception if it is not set or\n" + "the file cannot be written to.\n", + NTA_BasicType_Byte, + 0, // elementCount + "", // constraints + "", // defaultValue + ParameterSpec::ReadOnlyAccess)); ns->parameters.add( - "offsetVector", - ParameterSpec( - "Set or return the current offset vector 0.\n", - NTA_BasicType_Real32, - 0, // elementCount - "", // constraints - "", // defaultValue - ParameterSpec::ReadWriteAccess - )); + "scalingMode", + ParameterSpec( + "During compute, each vector is adjusted as follows. If X is the " + "data vector,\n" + "S the scaling vector and O the offset vector, then the node's " + "output\n" + " Y[i] = S[i]*(X[i] + O[i]).\n" + "\n" + "Scaling is applied according to scalingMode as follows:\n" + "\n" + " If 'none', the vectors are unchanged, i.e. S[i]=1 and O[i]=0.\n" + " If 'standardForm', S[i] is 1/standard deviation(i) and O[i] = - " + "mean(i)\n" + " If 'custom', each component is adjusted according to the " + "vectors specified by the\n" + "setScale and setOffset commands.\n", + NTA_BasicType_Byte, + 0, // elementCount + "", // constraints + "none", // defaultValue + ParameterSpec::ReadWriteAccess)); ns->parameters.add( - "activeOutputCount", - ParameterSpec( - "The number of active outputs of the node.", - NTA_BasicType_UInt32, - 1, // elementCount - "interval: [0, ...]", // constraints - "", // default Value - ParameterSpec::CreateAccess - )); + "scaleVector", + ParameterSpec("Set or return the current scale vector S.\n", + NTA_BasicType_Real32, + 0, // elementCount + "", // constraints + "", // defaultValue + ParameterSpec::ReadWriteAccess)); ns->parameters.add( - "maxOutputVectorCount", - ParameterSpec( - "The number of output vectors that can be generated by this sensor\n" - "under the current configuration.", - NTA_BasicType_UInt32, - 1, // elementCount - "interval: [0, ...]", // constraints - "0", // defaultValue - ParameterSpec::ReadOnlyAccess - )); + "offsetVector", + ParameterSpec("Set or return the current offset vector 0.\n", + NTA_BasicType_Real32, + 0, // elementCount + "", // constraints + "", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ns->parameters.add("activeOutputCount", + ParameterSpec("The number of active outputs of the node.", + NTA_BasicType_UInt32, + 1, // elementCount + "interval: [0, ...]", // constraints + "", // default Value + ParameterSpec::CreateAccess)); ns->parameters.add( - "hasCategoryOut", - ParameterSpec( - "Category info is present in data file.", - NTA_BasicType_UInt32, - 1, // elementCount - "enum: [0, 1]", // constraints - "0", // defaultValue - ParameterSpec::ReadWriteAccess - )); + "maxOutputVectorCount", + ParameterSpec( + "The number of output vectors that can be generated by this sensor\n" + "under the current configuration.", + NTA_BasicType_UInt32, + 1, // elementCount + "interval: [0, ...]", // constraints + "0", // defaultValue + ParameterSpec::ReadOnlyAccess)); + + ns->parameters.add("hasCategoryOut", + ParameterSpec("Category info is present in data file.", + NTA_BasicType_UInt32, + 1, // elementCount + "enum: [0, 1]", // constraints + "0", // defaultValue + ParameterSpec::ReadWriteAccess)); ns->parameters.add( - "hasResetOut", - ParameterSpec( - "New sequence reset signal is present in data file.", - NTA_BasicType_UInt32, - 1, // elementCount - "enum: [0, 1]", // constraints - "0", // defaultValue - ParameterSpec::ReadWriteAccess - )); + "hasResetOut", + ParameterSpec("New sequence reset signal is present in data file.", + NTA_BasicType_UInt32, + 1, // elementCount + "enum: [0, 1]", // constraints + "0", // defaultValue + ParameterSpec::ReadWriteAccess)); ns->commands.add( - "loadFile", - CommandSpec( - "loadFile [file_format]\n" - "Reads vectors from the specified file, replacing any vectors\n" - "currently in the list. Position is set to zero. \n" - "Available file formats are: \n" - " 0 # Reads in unlabeled file with first number = element count\n" - " 1 # Reads in a labeled file with first number = element count (deprecated)\n" - " 2 # Reads in unlabeled file without element count (default)\n" - " 3 # Reads in a csv file\n" - )); + "loadFile", + CommandSpec( + "loadFile [file_format]\n" + "Reads vectors from the specified file, replacing any vectors\n" + "currently in the list. Position is set to zero. \n" + "Available file formats are: \n" + " 0 # Reads in unlabeled file with first number = " + "element count\n" + " 1 # Reads in a labeled file with first number = " + "element count (deprecated)\n" + " 2 # Reads in unlabeled file without element count " + "(default)\n" + " 3 # Reads in a csv file\n")); ns->commands.add( - "appendFile", - CommandSpec( - "appendFile [file_format]\n" - "Reads vectors from the specified file, appending to current vector list.\n" - "Position remains unchanged. Available file formats are: \n" - " 0 # Reads in unlabeled file with first number = element count\n" - " 1 # Reads in a labeled file with first number = element count (deprecated)\n" - " 2 # Reads in unlabeled file without element count (default)\n" - " 3 # Reads in a csv file\n")); + "appendFile", + CommandSpec("appendFile [file_format]\n" + "Reads vectors from the specified file, appending to current " + "vector list.\n" + "Position remains unchanged. Available file formats are: \n" + " 0 # Reads in unlabeled file with first number " + "= element count\n" + " 1 # Reads in a labeled file with first number " + "= element count (deprecated)\n" + " 2 # Reads in unlabeled file without element " + "count (default)\n" + " 3 # Reads in a csv file\n")); ns->commands.add( - "saveFile", - CommandSpec( - "saveFile filename [format [begin [end]]]\n" - "Save the currently loaded vectors to a file. Typically used for debugging\n" - "but may be used to convert between formats.\n")); + "saveFile", CommandSpec("saveFile filename [format [begin [end]]]\n" + "Save the currently loaded vectors to a file. " + "Typically used for debugging\n" + "but may be used to convert between formats.\n")); ns->commands.add("dump", CommandSpec("Displays some debugging info.")); - - - - return ns; } - -void VectorFileSensor::getParameterArray(const std::string& name, Int64 index, Array & a) -{ +void VectorFileSensor::getParameterArray(const std::string &name, Int64 index, + Array &a) { if (a.getCount() != dataOut_.getCount()) - NTA_THROW << "getParameterArray(), array size is: " << a.getCount() << "instead of : " << dataOut_.getCount(); + NTA_THROW << "getParameterArray(), array size is: " << a.getCount() + << "instead of : " << dataOut_.getCount(); - Real * buf = (Real *)a.getBuffer(); + Real *buf = (Real *)a.getBuffer(); Real dummy; - if (name == "scaleVector") - { - Real * buf = (Real *)a.getBuffer(); - for (UInt i = 0; i < vectorFile_.getElementCount(); i++) - { + if (name == "scaleVector") { + Real *buf = (Real *)a.getBuffer(); + for (UInt i = 0; i < vectorFile_.getElementCount(); i++) { vectorFile_.getScaling(i, buf[i], dummy); } - } - else if (name == "offsetVector") - { + } else if (name == "offsetVector") { - for (UInt i = 0; i < vectorFile_.getElementCount(); i++) - { + for (UInt i = 0; i < vectorFile_.getElementCount(); i++) { vectorFile_.getScaling(i, dummy, buf[i]); } - } - else - { - NTA_THROW << "VectorfileSensor::getParameterArray(), unknown parameter: " << name; + } else { + NTA_THROW << "VectorfileSensor::getParameterArray(), unknown parameter: " + << name; } } -void VectorFileSensor::setParameterArray(const std::string& name, Int64 index, const Array & a) -{ +void VectorFileSensor::setParameterArray(const std::string &name, Int64 index, + const Array &a) { if (a.getCount() != dataOut_.getCount()) - NTA_THROW << "setParameterArray(), array size is: " << a.getCount() << "instead of : " << dataOut_.getCount(); + NTA_THROW << "setParameterArray(), array size is: " << a.getCount() + << "instead of : " << dataOut_.getCount(); - Real * buf = (Real *)a.getBuffer(); - if (name == "scaleVector") - { - for (UInt i = 0; i < vectorFile_.getElementCount(); i++) - { + Real *buf = (Real *)a.getBuffer(); + if (name == "scaleVector") { + for (UInt i = 0; i < vectorFile_.getElementCount(); i++) { vectorFile_.setScale(i, buf[i]); } - } - else if (name == "offsetVector") - { + } else if (name == "offsetVector") { - for (UInt i = 0; i < vectorFile_.getElementCount(); i++) - { + for (UInt i = 0; i < vectorFile_.getElementCount(); i++) { vectorFile_.setOffset(i, buf[i]); } - } - else - { - NTA_THROW << "VectorfileSensor::setParameterArray(), unknown parameter: " << name; + } else { + NTA_THROW << "VectorfileSensor::setParameterArray(), unknown parameter: " + << name; } scalingMode_ = "custom"; } -size_t VectorFileSensor::getParameterArrayCount(const std::string& name, Int64 index) -{ +size_t VectorFileSensor::getParameterArrayCount(const std::string &name, + Int64 index) { if (name != "scaleVector" && name != "offsetVector") - NTA_THROW << "VectorFileSensor::getParameterArrayCount(), unknown array parameter: " << name; + NTA_THROW << "VectorFileSensor::getParameterArrayCount(), unknown array " + "parameter: " + << name; return dataOut_.getCount(); } - } // end namespace nupic - diff --git a/src/nupic/regions/VectorFileSensor.hpp b/src/nupic/regions/VectorFileSensor.hpp index 5b27f6463d..4ee8ef4283 100644 --- a/src/nupic/regions/VectorFileSensor.hpp +++ b/src/nupic/regions/VectorFileSensor.hpp @@ -35,311 +35,322 @@ #include -#include -#include #include #include #include +#include #include +#include + +namespace nupic { +class ValueMap; + +/** + * VectorFileSensor is a sensor that reads in files containing lists of + * vectors and outputs these vectors in sequence. + * + * @b Description + * + * Three input file formats are supported: + * 0 - unlabeled files with element count + * 1 - labeled files with element count + * 2 - unlabeled files without element count (default) + * + * These input formats are described in more detail below. + * + * The Sensor implements the execute() commands as specified in the nodeSpec. + * + * @b Notes: + * The file format for an unlabeled file without element count is as follows: + * \verbatim + e11 e12 e13 ... e1N + e21 e22 e23 ... e2N + : + eM1 eM2 eM3 ... eMN + \endverbatim + * + * The file format for an unlabeled file with element count is as follows: + * \verbatim + N + e11 e12 e13 ... e1N + e21 e22 e23 ... e2N + : + eM1 eM2 eM3 ... eMN + \endverbatim + * + * The format for a labeled file with element count is as follows: + * \verbatim + N + EL1 EL2 EL3 ELN + VL1 e11 e12 e13 ... e1N + VL2 e21 e22 e23 ... e2N + : + VLM eM1 eM2 eM3 ... eMN + \endverbatim + * + * where ELi are string labels for each element in the vector and VLi are + * string labels for each vector. Strings are separated by whitespace. Strings + * with whitespace are not supported (e.g. no quoting of strings). + * + * Whitespace between numbers is ignored. + * The full list of vectors is read into memory when the loadFile command + * is executed. + * + */ -namespace nupic -{ - class ValueMap; - - - /** - * VectorFileSensor is a sensor that reads in files containing lists of - * vectors and outputs these vectors in sequence. - * - * @b Description - * - * Three input file formats are supported: - * 0 - unlabeled files with element count - * 1 - labeled files with element count - * 2 - unlabeled files without element count (default) - * - * These input formats are described in more detail below. - * - * The Sensor implements the execute() commands as specified in the nodeSpec. - * - * @b Notes: - * The file format for an unlabeled file without element count is as follows: - * \verbatim - e11 e12 e13 ... e1N - e21 e22 e23 ... e2N - : - eM1 eM2 eM3 ... eMN - \endverbatim - * - * The file format for an unlabeled file with element count is as follows: - * \verbatim - N - e11 e12 e13 ... e1N - e21 e22 e23 ... e2N - : - eM1 eM2 eM3 ... eMN - \endverbatim - * - * The format for a labeled file with element count is as follows: - * \verbatim - N - EL1 EL2 EL3 ELN - VL1 e11 e12 e13 ... e1N - VL2 e21 e22 e23 ... e2N - : - VLM eM1 eM2 eM3 ... eMN - \endverbatim - * - * where ELi are string labels for each element in the vector and VLi are - * string labels for each vector. Strings are separated by whitespace. Strings - * with whitespace are not supported (e.g. no quoting of strings). - * - * Whitespace between numbers is ignored. - * The full list of vectors is read into memory when the loadFile command - * is executed. - * - */ - - class VectorFileSensor : public RegionImpl - { - public: - - //------ Static methods for plug-in API ------------------------------------ - -// static const NTA_Spec * getSpec(const NTA_Byte * nodeType) -// { -// const char *description = -//"VectorFileSensor is a basic sensor for reading files containing vectors.\n" -//"\n" -//"VectorFileSensor reads in a text file containing lists of numbers\n" -//"and outputs these vectors in sequence. The output is updated\n" -//"each time the sensor's compute() method is called. If\n" -//"repeatCount is > 1, then each vector is repeated that many times\n" -//"before moving to the next one. The sensor loops when the end of\n" -//"the vector list is reached. The default file format\n" -//"is as follows (assuming the sensor is configured with N outputs):\n" -//"\n" -//" e11 e12 e13 ... e1N\n" -//" e21 e22 e23 ... e2N\n" -//" : \n" -//" eM1 eM2 eM3 ... eMN\n" -//"\n" -//"In this format the sensor ignores all whitespace in the file, including newlines\n" -//"If the file contains an incorrect number of floats, the sensor has no way\n" -//"of checking and will silently ignore the extra numbers at the end of the file.\n" -//"\n" -//"The sensor can also read in comma-separated (CSV) files following the format:\n" -//"\n" -//" e11, e12, e13, ... ,e1N\n" -//" e21, e22, e23, ... ,e2N\n" -//" : \n" -//" eM1, eM2, eM3, ... ,eMN\n" -//"\n" -//"When reading CSV files the sensor expects that each line contains a new vector\n" -//"Any line containing too few elements or any text will be ignored. If there are\n" -//"more than N numbers on a line, the sensor retains only the first N.\n" -//; -// -// -// nupic::SpecBuilder nsb("VectorFileSensor", description, 0 /* flags */); -// -// // ------ OUTPUTS -// nsb.addOutput("dataOut", "real", "This is VectorFileSensor's only output. " -// "It will be set to the next vector after each compute."); -// -// // ------ COMMANDS -// nsb.addCommand("loadFile", -// "loadFile [file_format]\n" -// "Reads vectors from the specified file, replacing any vectors\n" -// "currently in the list. Position is set to zero. \n" -// "Available file formats are: \n" -// " 0 # Reads in unlabeled file with first number = element count\n" -// " 1 # Reads in a labeled file with first number = element count (deprecated)\n" -// " 2 # Reads in unlabeled file without element count (default)\n" -// " 3 # Reads in a csv file\n"); -// -// nsb.addCommand("appendFile", -// "appendFile [file_format]\n" -// "Reads vectors from the specified file, appending to current vector list.\n" -// "Position remains unchanged. Available file formats are: \n" -// " 0 # Reads in unlabeled file with first number = element count\n" -// " 1 # Reads in a labeled file with first number = element count (deprecated)\n" -// " 2 # Reads in unlabeled file without element count (default)\n" -// " 3 # Reads in a csv file\n"); -// -// -// nsb.addCommand("dump", "Displays some debugging info."); -// -// nsb.addCommand("saveFile", -// "saveFile filename [format [begin [end]]]\n" -// "Save the currently loaded vectors to a file. Typically used for debugging\n" -// "but may be used to convert between formats.\n"); -// -// // ------ PARAMETERS -// -// nsb.addParameter("vectorCount", -// "uint32", -// "The number of vectors currently loaded in memory.", -// 1, /* elementCount */ -// "get", -// "interval: [0, ...]", -// "0" /* defaultValue */); -// -// nsb.addParameter("position", -// "uint32", -// "Set or get the current position within the list of vectors in memory.", -// 1, /* elementCount */ -// "getset", -// "interval: [0, ...]", -// "0" /* defaultValue */); -// -// nsb.addParameter("repeatCount", -// "uint32", -// "Set or get the current repeatCount. Each vector is repeated\n" -// "repeatCount times before moving to the next one.", -// 1, /* elementCount */ -// "getset", -// "interval: [1, ...]", -// "1" /* defaultValue */); -// -// nsb.addParameter("recentFile", -// "byteptr", -// "Name of the most recently file that is loaded or appended. Mostly to \n" -// "support interactive use.\n", -// 1, /* elementCount */ -// "get"); -// -// -// nsb.addParameter("scalingMode", -// "byteptr", -// "During compute, each vector is adjusted as follows. If X is the data vector,\n" -// "S the scaling vector and O the offset vector, then the node's output\n" -// " Y[i] = S[i]*(X[i] + O[i]).\n" -// "\n" -// "Scaling is applied according to scalingMode as follows:\n" -// "\n" -// " If 'none', the vectors are unchanged, i.e. S[i]=1 and O[i]=0.\n" -// " If 'standardForm', S[i] is 1/standard deviation(i) and O[i] = - mean(i)\n" -// " If 'custom', each component is adjusted according to the vectors specified by the\n" -// "setScale and setOffset commands.\n", -// 1, /* elementCount */ -// "all", /* access */ -// "", /* constraints */ -// "none" /* defaultValue */); -// -// nsb.addParameter("scaleVector", -// "real", -// "Set or return the current scale vector S.\n", -// 0, /* elementCount */ -// "all", /* access */ -// "", /* constraints */ -// "" /* defaultValue */); -// -// nsb.addParameter("offsetVector", -// "real", -// "Set or return the current offset vector 0.\n", -// 0, /* elementCount */ -// "all", /* access */ -// "", /* constraints */ -// "" /* defaultValue */); -// -// -// nsb.addParameter("activeOutputCount", -// "uint32", -// "The number of active outputs of the node.", -// 1, /* elementCount */ -// "get", /* access */ -// "interval: [0, ...]"); -// -// -// nsb.addParameter("maxOutputVectorCount", -// "uint32", -// "The number of output vectors that can be generated by this sensor\n" -// "under the current configuration.", -// 1, /* elementCount */ -// "get", -// "interval: [0, ...]"); -// -// -// -// return nsb.getSpec(); -// } - - static Spec* createSpec(); - size_t getNodeOutputElementCount(const std::string& outputName) override; - void getParameterFromBuffer(const std::string& name, Int64 index, IWriteBuffer& value) override; - - void setParameterFromBuffer(const std::string& name, Int64 index, IReadBuffer& value) override; - - size_t getParameterArrayCount(const std::string& name, Int64 index) override; - - virtual void getParameterArray(const std::string& name, Int64 index, Array & array) override; - virtual void setParameterArray(const std::string& name, Int64 index, const Array & array) override; - - //void setParameterString(const std::string& name, Int64 index, const std::string& s); - //std::string getParameterString(const std::string& name, Int64 index); - - void initialize() override; - - VectorFileSensor(const ValueMap & params, Region *region); - - VectorFileSensor(BundleIO& bundle, Region* region); - - VectorFileSensor(capnp::AnyPointer::Reader& proto, Region* region); - - - virtual ~VectorFileSensor(); - - - // --- - /// Serialize state to bundle - // --- - virtual void serialize(BundleIO& bundle) override; - - // --- - /// De-serialize state from bundle - // --- - virtual void deserialize(BundleIO& bundle) override; - - using RegionImpl::write; - virtual void write(capnp::AnyPointer::Builder& anyProto) const override; - - using RegionImpl::read; - virtual void read(capnp::AnyPointer::Reader& anyProto) override; - - void compute() override; - virtual std::string executeCommand(const std::vector& args, Int64 index) override; - - private: - void closeFile(); - void openFile(const std::string& filename); - - private: - NTA_UInt32 repeatCount_; // Repeat count for output vectors - NTA_UInt32 iterations_; // Number of times compute() has been called - NTA_UInt32 curVector_; // The index of the vector currently being output - NTA_UInt32 activeOutputCount_; // The number of elements in each input vector - bool hasCategoryOut_; // determine if a category output is needed - bool hasResetOut_; // determine if a reset output is needed - nupic::VectorFile vectorFile_; // Container class for the vectors - - ArrayRef dataOut_; - ArrayRef categoryOut_; - ArrayRef resetOut_; - std::string filename_; // Name of the output file - - std::string scalingMode_; - std::string recentFile_; // The most recently loaded or appended file - - //------------------- Utility routines and debugging support - - // Seek to the n'th vector in the list. n should be between 0 and - // numVectors-1. Logs a warning if n is outside those bounds. - void seek(int n); - - }; // end class VectorFileSensor +class VectorFileSensor : public RegionImpl { +public: + //------ Static methods for plug-in API ------------------------------------ + + // static const NTA_Spec * getSpec(const NTA_Byte * nodeType) + // { + // const char *description = + //"VectorFileSensor is a basic sensor for reading files containing vectors.\n" + //"\n" + //"VectorFileSensor reads in a text file containing lists of numbers\n" + //"and outputs these vectors in sequence. The output is updated\n" + //"each time the sensor's compute() method is called. If\n" + //"repeatCount is > 1, then each vector is repeated that many times\n" + //"before moving to the next one. The sensor loops when the end of\n" + //"the vector list is reached. The default file format\n" + //"is as follows (assuming the sensor is configured with N outputs):\n" + //"\n" + //" e11 e12 e13 ... e1N\n" + //" e21 e22 e23 ... e2N\n" + //" : \n" + //" eM1 eM2 eM3 ... eMN\n" + //"\n" + //"In this format the sensor ignores all whitespace in the file, including + // newlines\n" "If the file contains an incorrect number of floats, the sensor + // has no way\n" "of checking and will silently ignore the extra numbers at + // the end of the file.\n" + //"\n" + //"The sensor can also read in comma-separated (CSV) files following the + // format:\n" + //"\n" + //" e11, e12, e13, ... ,e1N\n" + //" e21, e22, e23, ... ,e2N\n" + //" : \n" + //" eM1, eM2, eM3, ... ,eMN\n" + //"\n" + //"When reading CSV files the sensor expects that each line contains a new + // vector\n" "Any line containing too few elements or any text will be + // ignored. If there are\n" "more than N numbers on a line, the sensor retains + // only the first N.\n" + //; + // + // + // nupic::SpecBuilder nsb("VectorFileSensor", description, 0 /* flags + // */); + // + // // ------ OUTPUTS + // nsb.addOutput("dataOut", "real", "This is VectorFileSensor's only + // output. " + // "It will be set to the next vector after each + // compute."); + // + // // ------ COMMANDS + // nsb.addCommand("loadFile", + // "loadFile [file_format]\n" + // "Reads vectors from the specified file, replacing any + // vectors\n" "currently in the list. Position is set to + // zero. \n" "Available file formats are: \n" " 0 # + // Reads in unlabeled file with first number = element + // count\n" " 1 # Reads in a labeled file + // with first number = element count (deprecated)\n" " 2 + // # Reads in unlabeled file without element count + // (default)\n" " 3 # Reads in a csv + // file\n"); + // + // nsb.addCommand("appendFile", + // "appendFile [file_format]\n" + // "Reads vectors from the specified file, appending to current + // vector list.\n" "Position remains unchanged. Available file + // formats are: \n" " 0 # Reads in unlabeled file + // with first number = element count\n" " 1 # Reads + // in a labeled file with first number = element count + // (deprecated)\n" " 2 # Reads in unlabeled file + // without element count (default)\n" " 3 # Reads in + // a csv file\n"); + // + // + // nsb.addCommand("dump", "Displays some debugging info."); + // + // nsb.addCommand("saveFile", + // "saveFile filename [format [begin [end]]]\n" + // "Save the currently loaded vectors to a file. Typically used for + // debugging\n" "but may be used to convert between formats.\n"); + // + // // ------ PARAMETERS + // + // nsb.addParameter("vectorCount", + // "uint32", + // "The number of vectors currently loaded in memory.", + // 1, /* elementCount */ + // "get", + // "interval: [0, ...]", + // "0" /* defaultValue */); + // + // nsb.addParameter("position", + // "uint32", + // "Set or get the current position within the list of + // vectors in memory.", 1, /* elementCount */ "getset", + // "interval: [0, ...]", + // "0" /* defaultValue */); + // + // nsb.addParameter("repeatCount", + // "uint32", + // "Set or get the current repeatCount. Each vector is + // repeated\n" "repeatCount times before moving to the + // next one.", 1, /* elementCount */ "getset", + // "interval: [1, ...]", + // "1" /* defaultValue */); + // + // nsb.addParameter("recentFile", + // "byteptr", + // "Name of the most recently file that is loaded or appended. Mostly + // to \n" + // "support interactive use.\n", + // 1, /* elementCount */ + // "get"); + // + // + // nsb.addParameter("scalingMode", + // "byteptr", + // "During compute, each vector is adjusted as follows. If X is the + // data vector,\n" "S the scaling vector and O the offset vector, then + // the node's output\n" " Y[i] = S[i]*(X[i] + O[i]).\n" + // "\n" + // "Scaling is applied according to scalingMode as follows:\n" + // "\n" + // " If 'none', the vectors are unchanged, i.e. S[i]=1 and + // O[i]=0.\n" " If 'standardForm', S[i] is 1/standard deviation(i) + // and O[i] = - mean(i)\n" " If 'custom', each component is adjusted + // according to the vectors specified by the\n" + // "setScale and setOffset commands.\n", + // 1, /* elementCount */ + // "all", /* access */ + // "", /* constraints */ + // "none" /* defaultValue */); + // + // nsb.addParameter("scaleVector", + // "real", + // "Set or return the current scale vector S.\n", + // 0, /* elementCount */ + // "all", /* access */ + // "", /* constraints */ + // "" /* defaultValue */); + // + // nsb.addParameter("offsetVector", + // "real", + // "Set or return the current offset vector 0.\n", + // 0, /* elementCount */ + // "all", /* access */ + // "", /* constraints */ + // "" /* defaultValue */); + // + // + // nsb.addParameter("activeOutputCount", + // "uint32", + // "The number of active outputs of the node.", + // 1, /* elementCount */ + // "get", /* access */ + // "interval: [0, ...]"); + // + // + // nsb.addParameter("maxOutputVectorCount", + // "uint32", + // "The number of output vectors that can be generated by this + // sensor\n" + // "under the current configuration.", + // 1, /* elementCount */ + // "get", + // "interval: [0, ...]"); + // + // + // + // return nsb.getSpec(); + // } + + static Spec *createSpec(); + size_t getNodeOutputElementCount(const std::string &outputName) override; + void getParameterFromBuffer(const std::string &name, Int64 index, + IWriteBuffer &value) override; + + void setParameterFromBuffer(const std::string &name, Int64 index, + IReadBuffer &value) override; + + size_t getParameterArrayCount(const std::string &name, Int64 index) override; + + virtual void getParameterArray(const std::string &name, Int64 index, + Array &array) override; + virtual void setParameterArray(const std::string &name, Int64 index, + const Array &array) override; + + // void setParameterString(const std::string& name, Int64 index, const + // std::string& s); std::string getParameterString(const std::string& name, + // Int64 index); + + void initialize() override; + + VectorFileSensor(const ValueMap ¶ms, Region *region); + + VectorFileSensor(BundleIO &bundle, Region *region); + + VectorFileSensor(capnp::AnyPointer::Reader &proto, Region *region); + + virtual ~VectorFileSensor(); + + // --- + /// Serialize state to bundle + // --- + virtual void serialize(BundleIO &bundle) override; + + // --- + /// De-serialize state from bundle + // --- + virtual void deserialize(BundleIO &bundle) override; + + using RegionImpl::write; + virtual void write(capnp::AnyPointer::Builder &anyProto) const override; + + using RegionImpl::read; + virtual void read(capnp::AnyPointer::Reader &anyProto) override; + + void compute() override; + virtual std::string executeCommand(const std::vector &args, + Int64 index) override; + +private: + void closeFile(); + void openFile(const std::string &filename); + +private: + NTA_UInt32 repeatCount_; // Repeat count for output vectors + NTA_UInt32 iterations_; // Number of times compute() has been called + NTA_UInt32 curVector_; // The index of the vector currently being output + NTA_UInt32 activeOutputCount_; // The number of elements in each input vector + bool hasCategoryOut_; // determine if a category output is needed + bool hasResetOut_; // determine if a reset output is needed + nupic::VectorFile vectorFile_; // Container class for the vectors + + ArrayRef dataOut_; + ArrayRef categoryOut_; + ArrayRef resetOut_; + std::string filename_; // Name of the output file + + std::string scalingMode_; + std::string recentFile_; // The most recently loaded or appended file + + //------------------- Utility routines and debugging support + + // Seek to the n'th vector in the list. n should be between 0 and + // numVectors-1. Logs a warning if n is outside those bounds. + void seek(int n); + +}; // end class VectorFileSensor - //---------------------------------------------------------------------- +//---------------------------------------------------------------------- } // end namespace nupic diff --git a/src/nupic/types/BasicType.cpp b/src/nupic/types/BasicType.cpp index 3c0c286a2e..d678dcbca9 100644 --- a/src/nupic/types/BasicType.cpp +++ b/src/nupic/types/BasicType.cpp @@ -20,183 +20,137 @@ * --------------------------------------------------------------------- */ - #include #include using namespace nupic; -bool BasicType::isValid(NTA_BasicType t) -{ +bool BasicType::isValid(NTA_BasicType t) { return (t >= 0) && (t < NTA_BasicType_Last); } -const char * BasicType::getName(NTA_BasicType t) -{ - static const char *names[] = - { - "Byte", - "Int16", - "UInt16", - "Int32", - "UInt32", - "Int64", - "UInt64", - "Real32", - "Real64", - "Handle", - "Bool", - }; - +const char *BasicType::getName(NTA_BasicType t) { + static const char *names[] = { + "Byte", "Int16", "UInt16", "Int32", "UInt32", "Int64", + "UInt64", "Real32", "Real64", "Handle", "Bool", + }; + if (!isValid(t)) - throw Exception(__FILE__, __LINE__, "BasicType::getName -- Basic type is not valid"); + throw Exception(__FILE__, __LINE__, + "BasicType::getName -- Basic type is not valid"); return names[t]; } - // gcc 4.2 requires (incorrectly) these to be defined inside a namespace -namespace nupic -{ - // getName - template <> const char* BasicType::getName() - { - return getName(NTA_BasicType_Byte); - } - - template <> const char* BasicType::getName() - { - return getName(NTA_BasicType_Int16); - } - - template <> const char* BasicType::getName() - { - return getName(NTA_BasicType_UInt16); - } - - template <> const char* BasicType::getName() - { - return getName(NTA_BasicType_Int32); - } - - template <> const char* BasicType::getName() - { - return getName(NTA_BasicType_UInt32); - } - - template <> const char* BasicType::getName() - { - return getName(NTA_BasicType_Int64); - } - - template <> const char* BasicType::getName() - { - return getName(NTA_BasicType_UInt64); - } - - template <> const char* BasicType::getName() - { - return getName(NTA_BasicType_Real32); - } - - template <> const char* BasicType::getName() - { - return getName(NTA_BasicType_Real64); - } - - template <> const char* BasicType::getName() - { - return getName(NTA_BasicType_Handle); - } - - template <> const char* BasicType::getName() - { - return getName(NTA_BasicType_Bool); - } - - - // getType - template <> NTA_BasicType BasicType::getType() - { - return NTA_BasicType_Byte; - } +namespace nupic { +// getName +template <> const char *BasicType::getName() { + return getName(NTA_BasicType_Byte); +} - template <> NTA_BasicType BasicType::getType() - { - return NTA_BasicType_Int16; - } +template <> const char *BasicType::getName() { + return getName(NTA_BasicType_Int16); +} - template <> NTA_BasicType BasicType::getType() - { - return NTA_BasicType_UInt16; - } +template <> const char *BasicType::getName() { + return getName(NTA_BasicType_UInt16); +} - template <> NTA_BasicType BasicType::getType() - { - return NTA_BasicType_Int32; - } +template <> const char *BasicType::getName() { + return getName(NTA_BasicType_Int32); +} - template <> NTA_BasicType BasicType::getType() - { - return NTA_BasicType_UInt32; - } +template <> const char *BasicType::getName() { + return getName(NTA_BasicType_UInt32); +} - template <> NTA_BasicType BasicType::getType() - { - return NTA_BasicType_Int64; - } +template <> const char *BasicType::getName() { + return getName(NTA_BasicType_Int64); +} - template <> NTA_BasicType BasicType::getType() - { - return NTA_BasicType_UInt64; - } +template <> const char *BasicType::getName() { + return getName(NTA_BasicType_UInt64); +} - template <> NTA_BasicType BasicType::getType() - { - return NTA_BasicType_Real32; - } +template <> const char *BasicType::getName() { + return getName(NTA_BasicType_Real32); +} - template <> NTA_BasicType BasicType::getType() - { - return NTA_BasicType_Real64; - } +template <> const char *BasicType::getName() { + return getName(NTA_BasicType_Real64); +} - template <> NTA_BasicType BasicType::getType() - { - return NTA_BasicType_Handle; - } +template <> const char *BasicType::getName() { + return getName(NTA_BasicType_Handle); +} - template <> NTA_BasicType BasicType::getType() - { - return NTA_BasicType_Bool; - } -} +template <> const char *BasicType::getName() { + return getName(NTA_BasicType_Bool); +} + +// getType +template <> NTA_BasicType BasicType::getType() { + return NTA_BasicType_Byte; +} + +template <> NTA_BasicType BasicType::getType() { + return NTA_BasicType_Int16; +} + +template <> NTA_BasicType BasicType::getType() { + return NTA_BasicType_UInt16; +} + +template <> NTA_BasicType BasicType::getType() { + return NTA_BasicType_Int32; +} + +template <> NTA_BasicType BasicType::getType() { + return NTA_BasicType_UInt32; +} + +template <> NTA_BasicType BasicType::getType() { + return NTA_BasicType_Int64; +} + +template <> NTA_BasicType BasicType::getType() { + return NTA_BasicType_UInt64; +} + +template <> NTA_BasicType BasicType::getType() { + return NTA_BasicType_Real32; +} + +template <> NTA_BasicType BasicType::getType() { + return NTA_BasicType_Real64; +} + +template <> NTA_BasicType BasicType::getType() { + return NTA_BasicType_Handle; +} + +template <> NTA_BasicType BasicType::getType() { + return NTA_BasicType_Bool; +} +} // namespace nupic // Return the size in bits of a basic type -size_t BasicType::getSize(NTA_BasicType t) -{ - static size_t basicTypeSizes[] = - { - sizeof(NTA_Byte), - sizeof(NTA_Int16), - sizeof(NTA_UInt16), - sizeof(NTA_Int32), - sizeof(NTA_UInt32), - sizeof(NTA_Int64), - sizeof(NTA_UInt64), - sizeof(NTA_Real32), - sizeof(NTA_Real64), - sizeof(NTA_Handle), - sizeof(bool), - }; - +size_t BasicType::getSize(NTA_BasicType t) { + static size_t basicTypeSizes[] = { + sizeof(NTA_Byte), sizeof(NTA_Int16), sizeof(NTA_UInt16), + sizeof(NTA_Int32), sizeof(NTA_UInt32), sizeof(NTA_Int64), + sizeof(NTA_UInt64), sizeof(NTA_Real32), sizeof(NTA_Real64), + sizeof(NTA_Handle), sizeof(bool), + }; + if (!isValid(t)) - throw Exception(__FILE__, __LINE__, "BasicType::getSize -- basic type is not valid"); + throw Exception(__FILE__, __LINE__, + "BasicType::getSize -- basic type is not valid"); return basicTypeSizes[t]; } -NTA_BasicType BasicType::parse(const std::string & s) -{ +NTA_BasicType BasicType::parse(const std::string &s) { if (s == std::string("Byte") || s == std::string("str")) return NTA_BasicType_Byte; else if (s == std::string("Int16")) @@ -222,9 +176,6 @@ NTA_BasicType BasicType::parse(const std::string & s) else if (s == std::string("Bool")) return NTA_BasicType_Bool; else - throw Exception(__FILE__, __LINE__, std::string("Invalid basic type name: ") + s); + throw Exception(__FILE__, __LINE__, + std::string("Invalid basic type name: ") + s); } - - - - diff --git a/src/nupic/types/BasicType.hpp b/src/nupic/types/BasicType.hpp index 0b713d7752..22d1efa754 100644 --- a/src/nupic/types/BasicType.hpp +++ b/src/nupic/types/BasicType.hpp @@ -26,55 +26,52 @@ #include #include -namespace nupic -{ - // The BasicType class provides operations on NTA_BasicType as static methods. +namespace nupic { +// The BasicType class provides operations on NTA_BasicType as static methods. +// +// The supported operations are: +// - isValid() +// - getName() +// - getSize() and parse(). +// +class BasicType { +public: + // Check if the provided basic type os in the proper range. // - // The supported operations are: - // - isValid() - // - getName() - // - getSize() and parse(). - // - class BasicType - { - public: - // Check if the provided basic type os in the proper range. - // - // In C++ enums are just glorified integers and you can cast - // an int to any enum even if the int value is outside of the range of - // definedenum values. The compiler will say nothing. The NTA_BasicType - // enum has a special value called NTA_BasicType_Last that marks the end of - // of the valid rnge of values and isValid() returns true if if the input - // falls in the range [0, NTA_BasicType_Last) and false otherwise. Note, - // that NTA_BasicType_Last itself is an invalid value eventhough it is - // defined in the enum. - static bool isValid(NTA_BasicType t); + // In C++ enums are just glorified integers and you can cast + // an int to any enum even if the int value is outside of the range of + // definedenum values. The compiler will say nothing. The NTA_BasicType + // enum has a special value called NTA_BasicType_Last that marks the end of + // of the valid rnge of values and isValid() returns true if if the input + // falls in the range [0, NTA_BasicType_Last) and false otherwise. Note, + // that NTA_BasicType_Last itself is an invalid value eventhough it is + // defined in the enum. + static bool isValid(NTA_BasicType t); - // Return the name of a basic type (without the "NTA_BasicType_") prefix. - // For example the name of NTA_BasicType_Int32 is "int32". - static const char * getName(NTA_BasicType t); + // Return the name of a basic type (without the "NTA_BasicType_") prefix. + // For example the name of NTA_BasicType_Int32 is "int32". + static const char *getName(NTA_BasicType t); - // Like getName above, but can be used in a templated method - template static const char* getName(); + // Like getName above, but can be used in a templated method + template static const char *getName(); - // To convert -> NTA_BasicType in a templated method - template static NTA_BasicType getType(); + // To convert -> NTA_BasicType in a templated method + template static NTA_BasicType getType(); - // Return the size in bits of a basic type - static size_t getSize(NTA_BasicType t); + // Return the size in bits of a basic type + static size_t getSize(NTA_BasicType t); - // Parse a string and return the corresponding basic type - // - // The string should contain the name of the basic type - // without the "NTA_BasicType_" prefix. For example the name - // of NTA_BasicType_Int32 is "Int32" - static NTA_BasicType parse(const std::string & s); + // Parse a string and return the corresponding basic type + // + // The string should contain the name of the basic type + // without the "NTA_BasicType_" prefix. For example the name + // of NTA_BasicType_Int32 is "Int32" + static NTA_BasicType parse(const std::string &s); - private: - BasicType(); - BasicType(const BasicType &); - }; -} +private: + BasicType(); + BasicType(const BasicType &); +}; +} // namespace nupic #endif - diff --git a/src/nupic/types/Exception.hpp b/src/nupic/types/Exception.hpp index 550dbe5db2..d73cc864cb 100644 --- a/src/nupic/types/Exception.hpp +++ b/src/nupic/types/Exception.hpp @@ -34,162 +34,133 @@ //---------------------------------------------------------------------- -namespace nupic -{ +namespace nupic { +/** + * @b Responsibility + * The Exception class is the standard Numenta exception class. + * It is responsible for storing rich error information: + * the filename and line number where the exceptional situation + * occured and a message that describes the exception. + * + * @b Rationale: + * It is important to store the location (filename, line number) where + * an exception originated, but no standard excepton class does it. + * The stack trace is even better and brings C++ programming to the + * ease of use of languages like Java and C#. + * + * @b Usage: + * This class may be used directly by instatiating an instance + * and throwing it, but usually you will use the NTA_THROW macro + * that makes it much simpler by automatically retreiving the __FILE__ + * and __LINE__ for you and also using a wrapping LogItem that allows + * you to construct the exception message conveniently using the << + * streame operator (see nupic/utils/Log.hpp for further details). + * + * @b Notes: + * 1. Exception is a subclass of the standard std::runtime_error. + * This is useful if your code needs to interoperate with other + * code that is not aware of the Exception class, but understands + * std::runtime_error. The what() method will return the exception message + * and the location information will not be avaialable to such code. + * + * 2. Source file and line number information is useful of course + * only if you have access to the source code. It is not recommended + * to display this information to users most of the time. + */ +class Exception : public std::runtime_error { +public: /** - * @b Responsibility - * The Exception class is the standard Numenta exception class. - * It is responsible for storing rich error information: - * the filename and line number where the exceptional situation - * occured and a message that describes the exception. - * - * @b Rationale: - * It is important to store the location (filename, line number) where - * an exception originated, but no standard excepton class does it. - * The stack trace is even better and brings C++ programming to the - * ease of use of languages like Java and C#. + * Constructor * - * @b Usage: - * This class may be used directly by instatiating an instance - * and throwing it, but usually you will use the NTA_THROW macro - * that makes it much simpler by automatically retreiving the __FILE__ - * and __LINE__ for you and also using a wrapping LogItem that allows - * you to construct the exception message conveniently using the << - * streame operator (see nupic/utils/Log.hpp for further details). + * Take the filename, line number and message + * and store it in private members * - * @b Notes: - * 1. Exception is a subclass of the standard std::runtime_error. - * This is useful if your code needs to interoperate with other - * code that is not aware of the Exception class, but understands - * std::runtime_error. The what() method will return the exception message - * and the location information will not be avaialable to such code. - * - * 2. Source file and line number information is useful of course - * only if you have access to the source code. It is not recommended - * to display this information to users most of the time. + * @param filename [const std::string &] the name of the source file + * where the exception originated. + * @param lineno [UInt32] the line number in the source file where + * the exception originated. + * + * @param message [const std::string &] the description of exception */ - class Exception : public std::runtime_error - { - public: - /** - * Constructor - * - * Take the filename, line number and message - * and store it in private members - * - * @param filename [const std::string &] the name of the source file - * where the exception originated. - * @param lineno [UInt32] the line number in the source file where - * the exception originated. - * - * @param message [const std::string &] the description of exception - */ - Exception(std::string filename, - UInt32 lineno, - std::string message, - std::string stacktrace = "") : - std::runtime_error(""), - filename_(std::move(filename)), - lineno_(lineno), - message_(std::move(message)), - stackTrace_(std::move(stacktrace)) - { - } - - /** - * Destructor - * - * Doesn't do anything, but must be present - * because the base class std::runtime_error - * defines a pure virtual destructor and the - * default destructor provided by the compiler - * doesn't have a empty exception specification - * - */ - virtual ~Exception() throw() - { - } - - /** - * Get the exception message via what() - * - * Overload the what() method of std::runtime_error - * and returns the exception description message. - * The emptry exception specification is required because - * it is part of the signature of std::runtime_error::what(). - * - * @retval [const Byte *] the exception message - */ - virtual const char * what() const throw() - { - try - { - return getMessage(); - } - catch (...) - { - return "Exception caught in non-throwing Exception::what()"; - } - } - - /** - * Get the source filename - * - * Returns the full path to the source file, from which - * the exception was thrown. - * - * @retval [const Byte *] the source filename - */ - const char * getFilename() const - { - return filename_.c_str(); - } - - /** - * Get the line number in the source file - * - * Returns the (0-based) line number in the source file, - * from which the exception was thrown. - * - * @retval [UInt32] the line number in the source file - */ - UInt32 getLineNumber() const - { - return lineno_; - } + Exception(std::string filename, UInt32 lineno, std::string message, + std::string stacktrace = "") + : std::runtime_error(""), filename_(std::move(filename)), lineno_(lineno), + message_(std::move(message)), stackTrace_(std::move(stacktrace)) {} - /** - * Get the error message - * - * @retval [const char *] the error message - */ - virtual const char * getMessage() const - { - return message_.c_str(); - } + /** + * Destructor + * + * Doesn't do anything, but must be present + * because the base class std::runtime_error + * defines a pure virtual destructor and the + * default destructor provided by the compiler + * doesn't have a empty exception specification + * + */ + virtual ~Exception() throw() {} - /** - * Get the stack trace - * - * Returns the stack trace from the point the exception - * was thrown. - * - * @retval [const Byte *] the stack trace - */ - virtual const char * getStackTrace() const - { - return stackTrace_.c_str(); + /** + * Get the exception message via what() + * + * Overload the what() method of std::runtime_error + * and returns the exception description message. + * The emptry exception specification is required because + * it is part of the signature of std::runtime_error::what(). + * + * @retval [const Byte *] the exception message + */ + virtual const char *what() const throw() { + try { + return getMessage(); + } catch (...) { + return "Exception caught in non-throwing Exception::what()"; } + } - protected: - std::string filename_; - UInt32 lineno_; - std::string message_; - std::string stackTrace_; + /** + * Get the source filename + * + * Returns the full path to the source file, from which + * the exception was thrown. + * + * @retval [const Byte *] the source filename + */ + const char *getFilename() const { return filename_.c_str(); } + + /** + * Get the line number in the source file + * + * Returns the (0-based) line number in the source file, + * from which the exception was thrown. + * + * @retval [UInt32] the line number in the source file + */ + UInt32 getLineNumber() const { return lineno_; } + + /** + * Get the error message + * + * @retval [const char *] the error message + */ + virtual const char *getMessage() const { return message_.c_str(); } + + /** + * Get the stack trace + * + * Returns the stack trace from the point the exception + * was thrown. + * + * @retval [const Byte *] the stack trace + */ + virtual const char *getStackTrace() const { return stackTrace_.c_str(); } - - }; // end class Exception +protected: + std::string filename_; + UInt32 lineno_; + std::string message_; + std::string stackTrace_; + +}; // end class Exception } // end namespace nupic #endif // NTA_EXCEPTION_HPP - diff --git a/src/nupic/types/Fraction.cpp b/src/nupic/types/Fraction.cpp index 20ccfa6caa..435686bb18 100644 --- a/src/nupic/types/Fraction.cpp +++ b/src/nupic/types/Fraction.cpp @@ -20,367 +20,277 @@ * --------------------------------------------------------------------- */ -#include -#include -#include #include #include //abs! +#include +#include +#include -#include #include +#include -namespace nupic -{ - Fraction::Fraction(int _numerator, int _denominator) : - numerator_(_numerator), - denominator_(_denominator) - { - if (_denominator == 0) - { - throw Exception(__FILE__, - __LINE__, - "Fraction - attempt to create with invalid zero valued " - "denominator"); - } - - //can't use abs() because abs(std::numeric_limits::min()) == -2^31 - if ((numerator_ > overflowCutoff) || (numerator_ < -overflowCutoff) || - (denominator_ > overflowCutoff) || (denominator_ < -overflowCutoff)) - { - throw Exception(__FILE__, - __LINE__, - "Fraction - integer overflow."); - } +namespace nupic { +Fraction::Fraction(int _numerator, int _denominator) + : numerator_(_numerator), denominator_(_denominator) { + if (_denominator == 0) { + throw Exception(__FILE__, __LINE__, + "Fraction - attempt to create with invalid zero valued " + "denominator"); } - Fraction::Fraction(int _numerator) : - numerator_(_numerator), - denominator_(1) - { - //can't use abs() because abs(std::numeric_limits::min()) == -2^31 - if ((numerator_ > overflowCutoff) || (numerator_ < -overflowCutoff)) - { - throw Exception(__FILE__, - __LINE__, - "Fraction - integer overflow."); - } + // can't use abs() because abs(std::numeric_limits::min()) == -2^31 + if ((numerator_ > overflowCutoff) || (numerator_ < -overflowCutoff) || + (denominator_ > overflowCutoff) || (denominator_ < -overflowCutoff)) { + throw Exception(__FILE__, __LINE__, "Fraction - integer overflow."); } +} - Fraction::Fraction() : - numerator_(0), - denominator_(1) - { +Fraction::Fraction(int _numerator) : numerator_(_numerator), denominator_(1) { + // can't use abs() because abs(std::numeric_limits::min()) == -2^31 + if ((numerator_ > overflowCutoff) || (numerator_ < -overflowCutoff)) { + throw Exception(__FILE__, __LINE__, "Fraction - integer overflow."); } +} - bool Fraction::isNaturalNumber() - { - return(((numerator_ % denominator_) == 0) && ((*this > 0) || (numerator_ == 0))); - } - - int Fraction::getNumerator() - { - return numerator_; - } +Fraction::Fraction() : numerator_(0), denominator_(1) {} - int Fraction::getDenominator() - { - return denominator_; - } +bool Fraction::isNaturalNumber() { + return (((numerator_ % denominator_) == 0) && + ((*this > 0) || (numerator_ == 0))); +} - void Fraction::setNumerator(int _numerator) - { - numerator_ = _numerator; - } - - void Fraction::setDenominator(int _denominator) - { - if(_denominator == 0) - { - throw Exception(__FILE__, - __LINE__, - "Fraction - attempt to set an invalid zero valued " - "denominator"); - } - denominator_ = _denominator; +int Fraction::getNumerator() { return numerator_; } + +int Fraction::getDenominator() { return denominator_; } + +void Fraction::setNumerator(int _numerator) { numerator_ = _numerator; } + +void Fraction::setDenominator(int _denominator) { + if (_denominator == 0) { + throw Exception(__FILE__, __LINE__, + "Fraction - attempt to set an invalid zero valued " + "denominator"); } - - void Fraction::setFraction(int _numerator, int _denominator) - { - numerator_ = _numerator; - denominator_ = _denominator; - if(_denominator == 0) - { - throw Exception(__FILE__, - __LINE__, - "Fraction - attempt to set an invalid zero valued " - "denominator"); - } + denominator_ = _denominator; +} + +void Fraction::setFraction(int _numerator, int _denominator) { + numerator_ = _numerator; + denominator_ = _denominator; + if (_denominator == 0) { + throw Exception(__FILE__, __LINE__, + "Fraction - attempt to set an invalid zero valued " + "denominator"); } +} + +unsigned int Fraction::computeGCD(int a, int b) { + unsigned int x, y, r; - unsigned int Fraction::computeGCD(int a, int b) - { - unsigned int x, y, r; - - if(a == 0) - { - if(b > 0) - { - return b; - } - else - { - return 1; - } + if (a == 0) { + if (b > 0) { + return b; + } else { + return 1; } - else if(b == 0) - { - if (a > 0) - { - return a; - } - else - { - return 1; - } + } else if (b == 0) { + if (a > 0) { + return a; + } else { + return 1; } + } - //Euclid's algorithm - a > b ? (x = abs(a), y = abs(b)) : - (x = abs(b), y = abs(a)); + // Euclid's algorithm + a > b ? (x = abs(a), y = abs(b)) : (x = abs(b), y = abs(a)); - r = x % y; + r = x % y; - while(r!=0) - { - x = y; - y = r; + while (r != 0) { + x = y; + y = r; - r = x % y; - } - - return y; + r = x % y; } - unsigned int Fraction::computeLCM(int a,int b) - { - int lcm = a*b/((int)computeGCD(a,b)); - if(lcm < 0) - { - lcm = 0; - } - return lcm; + return y; +} + +unsigned int Fraction::computeLCM(int a, int b) { + int lcm = a * b / ((int)computeGCD(a, b)); + if (lcm < 0) { + lcm = 0; } + return lcm; +} - void Fraction::reduce() - { - if(numerator_ == 0) - { - denominator_ = 1; - } - else - { - unsigned int m = computeGCD(numerator_, denominator_); +void Fraction::reduce() { + if (numerator_ == 0) { + denominator_ = 1; + } else { + unsigned int m = computeGCD(numerator_, denominator_); - numerator_ /= (int)m; - denominator_ /= (int)m; - } - if(denominator_ < 0) - { - numerator_ *= -1; - denominator_ *= -1; - } + numerator_ /= (int)m; + denominator_ /= (int)m; } - - Fraction Fraction::operator*(const Fraction& rhs) - { - return Fraction(numerator_ * rhs.numerator_, - denominator_ * rhs.denominator_); + if (denominator_ < 0) { + numerator_ *= -1; + denominator_ *= -1; } +} - Fraction Fraction::operator*(const int rhs) - { - return Fraction(numerator_ * rhs, denominator_); - } +Fraction Fraction::operator*(const Fraction &rhs) { + return Fraction(numerator_ * rhs.numerator_, denominator_ * rhs.denominator_); +} - Fraction operator/(const Fraction& lhs, const Fraction& rhs) - { - if(rhs.numerator_ == 0) - { - throw Exception(__FILE__, - __LINE__, - "Fraction - division by zero error"); - } +Fraction Fraction::operator*(const int rhs) { + return Fraction(numerator_ * rhs, denominator_); +} - return Fraction(lhs.numerator_ * rhs.denominator_, - lhs.denominator_ * rhs.numerator_); +Fraction operator/(const Fraction &lhs, const Fraction &rhs) { + if (rhs.numerator_ == 0) { + throw Exception(__FILE__, __LINE__, "Fraction - division by zero error"); } - Fraction operator-(const Fraction& lhs, const Fraction& rhs) - { - int num, lcm; + return Fraction(lhs.numerator_ * rhs.denominator_, + lhs.denominator_ * rhs.numerator_); +} - lcm = Fraction::computeLCM(lhs.denominator_, rhs.denominator_); - num = lhs.numerator_*(lcm/lhs.denominator_) - - rhs.numerator_*(lcm/rhs.denominator_); +Fraction operator-(const Fraction &lhs, const Fraction &rhs) { + int num, lcm; - return Fraction(num,lcm); - } + lcm = Fraction::computeLCM(lhs.denominator_, rhs.denominator_); + num = lhs.numerator_ * (lcm / lhs.denominator_) - + rhs.numerator_ * (lcm / rhs.denominator_); - Fraction Fraction::operator+(const Fraction& rhs) - { - int num, den; + return Fraction(num, lcm); +} - den = computeLCM(denominator_, rhs.denominator_); - num = den/denominator_*numerator_ + den/rhs.denominator_*rhs.numerator_; +Fraction Fraction::operator+(const Fraction &rhs) { + int num, den; - return Fraction(num,den); - } + den = computeLCM(denominator_, rhs.denominator_); + num = + den / denominator_ * numerator_ + den / rhs.denominator_ * rhs.numerator_; - Fraction Fraction::operator%(const Fraction& rhs) - { - // a/b % c/d = (ad % bc) / bc. gives output with same sign as a/b - if(rhs.numerator_ == 0) - { - throw Exception(__FILE__, - __LINE__, - "Fraction - division by zero error"); - return Fraction(0,1); - } + return Fraction(num, den); +} - return Fraction((rhs.denominator_ * numerator_) % - (denominator_ * rhs.numerator_), - denominator_ * rhs.denominator_); +Fraction Fraction::operator%(const Fraction &rhs) { + // a/b % c/d = (ad % bc) / bc. gives output with same sign as a/b + if (rhs.numerator_ == 0) { + throw Exception(__FILE__, __LINE__, "Fraction - division by zero error"); + return Fraction(0, 1); } - bool Fraction::operator<(const Fraction& rhs) - { - // a/b < c/d if (ad)/(bd) < (bc)/(bd), i.e. if a*d < b*c - bool negLHS = (denominator_ < 0); - bool negRHS = (rhs.denominator_ < 0); - if((negLHS || negRHS) && !(negLHS && negRHS)) - { - return((numerator_ * rhs.denominator_) > (denominator_ * rhs.numerator_)); - } - else - { - return((numerator_ * rhs.denominator_) < (denominator_ * rhs.numerator_)); - } - } + return Fraction((rhs.denominator_ * numerator_) % + (denominator_ * rhs.numerator_), + denominator_ * rhs.denominator_); +} - bool Fraction::operator>(const Fraction& rhs) - { - // a/b > c/d if (ad)/(bd) > (bc)/(bd), i.e. if a*d > b*c - bool negLHS = (denominator_ < 0); - bool negRHS = (rhs.denominator_ < 0); - if((negLHS || negRHS) && !(negLHS && negRHS)) - { - return((numerator_ * rhs.denominator_) < (denominator_ * rhs.numerator_)); - } - else - { - return((numerator_ * rhs.denominator_) > (denominator_ * rhs.numerator_)); - } +bool Fraction::operator<(const Fraction &rhs) { + // a/b < c/d if (ad)/(bd) < (bc)/(bd), i.e. if a*d < b*c + bool negLHS = (denominator_ < 0); + bool negRHS = (rhs.denominator_ < 0); + if ((negLHS || negRHS) && !(negLHS && negRHS)) { + return ((numerator_ * rhs.denominator_) > (denominator_ * rhs.numerator_)); + } else { + return ((numerator_ * rhs.denominator_) < (denominator_ * rhs.numerator_)); } +} - bool Fraction::operator<=(const Fraction& rhs) - { - return ((Fraction(numerator_, denominator_) < rhs) || - (Fraction(numerator_, denominator_) == rhs)); +bool Fraction::operator>(const Fraction &rhs) { + // a/b > c/d if (ad)/(bd) > (bc)/(bd), i.e. if a*d > b*c + bool negLHS = (denominator_ < 0); + bool negRHS = (rhs.denominator_ < 0); + if ((negLHS || negRHS) && !(negLHS && negRHS)) { + return ((numerator_ * rhs.denominator_) < (denominator_ * rhs.numerator_)); + } else { + return ((numerator_ * rhs.denominator_) > (denominator_ * rhs.numerator_)); } +} - bool Fraction::operator>=(const Fraction& rhs) - { - return ((Fraction(numerator_, denominator_) > rhs) || - (Fraction(numerator_, denominator_) == rhs)); - } +bool Fraction::operator<=(const Fraction &rhs) { + return ((Fraction(numerator_, denominator_) < rhs) || + (Fraction(numerator_, denominator_) == rhs)); +} - bool operator==(Fraction lhs, Fraction rhs) - { - lhs.reduce(); - rhs.reduce(); +bool Fraction::operator>=(const Fraction &rhs) { + return ((Fraction(numerator_, denominator_) > rhs) || + (Fraction(numerator_, denominator_) == rhs)); +} - return(lhs.numerator_ == rhs.numerator_ && - lhs.denominator_ == rhs.denominator_); - } +bool operator==(Fraction lhs, Fraction rhs) { + lhs.reduce(); + rhs.reduce(); - std::ostream& operator<<(std::ostream& out, Fraction rhs) - { - rhs.reduce(); + return (lhs.numerator_ == rhs.numerator_ && + lhs.denominator_ == rhs.denominator_); +} - if(rhs.denominator_ == 1) - { - out << rhs.numerator_; - } - else - { - out << rhs.numerator_ << "/" << rhs.denominator_; - } +std::ostream &operator<<(std::ostream &out, Fraction rhs) { + rhs.reduce(); - return out; + if (rhs.denominator_ == 1) { + out << rhs.numerator_; + } else { + out << rhs.numerator_ << "/" << rhs.denominator_; } - // Recovers a fraction representation of a provided double by building - // a continued fraction and stopping when a continuation component's - // denominator exceeds the provided tolerance. - Fraction Fraction::fromDouble(double value, unsigned int tolerance) - { - std::vector components; - int component, numerator_, denominator_; - double continuation; - bool isNegative; - - if(value < 0) - { - isNegative = true; - continuation = -value; - } - else - { - isNegative = false; - continuation = value; - } + return out; +} - //use arbitrary cutoff for integer values set in Fraction.hpp - if(std::abs(value) > overflowCutoff) - { - throw Exception(__FILE__, - __LINE__, - "Fraction - integer overflow for abritrary cutoff."); - } - else if((std::fabs(value) < 1.0/overflowCutoff) && (std::fabs(value) > 0)) - { - throw Exception(__FILE__, - __LINE__, - "Fraction - integer underflow for arbitrary cutoff."); - } - - do { - component = (int) continuation; - components.push_back(component); - continuation = 1.0 / (continuation - (double) component); - } while(continuation < tolerance && components.size() < 100); +// Recovers a fraction representation of a provided double by building +// a continued fraction and stopping when a continuation component's +// denominator exceeds the provided tolerance. +Fraction Fraction::fromDouble(double value, unsigned int tolerance) { + std::vector components; + int component, numerator_, denominator_; + double continuation; + bool isNegative; + + if (value < 0) { + isNegative = true; + continuation = -value; + } else { + isNegative = false; + continuation = value; + } - denominator_ = 1; - numerator_ = components.back(); - components.pop_back(); + // use arbitrary cutoff for integer values set in Fraction.hpp + if (std::abs(value) > overflowCutoff) { + throw Exception(__FILE__, __LINE__, + "Fraction - integer overflow for abritrary cutoff."); + } else if ((std::fabs(value) < 1.0 / overflowCutoff) && + (std::fabs(value) > 0)) { + throw Exception(__FILE__, __LINE__, + "Fraction - integer underflow for arbitrary cutoff."); + } - while(components.size()) - { - std::swap(numerator_, denominator_); - numerator_ += denominator_ * components.back(); - components.pop_back(); - } + do { + component = (int)continuation; + components.push_back(component); + continuation = 1.0 / (continuation - (double)component); + } while (continuation < tolerance && components.size() < 100); - if(isNegative) - { - numerator_ *= -1; - } + denominator_ = 1; + numerator_ = components.back(); + components.pop_back(); - return(Fraction(numerator_, denominator_)); + while (components.size()) { + std::swap(numerator_, denominator_); + numerator_ += denominator_ * components.back(); + components.pop_back(); } - double Fraction::toDouble() - { - return((double) numerator_ / (double) denominator_); + if (isNegative) { + numerator_ *= -1; } + + return (Fraction(numerator_, denominator_)); +} + +double Fraction::toDouble() { + return ((double)numerator_ / (double)denominator_); } +} // namespace nupic diff --git a/src/nupic/types/Fraction.hpp b/src/nupic/types/Fraction.hpp index ab42d7020a..1dd4f02d03 100644 --- a/src/nupic/types/Fraction.hpp +++ b/src/nupic/types/Fraction.hpp @@ -25,50 +25,48 @@ #include -namespace nupic -{ - class Fraction - { - private: - int numerator_, denominator_; - // arbitrary cutoff -- need to fix overflow handling. 64-bits everywhere? - const static int overflowCutoff = 10000000; - - public: - Fraction(int _numerator, int _denominator); - Fraction(int _numerator); - Fraction(); - - bool isNaturalNumber(); - - int getNumerator(); - int getDenominator(); - - void setNumerator(int _numerator); - void setDenominator(int _denominator); - void setFraction(int _numerator, int _denominator); - - static unsigned int computeGCD(int a, int b); - static unsigned int computeLCM(int a,int b); - - void reduce(); - - Fraction operator*(const Fraction& rhs); - Fraction operator*(const int rhs); - friend Fraction operator/(const Fraction& lhs, const Fraction& rhs); - friend Fraction operator-(const Fraction& lhs, const Fraction& rhs); - Fraction operator+(const Fraction& rhs); - Fraction operator%(const Fraction& rhs); - bool operator<(const Fraction& rhs); - bool operator>(const Fraction& rhs); - bool operator<=(const Fraction& rhs); - bool operator>=(const Fraction& rhs); - friend bool operator==(Fraction lhs, Fraction rhs); - friend std::ostream& operator<<(std::ostream& out, Fraction rhs); - - static Fraction fromDouble(double value, unsigned int tolerance = 10000); - double toDouble(); - }; -} +namespace nupic { +class Fraction { +private: + int numerator_, denominator_; + // arbitrary cutoff -- need to fix overflow handling. 64-bits everywhere? + const static int overflowCutoff = 10000000; -#endif //NTA_FRACTION_HPP +public: + Fraction(int _numerator, int _denominator); + Fraction(int _numerator); + Fraction(); + + bool isNaturalNumber(); + + int getNumerator(); + int getDenominator(); + + void setNumerator(int _numerator); + void setDenominator(int _denominator); + void setFraction(int _numerator, int _denominator); + + static unsigned int computeGCD(int a, int b); + static unsigned int computeLCM(int a, int b); + + void reduce(); + + Fraction operator*(const Fraction &rhs); + Fraction operator*(const int rhs); + friend Fraction operator/(const Fraction &lhs, const Fraction &rhs); + friend Fraction operator-(const Fraction &lhs, const Fraction &rhs); + Fraction operator+(const Fraction &rhs); + Fraction operator%(const Fraction &rhs); + bool operator<(const Fraction &rhs); + bool operator>(const Fraction &rhs); + bool operator<=(const Fraction &rhs); + bool operator>=(const Fraction &rhs); + friend bool operator==(Fraction lhs, Fraction rhs); + friend std::ostream &operator<<(std::ostream &out, Fraction rhs); + + static Fraction fromDouble(double value, unsigned int tolerance = 10000); + double toDouble(); +}; +} // namespace nupic + +#endif // NTA_FRACTION_HPP diff --git a/src/nupic/types/Serializable.hpp b/src/nupic/types/Serializable.hpp index be55919b9d..0598cc2825 100644 --- a/src/nupic/types/Serializable.hpp +++ b/src/nupic/types/Serializable.hpp @@ -35,37 +35,34 @@ namespace nupic { - /** - * Base Serializable class that any serializable class - * should inherit from. - */ - template - class Serializable { - public: - void write(std::ostream& stream) const - { - capnp::MallocMessageBuilder message; - typename ProtoT::Builder proto = message.initRoot(); - write(proto); +/** + * Base Serializable class that any serializable class + * should inherit from. + */ +template class Serializable { +public: + void write(std::ostream &stream) const { + capnp::MallocMessageBuilder message; + typename ProtoT::Builder proto = message.initRoot(); + write(proto); - kj::std::StdOutputStream out(stream); - capnp::writeMessage(out, message); - } + kj::std::StdOutputStream out(stream); + capnp::writeMessage(out, message); + } - void read(std::istream& stream) - { - kj::std::StdInputStream in(stream); + void read(std::istream &stream) { + kj::std::StdInputStream in(stream); - capnp::InputStreamMessageReader message(in); - typename ProtoT::Reader proto = message.getRoot(); - read(proto); - } + capnp::InputStreamMessageReader message(in); + typename ProtoT::Reader proto = message.getRoot(); + read(proto); + } - virtual void write(typename ProtoT::Builder& proto) const = 0; - virtual void read(typename ProtoT::Reader& proto) = 0; + virtual void write(typename ProtoT::Builder &proto) const = 0; + virtual void read(typename ProtoT::Reader &proto) = 0; - virtual ~Serializable() {} - }; + virtual ~Serializable() {} +}; } // end namespace nupic #endif // NTA_serializable_HPP diff --git a/src/nupic/types/Types.h b/src/nupic/types/Types.h index 54515b918e..4e96b522eb 100644 --- a/src/nupic/types/Types.h +++ b/src/nupic/types/Types.h @@ -20,11 +20,11 @@ * --------------------------------------------------------------------- */ -/** +/** * @file - * - * Basic C type definitions used throughout `nupic.core` . - * + * + * Basic C type definitions used throughout `nupic.core` . + * * It is included by `Types.hpp` - the C++ basic types file */ @@ -35,92 +35,97 @@ #include #if defined(NTA_OS_WINDOWS) && defined(NTA_COMPILER_MSVC) && defined(NDEBUG) -#pragma warning( disable : 4244 ) // conversion from 'double' to 'nta::Real', possible loss of data (LOTS of various type combinations) -#pragma warning( disable : 4251 ) // needs to have dll-interface to be used by clients of class -#pragma warning( disable : 4275 ) // non dll-interface struct used as base for dll-interface class -#pragma warning( disable : 4305 ) // truncation from 'double' to 'nta::Real', possible loss of data (LOTS of various type combinations) -#pragma warning( once : 4838 ) // narrowing conversions +#pragma warning( \ + disable : 4244) // conversion from 'double' to 'nta::Real', possible loss of + // data (LOTS of various type combinations) +#pragma warning(disable : 4251) // needs to have dll-interface to be used by + // clients of class +#pragma warning(disable : 4275) // non dll-interface struct used as base for + // dll-interface class +#pragma warning( \ + disable : 4305) // truncation from 'double' to 'nta::Real', possible loss of + // data (LOTS of various type combinations) +#pragma warning(once : 4838) // narrowing conversions #endif /*---------------------------------------------------------------------- */ - -/** + +/** * Basic types enumeration */ -typedef enum NTA_BasicType - { - /** - * Represents a 8-bit byte. - */ - NTA_BasicType_Byte, - - /** - * Represents a 16-bit signed integer. - */ - NTA_BasicType_Int16, - - /** - * Represents a 16-bit unsigned integer. - */ - NTA_BasicType_UInt16, - - /** - * Represents a 32-bit signed integer. - */ - NTA_BasicType_Int32, - - /** - * Represents a 32-bit unsigned integer. - */ - NTA_BasicType_UInt32, - - /** - * Represents a 64-bit signed integer. - */ - NTA_BasicType_Int64, - - /** - * Represents a 64-bit unsigned integer. - */ - NTA_BasicType_UInt64, - - /** - * Represents a 32-bit real number(a floating-point number). - */ - NTA_BasicType_Real32, - - /** - * Represents a 64-bit real number(a floating-point number). - */ - NTA_BasicType_Real64, - - /** - * Represents a opaque handle/pointer, same as `void *` - */ - NTA_BasicType_Handle, - - /** - * Represents a boolean. The size is compiler-defined. - * - * There is no typedef'd "Bool" or "NTA_Bool". We just need a way to refer - * to bools with a NTA_BasicType. - */ - NTA_BasicType_Bool, - - /** - * @note This is not an actual type, just a marker for validation purposes - */ - NTA_BasicType_Last, - -#ifdef NTA_DOUBLE_PRECISION - /** TODO: document */ - NTA_BasicType_Real = NTA_BasicType_Real64, -#else - /** TODO: document */ - NTA_BasicType_Real = NTA_BasicType_Real32, +typedef enum NTA_BasicType { + /** + * Represents a 8-bit byte. + */ + NTA_BasicType_Byte, + + /** + * Represents a 16-bit signed integer. + */ + NTA_BasicType_Int16, + + /** + * Represents a 16-bit unsigned integer. + */ + NTA_BasicType_UInt16, + + /** + * Represents a 32-bit signed integer. + */ + NTA_BasicType_Int32, + + /** + * Represents a 32-bit unsigned integer. + */ + NTA_BasicType_UInt32, + + /** + * Represents a 64-bit signed integer. + */ + NTA_BasicType_Int64, + + /** + * Represents a 64-bit unsigned integer. + */ + NTA_BasicType_UInt64, + + /** + * Represents a 32-bit real number(a floating-point number). + */ + NTA_BasicType_Real32, + + /** + * Represents a 64-bit real number(a floating-point number). + */ + NTA_BasicType_Real64, + + /** + * Represents a opaque handle/pointer, same as `void *` + */ + NTA_BasicType_Handle, + + /** + * Represents a boolean. The size is compiler-defined. + * + * There is no typedef'd "Bool" or "NTA_Bool". We just need a way to refer + * to bools with a NTA_BasicType. + */ + NTA_BasicType_Bool, + + /** + * @note This is not an actual type, just a marker for validation purposes + */ + NTA_BasicType_Last, + +#ifdef NTA_DOUBLE_PRECISION + /** TODO: document */ + NTA_BasicType_Real = NTA_BasicType_Real64, +#else + /** TODO: document */ + NTA_BasicType_Real = NTA_BasicType_Real32, #endif - } NTA_BasicType; +} NTA_BasicType; /** * @name Basic types @@ -131,116 +136,115 @@ typedef enum NTA_BasicType /** * Represents a 8-bit byte. */ -typedef char NTA_Byte; +typedef char NTA_Byte; /** * Represents lengths of arrays, strings and so on. */ -typedef size_t NTA_Size; +typedef size_t NTA_Size; /** * Represents a 16-bit signed integer. */ -typedef short NTA_Int16; +typedef short NTA_Int16; /** * Represents a 16-bit unsigned integer. */ -typedef unsigned short NTA_UInt16; - +typedef unsigned short NTA_UInt16; + /** * Represents a 32-bit real number(a floating-point number). */ -typedef float NTA_Real32; +typedef float NTA_Real32; /** * Represents a 64-bit real number(a floating-point number). */ -typedef double NTA_Real64; +typedef double NTA_Real64; /** * Represents an opaque handle/pointer, same as `void *` */ -typedef void * NTA_Handle; +typedef void *NTA_Handle; /** -* Represents an opaque pointer, same as `uintptr_t` -*/ -typedef uintptr_t NTA_UIntPtr; - + * Represents an opaque pointer, same as `uintptr_t` + */ +typedef uintptr_t NTA_UIntPtr; #if defined(NTA_OS_WINDOWS) - #if defined(NTA_ARCH_32) - /** - * Represents a 32-bit signed integer. - */ - typedef long NTA_Int32; - /** - * Represents a 32-bit unsigned integer. - */ - typedef unsigned long NTA_UInt32; - /** - * Represents a 64-bit signed integer. - */ - typedef long long NTA_Int64; - /** - * Represents a 64-bit unsigned integer. - */ - typedef unsigned long long NTA_UInt64; - #else // 64bit (LLP64 data models) - /** - * Represents a 32-bit signed integer. - */ - typedef int NTA_Int32; - /** - * Represents a 32-bit unsigned integer. - */ - typedef unsigned int NTA_UInt32; - /** - * Represents a 64-bit signed integer. - */ - typedef long long NTA_Int64; - /** - * Represents a 64-bit unsigned integer. - */ - typedef unsigned long long NTA_UInt64; - #endif +#if defined(NTA_ARCH_32) +/** + * Represents a 32-bit signed integer. + */ +typedef long NTA_Int32; +/** + * Represents a 32-bit unsigned integer. + */ +typedef unsigned long NTA_UInt32; +/** + * Represents a 64-bit signed integer. + */ +typedef long long NTA_Int64; +/** + * Represents a 64-bit unsigned integer. + */ +typedef unsigned long long NTA_UInt64; +#else // 64bit (LLP64 data models) +/** + * Represents a 32-bit signed integer. + */ +typedef int NTA_Int32; +/** + * Represents a 32-bit unsigned integer. + */ +typedef unsigned int NTA_UInt32; +/** + * Represents a 64-bit signed integer. + */ +typedef long long NTA_Int64; +/** + * Represents a 64-bit unsigned integer. + */ +typedef unsigned long long NTA_UInt64; +#endif #else // *nix (linux, darwin, etc) - #if defined(NTA_ARCH_32) - /** - * Represents a 32-bit signed integer. - */ - typedef int NTA_Int32; - /** - * Represents a 32-bit unsigned integer. - */ - typedef unsigned int NTA_UInt32; - /** - * Represents a 64-bit signed integer. - */ - typedef long long NTA_Int64; - /** - * Represents a 64-bit unsigned integer. - */ - typedef unsigned long long NTA_UInt64; - #else // 64bit - /** - * Represents a 32-bit signed integer. - */ - typedef int NTA_Int32; - /** - * Represents a 32-bit unsigned integer. - */ - typedef unsigned int NTA_UInt32; - /** - * Represents a 64-bit signed integer. - */ - typedef long NTA_Int64; - /** - * Represents a 64-bit unsigned integer. - */ - typedef unsigned long NTA_UInt64; - #endif +#if defined(NTA_ARCH_32) +/** + * Represents a 32-bit signed integer. + */ +typedef int NTA_Int32; +/** + * Represents a 32-bit unsigned integer. + */ +typedef unsigned int NTA_UInt32; +/** + * Represents a 64-bit signed integer. + */ +typedef long long NTA_Int64; +/** + * Represents a 64-bit unsigned integer. + */ +typedef unsigned long long NTA_UInt64; +#else // 64bit +/** + * Represents a 32-bit signed integer. + */ +typedef int NTA_Int32; +/** + * Represents a 32-bit unsigned integer. + */ +typedef unsigned int NTA_UInt32; +/** + * Represents a 64-bit signed integer. + */ +typedef long NTA_Int64; +/** + * Represents a 64-bit unsigned integer. + */ +typedef unsigned long NTA_UInt64; +#endif #endif /** * @} @@ -248,57 +252,60 @@ typedef uintptr_t NTA_UIntPtr; /** * @name Flexible types - * - * The following are flexible types depending on `NTA_DOUBLE_PRECISION` and `NTA_BIG_INTEGER`. + * + * The following are flexible types depending on `NTA_DOUBLE_PRECISION` and + * `NTA_BIG_INTEGER`. * * @{ - * + * */ -#ifdef NTA_DOUBLE_PRECISION - /** - * Represents a real number(a floating-point number). - * - * Same as NTA_Real64 if `NTA_DOUBLE_PRECISION` is defined, NTA_Real32 otherwise. - */ - typedef NTA_Real64 NTA_Real; - #define NTA_REAL_TYPE_STRING "NTA_Real64" +#ifdef NTA_DOUBLE_PRECISION +/** + * Represents a real number(a floating-point number). + * + * Same as NTA_Real64 if `NTA_DOUBLE_PRECISION` is defined, NTA_Real32 + * otherwise. + */ +typedef NTA_Real64 NTA_Real; +#define NTA_REAL_TYPE_STRING "NTA_Real64" #else - /** - * Represents a real number(a floating-point number). - * - * Same as NTA_Real64 if `NTA_DOUBLE_PRECISION` is defined, NTA_Real32 otherwise. - */ - typedef NTA_Real32 NTA_Real; - #define NTA_REAL_TYPE_STRING "NTA_Real32" +/** + * Represents a real number(a floating-point number). + * + * Same as NTA_Real64 if `NTA_DOUBLE_PRECISION` is defined, NTA_Real32 + * otherwise. + */ +typedef NTA_Real32 NTA_Real; +#define NTA_REAL_TYPE_STRING "NTA_Real32" #endif - + #ifdef NTA_BIG_INTEGER - /** - * Represents a signed integer. - * - * Same as NTA_Int64 if `NTA_BIG_INTEGER` is defined, NTA_Int32 otherwise. - */ - typedef NTA_Int64 NTA_Int; - /** - * Represents a unsigned integer. - * - * Same as NTA_UInt64 if `NTA_BIG_INTEGER` is defined, NTA_UInt32 otherwise. - */ - typedef NTA_UInt64 NTA_UInt; +/** + * Represents a signed integer. + * + * Same as NTA_Int64 if `NTA_BIG_INTEGER` is defined, NTA_Int32 otherwise. + */ +typedef NTA_Int64 NTA_Int; +/** + * Represents a unsigned integer. + * + * Same as NTA_UInt64 if `NTA_BIG_INTEGER` is defined, NTA_UInt32 otherwise. + */ +typedef NTA_UInt64 NTA_UInt; #else - /** - * Represents a signed integer. - * - * Same as NTA_Int64 if `NTA_BIG_INTEGER` is defined, NTA_Int32 otherwise. - */ - typedef NTA_Int32 NTA_Int; - /** - * Represents a unsigned integer. - * - * Same as NTA_UInt64 if `NTA_BIG_INTEGER` is defined, NTA_UInt32 otherwise. - */ - typedef NTA_UInt32 NTA_UInt; +/** + * Represents a signed integer. + * + * Same as NTA_Int64 if `NTA_BIG_INTEGER` is defined, NTA_Int32 otherwise. + */ +typedef NTA_Int32 NTA_Int; +/** + * Represents a unsigned integer. + * + * Same as NTA_UInt64 if `NTA_BIG_INTEGER` is defined, NTA_UInt32 otherwise. + */ +typedef NTA_UInt32 NTA_UInt; #endif /** @@ -310,39 +317,37 @@ typedef uintptr_t NTA_UIntPtr; #define NTA_EXPORT __declspec(dllexport) #define NTA_HIDDEN #else -#define NTA_EXPORT __attribute__ ((visibility ("default"))) -#define NTA_HIDDEN __attribute__ ((visibility ("hidden"))) +#define NTA_EXPORT __attribute__((visibility("default"))) +#define NTA_HIDDEN __attribute__((visibility("hidden"))) #endif - #else #define NTA_HIDDEN #define NTA_EXPORT #endif -/** - * This enum represents the documented logging level of the debug logger. - * +/** + * This enum represents the documented logging level of the debug logger. + * * Use it like `LDEBUG(NTA_LogLevel_XXX)`. */ -typedef enum NTA_LogLevel - { - /** - * Log level: None. - */ - NTA_LogLevel_None, - /** - * Log level: Minimal. - */ - NTA_LogLevel_Minimal, - /** - * Log level: Normal. - */ - NTA_LogLevel_Normal, - /** - * Log level: Verbose. - */ - NTA_LogLevel_Verbose, - } NTA_LogLevel; +typedef enum NTA_LogLevel { + /** + * Log level: None. + */ + NTA_LogLevel_None, + /** + * Log level: Minimal. + */ + NTA_LogLevel_Minimal, + /** + * Log level: Normal. + */ + NTA_LogLevel_Normal, + /** + * Log level: Verbose. + */ + NTA_LogLevel_Verbose, +} NTA_LogLevel; #endif /* NTA_TYPES_H */ diff --git a/src/nupic/types/Types.hpp b/src/nupic/types/Types.hpp index 8fd86aa8bf..ea887a1d8f 100644 --- a/src/nupic/types/Types.hpp +++ b/src/nupic/types/Types.hpp @@ -31,137 +31,137 @@ //---------------------------------------------------------------------- -namespace nupic -{ - /** - * @name Basic types - * - * @{ - */ +namespace nupic { +/** + * @name Basic types + * + * @{ + */ - /** - * Represents a 8-bit byte. - */ - typedef NTA_Byte Byte; +/** + * Represents a 8-bit byte. + */ +typedef NTA_Byte Byte; - /** - * Represents a 16-bit signed integer. - */ - typedef NTA_Int16 Int16; +/** + * Represents a 16-bit signed integer. + */ +typedef NTA_Int16 Int16; - /** - * Represents a 16-bit unsigned integer. - */ - typedef NTA_UInt16 UInt16; +/** + * Represents a 16-bit unsigned integer. + */ +typedef NTA_UInt16 UInt16; - /** - * Represents a 32-bit signed integer. - */ - typedef NTA_Int32 Int32; +/** + * Represents a 32-bit signed integer. + */ +typedef NTA_Int32 Int32; - /** - * Represents a 32-bit unsigned integer. - */ - typedef NTA_UInt32 UInt32; +/** + * Represents a 32-bit unsigned integer. + */ +typedef NTA_UInt32 UInt32; - /** - * Represents a 64-bit signed integer. - */ - typedef NTA_Int64 Int64; +/** + * Represents a 64-bit signed integer. + */ +typedef NTA_Int64 Int64; - /** - * Represents a 64-bit unsigned integer. - */ - typedef NTA_UInt64 UInt64; +/** + * Represents a 64-bit unsigned integer. + */ +typedef NTA_UInt64 UInt64; +/** + * Represents a 32-bit real number(a floating-point number). + */ +typedef NTA_Real32 Real32; - /** - * Represents a 32-bit real number(a floating-point number). - */ - typedef NTA_Real32 Real32; +/** + * Represents a 64-bit real number(a floating-point number). + */ +typedef NTA_Real64 Real64; - /** - * Represents a 64-bit real number(a floating-point number). - */ - typedef NTA_Real64 Real64; +/** + * Represents an opaque handle/pointer, same as `void *` + */ +typedef NTA_Handle Handle; - /** - * Represents an opaque handle/pointer, same as `void *` - */ - typedef NTA_Handle Handle; +/** + * Represents an opaque pointer, same as `uintptr_t` + */ +typedef NTA_UIntPtr UIntPtr; - /** - * Represents an opaque pointer, same as `uintptr_t` - */ - typedef NTA_UIntPtr UIntPtr; +/** + * @} + */ - /** - * @} - */ +/** + * @name Flexible types + * + * The following are flexible types depending on `NTA_DOUBLE_PRECISION` and + * `NTA_BIG_INTEGER`. + * + * @{ + * + */ - /** - * @name Flexible types - * - * The following are flexible types depending on `NTA_DOUBLE_PRECISION` and `NTA_BIG_INTEGER`. - * - * @{ - * - */ - - /** - * Represents a real number(a floating-point number). - * - * Same as nupic::Real64 if `NTA_DOUBLE_PRECISION` is defined, nupic::Real32 otherwise. - */ - typedef NTA_Real Real; +/** + * Represents a real number(a floating-point number). + * + * Same as nupic::Real64 if `NTA_DOUBLE_PRECISION` is defined, nupic::Real32 + * otherwise. + */ +typedef NTA_Real Real; - /** - * Represents a signed integer. - * - * Same as nupic::Int64 if `NTA_BIG_INTEGER` is defined, nupic::Int32 otherwise. - */ - typedef NTA_Int Int; +/** + * Represents a signed integer. + * + * Same as nupic::Int64 if `NTA_BIG_INTEGER` is defined, nupic::Int32 otherwise. + */ +typedef NTA_Int Int; +/** + * Represents a unsigned integer. + * + * Same as nupic::UInt64 if `NTA_BIG_INTEGER` is defined, nupic::UInt32 + * otherwise. + */ +typedef NTA_UInt UInt; + +/** + * Represents lengths of arrays, strings and so on. + */ +typedef NTA_Size Size; + +/** + * @} + */ + +/** + * This enum represents the documented logging level of the debug logger. + * + * Use it like `LDEBUG(nupic::LogLevel_XXX)`. + */ +enum LogLevel { /** - * Represents a unsigned integer. - * - * Same as nupic::UInt64 if `NTA_BIG_INTEGER` is defined, nupic::UInt32 otherwise. + * Log level: None. */ - typedef NTA_UInt UInt; - + LogLevel_None = NTA_LogLevel_None, /** - * Represents lengths of arrays, strings and so on. + * Log level: Minimal. */ - typedef NTA_Size Size; - + LogLevel_Minimal, /** - * @} + * Log level: Normal. */ - - /** - * This enum represents the documented logging level of the debug logger. - * - * Use it like `LDEBUG(nupic::LogLevel_XXX)`. + LogLevel_Normal, + /** + * Log level: Verbose. */ - enum LogLevel - { - /** - * Log level: None. - */ - LogLevel_None = NTA_LogLevel_None, - /** - * Log level: Minimal. - */ - LogLevel_Minimal, - /** - * Log level: Normal. - */ - LogLevel_Normal, - /** - * Log level: Verbose. - */ - LogLevel_Verbose, - }; + LogLevel_Verbose, +}; } // end namespace nupic @@ -170,6 +170,3 @@ namespace nupic #endif // SWIG #endif // NTA_TYPES_HPP - - - diff --git a/src/nupic/utils/ArrayProtoUtils.cpp b/src/nupic/utils/ArrayProtoUtils.cpp index 301c23af6a..867509b26f 100644 --- a/src/nupic/utils/ArrayProtoUtils.cpp +++ b/src/nupic/utils/ArrayProtoUtils.cpp @@ -24,70 +24,56 @@ * Implementation of the Array Capnproto utilities */ -#include #include -#include #include #include +#include +#include #include // for size_t - using namespace nupic; - -void ArrayProtoUtils::copyArrayToArrayProto( - const Array& array, - ArrayProto::Builder arrayBuilder) -{ +void ArrayProtoUtils::copyArrayToArrayProto(const Array &array, + ArrayProto::Builder arrayBuilder) { const size_t elementCount = array.getCount(); const auto arrayType = array.getType(); - switch (arrayType) - { + switch (arrayType) { case NTA_BasicType_Byte: _templatedCopyArrayToArrayProto( - array, - arrayBuilder.initByteArray(elementCount)); + array, arrayBuilder.initByteArray(elementCount)); break; case NTA_BasicType_Int16: _templatedCopyArrayToArrayProto( - array, - arrayBuilder.initInt16Array(elementCount)); + array, arrayBuilder.initInt16Array(elementCount)); break; case NTA_BasicType_UInt16: _templatedCopyArrayToArrayProto( - array, - arrayBuilder.initUint16Array(elementCount)); + array, arrayBuilder.initUint16Array(elementCount)); break; case NTA_BasicType_Int32: _templatedCopyArrayToArrayProto( - array, - arrayBuilder.initInt32Array(elementCount)); + array, arrayBuilder.initInt32Array(elementCount)); break; case NTA_BasicType_UInt32: _templatedCopyArrayToArrayProto( - array, - arrayBuilder.initUint32Array(elementCount)); + array, arrayBuilder.initUint32Array(elementCount)); break; case NTA_BasicType_Int64: _templatedCopyArrayToArrayProto( - array, - arrayBuilder.initInt64Array(elementCount)); + array, arrayBuilder.initInt64Array(elementCount)); break; case NTA_BasicType_UInt64: _templatedCopyArrayToArrayProto( - array, - arrayBuilder.initUint64Array(elementCount)); + array, arrayBuilder.initUint64Array(elementCount)); break; case NTA_BasicType_Real32: _templatedCopyArrayToArrayProto( - array, - arrayBuilder.initReal32Array(elementCount)); + array, arrayBuilder.initReal32Array(elementCount)); break; case NTA_BasicType_Real64: _templatedCopyArrayToArrayProto( - array, - arrayBuilder.initReal64Array(elementCount)); + array, arrayBuilder.initReal64Array(elementCount)); break; default: NTA_THROW << "Unexpected Array Type: " << arrayType; @@ -95,68 +81,54 @@ void ArrayProtoUtils::copyArrayToArrayProto( } } - void ArrayProtoUtils::copyArrayProtoToArray( - const ArrayProto::Reader arrayReader, - Array& array, - bool allocArrayBuffer) -{ + const ArrayProto::Reader arrayReader, Array &array, bool allocArrayBuffer) { auto unionSelection = arrayReader.which(); - switch (unionSelection) - { + switch (unionSelection) { case ArrayProto::BYTE_ARRAY: - _templatedCopyArrayProtoToArray(arrayReader.getByteArray(), - array, + _templatedCopyArrayProtoToArray(arrayReader.getByteArray(), array, NTA_BasicType_Byte, allocArrayBuffer); break; case ArrayProto::INT16_ARRAY: _templatedCopyArrayProtoToArray(arrayReader.getInt16Array(), - array, - NTA_BasicType_Int16, + array, NTA_BasicType_Int16, allocArrayBuffer); break; case ArrayProto::UINT16_ARRAY: _templatedCopyArrayProtoToArray(arrayReader.getUint16Array(), - array, - NTA_BasicType_UInt16, + array, NTA_BasicType_UInt16, allocArrayBuffer); break; case ArrayProto::INT32_ARRAY: _templatedCopyArrayProtoToArray(arrayReader.getInt32Array(), - array, - NTA_BasicType_Int32, + array, NTA_BasicType_Int32, allocArrayBuffer); break; case ArrayProto::UINT32_ARRAY: _templatedCopyArrayProtoToArray(arrayReader.getUint32Array(), - array, - NTA_BasicType_UInt32, + array, NTA_BasicType_UInt32, allocArrayBuffer); break; case ArrayProto::INT64_ARRAY: _templatedCopyArrayProtoToArray(arrayReader.getInt64Array(), - array, - NTA_BasicType_Int64, + array, NTA_BasicType_Int64, allocArrayBuffer); break; case ArrayProto::UINT64_ARRAY: _templatedCopyArrayProtoToArray(arrayReader.getUint64Array(), - array, - NTA_BasicType_UInt64, + array, NTA_BasicType_UInt64, allocArrayBuffer); break; case ArrayProto::REAL32_ARRAY: _templatedCopyArrayProtoToArray(arrayReader.getReal32Array(), - array, - NTA_BasicType_Real32, + array, NTA_BasicType_Real32, allocArrayBuffer); break; case ArrayProto::REAL64_ARRAY: _templatedCopyArrayProtoToArray(arrayReader.getReal64Array(), - array, - NTA_BasicType_Real64, + array, NTA_BasicType_Real64, allocArrayBuffer); break; default: @@ -165,14 +137,11 @@ void ArrayProtoUtils::copyArrayProtoToArray( } } - NTA_BasicType ArrayProtoUtils::getArrayTypeFromArrayProtoReader( - const ArrayProto::Reader arrayReader) -{ + const ArrayProto::Reader arrayReader) { auto unionSelection = arrayReader.which(); - switch (unionSelection) - { + switch (unionSelection) { case ArrayProto::BYTE_ARRAY: return NTA_BasicType_Byte; break; diff --git a/src/nupic/utils/ArrayProtoUtils.hpp b/src/nupic/utils/ArrayProtoUtils.hpp index ae777efd46..c7b4174f5c 100644 --- a/src/nupic/utils/ArrayProtoUtils.hpp +++ b/src/nupic/utils/ArrayProtoUtils.hpp @@ -33,121 +33,110 @@ #include #include // for size_t - -namespace nupic -{ - - class ArrayProtoUtils - { - public: - /** - * Serialise NTA Array to ArrayProto - * - * @param array source array - * @param arrayBuilder destination capnproto array builder - */ - static void copyArrayToArrayProto(const Array& array, - ArrayProto::Builder arrayBuilder); - - /** - * De-serialize ArrayProto into NTA Array - * - * @param arrayReader - * source capnproto array reader - * @param array - * destination array. NOTE: the array's buffer must be - * preallocated and of size that matches the source data. - * @param allocArrayBuffer - * If allocArrayBuffer is false, the Array is assumed to have its - * buffer preinitialized with the count of elements of type - * DestElementT matching the count in reader; if allocArrayBuffer - * is True, the Array's buffer will be released and replaced by a - * new buffer of the appropriate size. - */ - static void copyArrayProtoToArray(const ArrayProto::Reader arrayReader, - Array& array, - bool allocArrayBuffer); - - /** - * Return the NTA_BasicType corresponding to the given ArrayProto reader. - * - * @param arrayReader - * capnproto array reader - */ - static NTA_BasicType getArrayTypeFromArrayProtoReader( - const ArrayProto::Reader arrayReader); - - private: - /** - * Element type-specific templated function for copying an NTA Array to - * capnproto ArrayProto builder. - * - * @param src - * Source Array with elements of type SourceElementT - * @param builder - * Destination type-specific array union element builder of - * capnproto ArrayProto. - */ - template - static void _templatedCopyArrayToArrayProto(const Array& src, - ArrayBuilderT builder) - { - NTA_CHECK(BasicType::getSize(src.getType()) == sizeof(SourceElementT)); - NTA_CHECK(builder.size() == src.getCount()); - - auto srcData = (SourceElementT*)src.getBuffer(); - - for (size_t i=0; i < src.getCount(); ++i) - { - builder.set(i, srcData[i]); - } +namespace nupic { + +class ArrayProtoUtils { +public: + /** + * Serialise NTA Array to ArrayProto + * + * @param array source array + * @param arrayBuilder destination capnproto array builder + */ + static void copyArrayToArrayProto(const Array &array, + ArrayProto::Builder arrayBuilder); + + /** + * De-serialize ArrayProto into NTA Array + * + * @param arrayReader + * source capnproto array reader + * @param array + * destination array. NOTE: the array's buffer must be + * preallocated and of size that matches the source data. + * @param allocArrayBuffer + * If allocArrayBuffer is false, the Array is assumed to have its + * buffer preinitialized with the count of elements of type + * DestElementT matching the count in reader; if allocArrayBuffer + * is True, the Array's buffer will be released and replaced by a + * new buffer of the appropriate size. + */ + static void copyArrayProtoToArray(const ArrayProto::Reader arrayReader, + Array &array, bool allocArrayBuffer); + + /** + * Return the NTA_BasicType corresponding to the given ArrayProto reader. + * + * @param arrayReader + * capnproto array reader + */ + static NTA_BasicType + getArrayTypeFromArrayProtoReader(const ArrayProto::Reader arrayReader); + +private: + /** + * Element type-specific templated function for copying an NTA Array to + * capnproto ArrayProto builder. + * + * @param src + * Source Array with elements of type SourceElementT + * @param builder + * Destination type-specific array union element builder of + * capnproto ArrayProto. + */ + template + static void _templatedCopyArrayToArrayProto(const Array &src, + ArrayBuilderT builder) { + NTA_CHECK(BasicType::getSize(src.getType()) == sizeof(SourceElementT)); + NTA_CHECK(builder.size() == src.getCount()); + + auto srcData = (SourceElementT *)src.getBuffer(); + + for (size_t i = 0; i < src.getCount(); ++i) { + builder.set(i, srcData[i]); + } + } + + /** + * Element type-specific templated function for copying an NTA Array to + * capnproto ArrayProto builder. + * + * @param reader + * Destination type-specific array union element reader of + * capnproto ArrayProto. + * @param dest + * Destination Array of type arrayType. + * @param arrayType + * NTA_BasicType of Array elements. + * @param allocArrayBuffer + * If allocArrayBuffer is false, the Array is assumed to have its + * buffer preinitialized with the count of elements of type + * DestElementT matching the count in reader; if allocArrayBuffer + * is True, the Array's buffer will be released and replaced by a + * new buffer of the appropriate size. + */ + template + static void _templatedCopyArrayProtoToArray(ArrayReaderT reader, Array &dest, + NTA_BasicType arrayType, + bool allocArrayBuffer) { + NTA_CHECK(dest.getType() == arrayType); + NTA_CHECK(BasicType::getSize(arrayType) == sizeof(DestElementT)); + + if (allocArrayBuffer) { + dest.releaseBuffer(); + dest.allocateBuffer(reader.size()); } - /** - * Element type-specific templated function for copying an NTA Array to - * capnproto ArrayProto builder. - * - * @param reader - * Destination type-specific array union element reader of - * capnproto ArrayProto. - * @param dest - * Destination Array of type arrayType. - * @param arrayType - * NTA_BasicType of Array elements. - * @param allocArrayBuffer - * If allocArrayBuffer is false, the Array is assumed to have its - * buffer preinitialized with the count of elements of type - * DestElementT matching the count in reader; if allocArrayBuffer - * is True, the Array's buffer will be released and replaced by a - * new buffer of the appropriate size. - */ - template - static void _templatedCopyArrayProtoToArray(ArrayReaderT reader, - Array & dest, - NTA_BasicType arrayType, - bool allocArrayBuffer) - { - NTA_CHECK(dest.getType() == arrayType); - NTA_CHECK(BasicType::getSize(arrayType) == sizeof(DestElementT)); - - if (allocArrayBuffer) - { - dest.releaseBuffer(); - dest.allocateBuffer(reader.size()); - } - - NTA_CHECK(reader.size() == dest.getCount()); - - auto destData = (DestElementT*)dest.getBuffer(); - - for (auto entry: reader) - { - *destData++ = entry; - } + NTA_CHECK(reader.size() == dest.getCount()); + + auto destData = (DestElementT *)dest.getBuffer(); + + for (auto entry : reader) { + *destData++ = entry; } - }; + } +}; } // namespace nupic - #endif // NTA_ARRAY_PROTO_UTILS_HPP diff --git a/src/nupic/utils/GroupBy.hpp b/src/nupic/utils/GroupBy.hpp index e629ff51b4..6e3d3cfb8e 100644 --- a/src/nupic/utils/GroupBy.hpp +++ b/src/nupic/utils/GroupBy.hpp @@ -23,1244 +23,948 @@ #ifndef NTA_GROUPBY_HPP #define NTA_GROUPBY_HPP -#include #include // is_sorted +#include #include -namespace nupic -{ - /** @file - * Implements a groupBy function. - * - * This is modeled after Python's itertools.groupby, but with the added - * ability to traverse multiple sequences. Similar to Python, it requires the - * input to be sorted according to the supplied key functions. - * - * There are two functions: - * - * - `groupBy`, which takes in collections - * - `iterGroupBy`, which takes in pairs of iterators - * - * Both functions take a key function for each sequence. - * - * Both functions return an iterable object. The iterator returns a tuple - * containing the key, followed by a begin and end iterator for each - * sequence. The sequences are traversed lazily as the iterator is advanced. - * - * Note: The implementation includes this "minFrontKey" to avoid GCC - * "maybe-initialized" false positives. This approach makes it very obvious to - * the compiler that the "key" variable always gets initialized. - * - * Feel free to add new GroupBy7, GroupBy8, ..., GroupByN classes as needed. - */ - - // ========================================================================== - // CONVENIENCE KEY FUNCTIONS - // ========================================================================== - - template - T identity(T x) - { - return x; +namespace nupic { +/** @file + * Implements a groupBy function. + * + * This is modeled after Python's itertools.groupby, but with the added + * ability to traverse multiple sequences. Similar to Python, it requires the + * input to be sorted according to the supplied key functions. + * + * There are two functions: + * + * - `groupBy`, which takes in collections + * - `iterGroupBy`, which takes in pairs of iterators + * + * Both functions take a key function for each sequence. + * + * Both functions return an iterable object. The iterator returns a tuple + * containing the key, followed by a begin and end iterator for each + * sequence. The sequences are traversed lazily as the iterator is advanced. + * + * Note: The implementation includes this "minFrontKey" to avoid GCC + * "maybe-initialized" false positives. This approach makes it very obvious to + * the compiler that the "key" variable always gets initialized. + * + * Feel free to add new GroupBy7, GroupBy8, ..., GroupByN classes as needed. + */ + +// ========================================================================== +// CONVENIENCE KEY FUNCTIONS +// ========================================================================== + +template T identity(T x) { return x; } + +// ========================================================================== +// 1 SEQUENCE +// ========================================================================== + +template ()), + typename KeyType = typename std::remove_const()))>::type>::type> +class GroupBy1 { +public: + GroupBy1(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0) + : begin0_(begin0), end0_(end0), keyFn0_(keyFn0) { + NTA_ASSERT( + std::is_sorted(begin0, end0, [&](const Element0 &a, const Element0 &b) { + return keyFn0(a) < keyFn0(b); + })); } - // ========================================================================== - // 1 SEQUENCE - // ========================================================================== - - template()), - typename KeyType = typename std::remove_const< - typename std::result_of< - KeyFn0(decltype(*std::declval())) - >::type >::type> - class GroupBy1 - { + class Iterator { public: - - GroupBy1( - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0) - : begin0_(begin0), end0_(end0), keyFn0_(keyFn0) - { - NTA_ASSERT(std::is_sorted(begin0, end0, - [&](const Element0& a, const Element0& b) - { - return keyFn0(a) < keyFn0(b); - })); + Iterator(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0) + : current0_(begin0), end0_(end0), keyFn0_(keyFn0), finished_(false) { + calculateNext_(); } - class Iterator - { - public: - Iterator( - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0) - :current0_(begin0), end0_(end0), keyFn0_(keyFn0), - finished_(false) - { - calculateNext_(); - } - - bool operator !=(const Iterator& other) - { - return (finished_ != other.finished_ || - current0_ != other.current0_); - } + bool operator!=(const Iterator &other) { + return (finished_ != other.finished_ || current0_ != other.current0_); + } - const std::tuple& operator*() const - { - NTA_ASSERT(!finished_); - return v_; - } + const std::tuple &operator*() const { + NTA_ASSERT(!finished_); + return v_; + } - const Iterator& operator++() - { - NTA_ASSERT(!finished_); - calculateNext_(); - return *this; - } + const Iterator &operator++() { + NTA_ASSERT(!finished_); + calculateNext_(); + return *this; + } - private: - - void calculateNext_() - { - if (current0_ != end0_) - { - const KeyType key = keyFn0_(*current0_); - std::get<0>(v_) = key; - - // Find all elements with this key. - std::get<1>(v_) = current0_; - while (current0_ != end0_ && keyFn0_(*current0_) == key) - { - current0_++; - } - std::get<2>(v_) = current0_; - } - else - { - finished_ = true; + private: + void calculateNext_() { + if (current0_ != end0_) { + const KeyType key = keyFn0_(*current0_); + std::get<0>(v_) = key; + + // Find all elements with this key. + std::get<1>(v_) = current0_; + while (current0_ != end0_ && keyFn0_(*current0_) == key) { + current0_++; } + std::get<2>(v_) = current0_; + } else { + finished_ = true; } - - std::tuple v_; - - Iterator0 current0_; - Iterator0 end0_; - KeyFn0 keyFn0_; - - bool finished_; - }; - - Iterator begin() const - { - return Iterator(begin0_, end0_, keyFn0_); } - Iterator end() const - { - return Iterator(end0_, end0_, keyFn0_); - } + std::tuple v_; - private: - Iterator0 begin0_; + Iterator0 current0_; Iterator0 end0_; KeyFn0 keyFn0_; + + bool finished_; }; - template - GroupBy1 - groupBy(const Sequence0& sequence0, KeyFn0 keyFn0) - { - return {sequence0.begin(), sequence0.end(), keyFn0}; + Iterator begin() const { return Iterator(begin0_, end0_, keyFn0_); } + + Iterator end() const { return Iterator(end0_, end0_, keyFn0_); } + +private: + Iterator0 begin0_; + Iterator0 end0_; + KeyFn0 keyFn0_; +}; + +template +GroupBy1 +groupBy(const Sequence0 &sequence0, KeyFn0 keyFn0) { + return {sequence0.begin(), sequence0.end(), keyFn0}; +} + +template +GroupBy1 iterGroupBy(Iterator0 begin0, Iterator0 end0, + KeyFn0 keyFn0) { + return {begin0, end0, keyFn0}; +} + +// ========================================================================== +// 2 SEQUENCES +// ========================================================================== + +template ()))>::type>::type> +static KeyType minFrontKey(KeyType frontrunner, Iterator0 begin0, + Iterator0 end0, KeyFn0 keyFn0) { + if (begin0 != end0) { + return std::min(frontrunner, keyFn0(*begin0)); + } else { + return frontrunner; } - - template - GroupBy1 - iterGroupBy(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0) - { - return {begin0, end0, keyFn0}; +} + +template ()), + typename Element1 = decltype(*std::declval()), + typename KeyType = typename std::remove_const()))>::type>::type> +class GroupBy2 { +public: + GroupBy2(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, Iterator1 begin1, + Iterator1 end1, KeyFn1 keyFn1) + : begin0_(begin0), end0_(end0), keyFn0_(keyFn0), begin1_(begin1), + end1_(end1), keyFn1_(keyFn1) { + NTA_ASSERT( + std::is_sorted(begin0, end0, [&](const Element0 &a, const Element0 &b) { + return keyFn0(a) < keyFn0(b); + })); + NTA_ASSERT( + std::is_sorted(begin1, end1, [&](const Element1 &a, const Element1 &b) { + return keyFn1(a) < keyFn1(b); + })); } - - // ========================================================================== - // 2 SEQUENCES - // ========================================================================== - - template())) - >::type >::type> - static KeyType minFrontKey(KeyType frontrunner, - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0) - { - if (begin0 != end0) - { - return std::min(frontrunner, keyFn0(*begin0)); - } - else - { - return frontrunner; - } - } - - template()), - typename Element1 = decltype(*std::declval()), - typename KeyType = typename std::remove_const< - typename std::result_of< - KeyFn0(decltype(*std::declval())) - >::type >::type> - class GroupBy2 - { + class Iterator { public: + Iterator(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, Iterator1 begin1, + Iterator1 end1, KeyFn1 keyFn1) + : current0_(begin0), end0_(end0), keyFn0_(keyFn0), current1_(begin1), + end1_(end1), keyFn1_(keyFn1), finished_(false) { + calculateNext_(); + } - GroupBy2( - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, - Iterator1 begin1, Iterator1 end1, KeyFn1 keyFn1) - : begin0_(begin0), end0_(end0), keyFn0_(keyFn0), - begin1_(begin1), end1_(end1), keyFn1_(keyFn1) - { - NTA_ASSERT(std::is_sorted(begin0, end0, - [&](const Element0& a, const Element0& b) - { - return keyFn0(a) < keyFn0(b); - })); - NTA_ASSERT(std::is_sorted(begin1, end1, - [&](const Element1& a, const Element1& b) - { - return keyFn1(a) < keyFn1(b); - })); + bool operator!=(const Iterator &other) { + return (finished_ != other.finished_ || current0_ != other.current0_ || + current1_ != other.current1_); } - class Iterator - { - public: - Iterator( - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, - Iterator1 begin1, Iterator1 end1, KeyFn1 keyFn1) - :current0_(begin0), end0_(end0), keyFn0_(keyFn0), - current1_(begin1), end1_(end1), keyFn1_(keyFn1), - finished_(false) - { - calculateNext_(); - } + const std::tuple & + operator*() const { + NTA_ASSERT(!finished_); + return v_; + } - bool operator !=(const Iterator& other) - { - return (finished_ != other.finished_ || - current0_ != other.current0_ || - current1_ != other.current1_); - } + const Iterator &operator++() { + NTA_ASSERT(!finished_); + calculateNext_(); + return *this; + } - const std::tuple& operator*() const - { - NTA_ASSERT(!finished_); - return v_; - } + private: + void calculateNext_() { + if (current0_ != end0_ || current1_ != end1_) { + // Find the lowest key. + KeyType key; + if (current0_ != end0_) { + key = minFrontKey(keyFn0_(*current0_), current1_, end1_, keyFn1_); + } else { + key = keyFn1_(*current1_); + } - const Iterator& operator++() - { - NTA_ASSERT(!finished_); - calculateNext_(); - return *this; - } + std::get<0>(v_) = key; - private: - - void calculateNext_() - { - if (current0_ != end0_ || - current1_ != end1_) - { - // Find the lowest key. - KeyType key; - if (current0_ != end0_) - { - key = minFrontKey(keyFn0_(*current0_), - current1_, end1_, keyFn1_); - } - else - { - key = keyFn1_(*current1_); - } - - std::get<0>(v_) = key; - - // Find all elements with this key. - std::get<1>(v_) = current0_; - while (current0_ != end0_ && keyFn0_(*current0_) == key) - { - current0_++; - } - std::get<2>(v_) = current0_; - - std::get<3>(v_) = current1_; - while (current1_ != end1_ && keyFn1_(*current1_) == key) - { - current1_++; - } - std::get<4>(v_) = current1_; + // Find all elements with this key. + std::get<1>(v_) = current0_; + while (current0_ != end0_ && keyFn0_(*current0_) == key) { + current0_++; } - else - { - finished_ = true; + std::get<2>(v_) = current0_; + + std::get<3>(v_) = current1_; + while (current1_ != end1_ && keyFn1_(*current1_) == key) { + current1_++; } + std::get<4>(v_) = current1_; + } else { + finished_ = true; } - - std::tuple v_; - - Iterator0 current0_; - Iterator0 end0_; - KeyFn0 keyFn0_; - - Iterator1 current1_; - Iterator1 end1_; - KeyFn1 keyFn1_; - - bool finished_; - }; - - Iterator begin() const - { - return Iterator(begin0_, end0_, keyFn0_, - begin1_, end1_, keyFn1_); } - Iterator end() const - { - return Iterator(end0_, end0_, keyFn0_, - end1_, end1_, keyFn1_); - } + std::tuple v_; - private: - Iterator0 begin0_; + Iterator0 current0_; Iterator0 end0_; KeyFn0 keyFn0_; - Iterator1 begin1_; + Iterator1 current1_; Iterator1 end1_; KeyFn1 keyFn1_; + + bool finished_; }; - template - GroupBy2 - groupBy(const Sequence0& sequence0, KeyFn0 keyFn0, - const Sequence1& sequence1, KeyFn1 keyFn1) - { - return {sequence0.begin(), sequence0.end(), keyFn0, - sequence1.begin(), sequence1.end(), keyFn1}; + Iterator begin() const { + return Iterator(begin0_, end0_, keyFn0_, begin1_, end1_, keyFn1_); } - template - GroupBy2 - iterGroupBy(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, - Iterator1 begin1, Iterator1 end1, KeyFn1 keyFn1) - { - return {begin0, end0, keyFn0, - begin1, end1, keyFn1}; + Iterator end() const { + return Iterator(end0_, end0_, keyFn0_, end1_, end1_, keyFn1_); } - // ========================================================================== - // 3 SEQUENCES - // ========================================================================== - - template())) - >::type >::type> - static KeyType minFrontKey(KeyType frontrunner, - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, - Iterator1 begin1, Iterator1 end1, KeyFn1 keyFn1) - { - KeyType ret = frontrunner; - - if (begin0 != end0) - { - ret = std::min(ret, keyFn0(*begin0)); - } +private: + Iterator0 begin0_; + Iterator0 end0_; + KeyFn0 keyFn0_; + + Iterator1 begin1_; + Iterator1 end1_; + KeyFn1 keyFn1_; +}; + +template +GroupBy2 +groupBy(const Sequence0 &sequence0, KeyFn0 keyFn0, const Sequence1 &sequence1, + KeyFn1 keyFn1) { + return {sequence0.begin(), sequence0.end(), keyFn0, + sequence1.begin(), sequence1.end(), keyFn1}; +} + +template +GroupBy2 +iterGroupBy(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, Iterator1 begin1, + Iterator1 end1, KeyFn1 keyFn1) { + return {begin0, end0, keyFn0, begin1, end1, keyFn1}; +} + +// ========================================================================== +// 3 SEQUENCES +// ========================================================================== + +template ()))>::type>::type> +static KeyType minFrontKey(KeyType frontrunner, Iterator0 begin0, + Iterator0 end0, KeyFn0 keyFn0, Iterator1 begin1, + Iterator1 end1, KeyFn1 keyFn1) { + KeyType ret = frontrunner; + + if (begin0 != end0) { + ret = std::min(ret, keyFn0(*begin0)); + } - if (begin1 != end1) - { - ret = std::min(ret, keyFn1(*begin1)); - } + if (begin1 != end1) { + ret = std::min(ret, keyFn1(*begin1)); + } - return ret; + return ret; +} + +template ()), + typename Element1 = decltype(*std::declval()), + typename Element2 = decltype(*std::declval()), + typename KeyType = typename std::remove_const()))>::type>::type> +class GroupBy3 { +public: + GroupBy3(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, Iterator1 begin1, + Iterator1 end1, KeyFn1 keyFn1, Iterator2 begin2, Iterator2 end2, + KeyFn2 keyFn2) + : begin0_(begin0), end0_(end0), keyFn0_(keyFn0), begin1_(begin1), + end1_(end1), keyFn1_(keyFn1), begin2_(begin2), end2_(end2), + keyFn2_(keyFn2) { + NTA_ASSERT( + std::is_sorted(begin0, end0, [&](const Element0 &a, const Element0 &b) { + return keyFn0(a) < keyFn0(b); + })); + NTA_ASSERT( + std::is_sorted(begin1, end1, [&](const Element1 &a, const Element1 &b) { + return keyFn1(a) < keyFn1(b); + })); + NTA_ASSERT( + std::is_sorted(begin2, end2, [&](const Element2 &a, const Element2 &b) { + return keyFn2(a) < keyFn2(b); + })); } - template()), - typename Element1 = decltype(*std::declval()), - typename Element2 = decltype(*std::declval()), - typename KeyType = typename std::remove_const< - typename std::result_of< - KeyFn0(decltype(*std::declval())) - >::type >::type> - class GroupBy3 - { + class Iterator { public: - GroupBy3( - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, - Iterator1 begin1, Iterator1 end1, KeyFn1 keyFn1, - Iterator2 begin2, Iterator2 end2, KeyFn2 keyFn2) - : begin0_(begin0), end0_(end0), keyFn0_(keyFn0), - begin1_(begin1), end1_(end1), keyFn1_(keyFn1), - begin2_(begin2), end2_(end2), keyFn2_(keyFn2) - { - NTA_ASSERT(std::is_sorted(begin0, end0, - [&](const Element0& a, const Element0& b) - { - return keyFn0(a) < keyFn0(b); - })); - NTA_ASSERT(std::is_sorted(begin1, end1, - [&](const Element1& a, const Element1& b) - { - return keyFn1(a) < keyFn1(b); - })); - NTA_ASSERT(std::is_sorted(begin2, end2, - [&](const Element2& a, const Element2& b) - { - return keyFn2(a) < keyFn2(b); - })); + Iterator(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, Iterator1 begin1, + Iterator1 end1, KeyFn1 keyFn1, Iterator2 begin2, Iterator2 end2, + KeyFn2 keyFn2) + : current0_(begin0), end0_(end0), keyFn0_(keyFn0), current1_(begin1), + end1_(end1), keyFn1_(keyFn1), current2_(begin2), end2_(end2), + keyFn2_(keyFn2), finished_(false) { + calculateNext_(); } - class Iterator - { - public: - Iterator( - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, - Iterator1 begin1, Iterator1 end1, KeyFn1 keyFn1, - Iterator2 begin2, Iterator2 end2, KeyFn2 keyFn2) - :current0_(begin0), end0_(end0), keyFn0_(keyFn0), - current1_(begin1), end1_(end1), keyFn1_(keyFn1), - current2_(begin2), end2_(end2), keyFn2_(keyFn2), - finished_(false) - { - calculateNext_(); - } - - bool operator !=(const Iterator& other) - { - return (finished_ != other.finished_ || - current0_ != other.current0_ || - current1_ != other.current1_ || - current2_ != other.current2_); - } + bool operator!=(const Iterator &other) { + return (finished_ != other.finished_ || current0_ != other.current0_ || + current1_ != other.current1_ || current2_ != other.current2_); + } - const std::tuple& operator*() const - { - NTA_ASSERT(!finished_); - return v_; - } + const std::tuple & + operator*() const { + NTA_ASSERT(!finished_); + return v_; + } - const Iterator& operator++() - { - NTA_ASSERT(!finished_); - calculateNext_(); - return *this; - } + const Iterator &operator++() { + NTA_ASSERT(!finished_); + calculateNext_(); + return *this; + } - private: - - void calculateNext_() - { - if (current0_ != end0_ || - current1_ != end1_ || - current2_ != end2_) - { - // Find the lowest key. - KeyType key; - if (current0_ != end0_) - { - key = minFrontKey(keyFn0_(*current0_), - current1_, end1_, keyFn1_, - current2_, end2_, keyFn2_); - } - else if (current1_ != end1_) - { - key = minFrontKey(keyFn1_(*current1_), - current2_, end2_, keyFn2_); - } - else - { - key = keyFn2_(*current2_); - } - - std::get<0>(v_) = key; - - // Find all elements with this key. - std::get<1>(v_) = current0_; - while (current0_ != end0_ && keyFn0_(*current0_) == key) - { - current0_++; - } - std::get<2>(v_) = current0_; - - std::get<3>(v_) = current1_; - while (current1_ != end1_ && keyFn1_(*current1_) == key) - { - current1_++; - } - std::get<4>(v_) = current1_; - - std::get<5>(v_) = current2_; - while (current2_ != end2_ && keyFn2_(*current2_) == key) - { - current2_++; - } - std::get<6>(v_) = current2_; - } - else - { - finished_ = true; + private: + void calculateNext_() { + if (current0_ != end0_ || current1_ != end1_ || current2_ != end2_) { + // Find the lowest key. + KeyType key; + if (current0_ != end0_) { + key = minFrontKey(keyFn0_(*current0_), current1_, end1_, keyFn1_, + current2_, end2_, keyFn2_); + } else if (current1_ != end1_) { + key = minFrontKey(keyFn1_(*current1_), current2_, end2_, keyFn2_); + } else { + key = keyFn2_(*current2_); } - } - - std::tuple v_; - - Iterator0 current0_; - Iterator0 end0_; - KeyFn0 keyFn0_; - Iterator1 current1_; - Iterator1 end1_; - KeyFn1 keyFn1_; + std::get<0>(v_) = key; - Iterator2 current2_; - Iterator2 end2_; - KeyFn2 keyFn2_; + // Find all elements with this key. + std::get<1>(v_) = current0_; + while (current0_ != end0_ && keyFn0_(*current0_) == key) { + current0_++; + } + std::get<2>(v_) = current0_; - bool finished_; - }; + std::get<3>(v_) = current1_; + while (current1_ != end1_ && keyFn1_(*current1_) == key) { + current1_++; + } + std::get<4>(v_) = current1_; - Iterator begin() const - { - return Iterator(begin0_, end0_, keyFn0_, - begin1_, end1_, keyFn1_, - begin2_, end2_, keyFn2_); + std::get<5>(v_) = current2_; + while (current2_ != end2_ && keyFn2_(*current2_) == key) { + current2_++; + } + std::get<6>(v_) = current2_; + } else { + finished_ = true; + } } - Iterator end() const - { - return Iterator(end0_, end0_, keyFn0_, - end1_, end1_, keyFn1_, - end2_, end2_, keyFn2_); - } + std::tuple + v_; - private: - Iterator0 begin0_; + Iterator0 current0_; Iterator0 end0_; KeyFn0 keyFn0_; - Iterator1 begin1_; + Iterator1 current1_; Iterator1 end1_; KeyFn1 keyFn1_; - Iterator2 begin2_; + Iterator2 current2_; Iterator2 end2_; KeyFn2 keyFn2_; + + bool finished_; }; - template - GroupBy3 - groupBy(const Sequence0& sequence0, KeyFn0 keyFn0, - const Sequence1& sequence1, KeyFn1 keyFn1, - const Sequence2& sequence2, KeyFn2 keyFn2) - { - return {sequence0.begin(), sequence0.end(), keyFn0, - sequence1.begin(), sequence1.end(), keyFn1, - sequence2.begin(), sequence2.end(), keyFn2}; + Iterator begin() const { + return Iterator(begin0_, end0_, keyFn0_, begin1_, end1_, keyFn1_, begin2_, + end2_, keyFn2_); + } + + Iterator end() const { + return Iterator(end0_, end0_, keyFn0_, end1_, end1_, keyFn1_, end2_, end2_, + keyFn2_); } - template - GroupBy3 - iterGroupBy( - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, - Iterator1 begin1, Iterator1 end1, KeyFn1 keyFn1, - Iterator2 begin2, Iterator2 end2, KeyFn2 keyFn2) - { - return {begin0, end0, keyFn0, - begin1, end1, keyFn1, - begin2, end2, keyFn2}; +private: + Iterator0 begin0_; + Iterator0 end0_; + KeyFn0 keyFn0_; + + Iterator1 begin1_; + Iterator1 end1_; + KeyFn1 keyFn1_; + + Iterator2 begin2_; + Iterator2 end2_; + KeyFn2 keyFn2_; +}; + +template +GroupBy3 +groupBy(const Sequence0 &sequence0, KeyFn0 keyFn0, const Sequence1 &sequence1, + KeyFn1 keyFn1, const Sequence2 &sequence2, KeyFn2 keyFn2) { + return {sequence0.begin(), sequence0.end(), keyFn0, + sequence1.begin(), sequence1.end(), keyFn1, + sequence2.begin(), sequence2.end(), keyFn2}; +} + +template +GroupBy3 +iterGroupBy(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, Iterator1 begin1, + Iterator1 end1, KeyFn1 keyFn1, Iterator2 begin2, Iterator2 end2, + KeyFn2 keyFn2) { + return {begin0, end0, keyFn0, begin1, end1, keyFn1, begin2, end2, keyFn2}; +} + +// ========================================================================== +// 4 SEQUENCES +// ========================================================================== + +template ()))>::type>::type> +static KeyType minFrontKey(KeyType frontrunner, Iterator0 begin0, + Iterator0 end0, KeyFn0 keyFn0, Iterator1 begin1, + Iterator1 end1, KeyFn1 keyFn1, Iterator2 begin2, + Iterator2 end2, KeyFn2 keyFn2) { + KeyType ret = frontrunner; + + if (begin0 != end0) { + ret = std::min(ret, keyFn0(*begin0)); } + if (begin1 != end1) { + ret = std::min(ret, keyFn1(*begin1)); + } - // ========================================================================== - // 4 SEQUENCES - // ========================================================================== - - template())) - >::type >::type> - static KeyType minFrontKey(KeyType frontrunner, - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, - Iterator1 begin1, Iterator1 end1, KeyFn1 keyFn1, - Iterator2 begin2, Iterator2 end2, KeyFn2 keyFn2) - { - KeyType ret = frontrunner; - - if (begin0 != end0) - { - ret = std::min(ret, keyFn0(*begin0)); - } + if (begin2 != end2) { + ret = std::min(ret, keyFn2(*begin2)); + } - if (begin1 != end1) - { - ret = std::min(ret, keyFn1(*begin1)); + return ret; +} + +template ()), + typename Element1 = decltype(*std::declval()), + typename Element2 = decltype(*std::declval()), + typename Element3 = decltype(*std::declval()), + typename KeyType = typename std::remove_const()))>::type>::type> +class GroupBy4 { +public: + GroupBy4(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, Iterator1 begin1, + Iterator1 end1, KeyFn1 keyFn1, Iterator2 begin2, Iterator2 end2, + KeyFn2 keyFn2, Iterator3 begin3, Iterator3 end3, KeyFn3 keyFn3) + : begin0_(begin0), end0_(end0), keyFn0_(keyFn0), begin1_(begin1), + end1_(end1), keyFn1_(keyFn1), begin2_(begin2), end2_(end2), + keyFn2_(keyFn2), begin3_(begin3), end3_(end3), keyFn3_(keyFn3) { + NTA_ASSERT( + std::is_sorted(begin0, end0, [&](const Element0 &a, const Element0 &b) { + return keyFn0(a) < keyFn0(b); + })); + NTA_ASSERT( + std::is_sorted(begin1, end1, [&](const Element1 &a, const Element1 &b) { + return keyFn1(a) < keyFn1(b); + })); + NTA_ASSERT( + std::is_sorted(begin2, end2, [&](const Element2 &a, const Element2 &b) { + return keyFn2(a) < keyFn2(b); + })); + NTA_ASSERT( + std::is_sorted(begin3, end3, [&](const Element3 &a, const Element3 &b) { + return keyFn3(a) < keyFn3(b); + })); + } + + class Iterator { + public: + Iterator(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, Iterator1 begin1, + Iterator1 end1, KeyFn1 keyFn1, Iterator2 begin2, Iterator2 end2, + KeyFn2 keyFn2, Iterator3 begin3, Iterator3 end3, KeyFn3 keyFn3) + : current0_(begin0), end0_(end0), keyFn0_(keyFn0), current1_(begin1), + end1_(end1), keyFn1_(keyFn1), current2_(begin2), end2_(end2), + keyFn2_(keyFn2), current3_(begin3), end3_(end3), keyFn3_(keyFn3), + finished_(false) { + calculateNext_(); } - if (begin2 != end2) - { - ret = std::min(ret, keyFn2(*begin2)); + bool operator!=(const Iterator &other) { + return (finished_ != other.finished_ || current0_ != other.current0_ || + current1_ != other.current1_ || current2_ != other.current2_ || + current3_ != other.current3_); } - return ret; - } + const std::tuple & + operator*() const { + NTA_ASSERT(!finished_); + return v_; + } - template()), - typename Element1 = decltype(*std::declval()), - typename Element2 = decltype(*std::declval()), - typename Element3 = decltype(*std::declval()), - typename KeyType = typename std::remove_const< - typename std::result_of< - KeyFn0(decltype(*std::declval())) - >::type >::type> - class GroupBy4 - { - public: - GroupBy4( - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, - Iterator1 begin1, Iterator1 end1, KeyFn1 keyFn1, - Iterator2 begin2, Iterator2 end2, KeyFn2 keyFn2, - Iterator3 begin3, Iterator3 end3, KeyFn3 keyFn3) - : begin0_(begin0), end0_(end0), keyFn0_(keyFn0), - begin1_(begin1), end1_(end1), keyFn1_(keyFn1), - begin2_(begin2), end2_(end2), keyFn2_(keyFn2), - begin3_(begin3), end3_(end3), keyFn3_(keyFn3) - { - NTA_ASSERT(std::is_sorted(begin0, end0, - [&](const Element0& a, const Element0& b) - { - return keyFn0(a) < keyFn0(b); - })); - NTA_ASSERT(std::is_sorted(begin1, end1, - [&](const Element1& a, const Element1& b) - { - return keyFn1(a) < keyFn1(b); - })); - NTA_ASSERT(std::is_sorted(begin2, end2, - [&](const Element2& a, const Element2& b) - { - return keyFn2(a) < keyFn2(b); - })); - NTA_ASSERT(std::is_sorted(begin3, end3, - [&](const Element3& a, const Element3& b) - { - return keyFn3(a) < keyFn3(b); - })); + const Iterator &operator++() { + NTA_ASSERT(!finished_); + calculateNext_(); + return *this; } - class Iterator - { - public: - Iterator( - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, - Iterator1 begin1, Iterator1 end1, KeyFn1 keyFn1, - Iterator2 begin2, Iterator2 end2, KeyFn2 keyFn2, - Iterator3 begin3, Iterator3 end3, KeyFn3 keyFn3) - :current0_(begin0), end0_(end0), keyFn0_(keyFn0), - current1_(begin1), end1_(end1), keyFn1_(keyFn1), - current2_(begin2), end2_(end2), keyFn2_(keyFn2), - current3_(begin3), end3_(end3), keyFn3_(keyFn3), - finished_(false) - { - calculateNext_(); - } + private: + void calculateNext_() { + if (current0_ != end0_ || current1_ != end1_ || current2_ != end2_ || + current3_ != end3_) { + // Find the lowest key. + KeyType key; + if (current0_ != end0_) { + key = + minFrontKey(keyFn0_(*current0_), current1_, end1_, keyFn1_, + current2_, end2_, keyFn2_, current3_, end3_, keyFn3_); + } else if (current1_ != end1_) { + key = minFrontKey(keyFn1_(*current1_), current2_, end2_, keyFn2_, + current3_, end3_, keyFn3_); + } else if (current2_ != end2_) { + key = minFrontKey(keyFn2_(*current2_), current3_, end3_, keyFn3_); + } else { + key = keyFn3_(*current3_); + } - bool operator !=(const Iterator& other) - { - return (finished_ != other.finished_ || - current0_ != other.current0_ || - current1_ != other.current1_ || - current2_ != other.current2_ || - current3_ != other.current3_); - } + std::get<0>(v_) = key; - const std::tuple& operator*() const - { - NTA_ASSERT(!finished_); - return v_; - } + // Find all elements with this key. + std::get<1>(v_) = current0_; + while (current0_ != end0_ && keyFn0_(*current0_) == key) { + current0_++; + } + std::get<2>(v_) = current0_; - const Iterator& operator++() - { - NTA_ASSERT(!finished_); - calculateNext_(); - return *this; - } + std::get<3>(v_) = current1_; + while (current1_ != end1_ && keyFn1_(*current1_) == key) { + current1_++; + } + std::get<4>(v_) = current1_; - private: - - void calculateNext_() - { - if (current0_ != end0_ || - current1_ != end1_ || - current2_ != end2_ || - current3_ != end3_) - { - // Find the lowest key. - KeyType key; - if (current0_ != end0_) - { - key = minFrontKey(keyFn0_(*current0_), - current1_, end1_, keyFn1_, - current2_, end2_, keyFn2_, - current3_, end3_, keyFn3_); - } - else if (current1_ != end1_) - { - key = minFrontKey(keyFn1_(*current1_), - current2_, end2_, keyFn2_, - current3_, end3_, keyFn3_); - } - else if (current2_ != end2_) - { - key = minFrontKey(keyFn2_(*current2_), - current3_, end3_, keyFn3_); - } - else - { - key = keyFn3_(*current3_); - } - - std::get<0>(v_) = key; - - // Find all elements with this key. - std::get<1>(v_) = current0_; - while (current0_ != end0_ && keyFn0_(*current0_) == key) - { - current0_++; - } - std::get<2>(v_) = current0_; - - std::get<3>(v_) = current1_; - while (current1_ != end1_ && keyFn1_(*current1_) == key) - { - current1_++; - } - std::get<4>(v_) = current1_; - - std::get<5>(v_) = current2_; - while (current2_ != end2_ && keyFn2_(*current2_) == key) - { - current2_++; - } - std::get<6>(v_) = current2_; - - std::get<7>(v_) = current3_; - while (current3_ != end3_ && keyFn3_(*current3_) == key) - { - current3_++; - } - std::get<8>(v_) = current3_; + std::get<5>(v_) = current2_; + while (current2_ != end2_ && keyFn2_(*current2_) == key) { + current2_++; } - else - { - finished_ = true; + std::get<6>(v_) = current2_; + + std::get<7>(v_) = current3_; + while (current3_ != end3_ && keyFn3_(*current3_) == key) { + current3_++; } + std::get<8>(v_) = current3_; + } else { + finished_ = true; } - - std::tuple v_; - - Iterator0 current0_; - Iterator0 end0_; - KeyFn0 keyFn0_; - - Iterator1 current1_; - Iterator1 end1_; - KeyFn1 keyFn1_; - - Iterator2 current2_; - Iterator2 end2_; - KeyFn2 keyFn2_; - - Iterator3 current3_; - Iterator3 end3_; - KeyFn3 keyFn3_; - - bool finished_; - }; - - Iterator begin() const - { - return Iterator(begin0_, end0_, keyFn0_, - begin1_, end1_, keyFn1_, - begin2_, end2_, keyFn2_, - begin3_, end3_, keyFn3_); } - Iterator end() const - { - return Iterator(end0_, end0_, keyFn0_, - end1_, end1_, keyFn1_, - end2_, end2_, keyFn2_, - end3_, end3_, keyFn3_); - } + std::tuple + v_; - private: - Iterator0 begin0_; + Iterator0 current0_; Iterator0 end0_; KeyFn0 keyFn0_; - Iterator1 begin1_; + Iterator1 current1_; Iterator1 end1_; KeyFn1 keyFn1_; - Iterator2 begin2_; + Iterator2 current2_; Iterator2 end2_; KeyFn2 keyFn2_; - Iterator3 begin3_; + Iterator3 current3_; Iterator3 end3_; KeyFn3 keyFn3_; + + bool finished_; }; - template - GroupBy4 - groupBy(const Sequence0& sequence0, KeyFn0 keyFn0, - const Sequence1& sequence1, KeyFn1 keyFn1, - const Sequence2& sequence2, KeyFn2 keyFn2, - const Sequence3& sequence3, KeyFn3 keyFn3) - { - return {sequence0.begin(), sequence0.end(), keyFn0, - sequence1.begin(), sequence1.end(), keyFn1, - sequence2.begin(), sequence2.end(), keyFn2, - sequence3.begin(), sequence3.end(), keyFn3}; + Iterator begin() const { + return Iterator(begin0_, end0_, keyFn0_, begin1_, end1_, keyFn1_, begin2_, + end2_, keyFn2_, begin3_, end3_, keyFn3_); } - template - GroupBy4 - iterGroupBy( - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, - Iterator1 begin1, Iterator1 end1, KeyFn1 keyFn1, - Iterator2 begin2, Iterator2 end2, KeyFn2 keyFn2, - Iterator3 begin3, Iterator3 end3, KeyFn3 keyFn3) - { - return {begin0, end0, keyFn0, - begin1, end1, keyFn1, - begin2, end2, keyFn2, - begin3, end3, keyFn3}; + Iterator end() const { + return Iterator(end0_, end0_, keyFn0_, end1_, end1_, keyFn1_, end2_, end2_, + keyFn2_, end3_, end3_, keyFn3_); } +private: + Iterator0 begin0_; + Iterator0 end0_; + KeyFn0 keyFn0_; + + Iterator1 begin1_; + Iterator1 end1_; + KeyFn1 keyFn1_; + + Iterator2 begin2_; + Iterator2 end2_; + KeyFn2 keyFn2_; + + Iterator3 begin3_; + Iterator3 end3_; + KeyFn3 keyFn3_; +}; + +template +GroupBy4 +groupBy(const Sequence0 &sequence0, KeyFn0 keyFn0, const Sequence1 &sequence1, + KeyFn1 keyFn1, const Sequence2 &sequence2, KeyFn2 keyFn2, + const Sequence3 &sequence3, KeyFn3 keyFn3) { + return {sequence0.begin(), sequence0.end(), keyFn0, + sequence1.begin(), sequence1.end(), keyFn1, + sequence2.begin(), sequence2.end(), keyFn2, + sequence3.begin(), sequence3.end(), keyFn3}; +} + +template +GroupBy4 +iterGroupBy(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, Iterator1 begin1, + Iterator1 end1, KeyFn1 keyFn1, Iterator2 begin2, Iterator2 end2, + KeyFn2 keyFn2, Iterator3 begin3, Iterator3 end3, KeyFn3 keyFn3) { + return {begin0, end0, keyFn0, begin1, end1, keyFn1, + begin2, end2, keyFn2, begin3, end3, keyFn3}; +} + +// ========================================================================== +// 5 SEQUENCES +// ========================================================================== + +template ()))>::type>::type> +static KeyType minFrontKey(KeyType frontrunner, Iterator0 begin0, + Iterator0 end0, KeyFn0 keyFn0, Iterator1 begin1, + Iterator1 end1, KeyFn1 keyFn1, Iterator2 begin2, + Iterator2 end2, KeyFn2 keyFn2, Iterator3 begin3, + Iterator3 end3, KeyFn3 keyFn3) { + KeyType ret = frontrunner; + + if (begin0 != end0) { + ret = std::min(ret, keyFn0(*begin0)); + } - // ========================================================================== - // 5 SEQUENCES - // ========================================================================== - - template())) - >::type >::type> - static KeyType minFrontKey(KeyType frontrunner, - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, - Iterator1 begin1, Iterator1 end1, KeyFn1 keyFn1, - Iterator2 begin2, Iterator2 end2, KeyFn2 keyFn2, - Iterator3 begin3, Iterator3 end3, KeyFn3 keyFn3) - { - KeyType ret = frontrunner; - - if (begin0 != end0) - { - ret = std::min(ret, keyFn0(*begin0)); - } + if (begin1 != end1) { + ret = std::min(ret, keyFn1(*begin1)); + } - if (begin1 != end1) - { - ret = std::min(ret, keyFn1(*begin1)); - } + if (begin2 != end2) { + ret = std::min(ret, keyFn2(*begin2)); + } - if (begin2 != end2) - { - ret = std::min(ret, keyFn2(*begin2)); + if (begin3 != end3) { + ret = std::min(ret, keyFn3(*begin3)); + } + + return ret; +} + +template ()), + typename Element1 = decltype(*std::declval()), + typename Element2 = decltype(*std::declval()), + typename Element3 = decltype(*std::declval()), + typename Element4 = decltype(*std::declval()), + typename KeyType = typename std::remove_const()))>::type>::type> +class GroupBy5 { +public: + GroupBy5(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, Iterator1 begin1, + Iterator1 end1, KeyFn1 keyFn1, Iterator2 begin2, Iterator2 end2, + KeyFn2 keyFn2, Iterator3 begin3, Iterator3 end3, KeyFn3 keyFn3, + Iterator4 begin4, Iterator4 end4, KeyFn4 keyFn4) + : begin0_(begin0), end0_(end0), keyFn0_(keyFn0), begin1_(begin1), + end1_(end1), keyFn1_(keyFn1), begin2_(begin2), end2_(end2), + keyFn2_(keyFn2), begin3_(begin3), end3_(end3), keyFn3_(keyFn3), + begin4_(begin4), end4_(end4), keyFn4_(keyFn4) { + NTA_ASSERT( + std::is_sorted(begin0, end0, [&](const Element0 &a, const Element0 &b) { + return keyFn0(a) < keyFn0(b); + })); + NTA_ASSERT( + std::is_sorted(begin1, end1, [&](const Element1 &a, const Element1 &b) { + return keyFn1(a) < keyFn1(b); + })); + NTA_ASSERT( + std::is_sorted(begin2, end2, [&](const Element2 &a, const Element2 &b) { + return keyFn2(a) < keyFn2(b); + })); + NTA_ASSERT( + std::is_sorted(begin3, end3, [&](const Element3 &a, const Element3 &b) { + return keyFn3(a) < keyFn3(b); + })); + NTA_ASSERT( + std::is_sorted(begin4, end4, [&](const Element4 &a, const Element4 &b) { + return keyFn4(a) < keyFn4(b); + })); + } + + class Iterator { + public: + Iterator(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, Iterator1 begin1, + Iterator1 end1, KeyFn1 keyFn1, Iterator2 begin2, Iterator2 end2, + KeyFn2 keyFn2, Iterator3 begin3, Iterator3 end3, KeyFn3 keyFn3, + Iterator4 begin4, Iterator4 end4, KeyFn4 keyFn4) + : current0_(begin0), end0_(end0), keyFn0_(keyFn0), current1_(begin1), + end1_(end1), keyFn1_(keyFn1), current2_(begin2), end2_(end2), + keyFn2_(keyFn2), current3_(begin3), end3_(end3), keyFn3_(keyFn3), + current4_(begin4), end4_(end4), keyFn4_(keyFn4), finished_(false) { + calculateNext_(); } - if (begin3 != end3) - { - ret = std::min(ret, keyFn3(*begin3)); + bool operator!=(const Iterator &other) { + return (finished_ != other.finished_ || current0_ != other.current0_ || + current1_ != other.current1_ || current2_ != other.current2_ || + current3_ != other.current3_ || current4_ != other.current4_); } - return ret; - } + const std::tuple & + operator*() const { + NTA_ASSERT(!finished_); + return v_; + } - template()), - typename Element1 = decltype(*std::declval()), - typename Element2 = decltype(*std::declval()), - typename Element3 = decltype(*std::declval()), - typename Element4 = decltype(*std::declval()), - typename KeyType = typename std::remove_const< - typename std::result_of< - KeyFn0(decltype(*std::declval())) - >::type >::type> - class GroupBy5 - { - public: - GroupBy5( - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, - Iterator1 begin1, Iterator1 end1, KeyFn1 keyFn1, - Iterator2 begin2, Iterator2 end2, KeyFn2 keyFn2, - Iterator3 begin3, Iterator3 end3, KeyFn3 keyFn3, - Iterator4 begin4, Iterator4 end4, KeyFn4 keyFn4) - : begin0_(begin0), end0_(end0), keyFn0_(keyFn0), - begin1_(begin1), end1_(end1), keyFn1_(keyFn1), - begin2_(begin2), end2_(end2), keyFn2_(keyFn2), - begin3_(begin3), end3_(end3), keyFn3_(keyFn3), - begin4_(begin4), end4_(end4), keyFn4_(keyFn4) - { - NTA_ASSERT(std::is_sorted(begin0, end0, - [&](const Element0& a, const Element0& b) - { - return keyFn0(a) < keyFn0(b); - })); - NTA_ASSERT(std::is_sorted(begin1, end1, - [&](const Element1& a, const Element1& b) - { - return keyFn1(a) < keyFn1(b); - })); - NTA_ASSERT(std::is_sorted(begin2, end2, - [&](const Element2& a, const Element2& b) - { - return keyFn2(a) < keyFn2(b); - })); - NTA_ASSERT(std::is_sorted(begin3, end3, - [&](const Element3& a, const Element3& b) - { - return keyFn3(a) < keyFn3(b); - })); - NTA_ASSERT(std::is_sorted(begin4, end4, - [&](const Element4& a, const Element4& b) - { - return keyFn4(a) < keyFn4(b); - })); + const Iterator &operator++() { + NTA_ASSERT(!finished_); + calculateNext_(); + return *this; } - class Iterator - { - public: - Iterator( - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, - Iterator1 begin1, Iterator1 end1, KeyFn1 keyFn1, - Iterator2 begin2, Iterator2 end2, KeyFn2 keyFn2, - Iterator3 begin3, Iterator3 end3, KeyFn3 keyFn3, - Iterator4 begin4, Iterator4 end4, KeyFn4 keyFn4) - :current0_(begin0), end0_(end0), keyFn0_(keyFn0), - current1_(begin1), end1_(end1), keyFn1_(keyFn1), - current2_(begin2), end2_(end2), keyFn2_(keyFn2), - current3_(begin3), end3_(end3), keyFn3_(keyFn3), - current4_(begin4), end4_(end4), keyFn4_(keyFn4), - finished_(false) - { - calculateNext_(); - } + private: + void calculateNext_() { + if (current0_ != end0_ || current1_ != end1_ || current2_ != end2_ || + current3_ != end3_ || current4_ != end4_) { + // Find the lowest key. + KeyType key; + if (current0_ != end0_) { + key = minFrontKey(keyFn0_(*current0_), current1_, end1_, keyFn1_, + current2_, end2_, keyFn2_, current3_, end3_, + keyFn3_, current4_, end4_, keyFn4_); + } else if (current1_ != end1_) { + key = + minFrontKey(keyFn1_(*current1_), current2_, end2_, keyFn2_, + current3_, end3_, keyFn3_, current4_, end4_, keyFn4_); + } else if (current2_ != end2_) { + key = minFrontKey(keyFn2_(*current2_), current3_, end3_, keyFn3_, + current4_, end4_, keyFn4_); + } else if (current3_ != end3_) { + key = minFrontKey(keyFn3_(*current3_), current4_, end4_, keyFn4_); + } else { + key = keyFn4_(*current4_); + } - bool operator !=(const Iterator& other) - { - return (finished_ != other.finished_ || - current0_ != other.current0_ || - current1_ != other.current1_ || - current2_ != other.current2_ || - current3_ != other.current3_ || - current4_ != other.current4_); - } + std::get<0>(v_) = key; - const std::tuple& operator*() const - { - NTA_ASSERT(!finished_); - return v_; - } + // Find all elements with this key. + std::get<1>(v_) = current0_; + while (current0_ != end0_ && keyFn0_(*current0_) == key) { + current0_++; + } + std::get<2>(v_) = current0_; - const Iterator& operator++() - { - NTA_ASSERT(!finished_); - calculateNext_(); - return *this; - } + std::get<3>(v_) = current1_; + while (current1_ != end1_ && keyFn1_(*current1_) == key) { + current1_++; + } + std::get<4>(v_) = current1_; - private: - - void calculateNext_() - { - if (current0_ != end0_ || - current1_ != end1_ || - current2_ != end2_ || - current3_ != end3_ || - current4_ != end4_) - { - // Find the lowest key. - KeyType key; - if (current0_ != end0_) - { - key = minFrontKey(keyFn0_(*current0_), - current1_, end1_, keyFn1_, - current2_, end2_, keyFn2_, - current3_, end3_, keyFn3_, - current4_, end4_, keyFn4_); - } - else if (current1_ != end1_) - { - key = minFrontKey(keyFn1_(*current1_), - current2_, end2_, keyFn2_, - current3_, end3_, keyFn3_, - current4_, end4_, keyFn4_); - } - else if (current2_ != end2_) - { - key = minFrontKey(keyFn2_(*current2_), - current3_, end3_, keyFn3_, - current4_, end4_, keyFn4_); - } - else if (current3_ != end3_) - { - key = minFrontKey(keyFn3_(*current3_), - current4_, end4_, keyFn4_); - } - else - { - key = keyFn4_(*current4_); - } - - std::get<0>(v_) = key; - - // Find all elements with this key. - std::get<1>(v_) = current0_; - while (current0_ != end0_ && keyFn0_(*current0_) == key) - { - current0_++; - } - std::get<2>(v_) = current0_; - - std::get<3>(v_) = current1_; - while (current1_ != end1_ && keyFn1_(*current1_) == key) - { - current1_++; - } - std::get<4>(v_) = current1_; - - std::get<5>(v_) = current2_; - while (current2_ != end2_ && keyFn2_(*current2_) == key) - { - current2_++; - } - std::get<6>(v_) = current2_; - - std::get<7>(v_) = current3_; - while (current3_ != end3_ && keyFn3_(*current3_) == key) - { - current3_++; - } - std::get<8>(v_) = current3_; - - std::get<9>(v_) = current4_; - while (current4_ != end4_ && keyFn4_(*current4_) == key) - { - current4_++; - } - std::get<10>(v_) = current4_; + std::get<5>(v_) = current2_; + while (current2_ != end2_ && keyFn2_(*current2_) == key) { + current2_++; } - else - { - finished_ = true; + std::get<6>(v_) = current2_; + + std::get<7>(v_) = current3_; + while (current3_ != end3_ && keyFn3_(*current3_) == key) { + current3_++; } - } + std::get<8>(v_) = current3_; - std::tuple v_; - - Iterator0 current0_; - Iterator0 end0_; - KeyFn0 keyFn0_; - - Iterator1 current1_; - Iterator1 end1_; - KeyFn1 keyFn1_; - - Iterator2 current2_; - Iterator2 end2_; - KeyFn2 keyFn2_; - - Iterator3 current3_; - Iterator3 end3_; - KeyFn3 keyFn3_; - - Iterator4 current4_; - Iterator4 end4_; - KeyFn4 keyFn4_; - - bool finished_; - }; - - Iterator begin() const - { - return Iterator(begin0_, end0_, keyFn0_, - begin1_, end1_, keyFn1_, - begin2_, end2_, keyFn2_, - begin3_, end3_, keyFn3_, - begin4_, end4_, keyFn4_); + std::get<9>(v_) = current4_; + while (current4_ != end4_ && keyFn4_(*current4_) == key) { + current4_++; + } + std::get<10>(v_) = current4_; + } else { + finished_ = true; + } } - Iterator end() const - { - return Iterator(end0_, end0_, keyFn0_, - end1_, end1_, keyFn1_, - end2_, end2_, keyFn2_, - end3_, end3_, keyFn3_, - end4_, end4_, keyFn4_); - } + std::tuple + v_; - private: - Iterator0 begin0_; + Iterator0 current0_; Iterator0 end0_; KeyFn0 keyFn0_; - Iterator1 begin1_; + Iterator1 current1_; Iterator1 end1_; KeyFn1 keyFn1_; - Iterator2 begin2_; + Iterator2 current2_; Iterator2 end2_; KeyFn2 keyFn2_; - Iterator3 begin3_; + Iterator3 current3_; Iterator3 end3_; KeyFn3 keyFn3_; - Iterator4 begin4_; + Iterator4 current4_; Iterator4 end4_; KeyFn4 keyFn4_; + + bool finished_; }; - template - GroupBy5 - groupBy(const Sequence0& sequence0, KeyFn0 keyFn0, - const Sequence1& sequence1, KeyFn1 keyFn1, - const Sequence2& sequence2, KeyFn2 keyFn2, - const Sequence3& sequence3, KeyFn3 keyFn3, - const Sequence4& sequence4, KeyFn4 keyFn4) - { - return {sequence0.begin(), sequence0.end(), keyFn0, - sequence1.begin(), sequence1.end(), keyFn1, - sequence2.begin(), sequence2.end(), keyFn2, - sequence3.begin(), sequence3.end(), keyFn3, - sequence4.begin(), sequence4.end(), keyFn4}; + Iterator begin() const { + return Iterator(begin0_, end0_, keyFn0_, begin1_, end1_, keyFn1_, begin2_, + end2_, keyFn2_, begin3_, end3_, keyFn3_, begin4_, end4_, + keyFn4_); } - template - GroupBy5 - iterGroupBy( - Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, - Iterator1 begin1, Iterator1 end1, KeyFn1 keyFn1, - Iterator2 begin2, Iterator2 end2, KeyFn2 keyFn2, - Iterator3 begin3, Iterator3 end3, KeyFn3 keyFn3, - Iterator4 begin4, Iterator4 end4, KeyFn4 keyFn4) - { - return {begin0, end0, keyFn0, - begin1, end1, keyFn1, - begin2, end2, keyFn2, - begin3, end3, keyFn3, - begin4, end4, keyFn4}; + Iterator end() const { + return Iterator(end0_, end0_, keyFn0_, end1_, end1_, keyFn1_, end2_, end2_, + keyFn2_, end3_, end3_, keyFn3_, end4_, end4_, keyFn4_); } +private: + Iterator0 begin0_; + Iterator0 end0_; + KeyFn0 keyFn0_; + + Iterator1 begin1_; + Iterator1 end1_; + KeyFn1 keyFn1_; + + Iterator2 begin2_; + Iterator2 end2_; + KeyFn2 keyFn2_; + + Iterator3 begin3_; + Iterator3 end3_; + KeyFn3 keyFn3_; + + Iterator4 begin4_; + Iterator4 end4_; + KeyFn4 keyFn4_; +}; + +template +GroupBy5 +groupBy(const Sequence0 &sequence0, KeyFn0 keyFn0, const Sequence1 &sequence1, + KeyFn1 keyFn1, const Sequence2 &sequence2, KeyFn2 keyFn2, + const Sequence3 &sequence3, KeyFn3 keyFn3, const Sequence4 &sequence4, + KeyFn4 keyFn4) { + return {sequence0.begin(), sequence0.end(), keyFn0, + sequence1.begin(), sequence1.end(), keyFn1, + sequence2.begin(), sequence2.end(), keyFn2, + sequence3.begin(), sequence3.end(), keyFn3, + sequence4.begin(), sequence4.end(), keyFn4}; +} + +template +GroupBy5 +iterGroupBy(Iterator0 begin0, Iterator0 end0, KeyFn0 keyFn0, Iterator1 begin1, + Iterator1 end1, KeyFn1 keyFn1, Iterator2 begin2, Iterator2 end2, + KeyFn2 keyFn2, Iterator3 begin3, Iterator3 end3, KeyFn3 keyFn3, + Iterator4 begin4, Iterator4 end4, KeyFn4 keyFn4) { + return {begin0, end0, keyFn0, begin1, end1, keyFn1, begin2, end2, + keyFn2, begin3, end3, keyFn3, begin4, end4, keyFn4}; +} + } // end namespace nupic #endif // NTA_GROUPBY_HPP diff --git a/src/nupic/utils/Log.hpp b/src/nupic/utils/Log.hpp index 536e223d2b..b5d842b255 100644 --- a/src/nupic/utils/Log.hpp +++ b/src/nupic/utils/Log.hpp @@ -22,53 +22,64 @@ /** * @file - * Definition of C++ macros for logging. + * Definition of C++ macros for logging. */ #ifndef NTA_LOG2_HPP #define NTA_LOG2_HPP -#include #include +#include - -#define NTA_DEBUG nupic::LogItem(__FILE__, __LINE__, nupic::LogItem::debug).stream() +#define NTA_DEBUG \ + nupic::LogItem(__FILE__, __LINE__, nupic::LogItem::debug).stream() // Can be used in Loggable classes -#define NTA_LDEBUG(level) if (logLevel_ < (level)) {} \ - else nupic::LogItem(__FILE__, __LINE__, nupic::LogItem::debug).stream() - -// For informational messages that report status but do not indicate that anything is wrong -#define NTA_INFO nupic::LogItem(__FILE__, __LINE__, nupic::LogItem::info).stream() - -// For messages that indicate a recoverable error or something else that it may be -// important for the end user to know about. -#define NTA_WARN nupic::LogItem(__FILE__, __LINE__, nupic::LogItem::warn).stream() - -// To throw an exception and make sure the exception message is logged appropriately +#define NTA_LDEBUG(level) \ + if (logLevel_ < (level)) { \ + } else \ + nupic::LogItem(__FILE__, __LINE__, nupic::LogItem::debug).stream() + +// For informational messages that report status but do not indicate that +// anything is wrong +#define NTA_INFO \ + nupic::LogItem(__FILE__, __LINE__, nupic::LogItem::info).stream() + +// For messages that indicate a recoverable error or something else that it may +// be important for the end user to know about. +#define NTA_WARN \ + nupic::LogItem(__FILE__, __LINE__, nupic::LogItem::warn).stream() + +// To throw an exception and make sure the exception message is logged +// appropriately #define NTA_THROW throw nupic::LoggingException(__FILE__, __LINE__) // The difference between CHECK and ASSERT is that ASSERT is for // performance critical code and can be disabled in a release -// build. Both throw an exception on error. +// build. Both throw an exception on error. -#define NTA_CHECK(condition) if (condition) {} \ -else NTA_THROW << "CHECK FAILED: \"" << #condition << "\" " +#define NTA_CHECK(condition) \ + if (condition) { \ + } else \ + NTA_THROW << "CHECK FAILED: \"" << #condition << "\" " #ifdef NTA_ASSERTIONS_ON -#define NTA_ASSERT(condition) if (condition) {} \ -else NTA_THROW << "ASSERTION FAILED: \"" << #condition << "\" " +#define NTA_ASSERT(condition) \ + if (condition) { \ + } else \ + NTA_THROW << "ASSERTION FAILED: \"" << #condition << "\" " #else -// NTA_ASSERT macro does nothing. -// The second line should never be executed, or even compiled, but we +// NTA_ASSERT macro does nothing. +// The second line should never be executed, or even compiled, but we // need something that is syntactically compatible with NTA_ASSERT -#define NTA_ASSERT(condition) if (1) {} \ - else nupic::LogItem(__FILE__, __LINE__, nupic::LogItem::debug).stream() - -#endif // NTA_ASSERTIONS_ON +#define NTA_ASSERT(condition) \ + if (1) { \ + } else \ + nupic::LogItem(__FILE__, __LINE__, nupic::LogItem::debug).stream() +#endif // NTA_ASSERTIONS_ON #endif // NTA_LOG2_HPP diff --git a/src/nupic/utils/LogItem.cpp b/src/nupic/utils/LogItem.cpp index 0b23d62daa..22f543fd08 100644 --- a/src/nupic/utils/LogItem.cpp +++ b/src/nupic/utils/LogItem.cpp @@ -21,33 +21,26 @@ */ /** @file -* LogItem implementation -*/ - + * LogItem implementation + */ -#include +#include // cerr #include -#include // cerr +#include #include // runtime_error using namespace nupic; -std::ostream* LogItem::ostream_ = nullptr; +std::ostream *LogItem::ostream_ = nullptr; -void LogItem::setOutputFile(std::ostream& ostream) -{ - ostream_ = &ostream; -} +void LogItem::setOutputFile(std::ostream &ostream) { ostream_ = &ostream; } LogItem::LogItem(const char *filename, int line, LogLevel level) - : filename_(filename), lineno_(line), level_(level), msg_("") -{} + : filename_(filename), lineno_(line), level_(level), msg_("") {} -LogItem::~LogItem() -{ +LogItem::~LogItem() { std::string slevel; - switch(level_) - { + switch (level_) { case debug: slevel = "DEBUG:"; break; @@ -74,11 +67,6 @@ LogItem::~LogItem() (*ostream_) << " [" << filename_ << " line " << lineno_ << "]"; (*ostream_) << std::endl; - -} - -std::ostringstream& LogItem::stream() { - return msg_; } - +std::ostringstream &LogItem::stream() { return msg_; } diff --git a/src/nupic/utils/LogItem.hpp b/src/nupic/utils/LogItem.hpp index dbdf2f2a59..b8f256a645 100644 --- a/src/nupic/utils/LogItem.hpp +++ b/src/nupic/utils/LogItem.hpp @@ -21,64 +21,59 @@ */ /** @file -* LogItem interface -*/ + * LogItem interface + */ #ifndef NTA_LOG_ITEM_HPP #define NTA_LOG_ITEM_HPP -#include #include +#include namespace nupic { +/** + * @b Description + * A LogItem represents a single log entry. It contains a stream that + * accumulates a log message, and its destructor calls the logger. + * + * A LogItem contains an internal stream + * which is used for building up an application message using + * << operators. + * + */ + +class LogItem { +public: + typedef enum { debug, info, warn, error } LogLevel; /** - * @b Description - * A LogItem represents a single log entry. It contains a stream that accumulates - * a log message, and its destructor calls the logger. - * - * A LogItem contains an internal stream - * which is used for building up an application message using - * << operators. - * + * Record information to be logged */ + LogItem(const char *filename, int line, LogLevel level); + /** + * Destructor performs the logging + */ + virtual ~LogItem(); - class LogItem { - public: - - typedef enum {debug, info, warn, error} LogLevel; - /** - * Record information to be logged - */ - LogItem(const char *filename, int line, LogLevel level); - - /** - * Destructor performs the logging - */ - virtual ~LogItem(); - - /* - * Return the underlying stream object. Caller will use it to construct the log message. - */ - std::ostringstream& stream(); - - static void setOutputFile(std::ostream& ostream); - - - protected: - const char *filename_; // name of file - int lineno_; // line number in file - LogLevel level_; - std::ostringstream msg_; - - private: - static std::ostream* ostream_; + /* + * Return the underlying stream object. Caller will use it to construct the + * log message. + */ + std::ostringstream &stream(); - }; + static void setOutputFile(std::ostream &ostream); +protected: + const char *filename_; // name of file + int lineno_; // line number in file + LogLevel level_; + std::ostringstream msg_; -} +private: + static std::ostream *ostream_; +}; +} // namespace nupic #endif // NTA_LOG_ITEM_HPP diff --git a/src/nupic/utils/LoggingException.cpp b/src/nupic/utils/LoggingException.cpp index 39de5be8f8..5c2ce2a311 100644 --- a/src/nupic/utils/LoggingException.cpp +++ b/src/nupic/utils/LoggingException.cpp @@ -29,8 +29,7 @@ #include using namespace nupic; -LoggingException::~LoggingException() throw() -{ +LoggingException::~LoggingException() throw() { if (!alreadyLogged_) { // Let LogItem do the work for us. This code is a bit complex // because LogItem was designed to be used from a logging macro diff --git a/src/nupic/utils/LoggingException.hpp b/src/nupic/utils/LoggingException.hpp index 8cf5067afe..d31d6bcb4f 100644 --- a/src/nupic/utils/LoggingException.hpp +++ b/src/nupic/utils/LoggingException.hpp @@ -29,73 +29,66 @@ #include #include -#include +#include -namespace nupic -{ - class LoggingException : public Exception - { - public: - LoggingException(const std::string& filename, UInt32 lineno) : - Exception(filename, lineno, std::string()), ss_(std::string()), - lmessageValid_(false), alreadyLogged_(false) - { - } +namespace nupic { +class LoggingException : public Exception { +public: + LoggingException(const std::string &filename, UInt32 lineno) + : Exception(filename, lineno, std::string()), ss_(std::string()), + lmessageValid_(false), alreadyLogged_(false) {} - virtual ~LoggingException() throw(); + virtual ~LoggingException() throw(); - const char * getMessage() const override - { - // Make sure we use a persistent string. Otherwise the pointer may - // become invalid. - // If the underlying stringstream object hasn't changed, don't regenerate lmessage_. - // This is important because if we catch this exception a second call to exception.what() - // will trash the buffer returned by a first call to exception.what() - if (! lmessageValid_) { - lmessage_ = ss_.str(); - lmessageValid_ = true; - } - return lmessage_.c_str(); + const char *getMessage() const override { + // Make sure we use a persistent string. Otherwise the pointer may + // become invalid. + // If the underlying stringstream object hasn't changed, don't regenerate + // lmessage_. This is important because if we catch this exception a second + // call to exception.what() will trash the buffer returned by a first call + // to exception.what() + if (!lmessageValid_) { + lmessage_ = ss_.str(); + lmessageValid_ = true; } + return lmessage_.c_str(); + } - // for Index.hpp: // because stringstream cant take << vector - LoggingException& operator<<(std::vector> v) - { - lmessageValid_ = false; - ss_ << "["; - for(auto & elem : v) - ss_ << elem << " "; - ss_ << "]"; - return *this; - } + // for Index.hpp: // because stringstream cant take << vector + LoggingException & + operator<<(std::vector> v) { + lmessageValid_ = false; + ss_ << "["; + for (auto &elem : v) + ss_ << elem << " "; + ss_ << "]"; + return *this; + } - template LoggingException& operator<<(const T& obj) - { - // underlying stringstream changes, so let getMessage() know - // to regenerate lmessage_ - lmessageValid_ = false; - ss_ << obj; - return *this; - } + template LoggingException &operator<<(const T &obj) { + // underlying stringstream changes, so let getMessage() know + // to regenerate lmessage_ + lmessageValid_ = false; + ss_ << obj; + return *this; + } - LoggingException(const LoggingException& l) : Exception(l), - ss_(l.ss_.str()), - lmessage_(""), - lmessageValid_(false), - alreadyLogged_(true) // copied exception does not log + LoggingException(const LoggingException &l) + : Exception(l), ss_(l.ss_.str()), lmessage_(""), lmessageValid_(false), + alreadyLogged_(true) // copied exception does not log - { - // make sure message string is up to date for debuggers. - getMessage(); - } + { + // make sure message string is up to date for debuggers. + getMessage(); + } - private: - std::stringstream ss_; - mutable std::string lmessage_; // mutable because getMesssage() modifies it - mutable bool lmessageValid_; - bool alreadyLogged_; - }; // class LoggingException +private: + std::stringstream ss_; + mutable std::string lmessage_; // mutable because getMesssage() modifies it + mutable bool lmessageValid_; + bool alreadyLogged_; +}; // class LoggingException -} +} // namespace nupic #endif // NTA_LOGGING_EXCEPTION_HPP diff --git a/src/nupic/utils/MovingAverage.cpp b/src/nupic/utils/MovingAverage.cpp index e4623956a0..b02e790c36 100644 --- a/src/nupic/utils/MovingAverage.cpp +++ b/src/nupic/utils/MovingAverage.cpp @@ -28,31 +28,22 @@ #include using namespace std; -using namespace::nupic; +using namespace ::nupic; using namespace nupic::util; - -MovingAverage::MovingAverage(UInt wSize, const vector& historicalValues) - : windowSize_(wSize) -{ - if (historicalValues.size() != 0) - { - copy( - historicalValues.begin() + historicalValues.size() - wSize, - historicalValues.end(), - back_inserter(slidingWindow_)); +MovingAverage::MovingAverage(UInt wSize, const vector &historicalValues) + : windowSize_(wSize) { + if (historicalValues.size() != 0) { + copy(historicalValues.begin() + historicalValues.size() - wSize, + historicalValues.end(), back_inserter(slidingWindow_)); } total_ = Real32(accumulate(slidingWindow_.begin(), slidingWindow_.end(), 0)); } - MovingAverage::MovingAverage(UInt wSize) : windowSize_(wSize), total_(0) {} - -Real32 MovingAverage::compute(Real32 newVal) -{ - if (windowSize_ == slidingWindow_.size()) - { +Real32 MovingAverage::compute(Real32 newVal) { + if (windowSize_ == slidingWindow_.size()) { total_ -= slidingWindow_.front(); slidingWindow_.erase(slidingWindow_.begin()); // pop front element } @@ -62,34 +53,21 @@ Real32 MovingAverage::compute(Real32 newVal) return getCurrentAvg(); } - -std::vector MovingAverage::getSlidingWindow() const -{ +std::vector MovingAverage::getSlidingWindow() const { return slidingWindow_; } - -Real32 MovingAverage::getCurrentAvg() const -{ +Real32 MovingAverage::getCurrentAvg() const { return Real32(total_) / Real32(slidingWindow_.size()); } - -bool MovingAverage::operator==(const MovingAverage& r2) const -{ +bool MovingAverage::operator==(const MovingAverage &r2) const { return (windowSize_ == r2.windowSize_ && - slidingWindow_ == r2.slidingWindow_ && - total_ == r2.total_); + slidingWindow_ == r2.slidingWindow_ && total_ == r2.total_); } - -bool MovingAverage::operator!=(const MovingAverage& r2) const -{ +bool MovingAverage::operator!=(const MovingAverage &r2) const { return !operator==(r2); } - -Real32 MovingAverage::getTotal() const -{ - return total_; -} +Real32 MovingAverage::getTotal() const { return total_; } diff --git a/src/nupic/utils/MovingAverage.hpp b/src/nupic/utils/MovingAverage.hpp index 25bb649058..1bea68331f 100644 --- a/src/nupic/utils/MovingAverage.hpp +++ b/src/nupic/utils/MovingAverage.hpp @@ -27,30 +27,27 @@ #include - -namespace nupic -{ - - namespace util - { - - class MovingAverage - { - public: - MovingAverage(UInt wSize, const std::vector& historicalValues); - MovingAverage(UInt wSize); - std::vector getSlidingWindow() const; - Real32 getCurrentAvg() const; - Real32 compute(Real32 newValue); - Real32 getTotal() const; - bool operator==(const MovingAverage& r2) const; - bool operator!=(const MovingAverage& r2) const; - private: - UInt32 windowSize_; - std::vector slidingWindow_; - Real32 total_; - }; - } -} +namespace nupic { + +namespace util { + +class MovingAverage { +public: + MovingAverage(UInt wSize, const std::vector &historicalValues); + MovingAverage(UInt wSize); + std::vector getSlidingWindow() const; + Real32 getCurrentAvg() const; + Real32 compute(Real32 newValue); + Real32 getTotal() const; + bool operator==(const MovingAverage &r2) const; + bool operator!=(const MovingAverage &r2) const; + +private: + UInt32 windowSize_; + std::vector slidingWindow_; + Real32 total_; +}; +} // namespace util +} // namespace nupic #endif // NUPIC_UTIL_MOVING_AVERAGE_HPP diff --git a/src/nupic/utils/Random.cpp b/src/nupic/utils/Random.cpp index fc85f2c6d7..d8ed11ed76 100644 --- a/src/nupic/utils/Random.cpp +++ b/src/nupic/utils/Random.cpp @@ -24,9 +24,9 @@ Random Number Generator implementation */ +#include // For ldexp. #include #include -#include // For ldexp. #include // for istream, ostream #include @@ -39,15 +39,13 @@ #include using namespace nupic; -Random* Random::theInstanceP_ = nullptr; +Random *Random::theInstanceP_ = nullptr; RandomSeedFuncPtr Random::seeder_ = nullptr; const UInt32 Random::MAX32 = (UInt32)((Int32)(-1)); const UInt64 Random::MAX64 = (UInt64)((Int64)(-1)); - -static NTA_UInt64 badSeeder() -{ +static NTA_UInt64 badSeeder() { NTA_THROW << "Logic error in initialization of Random subsystem."; return 0; } @@ -62,41 +60,36 @@ static NTA_UInt64 badSeeder() // When we have different algorithms RandomImpl will become an interface // class and subclasses will implement specific algorithms -namespace nupic -{ - class RandomImpl - { - public: - RandomImpl(UInt64 seed); - ~RandomImpl() {}; - void write(RandomImplProto::Builder& proto) const; - void read(RandomImplProto::Reader& proto); - UInt32 getUInt32(); - // Note: copy constructor and operator= are needed - // The default is ok. - private: - friend std::ostream& operator<<(std::ostream& outStream, const RandomImpl& r); - friend std::istream& operator>>(std::istream& inStream, RandomImpl& r); - const static UInt32 VERSION = 2; - // internal state - static const int stateSize_ = 31; - static const int sep_ = 3; - UInt32 state_[stateSize_]; - int rptr_; - int fptr_; - - }; +namespace nupic { +class RandomImpl { +public: + RandomImpl(UInt64 seed); + ~RandomImpl(){}; + void write(RandomImplProto::Builder &proto) const; + void read(RandomImplProto::Reader &proto); + UInt32 getUInt32(); + // Note: copy constructor and operator= are needed + // The default is ok. +private: + friend std::ostream &operator<<(std::ostream &outStream, const RandomImpl &r); + friend std::istream &operator>>(std::istream &inStream, RandomImpl &r); + const static UInt32 VERSION = 2; + // internal state + static const int stateSize_ = 31; + static const int sep_ = 3; + UInt32 state_[stateSize_]; + int rptr_; + int fptr_; }; +}; // namespace nupic -Random::Random(const Random& r) -{ +Random::Random(const Random &r) { NTA_CHECK(r.impl_ != nullptr); seed_ = r.seed_; impl_ = new RandomImpl(*r.impl_); } -void Random::write(RandomProto::Builder& proto) const -{ +void Random::write(RandomProto::Builder &proto) const { // save Random state proto.setSeed(seed_); @@ -105,8 +98,7 @@ void Random::write(RandomProto::Builder& proto) const impl_->write(implProto); } -void Random::read(RandomProto::Reader& proto) -{ +void Random::read(RandomProto::Reader &proto) { // load Random state seed_ = proto.getSeed(); @@ -115,19 +107,15 @@ void Random::read(RandomProto::Reader& proto) impl_->read(implProto); } -void Random::reseed(UInt64 seed) -{ +void Random::reseed(UInt64 seed) { seed_ = seed; if (impl_) delete impl_; impl_ = new RandomImpl(seed); } - -Random& Random::operator=(const Random& other) -{ - if (this != &other) - { +Random &Random::operator=(const Random &other) { + if (this != &other) { seed_ = other.seed_; if (impl_) delete impl_; @@ -137,14 +125,9 @@ Random& Random::operator=(const Random& other) return *this; } -Random::~Random() -{ - delete impl_; -} - +Random::~Random() { delete impl_; } -Random::Random(UInt64 seed) -{ +Random::Random(UInt64 seed) { // Get the seeder even if we don't need it, because // this will have the side effect of allocating the // singleton if necessary. The singleton will actuallly @@ -167,11 +150,8 @@ Random::Random(UInt64 seed) impl_ = new RandomImpl(seed_); } - -RandomSeedFuncPtr Random::getSeeder() -{ - if (seeder_ == nullptr) - { +RandomSeedFuncPtr Random::getSeeder() { + if (seeder_ == nullptr) { NTA_CHECK(theInstanceP_ == nullptr); // set the seeder to something not NULL // so the constructor below will not @@ -183,24 +163,19 @@ RandomSeedFuncPtr Random::getSeeder() return seeder_; } -void Random::initSeeder(const RandomSeedFuncPtr r) -{ +void Random::initSeeder(const RandomSeedFuncPtr r) { NTA_CHECK(r != nullptr); seeder_ = r; } - -void Random::shutdown() -{ - if (theInstanceP_ != nullptr) - { +void Random::shutdown() { + if (theInstanceP_ != nullptr) { delete theInstanceP_; theInstanceP_ = nullptr; } } -UInt32 Random::getUInt32(const UInt32 max) -{ +UInt32 Random::getUInt32(const UInt32 max) { NTA_ASSERT(max > 0); UInt32 smax = Random::MAX32 - (Random::MAX32 % max); UInt32 sample; @@ -208,12 +183,12 @@ UInt32 Random::getUInt32(const UInt32 max) sample = impl_->getUInt32(); } while (sample > smax); - // NTA_WARN << "Random32(" << max << ") -> " << sample % max << " smax = " << smax; + // NTA_WARN << "Random32(" << max << ") -> " << sample % max << " smax = " << + // smax; return sample % max; } -UInt64 Random::getUInt64(const UInt64 max) -{ +UInt64 Random::getUInt64(const UInt64 max) { NTA_ASSERT(max > 0); UInt64 smax = Random::MAX64 - (Random::MAX64 % max); UInt64 sample, lo, hi; @@ -221,39 +196,36 @@ UInt64 Random::getUInt64(const UInt64 max) lo = impl_->getUInt32(); hi = impl_->getUInt32(); sample = lo | (hi << 32); - } while(sample > smax); - // NTA_WARN << "Random64(" << max << ") -> " << sample % max << " smax = " << smax; + } while (sample > smax); + // NTA_WARN << "Random64(" << max << ") -> " << sample % max << " smax = " << + // smax; return sample % max; } -double Random::getReal64() -{ +double Random::getReal64() { const int mantissaBits = 48; const UInt64 max = (UInt64)0x1U << mantissaBits; UInt64 value = getUInt64(max); - Real64 dvalue = (Real64) value; // No loss because we only need the 48 mantissa bits. + Real64 dvalue = + (Real64)value; // No loss because we only need the 48 mantissa bits. Real64 returnval = ::ldexp(dvalue, -mantissaBits); // NTA_WARN << "RandomReal -> " << returnval; return returnval; } - // ---- RandomImpl follows ---- - - - -UInt32 RandomImpl::getUInt32(void) -{ +UInt32 RandomImpl::getUInt32(void) { UInt32 i; #ifdef RANDOM_SUPERDEBUG - printf("Random::get *fptr = %ld; *rptr = %ld fptr = %ld rptr = %ld\n", state_[fptr_], state_[rptr_], fptr_, rptr_); + printf("Random::get *fptr = %ld; *rptr = %ld fptr = %ld rptr = %ld\n", + state_[fptr_], state_[rptr_], fptr_, rptr_); #endif - state_[fptr_] = (UInt32)( - ((UInt64)state_[fptr_] + (UInt64)state_[rptr_]) % Random::MAX32); + state_[fptr_] = + (UInt32)(((UInt64)state_[fptr_] + (UInt64)state_[rptr_]) % Random::MAX32); i = state_[fptr_]; - i = (i >> 1) & 0x7fffffff; /* chucking least random bit */ + i = (i >> 1) & 0x7fffffff; /* chucking least random bit */ if (++fptr_ >= stateSize_) { fptr_ = 0; ++rptr_; @@ -269,10 +241,7 @@ UInt32 RandomImpl::getUInt32(void) return i; } - - -RandomImpl::RandomImpl(UInt64 seed) -{ +RandomImpl::RandomImpl(UInt64 seed) { /** * Initialize our state. Taken from BSD source for random() @@ -286,8 +255,8 @@ RandomImpl::RandomImpl(UInt64 seed) * * 2^31-1 (prime) = 2147483647 = 127773*16807+2836 */ - Int32 quot = state_[i-1] / 127773; - Int32 rem = state_[i-1] % 127773; + Int32 quot = state_[i - 1] / 127773; + Int32 rem = state_[i - 1] % 127773; Int32 test = 16807 * rem - 2836 * quot; state_[i] = (UInt32)((test + (test < 0 ? 2147483647 : 0)) % Random::MAX32); } @@ -304,152 +273,119 @@ RandomImpl::RandomImpl(UInt64 seed) (void)getUInt32(); #ifdef RANDOM_SUPERDEBUG printf("Random: after init for seed = %lu\n", seed); - printf("Random: *fptr = %ld; *rptr = %ld fptr = %ld rptr = %ld\n", state_[fptr_], state_[rptr_], fptr_, rptr_); + printf("Random: *fptr = %ld; *rptr = %ld fptr = %ld rptr = %ld\n", + state_[fptr_], state_[rptr_], fptr_, rptr_); for (long i = 0; i < stateSize_; i++) { printf("Random: %d %ld\n", i, state_[i]); } #endif } - -void RandomImpl::write(RandomImplProto::Builder& proto) const -{ +void RandomImpl::write(RandomImplProto::Builder &proto) const { auto state = proto.initState(stateSize_); - for (UInt i = 0; i < stateSize_; ++i) - { + for (UInt i = 0; i < stateSize_; ++i) { state.set(i, state_[i]); } proto.setRptr(rptr_); proto.setFptr(fptr_); } - -void RandomImpl::read(RandomImplProto::Reader& proto) -{ +void RandomImpl::read(RandomImplProto::Reader &proto) { auto state = proto.getState(); - for (UInt i = 0; i < state.size(); ++i) - { + for (UInt i = 0; i < state.size(); ++i) { state_[i] = state[i]; } rptr_ = proto.getRptr(); fptr_ = proto.getFptr(); } +namespace nupic { +std::ostream &operator<<(std::ostream &outStream, const Random &r) { + outStream << "random-v1 "; + outStream << r.seed_ << " "; + NTA_CHECK(r.impl_ != nullptr); + outStream << *r.impl_; + outStream << " endrandom-v1"; + return outStream; +} + +std::istream &operator>>(std::istream &inStream, Random &r) { + std::string version; -namespace nupic -{ - std::ostream& operator<<(std::ostream& outStream, const Random& r) - { - outStream << "random-v1 "; - outStream << r.seed_ << " "; - NTA_CHECK(r.impl_ != nullptr); - outStream << *r.impl_; - outStream << " endrandom-v1"; - return outStream; + inStream >> version; + if (version != "random-v1") { + NTA_THROW << "Random() deserializer -- found unexpected version string '" + << version << "'"; } + inStream >> r.seed_; + if (!r.impl_) + r.impl_ = new RandomImpl(0); + inStream >> *r.impl_; - std::istream& operator>>(std::istream& inStream, Random& r) - { - std::string version; + std::string endtag; + inStream >> endtag; + if (endtag != "endrandom-v1") { + NTA_THROW << "Random() deserializer -- found unexpected end tag '" << endtag + << "'"; + } - inStream >> version; - if (version != "random-v1") - { - NTA_THROW << "Random() deserializer -- found unexpected version string '" - << version << "'"; - } - inStream >> r.seed_; - if (! r.impl_) - r.impl_ = new RandomImpl(0); - - inStream >> *r.impl_; - - std::string endtag; - inStream >> endtag; - if (endtag != "endrandom-v1") - { - NTA_THROW << "Random() deserializer -- found unexpected end tag '" - << endtag << "'"; - } + return inStream; +} - return inStream; +std::ostream &operator<<(std::ostream &outStream, const RandomImpl &r) { + outStream << "RandomImpl " << RandomImpl::VERSION << " "; + outStream << RandomImpl::stateSize_ << " "; + for (auto &elem : r.state_) { + outStream << elem << " "; } + outStream << r.rptr_ << " "; + outStream << r.fptr_; + return outStream; +} - std::ostream& operator<<(std::ostream& outStream, const RandomImpl& r) - { - outStream << "RandomImpl " << RandomImpl::VERSION << " "; - outStream << RandomImpl::stateSize_ << " "; - for (auto & elem : r.state_) - { - outStream << elem << " "; +std::istream &operator>>(std::istream &inStream, RandomImpl &r) { + std::string marker; + inStream >> marker; + UInt32 version; + if (marker == "RandomImpl") { + inStream >> version; + if (version != 2) { + NTA_THROW << "RandomImpl deserialization found unexpected version: " + << version; } - outStream << r.rptr_ << " "; - outStream << r.fptr_; - return outStream; + } else if (marker == "randomimpl-v1") { + version = 1; + } else { + NTA_THROW << "RandomImpl() deserializer -- found unexpected version " + << "string '" << marker << "'"; } - - std::istream& operator>>(std::istream& inStream, RandomImpl& r) - { - std::string marker; - inStream >> marker; - UInt32 version; - if (marker == "RandomImpl") - { - inStream >> version; - if (version != 2) - { - NTA_THROW << "RandomImpl deserialization found unexpected version: " - << version; - } - } - else if (marker == "randomimpl-v1") - { - version = 1; - } - else - { - NTA_THROW << "RandomImpl() deserializer -- found unexpected version " - << "string '" << marker << "'"; - } - UInt32 ss = 0; - inStream >> ss; - NTA_CHECK(ss == (UInt32)RandomImpl::stateSize_) << " ss = " << ss; - - int tmp; - for (auto & elem : r.state_) - { - if (version < 2) - { - inStream >> tmp; - elem = (UInt32)tmp; - } - else - { - inStream >> elem; - } + UInt32 ss = 0; + inStream >> ss; + NTA_CHECK(ss == (UInt32)RandomImpl::stateSize_) << " ss = " << ss; + + int tmp; + for (auto &elem : r.state_) { + if (version < 2) { + inStream >> tmp; + elem = (UInt32)tmp; + } else { + inStream >> elem; } - inStream >> r.rptr_; - inStream >> r.fptr_; - return inStream; } + inStream >> r.rptr_; + inStream >> r.fptr_; + return inStream; +} - // helper function for seeding RNGs across the plugin barrier - // Unless there is a logic error, should not be called if - // the Random singleton has not been initialized. - NTA_UInt64 GetRandomSeed() - { - Random* r = nupic::Random::theInstanceP_; - NTA_CHECK(r != nullptr); - NTA_UInt64 result = r->getUInt64(); - return result; - } - - +// helper function for seeding RNGs across the plugin barrier +// Unless there is a logic error, should not be called if +// the Random singleton has not been initialized. +NTA_UInt64 GetRandomSeed() { + Random *r = nupic::Random::theInstanceP_; + NTA_CHECK(r != nullptr); + NTA_UInt64 result = r->getUInt64(); + return result; +} } // namespace nupic - - - - - diff --git a/src/nupic/utils/Random.hpp b/src/nupic/utils/Random.hpp index b8717dd3a8..7b932567cc 100644 --- a/src/nupic/utils/Random.hpp +++ b/src/nupic/utils/Random.hpp @@ -30,8 +30,8 @@ #include #include #include -#include #include +#include #include #include @@ -41,186 +41,174 @@ typedef NTA_UInt64 (*RandomSeedFuncPtr)(); namespace nupic { +/** + * @b Responsibility + * Provides standardized random number generation for the NuPIC Runtime Engine. + * Seed can be logged in one run and then set in another. + * @b Rationale + * Makes it possible to reproduce tests that are driven by random number + * generation. + * + * @b Description + * Functionality is similar to the standard random() function that is provided + * by C. + * + * Each Random object is a random number generator. There are three ways of + * creating one: + * 1) explicit seed + * Random rng(seed); + * 2) self-seeded + * Random rng; + * 3) named generator -- normally self-seeded, but seed may be + * set explicitly through an environment variable + * Random rng("level2TP"); + * If NTA_RANDOM_DEBUG is set, this object will log its self-seed + * The seed can be explicitly set through NTA_RANDOM_SEED_level2TP + * + * Good self-seeds are generated by an internal global random number generator. + * This global rng is seeded from the current time, but its seed may be + * overridden with NTA_RANDOM_SEED + * + * Automated tests that use random numbers should normally use named generators. + * This allows them to get a different seed each time, but also allows + * reproducibility in the case that a test failure is triggered by a particular + * seed. + * + * Random should not be used if cryptographic strength is required (e.g. for + * generating a challenge in an authentication scheme). + * + * @todo Add ability to specify different rng algorithms. + */ +class RandomImpl; + +class Random : public Serializable { +public: /** - * @b Responsibility - * Provides standardized random number generation for the NuPIC Runtime Engine. - * Seed can be logged in one run and then set in another. - * @b Rationale - * Makes it possible to reproduce tests that are driven by random number generation. - * - * @b Description - * Functionality is similar to the standard random() function that is provided by C. - * - * Each Random object is a random number generator. There are three ways of - * creating one: - * 1) explicit seed - * Random rng(seed); - * 2) self-seeded - * Random rng; - * 3) named generator -- normally self-seeded, but seed may be - * set explicitly through an environment variable - * Random rng("level2TP"); - * If NTA_RANDOM_DEBUG is set, this object will log its self-seed - * The seed can be explicitly set through NTA_RANDOM_SEED_level2TP - * - * Good self-seeds are generated by an internal global random number generator. - * This global rng is seeded from the current time, but its seed may be - * overridden with NTA_RANDOM_SEED - * - * Automated tests that use random numbers should normally use named generators. - * This allows them to get a different seed each time, but also allows reproducibility - * in the case that a test failure is triggered by a particular seed. - * - * Random should not be used if cryptographic strength is required (e.g. for - * generating a challenge in an authentication scheme). - * - * @todo Add ability to specify different rng algorithms. + * Retrieve the seeder. If seeder not set, allocates the + * singleton and and initializes the seeder. */ - class RandomImpl; - - class Random : public Serializable - { - public: - /** - * Retrieve the seeder. If seeder not set, allocates the - * singleton and and initializes the seeder. - */ - static RandomSeedFuncPtr getSeeder(); - - Random(UInt64 seed = 0); - - // support copy constructor and operator= -- these require non-default - // implementations because of the impl_ pointer. - // They do a deep copy of impl_ so that an RNG and its copy generate the - // same set of numbers. - Random(const Random&); - Random& operator=(const Random&); - ~Random(); - - // write serialized data - using Serializable::write; - void write(RandomProto::Builder& proto) const override; - - // read and deserialize data - using Serializable::read; - void read(RandomProto::Reader& proto) override; - - // return a value uniformly distributed between 0 and max-1 - UInt32 getUInt32(UInt32 max = MAX32); - UInt64 getUInt64(UInt64 max = MAX64); - // return a double uniformly distributed on 0...1.0 - Real64 getReal64(); - - // populate choices with a random selection of nChoices elements from - // population. throws exception when nPopulation < nChoices - // templated functions must be defined in header - template - void sample(T population[], UInt32 nPopulation, - T choices[], UInt32 nChoices) - { - if (nChoices == 0) - { - return; - } - if (nChoices > nPopulation) - { - NTA_THROW << "population size must be greater than number of choices"; - } - UInt32 nextChoice = 0; - for (UInt32 i = 0; i < nPopulation; ++i) - { - if (getUInt32(nPopulation - i) < (nChoices - nextChoice)) - { - choices[nextChoice] = population[i]; - ++nextChoice; - if (nextChoice == nChoices) - { - break; - } + static RandomSeedFuncPtr getSeeder(); + + Random(UInt64 seed = 0); + + // support copy constructor and operator= -- these require non-default + // implementations because of the impl_ pointer. + // They do a deep copy of impl_ so that an RNG and its copy generate the + // same set of numbers. + Random(const Random &); + Random &operator=(const Random &); + ~Random(); + + // write serialized data + using Serializable::write; + void write(RandomProto::Builder &proto) const override; + + // read and deserialize data + using Serializable::read; + void read(RandomProto::Reader &proto) override; + + // return a value uniformly distributed between 0 and max-1 + UInt32 getUInt32(UInt32 max = MAX32); + UInt64 getUInt64(UInt64 max = MAX64); + // return a double uniformly distributed on 0...1.0 + Real64 getReal64(); + + // populate choices with a random selection of nChoices elements from + // population. throws exception when nPopulation < nChoices + // templated functions must be defined in header + template + void sample(T population[], UInt32 nPopulation, T choices[], + UInt32 nChoices) { + if (nChoices == 0) { + return; + } + if (nChoices > nPopulation) { + NTA_THROW << "population size must be greater than number of choices"; + } + UInt32 nextChoice = 0; + for (UInt32 i = 0; i < nPopulation; ++i) { + if (getUInt32(nPopulation - i) < (nChoices - nextChoice)) { + choices[nextChoice] = population[i]; + ++nextChoice; + if (nextChoice == nChoices) { + break; } } } - - // randomly shuffle the elements - template - void shuffle(RandomAccessIterator first, RandomAccessIterator last) - { - UInt n = last - first; - while (first != last) - { - // Pick a random position between the current and the end to swap the - // current element with. - UInt i = getUInt32(n); - std::swap(*first, *(first + i)); - - // Move to the next element and decrement the number of possible - // positions remaining. - first++; - n--; - } + } + + // randomly shuffle the elements + template + void shuffle(RandomAccessIterator first, RandomAccessIterator last) { + UInt n = last - first; + while (first != last) { + // Pick a random position between the current and the end to swap the + // current element with. + UInt i = getUInt32(n); + std::swap(*first, *(first + i)); + + // Move to the next element and decrement the number of possible + // positions remaining. + first++; + n--; } + } - // for STL compatibility - UInt32 operator()(UInt32 n = MAX32) { return getUInt32(n); } - - // normally used for debugging only - UInt64 getSeed() {return seed_;} + // for STL compatibility + UInt32 operator()(UInt32 n = MAX32) { return getUInt32(n); } - // for STL - typedef UInt32 argument_type; - typedef UInt32 result_type; + // normally used for debugging only + UInt64 getSeed() { return seed_; } - result_type max() { return MAX32; } - result_type min() { return 0; } + // for STL + typedef UInt32 argument_type; + typedef UInt32 result_type; - static const UInt32 MAX32; - static const UInt64 MAX64; + result_type max() { return MAX32; } + result_type min() { return 0; } - // called by the plugin framework so that plugins - // get the "global" seeder - static void initSeeder(const RandomSeedFuncPtr r); + static const UInt32 MAX32; + static const UInt64 MAX64; - static void shutdown(); + // called by the plugin framework so that plugins + // get the "global" seeder + static void initSeeder(const RandomSeedFuncPtr r); - protected: + static void shutdown(); - // each "universe" (application/plugin/python module) has its own instance, - // but the instance should be NULL in all but one - static Random *theInstanceP_; - // seeder_ is a function called by the constructor to get new random seeds - // If not set when we call Random constructor, then the singleton is allocated - // and seeder_ is set to a function that uses our singleton - // initFromPlatformServices can also be used to initialize the seeder_ - static RandomSeedFuncPtr seeder_; +protected: + // each "universe" (application/plugin/python module) has its own instance, + // but the instance should be NULL in all but one + static Random *theInstanceP_; + // seeder_ is a function called by the constructor to get new random seeds + // If not set when we call Random constructor, then the singleton is allocated + // and seeder_ is set to a function that uses our singleton + // initFromPlatformServices can also be used to initialize the seeder_ + static RandomSeedFuncPtr seeder_; - void reseed(UInt64 seed); + void reseed(UInt64 seed); - RandomImpl *impl_; - UInt64 seed_; + RandomImpl *impl_; + UInt64 seed_; - friend class RandomTest; - friend std::ostream& operator<<(std::ostream&, const Random&); - friend std::istream& operator>>(std::istream&, Random&); - friend NTA_UInt64 GetRandomSeed(); + friend class RandomTest; + friend std::ostream &operator<<(std::ostream &, const Random &); + friend std::istream &operator>>(std::istream &, Random &); + friend NTA_UInt64 GetRandomSeed(); +}; - }; - - // serialization/deserialization - std::ostream& operator<<(std::ostream&, const Random&); - std::istream& operator>>(std::istream&, Random&); - - // This function returns seeds from the Random singleton in our - // "universe" (application, plugin, python module). If, when the - // Random constructor is called, seeder_ is NULL, then seeder_ is - // set to this function. The plugin framework can override this - // behavior by explicitly setting the seeder to the RandomSeeder - // function provided by the application. - NTA_UInt64 GetRandomSeed(); +// serialization/deserialization +std::ostream &operator<<(std::ostream &, const Random &); +std::istream &operator>>(std::istream &, Random &); +// This function returns seeds from the Random singleton in our +// "universe" (application, plugin, python module). If, when the +// Random constructor is called, seeder_ is NULL, then seeder_ is +// set to this function. The plugin framework can override this +// behavior by explicitly setting the seeder to the RandomSeeder +// function provided by the application. +NTA_UInt64 GetRandomSeed(); } // namespace nupic - - #endif // NTA_RANDOM_HPP - diff --git a/src/nupic/utils/StringUtils.cpp b/src/nupic/utils/StringUtils.cpp index 393c59bdd9..3e022f1841 100644 --- a/src/nupic/utils/StringUtils.cpp +++ b/src/nupic/utils/StringUtils.cpp @@ -20,18 +20,17 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Implementation of utility functions for string conversion */ -#include -#include #include +#include +#include using namespace nupic; -bool StringUtils::toBool(const std::string& s, bool throwOnError, bool * fail) -{ +bool StringUtils::toBool(const std::string &s, bool throwOnError, bool *fail) { if (fail) *fail = false; bool b = false; @@ -41,31 +40,27 @@ bool StringUtils::toBool(const std::string& s, bool throwOnError, bool * fail) b = true; } else if (us == "false" || us == "no" || us == "0") { b = false; - } else if (! throwOnError) { + } else if (!throwOnError) { if (fail) *fail = true; } else { - NTA_THROW << "StringUtils::toBool: tried to parse non-boolean string \"" << s << "\""; + NTA_THROW << "StringUtils::toBool: tried to parse non-boolean string \"" + << s << "\""; } return b; } - -Real32 StringUtils::toReal32(const std::string& s, bool throwOnError, bool * fail) -{ +Real32 StringUtils::toReal32(const std::string &s, bool throwOnError, + bool *fail) { if (fail) *fail = false; Real32 r; std::istringstream ss(s); ss >> r; - if (ss.fail() || !ss.eof()) - { - if (throwOnError) - { + if (ss.fail() || !ss.eof()) { + if (throwOnError) { NTA_THROW << "StringUtils::toReal32 -- invalid string \"" << s << "\""; - } - else - { + } else { if (fail) *fail = true; } @@ -74,21 +69,17 @@ Real32 StringUtils::toReal32(const std::string& s, bool throwOnError, bool * fai return r; } -UInt32 StringUtils::toUInt32(const std::string& s, bool throwOnError, bool * fail) -{ +UInt32 StringUtils::toUInt32(const std::string &s, bool throwOnError, + bool *fail) { if (fail) *fail = false; UInt32 i; std::istringstream ss(s); ss >> i; - if (ss.fail() || !ss.eof()) - { - if (throwOnError) - { + if (ss.fail() || !ss.eof()) { + if (throwOnError) { NTA_THROW << "StringUtils::toInt -- invalid string \"" << s << "\""; - } - else - { + } else { if (fail) *fail = true; } @@ -97,21 +88,17 @@ UInt32 StringUtils::toUInt32(const std::string& s, bool throwOnError, bool * fai return i; } -Int32 StringUtils::toInt32(const std::string& s, bool throwOnError, bool * fail) -{ +Int32 StringUtils::toInt32(const std::string &s, bool throwOnError, + bool *fail) { if (fail) *fail = false; Int32 i; std::istringstream ss(s); ss >> i; - if (ss.fail() || !ss.eof()) - { - if (throwOnError) - { + if (ss.fail() || !ss.eof()) { + if (throwOnError) { NTA_THROW << "StringUtils::toInt -- invalid string \"" << s << "\""; - } - else - { + } else { if (fail) *fail = true; } @@ -120,21 +107,17 @@ Int32 StringUtils::toInt32(const std::string& s, bool throwOnError, bool * fail) return i; } -UInt64 StringUtils::toUInt64(const std::string& s, bool throwOnError, bool * fail) -{ +UInt64 StringUtils::toUInt64(const std::string &s, bool throwOnError, + bool *fail) { if (fail) *fail = false; UInt64 i; std::istringstream ss(s); ss >> i; - if (ss.fail() || !ss.eof()) - { - if (throwOnError) - { + if (ss.fail() || !ss.eof()) { + if (throwOnError) { NTA_THROW << "StringUtils::toInt -- invalid string \"" << s << "\""; - } - else - { + } else { if (fail) *fail = true; } @@ -143,36 +126,29 @@ UInt64 StringUtils::toUInt64(const std::string& s, bool throwOnError, bool * fai return i; } - -size_t StringUtils::toSizeT(const std::string& s, bool throwOnError, bool * fail) -{ +size_t StringUtils::toSizeT(const std::string &s, bool throwOnError, + bool *fail) { if (fail) *fail = false; size_t i; std::istringstream ss(s); ss >> i; - if (ss.fail() || !ss.eof()) - { - if (throwOnError) - { + if (ss.fail() || !ss.eof()) { + if (throwOnError) { NTA_THROW << "StringUtils::toSizeT -- invalid string \"" << s << "\""; - } - else - { + } else { if (fail) *fail = true; } - } + } return i; } -bool StringUtils::startsWith(const std::string& s, const std::string& prefix) -{ +bool StringUtils::startsWith(const std::string &s, const std::string &prefix) { return s.find(prefix) == 0; } -bool StringUtils::endsWith(const std::string& s, const std::string& ending) -{ +bool StringUtils::endsWith(const std::string &s, const std::string &ending) { if (ending.size() > s.size()) return false; size_t found = s.rfind(ending); @@ -183,102 +159,91 @@ bool StringUtils::endsWith(const std::string& s, const std::string& ending) return true; } - -std::string StringUtils::fromInt(long long i) -{ +std::string StringUtils::fromInt(long long i) { std::stringstream ss; ss << i; return ss.str(); } -std::string StringUtils::base64Encode(const void* buf, Size inLen) -{ +std::string StringUtils::base64Encode(const void *buf, Size inLen) { Size len = apr_base64_encode_len((int)inLen); // int-casting for win. std::string outS(len, '\0'); - apr_base64_encode((char*)outS.data(), (const char*)buf, (int) inLen); // int-casting for win. - outS.resize(len-1); // len includes the NULL at the end + apr_base64_encode((char *)outS.data(), (const char *)buf, + (int)inLen); // int-casting for win. + outS.resize(len - 1); // len includes the NULL at the end return outS; } - -std::string StringUtils::base64Encode(const std::string& s) -{ - Size len = apr_base64_encode_len ( (int) s.size() ); +std::string StringUtils::base64Encode(const std::string &s) { + Size len = apr_base64_encode_len((int)s.size()); std::string outS(len, '\0'); - apr_base64_encode((char*)outS.data(), s.data(), (int) s.size()); - outS.resize(len-1); // len includes the NULL at the end + apr_base64_encode((char *)outS.data(), s.data(), (int)s.size()); + outS.resize(len - 1); // len includes the NULL at the end return outS; } -std::string StringUtils::base64Decode(const void* buf, Size inLen) -{ - std::string outS(inLen+1, '\0'); - size_t decodedLen = apr_base64_decode_binary ((unsigned char*)outS.data(), (const char*)buf); +std::string StringUtils::base64Decode(const void *buf, Size inLen) { + std::string outS(inLen + 1, '\0'); + size_t decodedLen = + apr_base64_decode_binary((unsigned char *)outS.data(), (const char *)buf); outS.resize(decodedLen); return outS; } - -std::string StringUtils::base64Decode(const std::string& s) -{ - std::string outS(s.size()+1, '\0'); - size_t decodedLen = apr_base64_decode_binary ((unsigned char*)outS.data(), s.c_str()); +std::string StringUtils::base64Decode(const std::string &s) { + std::string outS(s.size() + 1, '\0'); + size_t decodedLen = + apr_base64_decode_binary((unsigned char *)outS.data(), s.c_str()); outS.resize(decodedLen); return outS; } -#define HEXIFY(val) ((val) > 9 ? ('a' + (val) - 10) : ('0' + (val))) +#define HEXIFY(val) ((val) > 9 ? ('a' + (val)-10) : ('0' + (val))) -std::string StringUtils::hexEncode(const void* buf, Size inLen) -{ - std::string s(inLen*2, '\0'); - const unsigned char *charbuf = (const unsigned char*)buf; - for (Size i = 0; i < inLen; i++) - { +std::string StringUtils::hexEncode(const void *buf, Size inLen) { + std::string s(inLen * 2, '\0'); + const unsigned char *charbuf = (const unsigned char *)buf; + for (Size i = 0; i < inLen; i++) { unsigned char x = charbuf[i]; // high order bits unsigned char val = x >> 4; - s[i*2] = HEXIFY(val); + s[i * 2] = HEXIFY(val); val = x & 0xF; - s[i*2+1] = HEXIFY(val); + s[i * 2 + 1] = HEXIFY(val); } return s; } - - //-------------------------------------------------------------------------------- -void StringUtils::toIntList(const std::string& s, std::vector& list, bool allowAll, - bool asRanges) -{ - if(!toIntListNoThrow(s, list, allowAll, asRanges)) { +void StringUtils::toIntList(const std::string &s, std::vector &list, + bool allowAll, bool asRanges) { + if (!toIntListNoThrow(s, list, allowAll, asRanges)) { const std::string errPrefix = "StringUtils::toIntList() - "; - throw (std::runtime_error(errPrefix+"Invalid string: " + s)); + throw(std::runtime_error(errPrefix + "Invalid string: " + s)); } } //-------------------------------------------------------------------------------- -bool StringUtils::toIntListNoThrow(const std::string& s, std::vector& list, - bool allowAll, bool asRanges) -{ - +bool StringUtils::toIntListNoThrow(const std::string &s, std::vector &list, + bool allowAll, bool asRanges) { + UInt startNum, endNum; - const char* startP = s.c_str(); - char* endP; - - // Set global errno to 0. strtoul sets this if a conversion error occurs. + const char *startP = s.c_str(); + char *endP; + + // Set global errno to 0. strtoul sets this if a conversion error occurs. errno = 0; - + // Loop through the string list.clear(); // Skip white space at start while (*startP && isspace(*startP)) startP++; - + // Do we allow all? if (allowAll) { - if (!strncmp (startP, "all", 3) && startP[3] == 0) + if (!strncmp(startP, "all", 3) && startP[3] == 0) return true; if (startP[0] == 0) return true; @@ -286,20 +251,19 @@ bool StringUtils::toIntListNoThrow(const std::string& s, std::vector& list, if (startP[0] == 0) return false; } - - while (*startP) - { + + while (*startP) { // ------------------------------------------------------------------------------ - // Get first digit + // Get first digit startNum = strtoul(startP, &endP, 10 /*base*/); if (errno != 0) return false; startP = endP; - + // Skip white space while (*startP && isspace(*startP)) startP++; - + // ------------------------------------------------------------------------------ // Do we have a '-'? If so, get the second number if (*startP == '-') { @@ -308,7 +272,7 @@ bool StringUtils::toIntListNoThrow(const std::string& s, std::vector& list, if (errno != 0) return false; startP = endP; - + // Store all number into the vector if (endNum < startNum) return false; @@ -316,7 +280,7 @@ bool StringUtils::toIntListNoThrow(const std::string& s, std::vector& list, list.push_back((Int)startNum); list.push_back((Int)(endNum - startNum + 1)); } else { - for (UInt i=startNum; i<=endNum; i++) + for (UInt i = startNum; i <= endNum; i++) list.push_back((Int)i); } @@ -325,23 +289,23 @@ bool StringUtils::toIntListNoThrow(const std::string& s, std::vector& list, startP++; } else { list.push_back((Int)startNum); - if (asRanges) - list.push_back((Int)1); + if (asRanges) + list.push_back((Int)1); } - + // Done if end of string if (*startP == 0) break; - + // ------------------------------------------------------------------------------ - // Must have a comma between entries - if (*startP++ != ',') + // Must have a comma between entries + if (*startP++ != ',') return false; - + // Skip white space after the comma while (*startP && isspace(*startP)) startP++; - + // Must be more digits after the comma if (*startP == 0) return false; @@ -349,32 +313,31 @@ bool StringUtils::toIntListNoThrow(const std::string& s, std::vector& list, return true; } - - + //-------------------------------------------------------------------------------- -boost::shared_array StringUtils::toByteArray(const std::string& s, Size bitCount) -{ +boost::shared_array StringUtils::toByteArray(const std::string &s, + Size bitCount) { // Get list of integers std::vector list; StringUtils::toIntList(s, list, true /*allowAll*/); if (list.empty()) return boost::shared_array(nullptr); - + // Put this into the mask - Size numBytes = (bitCount+7) / 8; + Size numBytes = (bitCount + 7) / 8; boost::shared_array mask(new Byte[numBytes]); - Byte* maskP = mask.get(); + Byte *maskP = mask.get(); ::memset(maskP, 0, numBytes); - for (auto & elem : list) { - UInt entry = elem; + for (auto &elem : list) { + UInt entry = elem; if (entry >= bitCount) - NTA_THROW << "StringUtils::toByteArray() - " << "The list " << s - << " contains an entry greater than the max allowed of " << bitCount; - maskP[entry/8] |= 1 << (entry%8); + NTA_THROW << "StringUtils::toByteArray() - " + << "The list " << s + << " contains an entry greater than the max allowed of " + << bitCount; + maskP[entry / 8] |= 1 << (entry % 8); } - + // Return it return mask; } - - diff --git a/src/nupic/utils/StringUtils.hpp b/src/nupic/utils/StringUtils.hpp index cd96dfb7bc..9168809de2 100644 --- a/src/nupic/utils/StringUtils.hpp +++ b/src/nupic/utils/StringUtils.hpp @@ -20,177 +20,179 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Utility functions for string conversion */ #ifndef NTA_STRING_UTILS_HPP #define NTA_STRING_UTILS_HPP - -#include #include +#include +#include +#include #include #include -#include -#include -namespace nupic -{ - // TODO: Should this be a namespace instead of a class? - class StringUtils - { - public: - //-------------------------------------------------------------------------------- - /** - * @b Responsibility: - * Convert string to a typed value (using stringstream) - * Bool: Convert a string to a bool. Accepts "true", "yes", "1", with different - * capitalizations. Anything else returns false. - * Int32/Int64/etc Convert a string to a numerical type. - * Uses a stringstream to convert. - * - * @param s a string to convert - * @param throwOnError a bool that determines if to throw an error on failure - * @param fail a bool pointer that if not NULL gets set to true if the conversion fails - * @retval boolean value - */ - static bool toBool(const std::string& s, bool throwOnError = false, bool * fail = nullptr); - static UInt32 toUInt32(const std::string& s, bool throwOnError = false, bool * fail = nullptr); - static Int32 toInt32(const std::string& s, bool throwOnError = false, bool * fail = nullptr); - static UInt64 toUInt64(const std::string& s, bool throwOnError = false, bool * fail = nullptr); - static Real32 toReal32(const std::string& s, bool throwOnError = false, bool * fail = nullptr); - static Real64 toReal64(const std::string& s, bool throwOnError = false, bool * fail = nullptr); - static size_t toSizeT(const std::string& s, bool throwOnError = false, bool * fail = nullptr); - - static bool startsWith(const std::string& s, const std::string& prefix); - static bool endsWith(const std::string& s, const std::string& ending); - - - - - //-------------------------------------------------------------------------------- - /** - * @b Responsibility: - * Convert an integer to a string - * - * @param i an integer to convert - * @retval string - */ - - static std::string fromInt(long long i); - - //-------------------------------------------------------------------------------- - /** - * @b Responsibility: - * Base64 encode a string - * - * @param s a string to encode - * @retval encoded string - */ - - static std::string base64Encode(const std::string& s); - - - //-------------------------------------------------------------------------------- - /** - * @b Responsibility: - * Base64 encode a memory buffer - * - * @param buf buffer containing the data to encode - * @param inLen the length in bytes of the buffer to encode - * @retval encoded string - */ - static std::string base64Encode(const void* buf, Size inLen); - - //-------------------------------------------------------------------------------- - /** - * @b Responsibility: - * Base64 decode a string - * - * @param s a string to decode - * @retval decoded string - */ - - static std::string base64Decode(const std::string& s); - - //-------------------------------------------------------------------------------- - /** - * @b Responsibility: - * Base64 decode from a memory buffer - * - * @param buf a buffer to decode - * @param inLen length of buffer - * @retval decoded string - */ - static std::string base64Decode(const void* buf, Size inLen); - - //-------------------------------------------------------------------------------- - /** - * @b Responsibility: - * Represent a binary buffer with a hexidecimal string - * - * @param buf a buffer to represent - * @param inLen length of buffer - * @retval hexidecimal string - */ - static std::string hexEncode(const void* buf, Size inLen); - - //-------------------------------------------------------------------------------- - /** - * @b Responsibility: - * Convert a string specifying a list of unsigned numbers into a vector. - * The string can be of the form "0-9,10, 12, 13-19". - * - * If 'allowAll' is true, the empty string and the string "all" both return an empty list. - * If 'allowAll' is not true, only integer lists are accepted and the empty string and - * the string "all" throw an exception - * - * @param s a string to convert - * @param list vector to fill in - * @param allowAll if true, s can be set to "all" - * @param asRanges if true, list is filled in as pairs of integers that specify the begin - * and size of each range of integers in s. If false, list contains - * each and every one of the integers specified by s. - * @retval void - */ - static void toIntList(const std::string& s, std::vector& list, bool allowAll=false, - bool asRanges=false); - - /** - * Non-throwing version of toIntList. - * - * If 'allowAll' is true, the empty string and the string "all" both return an empty list. - * If 'allowAll' is not true, only integer lists are accepted and the empty string and - * the string "all" throw an exception - * - * @param s a string to convert - * @param v vector to fill in - * @retval true if successfully parsed. false if a parsing error occurred - */ - static bool toIntListNoThrow(const std::string& s, std::vector& list, - bool allowAll=false, bool asRanges=false); - - //-------------------------------------------------------------------------------- - /** - * @b Responsibility: - * Convert a string specifying a list of unsigned numbers into pointer to an array of - * bytes that specify a mask of which numbers were included in the list. If a number - * is in the list, the corresponding bit will be set in the mask. Each byte specifies - * 8 bits of the mask, bit 0 of byte 0 holds entry 0, bit 1 of byte 0 holds entry 1, etc. - * - * The string can be of the form "0-9,10, 12, 13-19", "all", or "". Both "all" and "" - * are special cases representing all bits and return a boost::shared_array with a - * NIL pointer (retval.get() == NULL). - * - * @param s a string to convert - * @param bitCount number of bits to include in the return mask. - * @retval boost::shared_array containing the dynamically allocated mask - * - */ - static boost::shared_array toByteArray(const std::string& s, Size bitCount); - - }; -} +namespace nupic { +// TODO: Should this be a namespace instead of a class? +class StringUtils { +public: + //-------------------------------------------------------------------------------- + /** + * @b Responsibility: + * Convert string to a typed value (using stringstream) + * Bool: Convert a string to a bool. Accepts "true", "yes", "1", with + * different capitalizations. Anything else returns false. Int32/Int64/etc + * Convert a string to a numerical type. Uses a stringstream to convert. + * + * @param s a string to convert + * @param throwOnError a bool that determines if to throw an error on failure + * @param fail a bool pointer that if not NULL gets set to true if the + * conversion fails + * @retval boolean value + */ + static bool toBool(const std::string &s, bool throwOnError = false, + bool *fail = nullptr); + static UInt32 toUInt32(const std::string &s, bool throwOnError = false, + bool *fail = nullptr); + static Int32 toInt32(const std::string &s, bool throwOnError = false, + bool *fail = nullptr); + static UInt64 toUInt64(const std::string &s, bool throwOnError = false, + bool *fail = nullptr); + static Real32 toReal32(const std::string &s, bool throwOnError = false, + bool *fail = nullptr); + static Real64 toReal64(const std::string &s, bool throwOnError = false, + bool *fail = nullptr); + static size_t toSizeT(const std::string &s, bool throwOnError = false, + bool *fail = nullptr); + + static bool startsWith(const std::string &s, const std::string &prefix); + static bool endsWith(const std::string &s, const std::string &ending); + + //-------------------------------------------------------------------------------- + /** + * @b Responsibility: + * Convert an integer to a string + * + * @param i an integer to convert + * @retval string + */ + + static std::string fromInt(long long i); + + //-------------------------------------------------------------------------------- + /** + * @b Responsibility: + * Base64 encode a string + * + * @param s a string to encode + * @retval encoded string + */ + + static std::string base64Encode(const std::string &s); + + //-------------------------------------------------------------------------------- + /** + * @b Responsibility: + * Base64 encode a memory buffer + * + * @param buf buffer containing the data to encode + * @param inLen the length in bytes of the buffer to encode + * @retval encoded string + */ + static std::string base64Encode(const void *buf, Size inLen); + + //-------------------------------------------------------------------------------- + /** + * @b Responsibility: + * Base64 decode a string + * + * @param s a string to decode + * @retval decoded string + */ + + static std::string base64Decode(const std::string &s); + + //-------------------------------------------------------------------------------- + /** + * @b Responsibility: + * Base64 decode from a memory buffer + * + * @param buf a buffer to decode + * @param inLen length of buffer + * @retval decoded string + */ + static std::string base64Decode(const void *buf, Size inLen); + + //-------------------------------------------------------------------------------- + /** + * @b Responsibility: + * Represent a binary buffer with a hexidecimal string + * + * @param buf a buffer to represent + * @param inLen length of buffer + * @retval hexidecimal string + */ + static std::string hexEncode(const void *buf, Size inLen); + + //-------------------------------------------------------------------------------- + /** + * @b Responsibility: + * Convert a string specifying a list of unsigned numbers into a vector. + * The string can be of the form "0-9,10, 12, 13-19". + * + * If 'allowAll' is true, the empty string and the string "all" both return an + * empty list. If 'allowAll' is not true, only integer lists are accepted and + * the empty string and the string "all" throw an exception + * + * @param s a string to convert + * @param list vector to fill in + * @param allowAll if true, s can be set to "all" + * @param asRanges if true, list is filled in as pairs of integers that + * specify the begin and size of each range of integers in s. If false, list + * contains each and every one of the integers specified by s. + * @retval void + */ + static void toIntList(const std::string &s, std::vector &list, + bool allowAll = false, bool asRanges = false); + + /** + * Non-throwing version of toIntList. + * + * If 'allowAll' is true, the empty string and the string "all" both return an + * empty list. If 'allowAll' is not true, only integer lists are accepted and + * the empty string and the string "all" throw an exception + * + * @param s a string to convert + * @param v vector to fill in + * @retval true if successfully parsed. false if a parsing error occurred + */ + static bool toIntListNoThrow(const std::string &s, std::vector &list, + bool allowAll = false, bool asRanges = false); + + //-------------------------------------------------------------------------------- + /** + * @b Responsibility: + * Convert a string specifying a list of unsigned numbers into pointer to an + * array of bytes that specify a mask of which numbers were included in the + * list. If a number is in the list, the corresponding bit will be set in the + * mask. Each byte specifies 8 bits of the mask, bit 0 of byte 0 holds entry + * 0, bit 1 of byte 0 holds entry 1, etc. + * + * The string can be of the form "0-9,10, 12, 13-19", "all", or "". Both "all" + * and "" are special cases representing all bits and return a + * boost::shared_array with a NIL pointer (retval.get() == NULL). + * + * @param s a string to convert + * @param bitCount number of bits to include in the return mask. + * @retval boost::shared_array containing the dynamically allocated + * mask + * + */ + static boost::shared_array toByteArray(const std::string &s, + Size bitCount); +}; +} // namespace nupic #endif // NTA_STRING_UTILS_HPP diff --git a/src/nupic/utils/TRandom.cpp b/src/nupic/utils/TRandom.cpp index 10fe9a049e..db4ecc54b0 100644 --- a/src/nupic/utils/TRandom.cpp +++ b/src/nupic/utils/TRandom.cpp @@ -20,71 +20,62 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file Random Number Generator implementation */ -#include -#include -#include -#include #include #include +#include +#include +#include +#include using namespace nupic; -TRandom::TRandom(std::string name) -{ +TRandom::TRandom(std::string name) { UInt64 seed = 0; std::string optionName = "set_random"; - if (name != "") - { + if (name != "") { optionName += "_" + name; } bool seed_from_environment = false; - if (Env::isOptionSet(optionName)) - { + if (Env::isOptionSet(optionName)) { seed_from_environment = true; std::string val = Env::getOption(optionName); - try - { + try { seed = StringUtils::toUInt64(val, true); - } - catch (...) - { + } catch (...) { NTA_WARN << "Invalid value \"" << val << "\" for NTA_SET_RANDOM. Using 1"; seed = 1; } - } - else - { - // Seed the global rng from time(). - // Don't seed subsequent ones from time() because several random - // number generators may be initialized within the same second. - // Instead, use the global rng. + } else { + // Seed the global rng from time(). + // Don't seed subsequent ones from time() because several random + // number generators may be initialized within the same second. + // Instead, use the global rng. if (theInstanceP_ == nullptr) { seed = (UInt64)time(nullptr); } else { seed = (*Random::getSeeder())(); } + } - } - - if (Env::isOptionSet("random_debug")) - { + if (Env::isOptionSet("random_debug")) { if (seed_from_environment) { - NTA_INFO << "TRandom(" << name << ") -- initializing with seed " << seed << " from environment"; + NTA_INFO << "TRandom(" << name << ") -- initializing with seed " << seed + << " from environment"; } else { - NTA_INFO << "TRandom(" << name << ") -- initializing with seed " << seed; + NTA_INFO << "TRandom(" << name << ") -- initializing with seed " << seed; } } // Create the actual RNG // @todo to add different algorithm support, this is where we will // instantiate different implementations depending on the requested - // algorithm + // algorithm reseed(seed); } diff --git a/src/nupic/utils/TRandom.hpp b/src/nupic/utils/TRandom.hpp index 727b510965..e7f10f6485 100644 --- a/src/nupic/utils/TRandom.hpp +++ b/src/nupic/utils/TRandom.hpp @@ -20,11 +20,10 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file Random Number Generator interface (for tests) */ - #ifndef NTA_TRANDOM_HPP #define NTA_TRANDOM_HPP @@ -32,46 +31,46 @@ #include namespace nupic { - /** - * @b Responsibility - * Provides standard random number generation for testing. - * Seed can be logged in one run and then set in another. - * @b Rationale - * Makes it possible to reproduce tests that are driven by random number generation. - * - * @b Description - * Functionality is similar to the standard random() function that is provided by C. - * - * TRandom is a subclass of Random with an additional constructor. - * This constructor creates a named generator -- normally self-seeded, but - * seed may be set explicitly through an environment variable. For example: - * Random rng("level2TP"); - * If NTA_RANDOM_DEBUG is set, this object will log its self-seed - * The seed can be explicitly set through NTA_RANDOM_SEED_level2TP - * - * If self-seeded, the seed comes from the same global random number generaetor - * used for Random. - * - * Automated tests that use random numbers should normally use named generators. - * This allows them to get a different seed each time, but also allows reproducibility - * in the case that a test failure is triggered by a particular seed. - * - * Random should not be used if cryptographic strength is required (e.g. for - * generating a challenge in an authentication scheme). - * - * @todo Add ability to specify different rng algorithms. - */ - - class TRandom : public Random - { - public: - TRandom(std::string name); +/** + * @b Responsibility + * Provides standard random number generation for testing. + * Seed can be logged in one run and then set in another. + * @b Rationale + * Makes it possible to reproduce tests that are driven by random number + * generation. + * + * @b Description + * Functionality is similar to the standard random() function that is provided + * by C. + * + * TRandom is a subclass of Random with an additional constructor. + * This constructor creates a named generator -- normally self-seeded, but + * seed may be set explicitly through an environment variable. For example: + * Random rng("level2TP"); + * If NTA_RANDOM_DEBUG is set, this object will log its self-seed + * The seed can be explicitly set through NTA_RANDOM_SEED_level2TP + * + * If self-seeded, the seed comes from the same global random number generaetor + * used for Random. + * + * Automated tests that use random numbers should normally use named generators. + * This allows them to get a different seed each time, but also allows + * reproducibility in the case that a test failure is triggered by a particular + * seed. + * + * Random should not be used if cryptographic strength is required (e.g. for + * generating a challenge in an authentication scheme). + * + * @todo Add ability to specify different rng algorithms. + */ - private: - friend class TRandomTest; +class TRandom : public Random { +public: + TRandom(std::string name); - }; -} +private: + friend class TRandomTest; +}; +} // namespace nupic #endif // NTA_TRANDOM_HPP - diff --git a/src/nupic/utils/Watcher.cpp b/src/nupic/utils/Watcher.cpp index eb860ab6ba..c4344aba77 100644 --- a/src/nupic/utils/Watcher.cpp +++ b/src/nupic/utils/Watcher.cpp @@ -20,526 +20,433 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Implementation of the Watcher class */ +#include #include #include #include -#include -#include -#include -#include #include -#include +#include +#include #include -#include -#include #include +#include +#include #include +#include +#include -namespace nupic -{ +namespace nupic { - Watcher::Watcher(std::string fileName) - { - data_.fileName = fileName; - try - { - data_.outStream = new OFStream(fileName.c_str()); - } - catch (std::exception &) - { - NTA_THROW << "Unable to open filename " << fileName << " for network watcher"; - } +Watcher::Watcher(std::string fileName) { + data_.fileName = fileName; + try { + data_.outStream = new OFStream(fileName.c_str()); + } catch (std::exception &) { + NTA_THROW << "Unable to open filename " << fileName + << " for network watcher"; } +} - Watcher::~Watcher() - { - this->flushFile(); - this->closeFile(); - delete data_.outStream; - } +Watcher::~Watcher() { + this->flushFile(); + this->closeFile(); + delete data_.outStream; +} - unsigned int Watcher::watchParam(std::string regionName, - std::string varName, - int nodeIndex, - bool sparseOutput) - { - watchData watch; - watch.varName = varName; - watch.wType = parameter; - watch.regionName = regionName; - watch.nodeIndex = nodeIndex; - watch.sparseOutput = sparseOutput; - watch.watchID = data_.watches.size()+1; - data_.watches.push_back(watch); - return watch.watchID; - } +unsigned int Watcher::watchParam(std::string regionName, std::string varName, + int nodeIndex, bool sparseOutput) { + watchData watch; + watch.varName = varName; + watch.wType = parameter; + watch.regionName = regionName; + watch.nodeIndex = nodeIndex; + watch.sparseOutput = sparseOutput; + watch.watchID = data_.watches.size() + 1; + data_.watches.push_back(watch); + return watch.watchID; +} - unsigned int Watcher::watchOutput(std::string regionName, - std::string varName, - bool sparseOutput) - { - watchData watch; - watch.varName = varName; - watch.wType = output; - watch.regionName = regionName; - watch.nodeIndex = -1; - watch.isArray = false; - watch.sparseOutput = sparseOutput; - watch.watchID = data_.watches.size()+1; - data_.watches.push_back(watch); - return watch.watchID; - } +unsigned int Watcher::watchOutput(std::string regionName, std::string varName, + bool sparseOutput) { + watchData watch; + watch.varName = varName; + watch.wType = output; + watch.regionName = regionName; + watch.nodeIndex = -1; + watch.isArray = false; + watch.sparseOutput = sparseOutput; + watch.watchID = data_.watches.size() + 1; + data_.watches.push_back(watch); + return watch.watchID; +} - //TODO: clean up, add support for uncloned arrays, - //add support for output of a different type than Real32 - void Watcher::watcherCallback(Network* net, UInt64 iteration, void* dataIn) - { - allData& data = *(static_cast(dataIn)); - //iterate through each watch - for (auto & elem : data.watches) - { - watchData watch = elem; - std::string value; - std::stringstream out; - if (watch.wType == parameter) +// TODO: clean up, add support for uncloned arrays, +// add support for output of a different type than Real32 +void Watcher::watcherCallback(Network *net, UInt64 iteration, void *dataIn) { + allData &data = *(static_cast(dataIn)); + // iterate through each watch + for (auto &elem : data.watches) { + watchData watch = elem; + std::string value; + std::stringstream out; + if (watch.wType == parameter) { + if (watch.isArray) // currently don't support uncloned arrays { - if (watch.isArray) //currently don't support uncloned arrays - { - switch(watch.varType) - { - case NTA_BasicType_Int32: - { - Array a(NTA_BasicType_Int32); - watch.region->getParameterArray(watch.varName, a); - Int32* buf = (Int32*) a.getBuffer(); - out << a.getCount(); - if (watch.sparseOutput) - { - for (UInt j = 0; j < a.getCount(); j++) - { - if (buf[j] != (Int32)0) - out << " " << j; - } + switch (watch.varType) { + case NTA_BasicType_Int32: { + Array a(NTA_BasicType_Int32); + watch.region->getParameterArray(watch.varName, a); + Int32 *buf = (Int32 *)a.getBuffer(); + out << a.getCount(); + if (watch.sparseOutput) { + for (UInt j = 0; j < a.getCount(); j++) { + if (buf[j] != (Int32)0) + out << " " << j; } - else - { - for (UInt j = 0; j < a.getCount(); j++) - { - out << " " << buf[j]; - } + } else { + for (UInt j = 0; j < a.getCount(); j++) { + out << " " << buf[j]; } - break; } - case NTA_BasicType_UInt32: - { - Array a(NTA_BasicType_UInt32); - watch.region->getParameterArray(watch.varName, a); - UInt32* buf = (UInt32*) a.getBuffer(); - out << a.getCount(); - if (watch.sparseOutput) - { - for (UInt j = 0; j < a.getCount(); j++) - { - if (buf[j] != (UInt32)0) - out << " " << j; - } + break; + } + case NTA_BasicType_UInt32: { + Array a(NTA_BasicType_UInt32); + watch.region->getParameterArray(watch.varName, a); + UInt32 *buf = (UInt32 *)a.getBuffer(); + out << a.getCount(); + if (watch.sparseOutput) { + for (UInt j = 0; j < a.getCount(); j++) { + if (buf[j] != (UInt32)0) + out << " " << j; } - else - { - for (UInt j = 0; j < a.getCount(); j++) - { - out << " " << buf[j]; - } + } else { + for (UInt j = 0; j < a.getCount(); j++) { + out << " " << buf[j]; } - break; } - case NTA_BasicType_Int64: - { - Array a(NTA_BasicType_Int64); - watch.region->getParameterArray(watch.varName, a); - Int64* buf = (Int64*) a.getBuffer(); - out << a.getCount(); - if (watch.sparseOutput) - { - for (UInt j = 0; j < a.getCount(); j++) - { - if (buf[j] != (Int64)0) - out << " " << j; - } + break; + } + case NTA_BasicType_Int64: { + Array a(NTA_BasicType_Int64); + watch.region->getParameterArray(watch.varName, a); + Int64 *buf = (Int64 *)a.getBuffer(); + out << a.getCount(); + if (watch.sparseOutput) { + for (UInt j = 0; j < a.getCount(); j++) { + if (buf[j] != (Int64)0) + out << " " << j; } - else - { - for (UInt j = 0; j < a.getCount(); j++) - { - out << " " << buf[j]; - } + } else { + for (UInt j = 0; j < a.getCount(); j++) { + out << " " << buf[j]; } - break; } - case NTA_BasicType_UInt64: - { - Array a(NTA_BasicType_UInt64); - watch.region->getParameterArray(watch.varName, a); - UInt64* buf = (UInt64*) a.getBuffer(); - out << a.getCount(); - if (watch.sparseOutput) - { - for (UInt j = 0; j < a.getCount(); j++) - { - if (buf[j] != (UInt64)0) - out << " " << j; - } + break; + } + case NTA_BasicType_UInt64: { + Array a(NTA_BasicType_UInt64); + watch.region->getParameterArray(watch.varName, a); + UInt64 *buf = (UInt64 *)a.getBuffer(); + out << a.getCount(); + if (watch.sparseOutput) { + for (UInt j = 0; j < a.getCount(); j++) { + if (buf[j] != (UInt64)0) + out << " " << j; } - else - { - for (UInt j = 0; j < a.getCount(); j++) - { - out << " " << buf[j]; - } + } else { + for (UInt j = 0; j < a.getCount(); j++) { + out << " " << buf[j]; } - break; } - case NTA_BasicType_Real32: - { - Array a(NTA_BasicType_Real32); - watch.region->getParameterArray(watch.varName, a); - Real32* buf = (Real32*) a.getBuffer(); - out << a.getCount(); - if (watch.sparseOutput) - { - for (UInt j = 0; j < a.getCount(); j++) - { - if (buf[j] != (Real32)0) - out << " " << j; - } + break; + } + case NTA_BasicType_Real32: { + Array a(NTA_BasicType_Real32); + watch.region->getParameterArray(watch.varName, a); + Real32 *buf = (Real32 *)a.getBuffer(); + out << a.getCount(); + if (watch.sparseOutput) { + for (UInt j = 0; j < a.getCount(); j++) { + if (buf[j] != (Real32)0) + out << " " << j; } - else - { - for (UInt j = 0; j < a.getCount(); j++) - { - out << " " << buf[j]; - } + } else { + for (UInt j = 0; j < a.getCount(); j++) { + out << " " << buf[j]; } - break; } - case NTA_BasicType_Real64: - { - Array a(NTA_BasicType_Real64); - watch.region->getParameterArray(watch.varName, a); - Real64* buf = (Real64*) a.getBuffer(); - out << a.getCount(); - if (watch.sparseOutput) - { - for (UInt j = 0; j < a.getCount(); j++) - { - if (buf[j] != (Real64)0) - out << " " << j; - } + break; + } + case NTA_BasicType_Real64: { + Array a(NTA_BasicType_Real64); + watch.region->getParameterArray(watch.varName, a); + Real64 *buf = (Real64 *)a.getBuffer(); + out << a.getCount(); + if (watch.sparseOutput) { + for (UInt j = 0; j < a.getCount(); j++) { + if (buf[j] != (Real64)0) + out << " " << j; } - else - { - for (UInt j = 0; j < a.getCount(); j++) - { - out << " " << buf[j]; - } + } else { + for (UInt j = 0; j < a.getCount(); j++) { + out << " " << buf[j]; } - break; } - case NTA_BasicType_Byte: - { - Array a(NTA_BasicType_Byte); - watch.region->getParameterArray(watch.varName, a); - Byte* buf = (Byte*) a.getBuffer(); - out << a.getCount(); - if (watch.sparseOutput) - { - for (UInt j = 0; j < a.getCount(); j++) - { - out << " " << buf[j]; - } + break; + } + case NTA_BasicType_Byte: { + Array a(NTA_BasicType_Byte); + watch.region->getParameterArray(watch.varName, a); + Byte *buf = (Byte *)a.getBuffer(); + out << a.getCount(); + if (watch.sparseOutput) { + for (UInt j = 0; j < a.getCount(); j++) { + out << " " << buf[j]; } - else - { - for (UInt j = 0; j < a.getCount(); j++) - { - out << " " << buf[j]; - } + } else { + for (UInt j = 0; j < a.getCount(); j++) { + out << " " << buf[j]; } - break; - } - default: - NTA_THROW << "Internal error."; } + break; } - else if (watch.nodeIndex == -1) - { - switch (watch.varType) - { - case NTA_BasicType_Int32: - { - Int32 p = watch.region->getParameterInt32(watch.varName); - out << p; - break; - } - case NTA_BasicType_UInt32: - { - UInt32 p = watch.region->getParameterUInt32(watch.varName); - out << p; - break; - } - case NTA_BasicType_Int64: - { - Int64 p = watch.region->getParameterInt64(watch.varName); - out << p; - break; - } - case NTA_BasicType_UInt64: - { - UInt64 p = watch.region->getParameterUInt64(watch.varName); - out << p; - break; - } - case NTA_BasicType_Real32: - { - Real32 p = watch.region->getParameterReal32(watch.varName); - out << p; - break; - } - case NTA_BasicType_Real64: - { - Real64 p = watch.region->getParameterReal64(watch.varName); - out << p; - break; - } - case NTA_BasicType_Byte: - { - std::string p = watch.region->getParameterString(watch.varName); - out << p; - break; - } - default: - NTA_THROW << "Internal error."; - } + default: + NTA_THROW << "Internal error."; + } + } else if (watch.nodeIndex == -1) { + switch (watch.varType) { + case NTA_BasicType_Int32: { + Int32 p = watch.region->getParameterInt32(watch.varName); + out << p; + break; + } + case NTA_BasicType_UInt32: { + UInt32 p = watch.region->getParameterUInt32(watch.varName); + out << p; + break; + } + case NTA_BasicType_Int64: { + Int64 p = watch.region->getParameterInt64(watch.varName); + out << p; + break; + } + case NTA_BasicType_UInt64: { + UInt64 p = watch.region->getParameterUInt64(watch.varName); + out << p; + break; + } + case NTA_BasicType_Real32: { + Real32 p = watch.region->getParameterReal32(watch.varName); + out << p; + break; + } + case NTA_BasicType_Real64: { + Real64 p = watch.region->getParameterReal64(watch.varName); + out << p; + break; + } + case NTA_BasicType_Byte: { + std::string p = watch.region->getParameterString(watch.varName); + out << p; + break; + } + default: + NTA_THROW << "Internal error."; } - //else //nodeIndex != -1 - //{ - // Node n = watch.region->getNodeAtIndex((size_t)watch.nodeIndex); - // switch (watch.varType) - // { - // case NTA_BasicType_Int32: - // { - // Int32 p = n.getParameterInt32(watch.varName); - // out << p; - // break; - // } - // case NTA_BasicType_UInt32: - // { - // UInt32 p = n.getParameterUInt32(watch.varName); - // out << p; - // break; - // } - // case NTA_BasicType_Int64: - // { - // Int64 p = n.getParameterInt64(watch.varName); - // out << p; - // break; - // } - // case NTA_BasicType_UInt64: - // { - // UInt64 p = n.getParameterUInt64(watch.varName); - // out << p; - // break; - // } - // case NTA_BasicType_Real32: - // { - // Real32 p = n.getParameterReal32(watch.varName); - // out << p; - // break; - // } - // case NTA_BasicType_Real64: - // { - // Real64 p = n.getParameterReal64(watch.varName); - // out << p; - // break; - // } - // case NTA_BasicType_Byte: - // { - // std::string p = n.getParameterString(watch.varName); - // out << p; - // break; - // } - // default: - // NTA_THROW << "Internal error."; - // } - //} } - else if (watch.wType == output) - { - switch (watch.varType) - { - case NTA_BasicType_Real32: - { - Real32* outputData = (Real32*)(watch.array->getBuffer()); - unsigned int numOuts = watch.array->getCount(); - out << numOuts; - - if (watch.sparseOutput) - { - for (UInt j = 0; j < numOuts; j++) - { - if (outputData[j] != (Real32)0) - { - out << " " << j; - } + // else //nodeIndex != -1 + //{ + // Node n = watch.region->getNodeAtIndex((size_t)watch.nodeIndex); + // switch (watch.varType) + // { + // case NTA_BasicType_Int32: + // { + // Int32 p = n.getParameterInt32(watch.varName); + // out << p; + // break; + // } + // case NTA_BasicType_UInt32: + // { + // UInt32 p = n.getParameterUInt32(watch.varName); + // out << p; + // break; + // } + // case NTA_BasicType_Int64: + // { + // Int64 p = n.getParameterInt64(watch.varName); + // out << p; + // break; + // } + // case NTA_BasicType_UInt64: + // { + // UInt64 p = n.getParameterUInt64(watch.varName); + // out << p; + // break; + // } + // case NTA_BasicType_Real32: + // { + // Real32 p = n.getParameterReal32(watch.varName); + // out << p; + // break; + // } + // case NTA_BasicType_Real64: + // { + // Real64 p = n.getParameterReal64(watch.varName); + // out << p; + // break; + // } + // case NTA_BasicType_Byte: + // { + // std::string p = n.getParameterString(watch.varName); + // out << p; + // break; + // } + // default: + // NTA_THROW << "Internal error."; + // } + //} + } else if (watch.wType == output) { + switch (watch.varType) { + case NTA_BasicType_Real32: { + Real32 *outputData = (Real32 *)(watch.array->getBuffer()); + unsigned int numOuts = watch.array->getCount(); + out << numOuts; + + if (watch.sparseOutput) { + for (UInt j = 0; j < numOuts; j++) { + if (outputData[j] != (Real32)0) { + out << " " << j; } } - else - { - for (UInt j = 0; j < numOuts; j++) - { - out << " " << outputData[j]; - } + } else { + for (UInt j = 0; j < numOuts; j++) { + out << " " << outputData[j]; } - break; } - case NTA_BasicType_Real64: - { - Real64* outputData = (Real64*)(watch.array->getBuffer()); - unsigned int numOuts = watch.array->getCount(); - out << numOuts; - - if (watch.sparseOutput) - { - for (UInt j = 0; j < numOuts; j++) - { - if (outputData[j] != (Real64)0) - { - out << " " << j; - } + break; + } + case NTA_BasicType_Real64: { + Real64 *outputData = (Real64 *)(watch.array->getBuffer()); + unsigned int numOuts = watch.array->getCount(); + out << numOuts; + + if (watch.sparseOutput) { + for (UInt j = 0; j < numOuts; j++) { + if (outputData[j] != (Real64)0) { + out << " " << j; } } - else - { - for (UInt j = 0; j < numOuts; j++) - { - out << " " << outputData[j]; - } + } else { + for (UInt j = 0; j < numOuts; j++) { + out << " " << outputData[j]; } - break; - } - default: - NTA_THROW << "Watcher only supports Real32 or Real64 outputs."; } + break; } - else //should never happen - { - NTA_THROW << "Watcher can only watch parameters or outputs."; + default: + NTA_THROW << "Watcher only supports Real32 or Real64 outputs."; } - - value = out.str(); - - (*data.outStream) << watch.watchID << ", " << iteration << ", " << value << "\n"; + } else // should never happen + { + NTA_THROW << "Watcher can only watch parameters or outputs."; } - data.outStream->flush(); - } - void Watcher::closeFile() - { - data_.outStream->close(); - } + value = out.str(); - void Watcher::flushFile() - { - data_.outStream->flush(); + (*data.outStream) << watch.watchID << ", " << iteration << ", " << value + << "\n"; } + data.outStream->flush(); +} - //attach Watcher to a network and do initial writing to files - void Watcher::attachToNetwork(Network& net) - { - (*data_.outStream) << "Info: watchID, regionName, nodeType, nodeIndex, varName" << "\n"; +void Watcher::closeFile() { data_.outStream->close(); } - //go through each watch - watchData watch; +void Watcher::flushFile() { data_.outStream->flush(); } - for (UInt i = 0; i < data_.watches.size(); i++) - { - watch = data_.watches.at(i); - const Collection& regions = net.getRegions(); - watch.region = regions.getByName(watch.regionName); +// attach Watcher to a network and do initial writing to files +void Watcher::attachToNetwork(Network &net) { + (*data_.outStream) + << "Info: watchID, regionName, nodeType, nodeIndex, varName" + << "\n"; - //output general information for each watch - (*data_.outStream) << watch.watchID << ", "; - (*data_.outStream) << watch.regionName << ", "; - (*data_.outStream) << watch.region->getType() << ", "; - (*data_.outStream) << watch.nodeIndex << ", "; + // go through each watch + watchData watch; - if (watch.wType == parameter) - { - //find out varType and add it to watch struct - ParameterSpec p = watch.region->getSpec()->parameters.getByName(watch.varName); - watch.varType = p.dataType; + for (UInt i = 0; i < data_.watches.size(); i++) { + watch = data_.watches.at(i); + const Collection ®ions = net.getRegions(); + watch.region = regions.getByName(watch.regionName); - //find out if varType is supported - if (watch.varType != NTA_BasicType_Int32 && - watch.varType != NTA_BasicType_UInt32 && - watch.varType != NTA_BasicType_Int64 && - watch.varType != NTA_BasicType_UInt64 && - watch.varType != NTA_BasicType_Real32 && - watch.varType != NTA_BasicType_Real64 && - watch.varType != NTA_BasicType_Byte) - { - NTA_THROW << BasicType::getName(watch.varType) << " is not an " - << "array parameter type supported by Watcher."; - } - - //found out whether parameter is an array or not - watch.isArray = ((p.count == 0 || p.count > 1) - && watch.varType != NTA_BasicType_Byte); + // output general information for each watch + (*data_.outStream) << watch.watchID << ", "; + (*data_.outStream) << watch.regionName << ", "; + (*data_.outStream) << watch.region->getType() << ", "; + (*data_.outStream) << watch.nodeIndex << ", "; + + if (watch.wType == parameter) { + // find out varType and add it to watch struct + ParameterSpec p = + watch.region->getSpec()->parameters.getByName(watch.varName); + watch.varType = p.dataType; - (*data_.outStream) << watch.varName << "\n"; + // find out if varType is supported + if (watch.varType != NTA_BasicType_Int32 && + watch.varType != NTA_BasicType_UInt32 && + watch.varType != NTA_BasicType_Int64 && + watch.varType != NTA_BasicType_UInt64 && + watch.varType != NTA_BasicType_Real32 && + watch.varType != NTA_BasicType_Real64 && + watch.varType != NTA_BasicType_Byte) { + NTA_THROW << BasicType::getName(watch.varType) << " is not an " + << "array parameter type supported by Watcher."; } - else if (watch.wType == output) - { - watch.output = watch.region->getOutput(watch.varName); - (*data_.outStream) << watch.varName << "\n"; - watch.array = &(watch.output->getData()); + // found out whether parameter is an array or not + watch.isArray = ((p.count == 0 || p.count > 1) && + watch.varType != NTA_BasicType_Byte); - watch.varType = watch.array->getType(); + (*data_.outStream) << watch.varName << "\n"; + } else if (watch.wType == output) { + watch.output = watch.region->getOutput(watch.varName); + (*data_.outStream) << watch.varName << "\n"; - } - else //should never happen - { - NTA_THROW << "Watcher can only watch parameters or outputs."; - } + watch.array = &(watch.output->getData()); - //add the modified watch struct to data_.watches - allWatchData::iterator it; - it = data_.watches.begin() + i; - data_.watches.insert(it, watch); - data_.watches.erase(data_.watches.begin() + i + 1); + watch.varType = watch.array->getType(); + + } else // should never happen + { + NTA_THROW << "Watcher can only watch parameters or outputs."; } - (*data_.outStream) << "Data: watchID, iteration, paramValue" << "\n"; - - //actually attach to the network - Collection& callbacks = net.getCallbacks(); - Network::callbackItem callback(watcherCallback, (void*)(&data_)); - std::string callbackName = "Watcher: "; - callbackName += data_.fileName; - callbacks.add(callbackName, callback); + // add the modified watch struct to data_.watches + allWatchData::iterator it; + it = data_.watches.begin() + i; + data_.watches.insert(it, watch); + data_.watches.erase(data_.watches.begin() + i + 1); } - void Watcher::detachFromNetwork(Network& net) - { - Collection& callbacks = net.getCallbacks(); - std::string callbackName = "Watcher: "; - callbackName += data_.fileName; - callbacks.remove(callbackName); - } + (*data_.outStream) << "Data: watchID, iteration, paramValue" + << "\n"; + + // actually attach to the network + Collection &callbacks = net.getCallbacks(); + Network::callbackItem callback(watcherCallback, (void *)(&data_)); + std::string callbackName = "Watcher: "; + callbackName += data_.fileName; + callbacks.add(callbackName, callback); +} + +void Watcher::detachFromNetwork(Network &net) { + Collection &callbacks = net.getCallbacks(); + std::string callbackName = "Watcher: "; + callbackName += data_.fileName; + callbacks.remove(callbackName); } +} // namespace nupic diff --git a/src/nupic/utils/Watcher.hpp b/src/nupic/utils/Watcher.hpp index bf6bb28c09..294a4d4664 100644 --- a/src/nupic/utils/Watcher.hpp +++ b/src/nupic/utils/Watcher.hpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file * Interface for the Watcher class */ @@ -32,116 +32,98 @@ #include -namespace nupic -{ - class ArrayBase; - class Network; - class Region; - class OFStream; - - enum watcherType - { - parameter, - output - }; - - //Contains data specific for each individual parameter - //to be watched. - struct watchData - { - unsigned int watchID; //starts at 1 - std::string varName; - watcherType wType; - Output* output; - //Need regionName because we create data structure before - //we have the actual Network to attach it to. - std::string regionName; - Region* region; - Int64 nodeIndex; - NTA_BasicType varType; - std::string nodeName; - const ArrayBase * array; - bool isArray; - bool sparseOutput; - }; - - //Contains all data needed by the callback function. - struct allData - { - OFStream* outStream; - std::string fileName; - std::vector watches; - }; - - /* - * Writes the values of parameters and outputs to a file after each - * iteration of the network. - * - * Sample usage: - * - * Network net; - * ... - * ... - * - * Watcher w("fileName"); - * w.watchParam("regionName", "paramName"); - * w.watchParam("regionName", "paramName", nodeIndex); - * w.watchOutput("regionName", "bottomUpOut"); - * w.attachToNetwork(net); - * - * net.run(); - * - * w.detachFromNetwork(net); - */ - class Watcher - { - public: - Watcher(const std::string fileName); - - //calls flushFile() and closeFile() - ~Watcher(); - - //returns watchID - unsigned int - watchParam(std::string regionName, - std::string varName, - int nodeIndex = -1, - bool sparseOutput = true); - - //returns watchID - unsigned int - watchOutput(std::string regionName, - std::string varName, - bool sparseOutput = true); - - //callback function that will be called every time network is run - static void - watcherCallback(Network* net, UInt64 iteration, void* dataIn); - - //Attaches Watcher to a network and begins writing - //information to a file. Call this after adding all watches. - void - attachToNetwork(Network&); - - //Detaches the Watcher from the Network so the callback is no longer called - void - detachFromNetwork(Network&); - - //Closes the OFStream. - void - closeFile(); - - //Flushes the OFStream. - void - flushFile(); - - private: - typedef std::vector allWatchData; - - //private data structure - allData data_; - }; - -} //namespace nupic +namespace nupic { +class ArrayBase; +class Network; +class Region; +class OFStream; + +enum watcherType { parameter, output }; + +// Contains data specific for each individual parameter +// to be watched. +struct watchData { + unsigned int watchID; // starts at 1 + std::string varName; + watcherType wType; + Output *output; + // Need regionName because we create data structure before + // we have the actual Network to attach it to. + std::string regionName; + Region *region; + Int64 nodeIndex; + NTA_BasicType varType; + std::string nodeName; + const ArrayBase *array; + bool isArray; + bool sparseOutput; +}; + +// Contains all data needed by the callback function. +struct allData { + OFStream *outStream; + std::string fileName; + std::vector watches; +}; + +/* + * Writes the values of parameters and outputs to a file after each + * iteration of the network. + * + * Sample usage: + * + * Network net; + * ... + * ... + * + * Watcher w("fileName"); + * w.watchParam("regionName", "paramName"); + * w.watchParam("regionName", "paramName", nodeIndex); + * w.watchOutput("regionName", "bottomUpOut"); + * w.attachToNetwork(net); + * + * net.run(); + * + * w.detachFromNetwork(net); + */ +class Watcher { +public: + Watcher(const std::string fileName); + + // calls flushFile() and closeFile() + ~Watcher(); + + // returns watchID + unsigned int watchParam(std::string regionName, std::string varName, + int nodeIndex = -1, bool sparseOutput = true); + + // returns watchID + unsigned int watchOutput(std::string regionName, std::string varName, + bool sparseOutput = true); + + // callback function that will be called every time network is run + static void watcherCallback(Network *net, UInt64 iteration, void *dataIn); + + // Attaches Watcher to a network and begins writing + // information to a file. Call this after adding all watches. + void attachToNetwork(Network &); + + // Detaches the Watcher from the Network so the callback is no longer called + void detachFromNetwork(Network &); + + // Closes the OFStream. + void closeFile(); + + // Flushes the OFStream. + void flushFile(); + +private: + typedef std::vector allWatchData; + + // private data structure + allData data_; +}; + +} // namespace nupic #endif // NTA_WATCHER_HPP diff --git a/src/test/integration/ConnectionsPerformanceTest.cpp b/src/test/integration/ConnectionsPerformanceTest.cpp index 03f4c150ef..aec76c7c61 100644 --- a/src/test/integration/ConnectionsPerformanceTest.cpp +++ b/src/test/integration/ConnectionsPerformanceTest.cpp @@ -26,11 +26,11 @@ #include #include -#include #include +#include -#include #include +#include #include "ConnectionsPerformanceTest.hpp" @@ -41,297 +41,248 @@ using namespace nupic::algorithms::connections; #define SEED 42 -namespace nupic -{ +namespace nupic { - void ConnectionsPerformanceTest::RunTests() - { - srand(SEED); +void ConnectionsPerformanceTest::RunTests() { + srand(SEED); - testTemporalMemoryUsage(); - testLargeTemporalMemoryUsage(); - testSpatialPoolerUsage(); - testTemporalPoolerUsage(); - } - - /** - * Tests typical usage of Connections with Temporal Memory. - */ - void ConnectionsPerformanceTest::testTemporalMemoryUsage() - { - runTemporalMemoryTest(2048, 40, 5, 100, "temporal memory"); - } + testTemporalMemoryUsage(); + testLargeTemporalMemoryUsage(); + testSpatialPoolerUsage(); + testTemporalPoolerUsage(); +} - /** - * Tests typical usage of Connections with a large Temporal Memory. - */ - void ConnectionsPerformanceTest::testLargeTemporalMemoryUsage() - { - runTemporalMemoryTest(16384, 328, 3, 40, "temporal memory (large)"); - } +/** + * Tests typical usage of Connections with Temporal Memory. + */ +void ConnectionsPerformanceTest::testTemporalMemoryUsage() { + runTemporalMemoryTest(2048, 40, 5, 100, "temporal memory"); +} - /** - * Tests typical usage of Connections with Spatial Pooler. - */ - void ConnectionsPerformanceTest::testSpatialPoolerUsage() - { - runSpatialPoolerTest(2048, 2048, 40, 40, "spatial pooler"); - } +/** + * Tests typical usage of Connections with a large Temporal Memory. + */ +void ConnectionsPerformanceTest::testLargeTemporalMemoryUsage() { + runTemporalMemoryTest(16384, 328, 3, 40, "temporal memory (large)"); +} - /** - * Tests typical usage of Connections with Temporal Pooler. - */ - void ConnectionsPerformanceTest::testTemporalPoolerUsage() - { - runSpatialPoolerTest(2048, 16384, 40, 400, "temporal pooler"); - } +/** + * Tests typical usage of Connections with Spatial Pooler. + */ +void ConnectionsPerformanceTest::testSpatialPoolerUsage() { + runSpatialPoolerTest(2048, 2048, 40, 40, "spatial pooler"); +} - void ConnectionsPerformanceTest::runTemporalMemoryTest(UInt numColumns, - UInt w, - int numSequences, - int numElements, - string label) - { - clock_t timer = clock(); +/** + * Tests typical usage of Connections with Temporal Pooler. + */ +void ConnectionsPerformanceTest::testTemporalPoolerUsage() { + runSpatialPoolerTest(2048, 16384, 40, 400, "temporal pooler"); +} - // Initialize +void ConnectionsPerformanceTest::runTemporalMemoryTest(UInt numColumns, UInt w, + int numSequences, + int numElements, + string label) { + clock_t timer = clock(); - TemporalMemory tm; - vector columnDim; - columnDim.push_back(numColumns); - tm.initialize(columnDim); + // Initialize - checkpoint(timer, label + ": initialize"); + TemporalMemory tm; + vector columnDim; + columnDim.push_back(numColumns); + tm.initialize(columnDim); - // Learn + checkpoint(timer, label + ": initialize"); - vector< vector< vector > >sequences; - vector< vector >sequence; - vector sdr; + // Learn - for (int i = 0; i < numSequences; i++) - { - for (int j = 0; j < numElements; j++) - { - sdr = randomSDR(numColumns, w); - sequence.push_back(sdr); - } + vector>> sequences; + vector> sequence; + vector sdr; - sequences.push_back(sequence); + for (int i = 0; i < numSequences; i++) { + for (int j = 0; j < numElements; j++) { + sdr = randomSDR(numColumns, w); + sequence.push_back(sdr); } - for (int i = 0; i < 5; i++) - { - for (auto sequence : sequences) - { - for (auto sdr : sequence) - { - feedTM(tm, sdr); - tm.reset(); - } + sequences.push_back(sequence); + } + + for (int i = 0; i < 5; i++) { + for (auto sequence : sequences) { + for (auto sdr : sequence) { + feedTM(tm, sdr); + tm.reset(); } } + } - checkpoint(timer, label + ": initialize + learn"); + checkpoint(timer, label + ": initialize + learn"); - // Test + // Test - for (auto sequence : sequences) - { - for (auto sdr : sequence) - { - feedTM(tm, sdr, false); - tm.reset(); - } + for (auto sequence : sequences) { + for (auto sdr : sequence) { + feedTM(tm, sdr, false); + tm.reset(); } - - checkpoint(timer, label + ": initialize + learn + test"); } - void ConnectionsPerformanceTest::runSpatialPoolerTest(UInt numCells, - UInt numInputs, - UInt w, - UInt numWinners, - string label) - { - clock_t timer = clock(); - - Connections connections(numCells); - Segment segment; - vector sdr; - - // Initialize - - for (UInt c = 0; c < numCells; c++) - { - segment = connections.createSegment(c); - - for (UInt i = 0; i < numInputs; i++) - { - const Permanence permanence = max((Permanence)0.000001, - (Permanence)rand()/RAND_MAX); - connections.createSynapse(segment, i, permanence); - } + checkpoint(timer, label + ": initialize + learn + test"); +} + +void ConnectionsPerformanceTest::runSpatialPoolerTest(UInt numCells, + UInt numInputs, UInt w, + UInt numWinners, + string label) { + clock_t timer = clock(); + + Connections connections(numCells); + Segment segment; + vector sdr; + + // Initialize + + for (UInt c = 0; c < numCells; c++) { + segment = connections.createSegment(c); + + for (UInt i = 0; i < numInputs; i++) { + const Permanence permanence = + max((Permanence)0.000001, (Permanence)rand() / RAND_MAX); + connections.createSynapse(segment, i, permanence); } + } - checkpoint(timer, label + ": initialize"); + checkpoint(timer, label + ": initialize"); - // Learn + // Learn - vector winnerCells; - Permanence permanence; + vector winnerCells; + Permanence permanence; - for (int i = 0; i < 500; i++) - { - sdr = randomSDR(numInputs, w); - vector numActiveConnectedSynapsesForSegment( + for (int i = 0; i < 500; i++) { + sdr = randomSDR(numInputs, w); + vector numActiveConnectedSynapsesForSegment( connections.segmentFlatListLength(), 0); - vector numActivePotentialSynapsesForSegment( + vector numActivePotentialSynapsesForSegment( connections.segmentFlatListLength(), 0); - connections.computeActivity(numActiveConnectedSynapsesForSegment, - numActivePotentialSynapsesForSegment, - sdr, 0.5); - winnerCells = computeSPWinnerCells(connections, numWinners, - numActiveConnectedSynapsesForSegment); - - for (CellIdx winnerCell : winnerCells) - { - segment = connections.getSegment(winnerCell, 0); - - const vector& synapses = - connections.synapsesForSegment(segment); - - for (SynapseIdx i = 0; i < (SynapseIdx)synapses.size();) - { - const Synapse synapse = synapses[i]; - const SynapseData& synapseData = connections.dataForSynapse(synapse); - permanence = synapseData.permanence; - - if (find(sdr.begin(), sdr.end(), synapseData.presynapticCell) != - sdr.end()) - { - permanence += 0.2; - } - else - { - permanence -= 0.1; - } - - permanence = max(permanence, (Permanence)0); - permanence = min(permanence, (Permanence)1); - - if (permanence == 0) - { - connections.destroySynapse(synapse); - // The synapses list is updated in-place, so don't update `i`. - } - else - { - connections.updateSynapsePermanence(synapse, permanence); - i++; - } + connections.computeActivity(numActiveConnectedSynapsesForSegment, + numActivePotentialSynapsesForSegment, sdr, 0.5); + winnerCells = computeSPWinnerCells(connections, numWinners, + numActiveConnectedSynapsesForSegment); + + for (CellIdx winnerCell : winnerCells) { + segment = connections.getSegment(winnerCell, 0); + + const vector &synapses = connections.synapsesForSegment(segment); + + for (SynapseIdx i = 0; i < (SynapseIdx)synapses.size();) { + const Synapse synapse = synapses[i]; + const SynapseData &synapseData = connections.dataForSynapse(synapse); + permanence = synapseData.permanence; + + if (find(sdr.begin(), sdr.end(), synapseData.presynapticCell) != + sdr.end()) { + permanence += 0.2; + } else { + permanence -= 0.1; + } + + permanence = max(permanence, (Permanence)0); + permanence = min(permanence, (Permanence)1); + + if (permanence == 0) { + connections.destroySynapse(synapse); + // The synapses list is updated in-place, so don't update `i`. + } else { + connections.updateSynapsePermanence(synapse, permanence); + i++; } } } + } - checkpoint(timer, label + ": initialize + learn"); + checkpoint(timer, label + ": initialize + learn"); - // Test + // Test - for (int i = 0; i < 500; i++) - { - sdr = randomSDR(numInputs, w); - vector numActiveConnectedSynapsesForSegment( + for (int i = 0; i < 500; i++) { + sdr = randomSDR(numInputs, w); + vector numActiveConnectedSynapsesForSegment( connections.segmentFlatListLength(), 0); - vector numActivePotentialSynapsesForSegment( + vector numActivePotentialSynapsesForSegment( connections.segmentFlatListLength(), 0); - connections.computeActivity(numActiveConnectedSynapsesForSegment, - numActivePotentialSynapsesForSegment, - sdr, 0.5); - winnerCells = computeSPWinnerCells(connections, numWinners, - numActiveConnectedSynapsesForSegment); - } - - checkpoint(timer, label + ": initialize + learn + test"); + connections.computeActivity(numActiveConnectedSynapsesForSegment, + numActivePotentialSynapsesForSegment, sdr, 0.5); + winnerCells = computeSPWinnerCells(connections, numWinners, + numActiveConnectedSynapsesForSegment); } - void ConnectionsPerformanceTest::checkpoint(clock_t timer, string text) - { - float duration = (float)(clock() - timer) / CLOCKS_PER_SEC; - cout << duration << " in " << text << endl; - } + checkpoint(timer, label + ": initialize + learn + test"); +} - vector ConnectionsPerformanceTest::randomSDR(UInt n, UInt w) - { - set sdrSet = set(); - vector sdr; +void ConnectionsPerformanceTest::checkpoint(clock_t timer, string text) { + float duration = (float)(clock() - timer) / CLOCKS_PER_SEC; + cout << duration << " in " << text << endl; +} - for (UInt i = 0; i < w; i++) - { - sdrSet.insert(rand() % (UInt)n); - } +vector ConnectionsPerformanceTest::randomSDR(UInt n, UInt w) { + set sdrSet = set(); + vector sdr; - for (UInt c : sdrSet) - { - sdr.push_back(c); - } + for (UInt i = 0; i < w; i++) { + sdrSet.insert(rand() % (UInt)n); + } - return sdr; + for (UInt c : sdrSet) { + sdr.push_back(c); } - void ConnectionsPerformanceTest::feedTM(TemporalMemory &tm, - vector sdr, - bool learn) - { - vector activeColumns; + return sdr; +} - for (auto c : sdr) - { - activeColumns.push_back(c); - } +void ConnectionsPerformanceTest::feedTM(TemporalMemory &tm, vector sdr, + bool learn) { + vector activeColumns; - tm.compute(activeColumns.size(), activeColumns.data(), learn); + for (auto c : sdr) { + activeColumns.push_back(c); } - vector ConnectionsPerformanceTest::computeSPWinnerCells( - Connections& connections, - UInt numCells, - const vector& numActiveSynapsesForSegment) - { - // Activate every segment, then choose the top few. - vector activeSegments; - for (Segment segment = 0; - segment < numActiveSynapsesForSegment.size(); - segment++) - { - activeSegments.push_back(segment); - } + tm.compute(activeColumns.size(), activeColumns.data(), learn); +} - set winnerCells; - std::sort(activeSegments.begin(), activeSegments.end(), - [&](Segment a, Segment b) - { - return - numActiveSynapsesForSegment[a] > - numActiveSynapsesForSegment[b]; - }); - - for (Segment segment : activeSegments) - { - winnerCells.insert(connections.cellForSegment(segment)); - if (winnerCells.size() >= numCells) - { - break; - } - } +vector ConnectionsPerformanceTest::computeSPWinnerCells( + Connections &connections, UInt numCells, + const vector &numActiveSynapsesForSegment) { + // Activate every segment, then choose the top few. + vector activeSegments; + for (Segment segment = 0; segment < numActiveSynapsesForSegment.size(); + segment++) { + activeSegments.push_back(segment); + } + + set winnerCells; + std::sort( + activeSegments.begin(), activeSegments.end(), [&](Segment a, Segment b) { + return numActiveSynapsesForSegment[a] > numActiveSynapsesForSegment[b]; + }); - return vector(winnerCells.begin(), winnerCells.end()); + for (Segment segment : activeSegments) { + winnerCells.insert(connections.cellForSegment(segment)); + if (winnerCells.size() >= numCells) { + break; + } } + return vector(winnerCells.begin(), winnerCells.end()); +} + } // end namespace nupic -int main(int argc, char *argv[]) -{ +int main(int argc, char *argv[]) { ConnectionsPerformanceTest test = ConnectionsPerformanceTest(); test.RunTests(); } diff --git a/src/test/integration/ConnectionsPerformanceTest.hpp b/src/test/integration/ConnectionsPerformanceTest.hpp index 427e0132e4..87ac181b0f 100644 --- a/src/test/integration/ConnectionsPerformanceTest.hpp +++ b/src/test/integration/ConnectionsPerformanceTest.hpp @@ -33,59 +33,46 @@ #include -namespace nupic -{ - - namespace algorithms - { - namespace temporal_memory - { - class TemporalMemory; - } - - namespace connections - { - typedef UInt32 Segment; - } - } - - class ConnectionsPerformanceTest - { - public: - ConnectionsPerformanceTest() {} - virtual ~ConnectionsPerformanceTest() {} - - // Run all appropriate tests - virtual void RunTests(); - - void testTemporalMemoryUsage(); - void testLargeTemporalMemoryUsage(); - void testSpatialPoolerUsage(); - void testTemporalPoolerUsage(); - - private: - void runTemporalMemoryTest(UInt numColumns, - UInt w, - int numSequences, - int numElements, - std::string label); - void runSpatialPoolerTest(UInt numCells, - UInt numInputs, - UInt w, - UInt numWinners, - std::string label); - - void checkpoint(clock_t timer, std::string text); - std::vector randomSDR(UInt n, UInt w); - void feedTM(algorithms::temporal_memory::TemporalMemory &tm, - std::vector sdr, - bool learn = true); - std::vector computeSPWinnerCells( - Connections& connections, - UInt numCells, - const vector& numActiveSynapsesForSegment); - - }; // end class ConnectionsPerformanceTest +namespace nupic { + +namespace algorithms { +namespace temporal_memory { +class TemporalMemory; +} + +namespace connections { +typedef UInt32 Segment; +} +} // namespace algorithms + +class ConnectionsPerformanceTest { +public: + ConnectionsPerformanceTest() {} + virtual ~ConnectionsPerformanceTest() {} + + // Run all appropriate tests + virtual void RunTests(); + + void testTemporalMemoryUsage(); + void testLargeTemporalMemoryUsage(); + void testSpatialPoolerUsage(); + void testTemporalPoolerUsage(); + +private: + void runTemporalMemoryTest(UInt numColumns, UInt w, int numSequences, + int numElements, std::string label); + void runSpatialPoolerTest(UInt numCells, UInt numInputs, UInt w, + UInt numWinners, std::string label); + + void checkpoint(clock_t timer, std::string text); + std::vector randomSDR(UInt n, UInt w); + void feedTM(algorithms::temporal_memory::TemporalMemory &tm, + std::vector sdr, bool learn = true); + std::vector + computeSPWinnerCells(Connections &connections, UInt numCells, + const vector &numActiveSynapsesForSegment); + +}; // end class ConnectionsPerformanceTest } // end namespace nupic diff --git a/src/test/integration/CppRegionTest.cpp b/src/test/integration/CppRegionTest.cpp index da150a0b18..17a34743de 100644 --- a/src/test/integration/CppRegionTest.cpp +++ b/src/test/integration/CppRegionTest.cpp @@ -20,79 +20,69 @@ * --------------------------------------------------------------------- */ - -#include +#include +#include #include +#include +#include #include #include -#include -#include -#include -#include #include #include -#include -#include // memory leak detection +#include #include +#include // memory leak detection #include #include +#include -#include -#include -#include // fabs/abs +#include // fabs/abs #include // exit #include #include +#include +#include bool ignore_negative_tests = false; -#define SHOULDFAIL(statement) \ - { \ - if (!ignore_negative_tests) \ - { \ - bool caughtException = false; \ - try { \ - statement; \ - } catch(std::exception& ) { \ - caughtException = true; \ - std::cout << "Caught exception as expected: " # statement "" << std::endl; \ - } \ - if (!caughtException) { \ - NTA_THROW << "Operation '" #statement "' did not fail as expected"; \ - } \ - } \ +#define SHOULDFAIL(statement) \ + { \ + if (!ignore_negative_tests) { \ + bool caughtException = false; \ + try { \ + statement; \ + } catch (std::exception &) { \ + caughtException = true; \ + std::cout << "Caught exception as expected: " #statement "" \ + << std::endl; \ + } \ + if (!caughtException) { \ + NTA_THROW << "Operation '" #statement "' did not fail as expected"; \ + } \ + } \ } using namespace nupic; bool verbose = false; -struct MemoryMonitor -{ - MemoryMonitor() - { - OS::getProcessMemoryUsage(initial_vmem, initial_rmem); - } +struct MemoryMonitor { + MemoryMonitor() { OS::getProcessMemoryUsage(initial_vmem, initial_rmem); } - ~MemoryMonitor() - { - if (hasMemoryLeaks()) - { - NTA_DEBUG - << "Memory leaks detected. " - << "Real Memory: " << diff_rmem - << ", Virtual Memory: " << diff_vmem; + ~MemoryMonitor() { + if (hasMemoryLeaks()) { + NTA_DEBUG << "Memory leaks detected. " + << "Real Memory: " << diff_rmem + << ", Virtual Memory: " << diff_vmem; } } - void update() - { + void update() { OS::getProcessMemoryUsage(current_vmem, current_rmem); diff_vmem = current_vmem - initial_vmem; diff_rmem = current_rmem - initial_rmem; } - bool hasMemoryLeaks() - { + bool hasMemoryLeaks() { update(); return diff_vmem > 0 || diff_rmem > 0; } @@ -105,29 +95,26 @@ struct MemoryMonitor size_t diff_vmem; }; - -void testCppInputOutputAccess(Region * level1) -{ +void testCppInputOutputAccess(Region *level1) { // --- input/output access for level 1 (C++ TestNode) --- - SHOULDFAIL( level1->getOutputData("doesnotexist") ); + SHOULDFAIL(level1->getOutputData("doesnotexist")); // getting access via zero-copy std::cout << "Getting output for zero-copy access" << std::endl; ArrayRef output = level1->getOutputData("bottomUpOut"); - std::cout << "Element count in bottomUpOut is " << output.getCount() << "" << std::endl; - Real64 *data_actual = (Real64*)output.getBuffer(); + std::cout << "Element count in bottomUpOut is " << output.getCount() << "" + << std::endl; + Real64 *data_actual = (Real64 *)output.getBuffer(); // set the actual output - data_actual[12] = 54321; + data_actual[12] = 54321; } - -void testCppLinking(std::string linkPolicy, std::string linkParams) -{ +void testCppLinking(std::string linkPolicy, std::string linkParams) { Network net = Network(); - Region* region1 = net.addRegion("region1", "TestNode", ""); - Region* region2 = net.addRegion("region2", "TestNode", ""); + Region *region1 = net.addRegion("region1", "TestNode", ""); + Region *region2 = net.addRegion("region2", "TestNode", ""); net.link("region1", "region2", linkPolicy, linkParams); std::cout << "Initialize should fail..." << std::endl; @@ -142,64 +129,58 @@ void testCppLinking(std::string linkPolicy, std::string linkParams) std::cout << "Initialize should now succeed" << std::endl; net.initialize(); - const Dimensions& r2dims = region2->getDimensions(); + const Dimensions &r2dims = region2->getDimensions(); NTA_CHECK(r2dims.size() == 2) << " actual dims: " << r2dims.toString(); NTA_CHECK(r2dims[0] == 3) << " actual dims: " << r2dims.toString(); NTA_CHECK(r2dims[1] == 2) << " actual dims: " << r2dims.toString(); SHOULDFAIL(region2->setDimensions(r1dims)); - + ArrayRef r1OutputArray = region1->getOutputData("bottomUpOut"); region1->compute(); std::cout << "Checking region1 output after first iteration..." << std::endl; - Real64 *buffer = (Real64*) r1OutputArray.getBuffer(); + Real64 *buffer = (Real64 *)r1OutputArray.getBuffer(); - for (size_t i = 0; i < r1OutputArray.getCount(); i++) - { + for (size_t i = 0; i < r1OutputArray.getCount(); i++) { if (verbose) std::cout << " " << i << " " << buffer[i] << "" << std::endl; - if (i%2 == 0) + if (i % 2 == 0) NTA_CHECK(buffer[i] == 0); else - NTA_CHECK(buffer[i] == (i-1)/2); + NTA_CHECK(buffer[i] == (i - 1) / 2); } - region2->prepareInputs(); ArrayRef r2InputArray = region2->getInputData("bottomUpIn"); std::cout << "Region 2 input after first iteration:" << std::endl; - Real64 *buffer2 = (Real64*) r2InputArray.getBuffer(); + Real64 *buffer2 = (Real64 *)r2InputArray.getBuffer(); NTA_CHECK(buffer != buffer2); - for (size_t i = 0; i < r2InputArray.getCount(); i++) - { + for (size_t i = 0; i < r2InputArray.getCount(); i++) { if (verbose) std::cout << " " << i << " " << buffer2[i] << "" << std::endl; - if (i%2 == 0) + if (i % 2 == 0) NTA_CHECK(buffer[i] == 0); else - NTA_CHECK(buffer[i] == (i-1)/2); + NTA_CHECK(buffer[i] == (i - 1) / 2); } std::cout << "Region 2 input by node" << std::endl; std::vector r2NodeInput; - for (size_t node = 0; node < 6; node++) - { + for (size_t node = 0; node < 6; node++) { region2->getInput("bottomUpIn")->getInputForNode(node, r2NodeInput); - if (verbose) - { + if (verbose) { std::cout << "Node " << node << ": "; - for (auto & elem : r2NodeInput) - { + for (auto &elem : r2NodeInput) { std::cout << elem << " "; } std::cout << "" << std::endl; } // 4 nodes in r1 fan in to 1 node in r2 - int row = node/3; + int row = node / 3; int col = node - (row * 3); NTA_CHECK(r2NodeInput.size() == 8); NTA_CHECK(r2NodeInput[0] == 0); @@ -207,28 +188,26 @@ void testCppLinking(std::string linkPolicy, std::string linkParams) NTA_CHECK(r2NodeInput[4] == 0); NTA_CHECK(r2NodeInput[6] == 0); // these values are specific to the fanin2 link policy - NTA_CHECK(r2NodeInput[1] == row * 12 + col * 2) - << "row: " << row << " col: " << col << " val: " << r2NodeInput[1]; - NTA_CHECK(r2NodeInput[3] == row * 12 + col * 2 + 1) - << "row: " << row << " col: " << col << " val: " << r2NodeInput[3]; + NTA_CHECK(r2NodeInput[1] == row * 12 + col * 2) + << "row: " << row << " col: " << col << " val: " << r2NodeInput[1]; + NTA_CHECK(r2NodeInput[3] == row * 12 + col * 2 + 1) + << "row: " << row << " col: " << col << " val: " << r2NodeInput[3]; NTA_CHECK(r2NodeInput[5] == row * 12 + 6 + col * 2) - << "row: " << row << " col: " << col << " val: " << r2NodeInput[5]; + << "row: " << row << " col: " << col << " val: " << r2NodeInput[5]; NTA_CHECK(r2NodeInput[7] == row * 12 + 6 + col * 2 + 1) - << "row: " << row << " col: " << col << " val: " << r2NodeInput[7]; + << "row: " << row << " col: " << col << " val: " << r2NodeInput[7]; } } -void testYAML() -{ +void testYAML() { const char *params = "{int32Param: 1234, real64Param: 23.1}"; // badparams contains a non-existent parameter const char *badparams = "{int32Param: 1234, real64Param: 23.1, badParam: 4}"; - Network net = Network(); - Region* level1; + Region *level1; SHOULDFAIL(level1 = net.addRegion("level1", "TestNode", badparams);); - + level1 = net.addRegion("level1", "TestNode", params); Dimensions d; d.push_back(1); @@ -237,43 +216,45 @@ void testYAML() // check default values Real32 r32val = level1->getParameterReal32("real32Param"); - NTA_CHECK(::fabs(r32val - 32.1) < 0.00001) << "r32val = " << r32val << " diff = " << (r32val - 32.1); + NTA_CHECK(::fabs(r32val - 32.1) < 0.00001) + << "r32val = " << r32val << " diff = " << (r32val - 32.1); Int64 i64val = level1->getParameterInt64("int64Param"); NTA_CHECK(i64val == 64) << "i64val = " << i64val; - // check values set in region constructor Int32 ival = level1->getParameterInt32("int32Param"); NTA_CHECK(ival == 1234) << "ival = " << ival; Real64 rval = level1->getParameterReal64("real64Param"); NTA_CHECK(::fabs(rval - 23.1) < 0.00000000001) << "rval = " << rval; - // TODO: if we get the real64 param with getParameterInt32 + // TODO: if we get the real64 param with getParameterInt32 // it works -- should we flag an error? - std::cout << "Got the correct values for all parameters set at region creation" << std::endl; - + std::cout + << "Got the correct values for all parameters set at region creation" + << std::endl; } -int realmain(bool leakTest) -{ +int realmain(bool leakTest) { // verbose == true turns on extra output that is useful for - // debugging the test (e.g. when the TestNode compute() + // debugging the test (e.g. when the TestNode compute() // algorithm changes) - std::cout << "Creating network..." << std::endl; Network n; - - std::cout << "Region count is " << n.getRegions().getCount() << "" << std::endl; + + std::cout << "Region count is " << n.getRegions().getCount() << "" + << std::endl; std::cout << "Adding a FDRNode region..." << std::endl; - Region* level1 = n.addRegion("level1", "TestNode", ""); + Region *level1 = n.addRegion("level1", "TestNode", ""); - std::cout << "Region count is " << n.getRegions().getCount() << "" << std::endl; + std::cout << "Region count is " << n.getRegions().getCount() << "" + << std::endl; std::cout << "Node type: " << level1->getType() << "" << std::endl; - std::cout << "Nodespec is:\n" << level1->getSpec()->toString() << "" << std::endl; - + std::cout << "Nodespec is:\n" + << level1->getSpec()->toString() << "" << std::endl; + Int64 val; Real64 rval; std::string int64Param("int64Param"); @@ -283,18 +264,20 @@ int realmain(bool leakTest) rval = level1->getParameterReal64(real64Param); std::cout << "level1.int64Param = " << val << "" << std::endl; std::cout << "level1.real64Param = " << rval << "" << std::endl; - + val = 20; level1->setParameterInt64(int64Param, val); val = 0; val = level1->getParameterInt64(int64Param); - std::cout << "level1.int64Param = " << val << " after setting to 20" << std::endl; + std::cout << "level1.int64Param = " << val << " after setting to 20" + << std::endl; rval = 30.1; level1->setParameterReal64(real64Param, rval); rval = 0.0; rval = level1->getParameterReal64(real64Param); - std::cout << "level1.real64Param = " << rval << " after setting to 30.1" << std::endl; + std::cout << "level1.real64Param = " << rval << " after setting to 30.1" + << std::endl; // --- test getParameterInt64Array --- // Array a is not allocated by us. Will be allocated inside getParameter @@ -302,23 +285,23 @@ int realmain(bool leakTest) level1->getParameterArray("int64ArrayParam", a); std::cout << "level1.int64ArrayParam size = " << a.getCount() << std::endl; std::cout << "level1.int64ArrayParam = [ "; - Int64 * buff = (Int64 *)a.getBuffer(); + Int64 *buff = (Int64 *)a.getBuffer(); for (int i = 0; i < int(a.getCount()); ++i) std::cout << buff[i] << " "; std::cout << "]" << std::endl; - + // --- test setParameterInt64Array --- std::cout << "Setting level1.int64ArrayParam to [ 1 2 3 4 ]" << std::endl; std::vector v(4); for (int i = 0; i < 4; ++i) - v[i] = i+1 ; + v[i] = i + 1; Array newa(NTA_BasicType_Int64, &v[0], v.size()); level1->setParameterArray("int64ArrayParam", newa); // get the value of intArrayParam after the setParameter call. // The array a owns its buffer, so we can call releaseBuffer if we - // want, but the buffer should be reused if we just pass it again. + // want, but the buffer should be reused if we just pass it again. // a.releaseBuffer(); level1->getParameterArray("int64ArrayParam", a); std::cout << "level1.int64ArrayParam size = " << a.getCount() << std::endl; @@ -332,7 +315,7 @@ int realmain(bool leakTest) SHOULDFAIL(n.run(1)); // should fail because network can't be initialized - SHOULDFAIL (n.initialize() ); + SHOULDFAIL(n.initialize()); std::cout << "Setting dimensions of level1..." << std::endl; Dimensions d; @@ -340,7 +323,6 @@ int realmain(bool leakTest) d.push_back(4); level1->setDimensions(d); - std::cout << "Initializing again..." << std::endl; n.initialize(); @@ -348,7 +330,7 @@ int realmain(bool leakTest) testCppInputOutputAccess(level1); testCppLinking("TestFanIn2", ""); - testCppLinking("UniformLink","{mapping: in, rfSize: [2]}"); + testCppLinking("UniformLink", "{mapping: in, rfSize: [2]}"); testYAML(); std::cout << "Done -- all tests passed" << std::endl; @@ -356,28 +338,26 @@ int realmain(bool leakTest) return 0; } -int main(int argc, char *argv[]) -{ - - /* +int main(int argc, char *argv[]) { + + /* * Without arguments, this program is a simple end-to-end demo - * of NuPIC 2 functionality, used as a developer tool (when - * we add a feature, we add it to this program. + * of NuPIC 2 functionality, used as a developer tool (when + * we add a feature, we add it to this program. * With an integer argument N, runs the same test N times * and requires that memory use stay constant -- it can't - * grow by even one byte. + * grow by even one byte. */ // TODO: real argument parsing - // Optional arg is number of iterations to do. + // Optional arg is number of iterations to do. NTA_CHECK(argc == 1 || argc == 2); size_t count = 1; - if (argc == 2) - { + if (argc == 2) { std::stringstream ss(argv[1]); ss >> count; } - // Start checking memory usage after this many iterations. + // Start checking memory usage after this many iterations. #if defined(NTA_OS_WINDOWS) // takes longer to settle down on win32 size_t memoryLeakStartIter = 6000; @@ -385,76 +365,70 @@ int main(int argc, char *argv[]) size_t memoryLeakStartIter = 150; #endif - // This determines how frequently we check. + // This determines how frequently we check. size_t memoryLeakDeltaIterCheck = 10; size_t minCount = memoryLeakStartIter + 5 * memoryLeakDeltaIterCheck; - if (count > 1 && count < minCount) - { + if (count > 1 && count < minCount) { std::cout << "Run count of " << count << " specified\n"; - std::cout << "When run in leak detection mode, count must be at least " << minCount << "\n"; + std::cout << "When run in leak detection mode, count must be at least " + << minCount << "\n"; ::exit(1); } - size_t initial_vmem = 0; size_t initial_rmem = 0; size_t current_vmem = 0; size_t current_rmem = 0; try { - for (size_t i = 0; i < count; i++) - { - //MemoryMonitor m; + for (size_t i = 0; i < count; i++) { + // MemoryMonitor m; NuPIC::init(); realmain(count > 1); - //testExceptionBug(); - //testCppLinking("TestFanIn2",""); + // testExceptionBug(); + // testCppLinking("TestFanIn2",""); NuPIC::shutdown(); // memory leak detection // we check even prior to the initial tracking iteration, because the act // of checking potentially modifies our memory usage - if (i % memoryLeakDeltaIterCheck == 0) - { + if (i % memoryLeakDeltaIterCheck == 0) { OS::getProcessMemoryUsage(current_rmem, current_vmem); - if(i == memoryLeakStartIter) - { + if (i == memoryLeakStartIter) { initial_rmem = current_rmem; initial_vmem = current_vmem; } - std::cout << "Memory usage: " << current_vmem << " (virtual) " + std::cout << "Memory usage: " << current_vmem << " (virtual) " << current_rmem << " (real) at iteration " << i << std::endl; - if(i >= memoryLeakStartIter) - { - if (current_vmem > initial_vmem || current_rmem > initial_rmem) - { + if (i >= memoryLeakStartIter) { + if (current_vmem > initial_vmem || current_rmem > initial_rmem) { std::cout << "Tracked memory usage (iteration " << memoryLeakStartIter << "): " << initial_vmem - << " (virtual) " << initial_rmem << " (real)" << std::endl; + << " (virtual) " << initial_rmem << " (real)" + << std::endl; throw std::runtime_error("Memory leak detected"); } } } } - } catch (nupic::Exception& e) { - std::cout - << "Exception: " << e.getMessage() - << " at: " << e.getFilename() << ":" << e.getLineNumber() - << std::endl; + } catch (nupic::Exception &e) { + std::cout << "Exception: " << e.getMessage() << " at: " << e.getFilename() + << ":" << e.getLineNumber() << std::endl; return 1; - } catch (std::exception& e) { + } catch (std::exception &e) { std::cout << "Exception: " << e.what() << "" << std::endl; return 1; - } - catch (...) { - std::cout << "\nHtmTest is exiting because an exception was thrown" << std::endl; + } catch (...) { + std::cout << "\nHtmTest is exiting because an exception was thrown" + << std::endl; return 1; } if (count > 20) - std::cout << "Memory leak check passed -- " << count << " iterations" << std::endl; + std::cout << "Memory leak check passed -- " << count << " iterations" + << std::endl; std::cout << "--- ALL TESTS PASSED ---" << std::endl; return 0; diff --git a/src/test/integration/PyRegionTest.cpp b/src/test/integration/PyRegionTest.cpp index ab4652bfaa..f83a6c5949 100644 --- a/src/test/integration/PyRegionTest.cpp +++ b/src/test/integration/PyRegionTest.cpp @@ -22,86 +22,76 @@ /* This file is similar to CppRegionTest except that it also tests Python nodes. - It is build in nupic.core but tested in nupic. So its execution and README instructions - remains in nupic. + It is build in nupic.core but tested in nupic. So its execution and README + instructions remains in nupic. */ - -#include +#include +#include #include +#include +#include #include #include -#include -#include -#include -#include #include #include -#include -#include -#include // memory leak detection +#include #include +#include // memory leak detection #include #include +#include +#include -#include -#include -#include // fabs/abs +#include // fabs/abs #include // exit #include #include +#include +#include #include bool ignore_negative_tests = false; -#define SHOULDFAIL(statement) \ - { \ - if (!ignore_negative_tests) \ - { \ - bool caughtException = false; \ - try { \ - statement; \ - } catch(std::exception& ) { \ - caughtException = true; \ - std::cout << "Caught exception as expected: " # statement "" << std::endl; \ - } \ - if (!caughtException) { \ - NTA_THROW << "Operation '" #statement "' did not fail as expected"; \ - } \ - } \ +#define SHOULDFAIL(statement) \ + { \ + if (!ignore_negative_tests) { \ + bool caughtException = false; \ + try { \ + statement; \ + } catch (std::exception &) { \ + caughtException = true; \ + std::cout << "Caught exception as expected: " #statement "" \ + << std::endl; \ + } \ + if (!caughtException) { \ + NTA_THROW << "Operation '" #statement "' did not fail as expected"; \ + } \ + } \ } using namespace nupic; bool verbose = false; -struct MemoryMonitor -{ - MemoryMonitor() - { - OS::getProcessMemoryUsage(initial_vmem, initial_rmem); - } +struct MemoryMonitor { + MemoryMonitor() { OS::getProcessMemoryUsage(initial_vmem, initial_rmem); } - ~MemoryMonitor() - { - if (hasMemoryLeaks()) - { - NTA_DEBUG - << "Memory leaks detected. " - << "Real Memory: " << diff_rmem - << ", Virtual Memory: " << diff_vmem; + ~MemoryMonitor() { + if (hasMemoryLeaks()) { + NTA_DEBUG << "Memory leaks detected. " + << "Real Memory: " << diff_rmem + << ", Virtual Memory: " << diff_vmem; } } - void update() - { + void update() { OS::getProcessMemoryUsage(current_vmem, current_rmem); diff_vmem = current_vmem - initial_vmem; diff_rmem = current_rmem - initial_rmem; } - bool hasMemoryLeaks() - { + bool hasMemoryLeaks() { update(); return diff_vmem > 0 || diff_rmem > 0; } @@ -114,29 +104,27 @@ struct MemoryMonitor size_t diff_vmem; }; - -void testPynodeInputOutputAccess(Region * level2) -{ +void testPynodeInputOutputAccess(Region *level2) { // --- input/output access for level 2 (Python py.TestNode) --- - SHOULDFAIL(level2->getOutputData("doesnotexist") ); + SHOULDFAIL(level2->getOutputData("doesnotexist")); // getting access via zero-copy std::cout << "Getting output for zero-copy access" << std::endl; ArrayRef output = level2->getOutputData("bottomUpOut"); - std::cout << "Element count in bottomUpOut is " << output.getCount() << "" << std::endl; - Real64 *data_actual = (Real64*)output.getBuffer(); + std::cout << "Element count in bottomUpOut is " << output.getCount() << "" + << std::endl; + Real64 *data_actual = (Real64 *)output.getBuffer(); // set the actual output data_actual[12] = 54321; } -void testPynodeArrayParameters(Region * level2) -{ +void testPynodeArrayParameters(Region *level2) { // Array a is not allocated by us. Will be allocated inside getParameter Array a(NTA_BasicType_Int64); level2->getParameterArray("int64ArrayParam", a); std::cout << "level2.int64ArrayParam size = " << a.getCount() << std::endl; std::cout << "level2.int64ArrayParam = [ "; - Int64 * buff = (Int64 *)a.getBuffer(); + Int64 *buff = (Int64 *)a.getBuffer(); for (int i = 0; i < int(a.getCount()); ++i) std::cout << buff[i] << " "; std::cout << "]" << std::endl; @@ -145,7 +133,7 @@ void testPynodeArrayParameters(Region * level2) std::cout << "Setting level2.int64ArrayParam to [ 1 2 3 4 ]" << std::endl; std::vector v(4); for (int i = 0; i < 4; ++i) - v[i] = i+1 ; + v[i] = i + 1; Array newa(NTA_BasicType_Int64, &v[0], v.size()); level2->setParameterArray("int64ArrayParam", newa); @@ -161,13 +149,11 @@ void testPynodeArrayParameters(Region * level2) std::cout << "]" << std::endl; } - -void testPynodeLinking() -{ +void testPynodeLinking() { Network net = Network(); - Region * region1 = net.addRegion("region1", "TestNode", ""); - Region * region2 = net.addRegion("region2", "py.TestNode", ""); + Region *region1 = net.addRegion("region1", "TestNode", ""); + Region *region2 = net.addRegion("region2", "py.TestNode", ""); std::cout << "Linking region 1 to region 2" << std::endl; net.link("region1", "region2", "TestFanIn2", ""); @@ -180,7 +166,7 @@ void testPynodeLinking() std::cout << "Initializing network..." << std::endl; net.initialize(); - const Dimensions& r2dims = region2->getDimensions(); + const Dimensions &r2dims = region2->getDimensions(); NTA_CHECK(r2dims.size() == 2) << " actual dims: " << r2dims.toString(); NTA_CHECK(r2dims[0] == 3) << " actual dims: " << r2dims.toString(); NTA_CHECK(r2dims[1] == 2) << " actual dims: " << r2dims.toString(); @@ -190,52 +176,47 @@ void testPynodeLinking() region1->compute(); std::cout << "Checking region1 output after first iteration..." << std::endl; - Real64 *buffer = (Real64*) r1OutputArray.getBuffer(); + Real64 *buffer = (Real64 *)r1OutputArray.getBuffer(); - for (size_t i = 0; i < r1OutputArray.getCount(); i++) - { + for (size_t i = 0; i < r1OutputArray.getCount(); i++) { if (verbose) std::cout << " " << i << " " << buffer[i] << "" << std::endl; - if (i%2 == 0) + if (i % 2 == 0) NTA_CHECK(buffer[i] == 0); else - NTA_CHECK(buffer[i] == (i-1)/2); + NTA_CHECK(buffer[i] == (i - 1) / 2); } region2->prepareInputs(); ArrayRef r2InputArray = region2->getInputData("bottomUpIn"); std::cout << "Region 2 input after first iteration:" << std::endl; - Real64 *buffer2 = (Real64*) r2InputArray.getBuffer(); + Real64 *buffer2 = (Real64 *)r2InputArray.getBuffer(); NTA_CHECK(buffer != buffer2); - for (size_t i = 0; i < r2InputArray.getCount(); i++) - { + for (size_t i = 0; i < r2InputArray.getCount(); i++) { if (verbose) std::cout << " " << i << " " << buffer2[i] << "" << std::endl; - if (i%2 == 0) + if (i % 2 == 0) NTA_CHECK(buffer[i] == 0); else - NTA_CHECK(buffer[i] == (i-1)/2); + NTA_CHECK(buffer[i] == (i - 1) / 2); } std::cout << "Region 2 input by node" << std::endl; std::vector r2NodeInput; - for (size_t node = 0; node < 6; node++) - { + for (size_t node = 0; node < 6; node++) { region2->getInput("bottomUpIn")->getInputForNode(node, r2NodeInput); - if (verbose) - { + if (verbose) { std::cout << "Node " << node << ": "; - for (size_t i = 0; i < r2NodeInput.size(); i++) - { + for (size_t i = 0; i < r2NodeInput.size(); i++) { std::cout << r2NodeInput[i] << " "; } std::cout << "" << std::endl; } // 4 nodes in r1 fan in to 1 node in r2 - int row = node/3; + int row = node / 3; int col = node - (row * 3); NTA_CHECK(r2NodeInput.size() == 8); NTA_CHECK(r2NodeInput[0] == 0); @@ -243,66 +224,62 @@ void testPynodeLinking() NTA_CHECK(r2NodeInput[4] == 0); NTA_CHECK(r2NodeInput[6] == 0); // these values are specific to the fanin2 link policy - NTA_CHECK(r2NodeInput[1] == row * 12 + col * 2) - << "row: " << row << " col: " << col << " val: " << r2NodeInput[1]; - NTA_CHECK(r2NodeInput[3] == row * 12 + col * 2 + 1) - << "row: " << row << " col: " << col << " val: " << r2NodeInput[3]; + NTA_CHECK(r2NodeInput[1] == row * 12 + col * 2) + << "row: " << row << " col: " << col << " val: " << r2NodeInput[1]; + NTA_CHECK(r2NodeInput[3] == row * 12 + col * 2 + 1) + << "row: " << row << " col: " << col << " val: " << r2NodeInput[3]; NTA_CHECK(r2NodeInput[5] == row * 12 + 6 + col * 2) - << "row: " << row << " col: " << col << " val: " << r2NodeInput[5]; + << "row: " << row << " col: " << col << " val: " << r2NodeInput[5]; NTA_CHECK(r2NodeInput[7] == row * 12 + 6 + col * 2 + 1) - << "row: " << row << " col: " << col << " val: " << r2NodeInput[7]; + << "row: " << row << " col: " << col << " val: " << r2NodeInput[7]; } region2->compute(); } -void testSecondTimeLeak() -{ +void testSecondTimeLeak() { Network n; n.addRegion("r1", "py.TestNode", ""); n.addRegion("r2", "py.TestNode", ""); } -void testRegionDuplicateRegister() -{ +void testRegionDuplicateRegister() { // Register a region Network::registerPyRegion("nupic.regions.TestDuplicateNodes", "TestDuplicateNodes"); // Validate that the same region can be registered multiple times - try - { + try { Network::registerPyRegion("nupic.regions.TestDuplicateNodes", "TestDuplicateNodes"); - } catch (std::exception& e) { + } catch (std::exception &e) { NTA_THROW << "testRegionDuplicateRegister failed with exception: '" << e.what() << "'"; } // Validate that a region from a different module but with the same name // cannot be registered - try - { + try { Network::registerPyRegion("nupic.regions.DifferentModule", "TestDuplicateNodes"); NTA_THROW << "testRegionDuplicateRegister failed to throw exception for " << "region with same name but different module as existing " << "registered region"; - } catch (std::exception& e) { + } catch (std::exception &e) { } } -void testCreationParamTypes() -{ +void testCreationParamTypes() { // Verify that parameters of all types can be passed in through the creation // params. Network n; - Region* region = n.addRegion("test", "py.TestNode", - "{" - "int32Param: -2000000000, uint32Param: 3000000000, " - "int64Param: -5000000000, uint64Param: 5000000001, " - "real32Param: 10.5, real64Param: 11.5, " - "boolParam: true" - "}"); + Region *region = + n.addRegion("test", "py.TestNode", + "{" + "int32Param: -2000000000, uint32Param: 3000000000, " + "int64Param: -5000000000, uint64Param: 5000000001, " + "real32Param: 10.5, real64Param: 11.5, " + "boolParam: true" + "}"); NTA_CHECK(region->getParameterInt32("int32Param") == -2000000000); NTA_CHECK(region->getParameterUInt32("uint32Param") == 3000000000); @@ -313,32 +290,27 @@ void testCreationParamTypes() NTA_CHECK(region->getParameterBool("boolParam") == true); } -void testUnregisterRegion() -{ +void testUnregisterRegion() { Network n; n.addRegion("test", "py.TestNode", ""); Network::unregisterPyRegion("TestNode"); bool caughtException = false; - try - { + try { n.addRegion("test", "py.TestNode", ""); - } catch (std::exception& e) { + } catch (std::exception &e) { NTA_DEBUG << "Caught exception as expected: '" << e.what() << "'"; caughtException = true; } - if (caughtException) - { + if (caughtException) { NTA_DEBUG << "testUnregisterRegion passed"; } else { NTA_THROW << "testUnregisterRegion did not throw an exception as expected"; } - } -void testWriteRead() -{ +void testWriteRead() { Int32 int32Param = 42; UInt32 uint32Param = 43; Int64 int64Param = 44; @@ -349,34 +321,27 @@ void testWriteRead() std::string stringParam = "hello"; std::vector int64ArrayParamBuff(4); - for (int i = 0; i < 4; i++) - { + for (int i = 0; i < 4; i++) { int64ArrayParamBuff[i] = i + 1; } - Array int64ArrayParam(NTA_BasicType_Int64, - &int64ArrayParamBuff[0], + Array int64ArrayParam(NTA_BasicType_Int64, &int64ArrayParamBuff[0], int64ArrayParamBuff.size()); std::vector real32ArrayParamBuff(4); - for (int i = 0; i < 4; i++) - { + for (int i = 0; i < 4; i++) { real32ArrayParamBuff[i] = i + 1; } - Array real32ArrayParam(NTA_BasicType_Real32, - &real32ArrayParamBuff[0], + Array real32ArrayParam(NTA_BasicType_Real32, &real32ArrayParamBuff[0], real32ArrayParamBuff.size()); bool boolArrayParamBuff[4]; - for (int i = 0; i < 4; i++) - { + for (int i = 0; i < 4; i++) { boolArrayParamBuff[i] = (i % 2) == 1; } - Array boolArrayParam(NTA_BasicType_Bool, - boolArrayParamBuff, - 4); + Array boolArrayParam(NTA_BasicType_Bool, boolArrayParamBuff, 4); Network n1; - Region* region1 = n1.addRegion("rw1", "py.TestNode", ""); + Region *region1 = n1.addRegion("rw1", "py.TestNode", ""); region1->setParameterInt32("int32Param", int32Param); region1->setParameterUInt32("uint32Param", uint32Param); region1->setParameterInt64("int64Param", int64Param); @@ -395,9 +360,9 @@ void testWriteRead() n1.write(ss); n2.read(ss); - const Collection& regions = n2.getRegions(); - const std::pair& regionPair = regions.getByIndex(0); - Region* region2 = regionPair.second; + const Collection ®ions = n2.getRegions(); + const std::pair ®ionPair = regions.getByIndex(0); + Region *region2 = regionPair.second; NTA_CHECK(region2->getParameterInt32("int32Param") == int32Param); NTA_CHECK(region2->getParameterUInt32("uint32Param") == uint32Param); @@ -410,51 +375,49 @@ void testWriteRead() Array int64Array(NTA_BasicType_Int64); region2->getParameterArray("int64ArrayParam", int64Array); - Int64 * int64ArrayBuff = (Int64 *)int64Array.getBuffer(); + Int64 *int64ArrayBuff = (Int64 *)int64Array.getBuffer(); NTA_CHECK(int64ArrayParam.getCount() == int64Array.getCount()); - for (int i = 0; i < int(int64ArrayParam.getCount()); i++) - { + for (int i = 0; i < int(int64ArrayParam.getCount()); i++) { NTA_CHECK(int64ArrayBuff[i] == int64ArrayParamBuff[i]); } Array real32Array(NTA_BasicType_Real32); region2->getParameterArray("real32ArrayParam", real32Array); - Real32 * real32ArrayBuff = (Real32 *)real32Array.getBuffer(); + Real32 *real32ArrayBuff = (Real32 *)real32Array.getBuffer(); NTA_CHECK(real32ArrayParam.getCount() == real32Array.getCount()); - for (int i = 0; i < int(real32ArrayParam.getCount()); i++) - { + for (int i = 0; i < int(real32ArrayParam.getCount()); i++) { NTA_CHECK(real32ArrayBuff[i] == real32ArrayParamBuff[i]); } Array boolArray(NTA_BasicType_Bool); region2->getParameterArray("boolArrayParam", boolArray); - bool * boolArrayBuff = (bool *)boolArray.getBuffer(); + bool *boolArrayBuff = (bool *)boolArray.getBuffer(); NTA_CHECK(boolArrayParam.getCount() == boolArray.getCount()); - for (int i = 0; i < int(boolArrayParam.getCount()); i++) - { + for (int i = 0; i < int(boolArrayParam.getCount()); i++) { NTA_CHECK(boolArrayBuff[i] == boolArrayParamBuff[i]); } } -int realmain(bool leakTest) -{ +int realmain(bool leakTest) { // verbose == true turns on extra output that is useful for // debugging the test (e.g. when the TestNode compute() // algorithm changes) - std::cout << "Creating network..." << std::endl; Network n; - std::cout << "Region count is " << n.getRegions().getCount() << "" << std::endl; + std::cout << "Region count is " << n.getRegions().getCount() << "" + << std::endl; std::cout << "Adding a PyNode region..." << std::endl; Network::registerPyRegion("nupic.bindings.regions.TestNode", "TestNode"); - Region* level2 = n.addRegion("level2", "py.TestNode", "{int32Param: 444}"); + Region *level2 = n.addRegion("level2", "py.TestNode", "{int32Param: 444}"); - std::cout << "Region count is " << n.getRegions().getCount() << "" << std::endl; + std::cout << "Region count is " << n.getRegions().getCount() << "" + << std::endl; std::cout << "Node type: " << level2->getType() << "" << std::endl; - std::cout << "Nodespec is:\n" << level2->getSpec()->toString() << "" << std::endl; + std::cout << "Nodespec is:\n" + << level2->getSpec()->toString() << "" << std::endl; Real64 rval; std::string int64Param("int64Param"); @@ -476,14 +439,13 @@ int realmain(bool leakTest) SHOULDFAIL(n.run(1)); // should fail because network can't be initialized - SHOULDFAIL (n.initialize() ); + SHOULDFAIL(n.initialize()); std::cout << "Setting dimensions of level1..." << std::endl; Dimensions d; d.push_back(4); d.push_back(4); - std::cout << "Setting dimensions of level2..." << std::endl; level2->setDimensions(d); @@ -496,10 +458,9 @@ int realmain(bool leakTest) testRegionDuplicateRegister(); testCreationParamTypes(); - if (!leakTest) - { - //testNuPIC1x(); - //testPynode1xLinking(); + if (!leakTest) { + // testNuPIC1x(); + // testPynode1xLinking(); } #if !CAPNP_LITE // PyRegion::write is implemented only when nupic.core is compiled with @@ -516,8 +477,7 @@ int realmain(bool leakTest) return 0; } -int main(int argc, char *argv[]) -{ +int main(int argc, char *argv[]) { // This isn't running inside one of the SWIG modules, so we need to // initialize the numpy C API. Py_Initialize(); @@ -537,8 +497,7 @@ int main(int argc, char *argv[]) // Optional arg is number of iterations to do. NTA_CHECK(argc == 1 || argc == 2); size_t count = 1; - if (argc == 2) - { + if (argc == 2) { std::stringstream ss(argv[1]); ss >> count; } @@ -555,75 +514,69 @@ int main(int argc, char *argv[]) size_t minCount = memoryLeakStartIter + 5 * memoryLeakDeltaIterCheck; - if (count > 1 && count < minCount) - { + if (count > 1 && count < minCount) { std::cout << "Run count of " << count << " specified\n"; - std::cout << "When run in leak detection mode, count must be at least " << minCount << "\n"; + std::cout << "When run in leak detection mode, count must be at least " + << minCount << "\n"; ::exit(1); } - size_t initial_vmem = 0; size_t initial_rmem = 0; size_t current_vmem = 0; size_t current_rmem = 0; try { - for (size_t i = 0; i < count; i++) - { - //MemoryMonitor m; + for (size_t i = 0; i < count; i++) { + // MemoryMonitor m; NuPIC::init(); realmain(count > 1); - //testExceptionBug(); - //testPynode1xLinking(); + // testExceptionBug(); + // testPynode1xLinking(); // testNuPIC1x(); - //testSecondTimeLeak(); - //testPynodeLinking(); - //testCppLinking("TestFanIn2",""); + // testSecondTimeLeak(); + // testPynodeLinking(); + // testCppLinking("TestFanIn2",""); NuPIC::shutdown(); // memory leak detection // we check even prior to the initial tracking iteration, because the act // of checking potentially modifies our memory usage - if (i % memoryLeakDeltaIterCheck == 0) - { + if (i % memoryLeakDeltaIterCheck == 0) { OS::getProcessMemoryUsage(current_rmem, current_vmem); - if(i == memoryLeakStartIter) - { + if (i == memoryLeakStartIter) { initial_rmem = current_rmem; initial_vmem = current_vmem; } std::cout << "Memory usage: " << current_vmem << " (virtual) " << current_rmem << " (real) at iteration " << i << std::endl; - if(i >= memoryLeakStartIter) - { - if (current_vmem > initial_vmem || current_rmem > initial_rmem) - { + if (i >= memoryLeakStartIter) { + if (current_vmem > initial_vmem || current_rmem > initial_rmem) { std::cout << "Tracked memory usage (iteration " << memoryLeakStartIter << "): " << initial_vmem - << " (virtual) " << initial_rmem << " (real)" << std::endl; + << " (virtual) " << initial_rmem << " (real)" + << std::endl; throw std::runtime_error("Memory leak detected"); } } } } - } catch (nupic::Exception& e) { - std::cout - << "Exception: " << e.getMessage() - << " at: " << e.getFilename() << ":" << e.getLineNumber() - << std::endl; + } catch (nupic::Exception &e) { + std::cout << "Exception: " << e.getMessage() << " at: " << e.getFilename() + << ":" << e.getLineNumber() << std::endl; return 1; - } catch (std::exception& e) { + } catch (std::exception &e) { std::cout << "Exception: " << e.what() << "" << std::endl; return 1; - } - catch (...) { - std::cout << "\nhtmtest is exiting because an exception was thrown" << std::endl; + } catch (...) { + std::cout << "\nhtmtest is exiting because an exception was thrown" + << std::endl; return 1; } if (count > 20) - std::cout << "Memory leak check passed -- " << count << " iterations" << std::endl; + std::cout << "Memory leak check passed -- " << count << " iterations" + << std::endl; std::cout << "--- ALL TESTS PASSED ---" << std::endl; return 0; diff --git a/src/test/unit/UnitTestMain.cpp b/src/test/unit/UnitTestMain.cpp index e8ee037da8..159ccadd0b 100644 --- a/src/test/unit/UnitTestMain.cpp +++ b/src/test/unit/UnitTestMain.cpp @@ -20,7 +20,7 @@ * --------------------------------------------------------------------- */ -/** @file +/** @file Google test main program */ @@ -29,8 +29,8 @@ Google test main program #define WIN32_LEAN_AND_MEAN #endif -#include #include +#include using namespace std; using namespace nupic; @@ -38,12 +38,13 @@ using namespace nupic; // APR must be explicit initialized #include -int main(int argc, char ** argv) { +int main(int argc, char **argv) { // initialize APR - apr_status_t result; - result = apr_app_initialize(&argc, (char const *const **)&argv, nullptr /*env*/); - if (result) + apr_status_t result; + result = + apr_app_initialize(&argc, (char const *const **)&argv, nullptr /*env*/); + if (result) NTA_THROW << "error initializing APR. Err code: " << result; // initialize GoogleTest diff --git a/src/test/unit/algorithms/AnomalyTest.cpp b/src/test/unit/algorithms/AnomalyTest.cpp index d30f22db8a..c212ae7c12 100644 --- a/src/test/unit/algorithms/AnomalyTest.cpp +++ b/src/test/unit/algorithms/AnomalyTest.cpp @@ -20,8 +20,8 @@ * --------------------------------------------------------------------- */ -#include #include +#include #include "gtest/gtest.h" @@ -31,122 +31,90 @@ using namespace nupic::algorithms::anomaly; using namespace nupic; - -TEST(ComputeRawAnomalyScore, NoActiveOrPredicted) -{ +TEST(ComputeRawAnomalyScore, NoActiveOrPredicted) { std::vector active; std::vector predicted; ASSERT_FLOAT_EQ(computeRawAnomalyScore(active, predicted), 0.0); }; - -TEST(ComputeRawAnomalyScore, NoActive) -{ +TEST(ComputeRawAnomalyScore, NoActive) { std::vector active; std::vector predicted = {3, 5}; ASSERT_FLOAT_EQ(computeRawAnomalyScore(active, predicted), 0.0); }; - -TEST(ComputeRawAnomalyScore, PerfectMatch) -{ +TEST(ComputeRawAnomalyScore, PerfectMatch) { std::vector active = {3, 5, 7}; std::vector predicted = {3, 5, 7}; ASSERT_FLOAT_EQ(computeRawAnomalyScore(active, predicted), 0.0); }; - -TEST(ComputeRawAnomalyScore, NoMatch) -{ +TEST(ComputeRawAnomalyScore, NoMatch) { std::vector active = {2, 4, 6}; std::vector predicted = {3, 5, 7}; ASSERT_FLOAT_EQ(computeRawAnomalyScore(active, predicted), 1.0); }; - -TEST(ComputeRawAnomalyScore, PartialMatch) -{ +TEST(ComputeRawAnomalyScore, PartialMatch) { std::vector active = {2, 3, 6}; std::vector predicted = {3, 5, 7}; ASSERT_FLOAT_EQ(computeRawAnomalyScore(active, predicted), 2.0 / 3.0); }; - -TEST(Anomaly, ComputeScoreNoActiveOrPredicted) -{ +TEST(Anomaly, ComputeScoreNoActiveOrPredicted) { std::vector active; std::vector predicted; Anomaly a; ASSERT_FLOAT_EQ(a.compute(active, predicted), 0.0); } - -TEST(Anomaly, ComputeScoreNoActive) -{ +TEST(Anomaly, ComputeScoreNoActive) { std::vector active; std::vector predicted = {3, 5}; Anomaly a; ASSERT_FLOAT_EQ(a.compute(active, predicted), 0.0); } - -TEST(Anomaly, ComputeScorePerfectMatch) -{ +TEST(Anomaly, ComputeScorePerfectMatch) { std::vector active = {3, 5, 7}; std::vector predicted = {3, 5, 7}; Anomaly a; ASSERT_FLOAT_EQ(a.compute(active, predicted), 0.0); } - -TEST(Anomaly, ComputeScoreNoMatch) -{ +TEST(Anomaly, ComputeScoreNoMatch) { std::vector active = {2, 4, 6}; std::vector predicted = {3, 5, 7}; Anomaly a; ASSERT_FLOAT_EQ(a.compute(active, predicted), 1.0); } - -TEST(Anomaly, ComputeScorePartialMatch) -{ +TEST(Anomaly, ComputeScorePartialMatch) { std::vector active = {2, 3, 6}; std::vector predicted = {3, 5, 7}; Anomaly a; ASSERT_FLOAT_EQ(a.compute(active, predicted), 2.0 / 3.0); } - -TEST(Anomaly, Cumulative) -{ +TEST(Anomaly, Cumulative) { const int TEST_COUNT = 9; Anomaly a{3}; - std::vector< std::vector > preds{TEST_COUNT, {1, 2, 6}}; - - std::vector< std::vector > acts = { - {1, 2, 6}, - {1, 2, 6}, - {1, 4, 6}, - {10, 11, 6}, - {10, 11, 12}, - {10, 11, 12}, - {10, 11, 12}, - {1, 2, 6}, - {1, 2, 6} - }; - - std::vector expected = {0.0, 0.0, 1.0/9.0, 3.0/9.0, 2.0/3.0, 8.0/9.0, - 1.0, 2.0/3.0, 1.0/3.0}; - - for (int index = 0; index < TEST_COUNT; index++) - { - ASSERT_FLOAT_EQ(a.compute(acts[index], preds[index]), expected[index]); + std::vector> preds{TEST_COUNT, {1, 2, 6}}; + + std::vector> acts = { + {1, 2, 6}, {1, 2, 6}, {1, 4, 6}, {10, 11, 6}, {10, 11, 12}, + {10, 11, 12}, {10, 11, 12}, {1, 2, 6}, {1, 2, 6}}; + + std::vector expected = {0.0, 0.0, 1.0 / 9.0, + 3.0 / 9.0, 2.0 / 3.0, 8.0 / 9.0, + 1.0, 2.0 / 3.0, 1.0 / 3.0}; + + for (int index = 0; index < TEST_COUNT; index++) { + ASSERT_FLOAT_EQ(a.compute(acts[index], preds[index]), expected[index]); } } - -TEST(Anomaly, SelectModePure) -{ +TEST(Anomaly, SelectModePure) { Anomaly a{0, AnomalyMode::PURE, 0}; std::vector active = {2, 3, 6}; std::vector predicted = {3, 5, 7}; diff --git a/src/test/unit/algorithms/Cells4Test.cpp b/src/test/unit/algorithms/Cells4Test.cpp index d572bda90a..efb5ce0916 100644 --- a/src/test/unit/algorithms/Cells4Test.cpp +++ b/src/test/unit/algorithms/Cells4Test.cpp @@ -29,24 +29,20 @@ #include #include -#include #include +#include #include #include #include // is_in #include - using namespace nupic::algorithms::Cells4; - - template -std::vector _getOrderedSrcCellIndexesForSrcCells(const Segment& segment, +std::vector _getOrderedSrcCellIndexesForSrcCells(const Segment &segment, InputIterator first, - InputIterator last) -{ + InputIterator last) { std::vector result; const std::set srcCellsSet(first, last); @@ -61,12 +57,10 @@ std::vector _getOrderedSrcCellIndexesForSrcCells(const Segment& segment, return result; } - template -std::vector _getOrderedSynapseIndexesForSrcCells(const Segment& segment, +std::vector _getOrderedSynapseIndexesForSrcCells(const Segment &segment, InputIterator first, - InputIterator last) -{ + InputIterator last) { std::vector result; const std::set srcCellsSet(first, last); @@ -81,15 +75,12 @@ std::vector _getOrderedSynapseIndexesForSrcCells(const Segment& segment, return result; } - /** * Simple comparison function that does the easy checks. It can be expanded to * cover more of the attributes of Cells4 in the future. */ -bool checkCells4Attributes(const Cells4& c1, const Cells4& c2) -{ - if (c1.nSegments() != c2.nSegments() || - c1.nCells() != c2.nCells() || +bool checkCells4Attributes(const Cells4 &c1, const Cells4 &c2) { + if (c1.nSegments() != c2.nSegments() || c1.nCells() != c2.nCells() || c1.nColumns() != c2.nColumns() || c1.nCellsPerCol() != c2.nCellsPerCol() || c1.getMinThreshold() != c2.getMinThreshold() || @@ -107,18 +98,15 @@ bool checkCells4Attributes(const Cells4& c1, const Cells4& c2) c1.getMaxSegmentsPerCell() != c2.getMaxSegmentsPerCell() || c1.getMaxSynapsesPerSegment() != c2.getMaxSynapsesPerSegment() || - c1.getCheckSynapseConsistency() != c2.getCheckSynapseConsistency()) - { + c1.getCheckSynapseConsistency() != c2.getCheckSynapseConsistency()) { return false; } return true; } - -TEST(Cells4Test, capnpSerialization) -{ - Cells4 cells( - 10, 2, 1, 1, 1, 1, 0.5, 0.8, 1, 0.1, 0.1, 0, false, -1, true, false); +TEST(Cells4Test, capnpSerialization) { + Cells4 cells(10, 2, 1, 1, 1, 1, 0.5, 0.8, 1, 0.1, 0.1, 0, false, -1, true, + false); std::vector input1(10, 0.0); input1[1] = 1.0; input1[4] = 1.0; @@ -139,9 +127,8 @@ TEST(Cells4Test, capnpSerialization) input4[4] = 1.0; input4[7] = 1.0; input4[8] = 1.0; - std::vector output(10*2); - for (UInt i = 0; i < 10; ++i) - { + std::vector output(10 * 2); + for (UInt i = 0; i < 10; ++i) { cells.compute(&input1.front(), &output.front(), true, true); cells.compute(&input2.front(), &output.front(), true, true); cells.compute(&input3.front(), &output.front(), true, true); @@ -166,69 +153,59 @@ TEST(Cells4Test, capnpSerialization) NTA_CHECK(checkCells4Attributes(cells, secondCells)); - std::vector secondOutput(10*2); + std::vector secondOutput(10 * 2); cells.compute(&input1.front(), &output.front(), true, true); secondCells.compute(&input1.front(), &secondOutput.front(), true, true); - for (UInt i = 0; i < 10; ++i) - { + for (UInt i = 0; i < 10; ++i) { ASSERT_EQ(output[i], secondOutput[i]) << "Outputs differ at index " << i; } NTA_CHECK(checkCells4Attributes(cells, secondCells)); } - - /* * Test Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment. */ -TEST(Cells4Test, generateListsOfSynapsesToAdjustForAdaptSegment) -{ +TEST(Cells4Test, generateListsOfSynapsesToAdjustForAdaptSegment) { Segment segment; - const std::set srcCells {99, 88, 77, 66, 55, 44, 33, 22, 11, 0}; + const std::set srcCells{99, 88, 77, 66, 55, 44, 33, 22, 11, 0}; - segment.addSynapses(srcCells, - 0.8/*initStrength*/, - 0.5/*permConnected*/); + segment.addSynapses(srcCells, 0.8 /*initStrength*/, 0.5 /*permConnected*/); - std::set synapsesSet {222, 111, 77, 55, 22, 0}; + std::set synapsesSet{222, 111, 77, 55, 22, 0}; - std::vector inactiveSrcCellIdxs {(UInt)-1}; - std::vector inactiveSynapseIdxs {(UInt)-1}; - std::vector activeSrcCellIdxs {(UInt)-1}; - std::vector activeSynapseIdxs {(UInt)-1}; + std::vector inactiveSrcCellIdxs{(UInt)-1}; + std::vector inactiveSynapseIdxs{(UInt)-1}; + std::vector activeSrcCellIdxs{(UInt)-1}; + std::vector activeSynapseIdxs{(UInt)-1}; - const std::set expectedInactiveSrcCellSet {99, 88, 66, 44, 33, 11}; - const std::set expectedActiveSrcCellSet {77, 55, 22, 0}; - const std::set expectedSynapsesSet {222, 111}; + const std::set expectedInactiveSrcCellSet{99, 88, 66, 44, 33, 11}; + const std::set expectedActiveSrcCellSet{77, 55, 22, 0}; + const std::set expectedSynapsesSet{222, 111}; const std::vector expectedInactiveSrcCellIdxs = - _getOrderedSrcCellIndexesForSrcCells(segment, - expectedInactiveSrcCellSet.begin(), - expectedInactiveSrcCellSet.end()); + _getOrderedSrcCellIndexesForSrcCells(segment, + expectedInactiveSrcCellSet.begin(), + expectedInactiveSrcCellSet.end()); const std::vector expectedInactiveSynapseIdxs = - _getOrderedSynapseIndexesForSrcCells(segment, - expectedInactiveSrcCellSet.begin(), - expectedInactiveSrcCellSet.end()); + _getOrderedSynapseIndexesForSrcCells(segment, + expectedInactiveSrcCellSet.begin(), + expectedInactiveSrcCellSet.end()); const std::vector expectedActiveSrcCellIdxs = - _getOrderedSrcCellIndexesForSrcCells(segment, - expectedActiveSrcCellSet.begin(), - expectedActiveSrcCellSet.end()); + _getOrderedSrcCellIndexesForSrcCells(segment, + expectedActiveSrcCellSet.begin(), + expectedActiveSrcCellSet.end()); const std::vector expectedActiveSynapseIdxs = - _getOrderedSynapseIndexesForSrcCells(segment, - expectedActiveSrcCellSet.begin(), - expectedActiveSrcCellSet.end()); + _getOrderedSynapseIndexesForSrcCells(segment, + expectedActiveSrcCellSet.begin(), + expectedActiveSrcCellSet.end()); Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment( - segment, - synapsesSet, - inactiveSrcCellIdxs, - inactiveSynapseIdxs, - activeSrcCellIdxs, - activeSynapseIdxs); + segment, synapsesSet, inactiveSrcCellIdxs, inactiveSynapseIdxs, + activeSrcCellIdxs, activeSynapseIdxs); ASSERT_EQ(expectedSynapsesSet, synapsesSet); ASSERT_EQ(expectedInactiveSrcCellIdxs, inactiveSrcCellIdxs); @@ -237,60 +214,52 @@ TEST(Cells4Test, generateListsOfSynapsesToAdjustForAdaptSegment) ASSERT_EQ(expectedActiveSynapseIdxs, activeSynapseIdxs); } - - /* * Test Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment with new * synapes, but no active synapses. */ -TEST(Cells4Test, generateListsOfSynapsesToAdjustForAdaptSegmentWithOnlyNewSynapses) -{ +TEST(Cells4Test, + generateListsOfSynapsesToAdjustForAdaptSegmentWithOnlyNewSynapses) { Segment segment; - const std::set srcCells {99, 88, 77, 66, 55}; + const std::set srcCells{99, 88, 77, 66, 55}; - segment.addSynapses(srcCells, - 0.8/*initStrength*/, - 0.5/*permConnected*/); + segment.addSynapses(srcCells, 0.8 /*initStrength*/, 0.5 /*permConnected*/); - std::set synapsesSet {222, 111}; + std::set synapsesSet{222, 111}; - std::vector inactiveSrcCellIdxs {(UInt)-1}; - std::vector inactiveSynapseIdxs {(UInt)-1}; - std::vector activeSrcCellIdxs {(UInt)-1}; - std::vector activeSynapseIdxs {(UInt)-1}; + std::vector inactiveSrcCellIdxs{(UInt)-1}; + std::vector inactiveSynapseIdxs{(UInt)-1}; + std::vector activeSrcCellIdxs{(UInt)-1}; + std::vector activeSynapseIdxs{(UInt)-1}; - const std::set expectedInactiveSrcCellSet {99, 88, 77, 66, 55}; - const std::set expectedActiveSrcCellSet {}; - const std::set expectedSynapsesSet {222, 111}; + const std::set expectedInactiveSrcCellSet{99, 88, 77, 66, 55}; + const std::set expectedActiveSrcCellSet{}; + const std::set expectedSynapsesSet{222, 111}; const std::vector expectedInactiveSrcCellIdxs = - _getOrderedSrcCellIndexesForSrcCells(segment, - expectedInactiveSrcCellSet.begin(), - expectedInactiveSrcCellSet.end()); + _getOrderedSrcCellIndexesForSrcCells(segment, + expectedInactiveSrcCellSet.begin(), + expectedInactiveSrcCellSet.end()); const std::vector expectedInactiveSynapseIdxs = - _getOrderedSynapseIndexesForSrcCells(segment, - expectedInactiveSrcCellSet.begin(), - expectedInactiveSrcCellSet.end()); + _getOrderedSynapseIndexesForSrcCells(segment, + expectedInactiveSrcCellSet.begin(), + expectedInactiveSrcCellSet.end()); const std::vector expectedActiveSrcCellIdxs = - _getOrderedSrcCellIndexesForSrcCells(segment, - expectedActiveSrcCellSet.begin(), - expectedActiveSrcCellSet.end()); + _getOrderedSrcCellIndexesForSrcCells(segment, + expectedActiveSrcCellSet.begin(), + expectedActiveSrcCellSet.end()); const std::vector expectedActiveSynapseIdxs = - _getOrderedSynapseIndexesForSrcCells(segment, - expectedActiveSrcCellSet.begin(), - expectedActiveSrcCellSet.end()); + _getOrderedSynapseIndexesForSrcCells(segment, + expectedActiveSrcCellSet.begin(), + expectedActiveSrcCellSet.end()); Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment( - segment, - synapsesSet, - inactiveSrcCellIdxs, - inactiveSynapseIdxs, - activeSrcCellIdxs, - activeSynapseIdxs); + segment, synapsesSet, inactiveSrcCellIdxs, inactiveSynapseIdxs, + activeSrcCellIdxs, activeSynapseIdxs); ASSERT_EQ(expectedSynapsesSet, synapsesSet); ASSERT_EQ(expectedInactiveSrcCellIdxs, inactiveSrcCellIdxs); @@ -299,60 +268,52 @@ TEST(Cells4Test, generateListsOfSynapsesToAdjustForAdaptSegmentWithOnlyNewSynaps ASSERT_EQ(expectedActiveSynapseIdxs, activeSynapseIdxs); } - - /* * Test Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment with active * synapses, but no new synapses. */ -TEST(Cells4Test, generateListsOfSynapsesToAdjustForAdaptSegmentWithoutNewSynapses) -{ +TEST(Cells4Test, + generateListsOfSynapsesToAdjustForAdaptSegmentWithoutNewSynapses) { Segment segment; - const std::set srcCells {99, 88, 77, 66, 55}; + const std::set srcCells{99, 88, 77, 66, 55}; - segment.addSynapses(srcCells, - 0.8/*initStrength*/, - 0.5/*permConnected*/); + segment.addSynapses(srcCells, 0.8 /*initStrength*/, 0.5 /*permConnected*/); - std::set synapsesSet {88, 66}; + std::set synapsesSet{88, 66}; - std::vector inactiveSrcCellIdxs {(UInt)-1}; - std::vector inactiveSynapseIdxs {(UInt)-1}; - std::vector activeSrcCellIdxs {(UInt)-1}; - std::vector activeSynapseIdxs {(UInt)-1}; + std::vector inactiveSrcCellIdxs{(UInt)-1}; + std::vector inactiveSynapseIdxs{(UInt)-1}; + std::vector activeSrcCellIdxs{(UInt)-1}; + std::vector activeSynapseIdxs{(UInt)-1}; - const std::set expectedInactiveSrcCellSet {99, 77, 55}; - const std::set expectedActiveSrcCellSet {88, 66}; - const std::set expectedSynapsesSet {}; + const std::set expectedInactiveSrcCellSet{99, 77, 55}; + const std::set expectedActiveSrcCellSet{88, 66}; + const std::set expectedSynapsesSet{}; const std::vector expectedInactiveSrcCellIdxs = - _getOrderedSrcCellIndexesForSrcCells(segment, - expectedInactiveSrcCellSet.begin(), - expectedInactiveSrcCellSet.end()); + _getOrderedSrcCellIndexesForSrcCells(segment, + expectedInactiveSrcCellSet.begin(), + expectedInactiveSrcCellSet.end()); const std::vector expectedInactiveSynapseIdxs = - _getOrderedSynapseIndexesForSrcCells(segment, - expectedInactiveSrcCellSet.begin(), - expectedInactiveSrcCellSet.end()); + _getOrderedSynapseIndexesForSrcCells(segment, + expectedInactiveSrcCellSet.begin(), + expectedInactiveSrcCellSet.end()); const std::vector expectedActiveSrcCellIdxs = - _getOrderedSrcCellIndexesForSrcCells(segment, - expectedActiveSrcCellSet.begin(), - expectedActiveSrcCellSet.end()); + _getOrderedSrcCellIndexesForSrcCells(segment, + expectedActiveSrcCellSet.begin(), + expectedActiveSrcCellSet.end()); const std::vector expectedActiveSynapseIdxs = - _getOrderedSynapseIndexesForSrcCells(segment, - expectedActiveSrcCellSet.begin(), - expectedActiveSrcCellSet.end()); + _getOrderedSynapseIndexesForSrcCells(segment, + expectedActiveSrcCellSet.begin(), + expectedActiveSrcCellSet.end()); Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment( - segment, - synapsesSet, - inactiveSrcCellIdxs, - inactiveSynapseIdxs, - activeSrcCellIdxs, - activeSynapseIdxs); + segment, synapsesSet, inactiveSrcCellIdxs, inactiveSynapseIdxs, + activeSrcCellIdxs, activeSynapseIdxs); ASSERT_EQ(expectedSynapsesSet, synapsesSet); ASSERT_EQ(expectedInactiveSrcCellIdxs, inactiveSrcCellIdxs); @@ -361,60 +322,52 @@ TEST(Cells4Test, generateListsOfSynapsesToAdjustForAdaptSegmentWithoutNewSynapse ASSERT_EQ(expectedActiveSynapseIdxs, activeSynapseIdxs); } - - /* * Test Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment with active and * new synapses, but no inactive synapses. */ -TEST(Cells4Test, generateListsOfSynapsesToAdjustForAdaptSegmentWithoutInactiveSynapses) -{ +TEST(Cells4Test, + generateListsOfSynapsesToAdjustForAdaptSegmentWithoutInactiveSynapses) { Segment segment; - const std::set srcCells {88, 77, 66}; + const std::set srcCells{88, 77, 66}; - segment.addSynapses(srcCells, - 0.8/*initStrength*/, - 0.5/*permConnected*/); + segment.addSynapses(srcCells, 0.8 /*initStrength*/, 0.5 /*permConnected*/); - std::set synapsesSet {222, 111, 88, 77, 66}; + std::set synapsesSet{222, 111, 88, 77, 66}; - std::vector inactiveSrcCellIdxs {(UInt)-1}; - std::vector inactiveSynapseIdxs {(UInt)-1}; - std::vector activeSrcCellIdxs {(UInt)-1}; - std::vector activeSynapseIdxs {(UInt)-1}; + std::vector inactiveSrcCellIdxs{(UInt)-1}; + std::vector inactiveSynapseIdxs{(UInt)-1}; + std::vector activeSrcCellIdxs{(UInt)-1}; + std::vector activeSynapseIdxs{(UInt)-1}; - const std::set expectedInactiveSrcCellSet {}; - const std::set expectedActiveSrcCellSet {88, 77, 66}; - const std::set expectedSynapsesSet {222, 111}; + const std::set expectedInactiveSrcCellSet{}; + const std::set expectedActiveSrcCellSet{88, 77, 66}; + const std::set expectedSynapsesSet{222, 111}; const std::vector expectedInactiveSrcCellIdxs = - _getOrderedSrcCellIndexesForSrcCells(segment, - expectedInactiveSrcCellSet.begin(), - expectedInactiveSrcCellSet.end()); + _getOrderedSrcCellIndexesForSrcCells(segment, + expectedInactiveSrcCellSet.begin(), + expectedInactiveSrcCellSet.end()); const std::vector expectedInactiveSynapseIdxs = - _getOrderedSynapseIndexesForSrcCells(segment, - expectedInactiveSrcCellSet.begin(), - expectedInactiveSrcCellSet.end()); + _getOrderedSynapseIndexesForSrcCells(segment, + expectedInactiveSrcCellSet.begin(), + expectedInactiveSrcCellSet.end()); const std::vector expectedActiveSrcCellIdxs = - _getOrderedSrcCellIndexesForSrcCells(segment, - expectedActiveSrcCellSet.begin(), - expectedActiveSrcCellSet.end()); + _getOrderedSrcCellIndexesForSrcCells(segment, + expectedActiveSrcCellSet.begin(), + expectedActiveSrcCellSet.end()); const std::vector expectedActiveSynapseIdxs = - _getOrderedSynapseIndexesForSrcCells(segment, - expectedActiveSrcCellSet.begin(), - expectedActiveSrcCellSet.end()); + _getOrderedSynapseIndexesForSrcCells(segment, + expectedActiveSrcCellSet.begin(), + expectedActiveSrcCellSet.end()); Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment( - segment, - synapsesSet, - inactiveSrcCellIdxs, - inactiveSynapseIdxs, - activeSrcCellIdxs, - activeSynapseIdxs); + segment, synapsesSet, inactiveSrcCellIdxs, inactiveSynapseIdxs, + activeSrcCellIdxs, activeSynapseIdxs); ASSERT_EQ(expectedSynapsesSet, synapsesSet); ASSERT_EQ(expectedInactiveSrcCellIdxs, inactiveSrcCellIdxs); @@ -423,60 +376,52 @@ TEST(Cells4Test, generateListsOfSynapsesToAdjustForAdaptSegmentWithoutInactiveSy ASSERT_EQ(expectedActiveSynapseIdxs, activeSynapseIdxs); } - - /* * Test Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment without initial * synapses, and only new synapses. */ -TEST(Cells4Test, generateListsOfSynapsesToAdjustForAdaptSegmentWithoutInitialSynapses) -{ +TEST(Cells4Test, + generateListsOfSynapsesToAdjustForAdaptSegmentWithoutInitialSynapses) { Segment segment; - const std::set srcCells {}; + const std::set srcCells{}; - segment.addSynapses(srcCells, - 0.8/*initStrength*/, - 0.5/*permConnected*/); + segment.addSynapses(srcCells, 0.8 /*initStrength*/, 0.5 /*permConnected*/); - std::set synapsesSet {222, 111}; + std::set synapsesSet{222, 111}; - std::vector inactiveSrcCellIdxs {(UInt)-1}; - std::vector inactiveSynapseIdxs {(UInt)-1}; - std::vector activeSrcCellIdxs {(UInt)-1}; - std::vector activeSynapseIdxs {(UInt)-1}; + std::vector inactiveSrcCellIdxs{(UInt)-1}; + std::vector inactiveSynapseIdxs{(UInt)-1}; + std::vector activeSrcCellIdxs{(UInt)-1}; + std::vector activeSynapseIdxs{(UInt)-1}; - const std::set expectedInactiveSrcCellSet {}; - const std::set expectedActiveSrcCellSet {}; - const std::set expectedSynapsesSet {222, 111}; + const std::set expectedInactiveSrcCellSet{}; + const std::set expectedActiveSrcCellSet{}; + const std::set expectedSynapsesSet{222, 111}; const std::vector expectedInactiveSrcCellIdxs = - _getOrderedSrcCellIndexesForSrcCells(segment, - expectedInactiveSrcCellSet.begin(), - expectedInactiveSrcCellSet.end()); + _getOrderedSrcCellIndexesForSrcCells(segment, + expectedInactiveSrcCellSet.begin(), + expectedInactiveSrcCellSet.end()); const std::vector expectedInactiveSynapseIdxs = - _getOrderedSynapseIndexesForSrcCells(segment, - expectedInactiveSrcCellSet.begin(), - expectedInactiveSrcCellSet.end()); + _getOrderedSynapseIndexesForSrcCells(segment, + expectedInactiveSrcCellSet.begin(), + expectedInactiveSrcCellSet.end()); const std::vector expectedActiveSrcCellIdxs = - _getOrderedSrcCellIndexesForSrcCells(segment, - expectedActiveSrcCellSet.begin(), - expectedActiveSrcCellSet.end()); + _getOrderedSrcCellIndexesForSrcCells(segment, + expectedActiveSrcCellSet.begin(), + expectedActiveSrcCellSet.end()); const std::vector expectedActiveSynapseIdxs = - _getOrderedSynapseIndexesForSrcCells(segment, - expectedActiveSrcCellSet.begin(), - expectedActiveSrcCellSet.end()); + _getOrderedSynapseIndexesForSrcCells(segment, + expectedActiveSrcCellSet.begin(), + expectedActiveSrcCellSet.end()); Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment( - segment, - synapsesSet, - inactiveSrcCellIdxs, - inactiveSynapseIdxs, - activeSrcCellIdxs, - activeSynapseIdxs); + segment, synapsesSet, inactiveSrcCellIdxs, inactiveSynapseIdxs, + activeSrcCellIdxs, activeSynapseIdxs); ASSERT_EQ(expectedSynapsesSet, synapsesSet); ASSERT_EQ(expectedInactiveSrcCellIdxs, inactiveSrcCellIdxs); @@ -485,60 +430,52 @@ TEST(Cells4Test, generateListsOfSynapsesToAdjustForAdaptSegmentWithoutInitialSyn ASSERT_EQ(expectedActiveSynapseIdxs, activeSynapseIdxs); } - - /* * Test Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment with empty * update set. */ -TEST(Cells4Test, generateListsOfSynapsesToAdjustForAdaptSegmentWithEmptySynapseSet) -{ +TEST(Cells4Test, + generateListsOfSynapsesToAdjustForAdaptSegmentWithEmptySynapseSet) { Segment segment; - const std::set srcCells {88, 77, 66}; + const std::set srcCells{88, 77, 66}; - segment.addSynapses(srcCells, - 0.8/*initStrength*/, - 0.5/*permConnected*/); + segment.addSynapses(srcCells, 0.8 /*initStrength*/, 0.5 /*permConnected*/); - std::set synapsesSet {}; + std::set synapsesSet{}; - std::vector inactiveSrcCellIdxs {(UInt)-1}; - std::vector inactiveSynapseIdxs {(UInt)-1}; - std::vector activeSrcCellIdxs {(UInt)-1}; - std::vector activeSynapseIdxs {(UInt)-1}; + std::vector inactiveSrcCellIdxs{(UInt)-1}; + std::vector inactiveSynapseIdxs{(UInt)-1}; + std::vector activeSrcCellIdxs{(UInt)-1}; + std::vector activeSynapseIdxs{(UInt)-1}; - const std::set expectedInactiveSrcCellSet {88, 77, 66}; - const std::set expectedActiveSrcCellSet {}; - const std::set expectedSynapsesSet {}; + const std::set expectedInactiveSrcCellSet{88, 77, 66}; + const std::set expectedActiveSrcCellSet{}; + const std::set expectedSynapsesSet{}; const std::vector expectedInactiveSrcCellIdxs = - _getOrderedSrcCellIndexesForSrcCells(segment, - expectedInactiveSrcCellSet.begin(), - expectedInactiveSrcCellSet.end()); + _getOrderedSrcCellIndexesForSrcCells(segment, + expectedInactiveSrcCellSet.begin(), + expectedInactiveSrcCellSet.end()); const std::vector expectedInactiveSynapseIdxs = - _getOrderedSynapseIndexesForSrcCells(segment, - expectedInactiveSrcCellSet.begin(), - expectedInactiveSrcCellSet.end()); + _getOrderedSynapseIndexesForSrcCells(segment, + expectedInactiveSrcCellSet.begin(), + expectedInactiveSrcCellSet.end()); const std::vector expectedActiveSrcCellIdxs = - _getOrderedSrcCellIndexesForSrcCells(segment, - expectedActiveSrcCellSet.begin(), - expectedActiveSrcCellSet.end()); + _getOrderedSrcCellIndexesForSrcCells(segment, + expectedActiveSrcCellSet.begin(), + expectedActiveSrcCellSet.end()); const std::vector expectedActiveSynapseIdxs = - _getOrderedSynapseIndexesForSrcCells(segment, - expectedActiveSrcCellSet.begin(), - expectedActiveSrcCellSet.end()); + _getOrderedSynapseIndexesForSrcCells(segment, + expectedActiveSrcCellSet.begin(), + expectedActiveSrcCellSet.end()); Cells4::_generateListsOfSynapsesToAdjustForAdaptSegment( - segment, - synapsesSet, - inactiveSrcCellIdxs, - inactiveSynapseIdxs, - activeSrcCellIdxs, - activeSynapseIdxs); + segment, synapsesSet, inactiveSrcCellIdxs, inactiveSynapseIdxs, + activeSrcCellIdxs, activeSynapseIdxs); ASSERT_EQ(expectedSynapsesSet, synapsesSet); ASSERT_EQ(expectedInactiveSrcCellIdxs, inactiveSrcCellIdxs); diff --git a/src/test/unit/algorithms/CondProbTableTest.cpp b/src/test/unit/algorithms/CondProbTableTest.cpp index 73afd4bde7..2a47045910 100644 --- a/src/test/unit/algorithms/CondProbTableTest.cpp +++ b/src/test/unit/algorithms/CondProbTableTest.cpp @@ -21,294 +21,288 @@ */ /** @file -* Notes -*/ + * Notes + */ +#include #include #include -#include +// clang-format off #include //FIXME this include is fix for the include below in boost .64, remove later when fixed upstream #include +// clang-format on #include "gtest/gtest.h" - using namespace std; using namespace boost; using namespace nupic; - namespace { - - // Size of the table we construct - Size numRows() {return 4;} - Size numCols() {return 3;} - - static vector makeRow(Real a, Real b, Real c) - { - vector result(3); - result[0] = a; - result[1] = b; - result[2] = c; - - return result; - } - - static vector makeCol(Real a, Real b, Real c, Real d) - { - vector result(4); - result[0] = a; - result[1] = b; - result[2] = c; - result[3] = d; - - return result; + +// Size of the table we construct +Size numRows() { return 4; } +Size numCols() { return 3; } + +static vector makeRow(Real a, Real b, Real c) { + vector result(3); + result[0] = a; + result[1] = b; + result[2] = c; + + return result; +} + +static vector makeCol(Real a, Real b, Real c, Real d) { + vector result(4); + result[0] = a; + result[1] = b; + result[2] = c; + result[3] = d; + + return result; +} + +void testVectors(const string &testName, const vector &v1, + const vector &v2) { + stringstream s1, s2; + s1 << v1; + s2 << v2; + EXPECT_EQ(s1.str(), s2.str()); +} + +void testTable(const string &testName, CondProbTable &table, + const vector> &rows) { + + // Test the numRows(), numCols() calls + ASSERT_EQ(numRows(), table.numRows()); + ASSERT_EQ(numCols(), table.numColumns()); + + // See if they got added right + vector testRow(numCols()); + for (Size i = 0; i < numRows(); i++) { + stringstream ss; + ss << "updateRow " << i; + + table.getRow((UInt)i, testRow); + ASSERT_NO_FATAL_FAILURE(testVectors(testName + ss.str(), rows[i], testRow)); } - void testVectors(const string& testName, const vector& v1, - const vector& v2) + // -------------------------------------------------------------------- + // Try out normal inference + vector expValue; + vector output(numRows()); + + // Row 0 matches row 3, so we get half and half hits on those rows + table.inferRow(rows[0], output, CondProbTable::inferMarginal); + ASSERT_NO_FATAL_FAILURE( + testVectors(testName + "row 0 infer", + makeCol((Real).16, (Real)0, (Real)0, (Real).24), output)); + + // Row 1 matches only row 1 + table.inferRow(rows[1], output, CondProbTable::inferMarginal); + ASSERT_NO_FATAL_FAILURE(testVectors( + testName + "row 1 infer", makeCol((Real)0, 1, (Real)0, (Real)0), output)); + + // Row 2 matches only row 2 and 3 + table.inferRow(rows[2], output, CondProbTable::inferMarginal); + ASSERT_NO_FATAL_FAILURE( + testVectors(testName + "row 2 infer", + makeCol((Real)0, (Real)0, (Real).36, (Real).24), output)); + + // Row 3 matches row 0 & row 2 halfway, and row 3 exactly + table.inferRow(rows[3], output, CondProbTable::inferMarginal); + ASSERT_NO_FATAL_FAILURE( + testVectors(testName + "row 3 infer", + makeCol((Real).24, (Real)0, (Real).24, (Real).52), output)); + + // -------------------------------------------------------------------- + // Try out inferEvidence inference + + // Row 0 matches row 0 and half row 3, so we get half and half hits on those + // rows + table.inferRow(rows[0], output, CondProbTable::inferRowEvidence); + ASSERT_NO_FATAL_FAILURE( + testVectors(testName + "row 0 inferEvidence", + makeCol((Real).4, (Real)0, (Real)0, (Real).24), output)); + + // Row 1 matches only row 1 + table.inferRow(rows[1], output, CondProbTable::inferRowEvidence); + ASSERT_NO_FATAL_FAILURE(testVectors(testName + "row 1 inferEvidence", + makeCol((Real)0, 1, (Real)0, (Real)0), + output)); + + // Row 2 matches only row 2 and half row 3 + table.inferRow(rows[2], output, CondProbTable::inferRowEvidence); + ASSERT_NO_FATAL_FAILURE( + testVectors(testName + "row 2 inferEvidence", + makeCol((Real)0, (Real)0, (Real).6, (Real).24), output)); + + // Row 3 matches row 0 & row 2 halfway, and row 3 exactly + table.inferRow(rows[3], output, CondProbTable::inferRowEvidence); + ASSERT_NO_FATAL_FAILURE( + testVectors(testName + "row 3 inferEvidence", + makeCol((Real).6, (Real)0, (Real).4, (Real).52), output)); + + // -------------------------------------------------------------------- + // Try out inferMaxProd inference + + // Row 0 matches row 0 and half row 3, so we get half and half hits on those + // rows + table.inferRow(rows[0], output, CondProbTable::inferMaxProd); + ASSERT_NO_FATAL_FAILURE( + testVectors(testName + "row 0 inferMaxProd", + makeCol((Real).16, (Real)0, (Real)0, (Real).24), output)); + + // Row 1 matches only row 1 + table.inferRow(rows[1], output, CondProbTable::inferMaxProd); + ASSERT_NO_FATAL_FAILURE(testVectors(testName + "row 1 inferMaxProd", + makeCol((Real)0, 1, (Real)0, (Real)0), + output)); + + // Row 2 matches only row 2 and half row 3 + table.inferRow(rows[2], output, CondProbTable::inferMaxProd); + ASSERT_NO_FATAL_FAILURE( + testVectors(testName + "row 2 inferMaxProd", + makeCol((Real)0, (Real)0, (Real).36, (Real).24), output)); + + // Row 3 matches row 0 & row 2 halfway, and row 3 exactly + table.inferRow(rows[3], output, CondProbTable::inferMaxProd); + ASSERT_NO_FATAL_FAILURE( + testVectors(testName + "row 3 inferMaxProd", + makeCol((Real).24, (Real)0, (Real).24, (Real).36), output)); + + // -------------------------------------------------------------------- + // Try out inferViterbi inference + + // Row 0 matches row 0 and half row 3, so we get half and half hits on those + // rows + table.inferRow(rows[0], output, CondProbTable::inferViterbi); + ASSERT_NO_FATAL_FAILURE( + testVectors(testName + "row 0 inferViterbi", + makeCol((Real)0, (Real)0, (Real)0, (Real).4), output)); + + // Row 1 matches only row 1 + table.inferRow(rows[1], output, CondProbTable::inferViterbi); + ASSERT_NO_FATAL_FAILURE(testVectors(testName + "row 1 inferViterbi", + makeCol((Real)0, 1, (Real)0, (Real)0), + output)); + + // Row 2 matches only row 2 and half row 3 + table.inferRow(rows[2], output, CondProbTable::inferViterbi); + ASSERT_NO_FATAL_FAILURE( + testVectors(testName + "row 2 inferViterbi", + makeCol((Real)0, (Real)0, (Real).6, (Real)0), output)); + + // Row 3 matches row 0 & row 2 halfway, and row 3 exactly + table.inferRow(rows[3], output, CondProbTable::inferViterbi); + ASSERT_NO_FATAL_FAILURE( + testVectors(testName + "row 3 inferViterbi", + makeCol((Real)0, (Real)0, (Real).4, (Real).6), output)); + + // Add a row a second time, the row should double in value + table.updateRow(0, rows[0]); + expValue = rows[0]; + for (Size i = 0; i < numCols(); i++) + expValue[i] *= 2; + table.getRow(0, testRow); + ASSERT_NO_FATAL_FAILURE( + testVectors(testName + "row 0 update#2", expValue, testRow)); +} + +//---------------------------------------------------------------------- +TEST(CondProbTableTest, Basic) { + // Our 4 rows + vector> rows; + rows.resize(numRows()); + rows[0] = makeRow((Real)0.0, (Real)0.4, (Real)0.0); + rows[1] = makeRow((Real)1.0, (Real)0.0, (Real)0.0); + rows[2] = makeRow((Real)0.0, (Real)0.0, (Real)0.6); + rows[3] = makeRow((Real)0.0, (Real)0.6, (Real)0.4); + + // Test constructing without # of columns { - stringstream s1, s2; - s1 << v1; - s2 << v2; - EXPECT_EQ(s1.str(), s2.str()); + CondProbTable table; + + // Add the 4 rows + for (Size i = 0; i < numRows(); i++) + table.updateRow((UInt)i, rows[i]); + + // Test it + ASSERT_NO_FATAL_FAILURE(testTable("Dynamic columns:", table, rows)); } - void testTable(const string& testName, CondProbTable& table, - const vector > & rows) + // Test constructing and growing the columns dynamically { - - // Test the numRows(), numCols() calls - ASSERT_EQ(numRows(), table.numRows()); - ASSERT_EQ(numCols(), table.numColumns()); - - // See if they got added right - vector testRow(numCols()); - for (Size i=0; i expValue; - vector output(numRows()); - - // Row 0 matches row 3, so we get half and half hits on those rows - table.inferRow (rows[0], output, CondProbTable::inferMarginal); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 0 infer", makeCol((Real).16, (Real)0, (Real)0, (Real).24), output)); - - // Row 1 matches only row 1 - table.inferRow (rows[1], output, CondProbTable::inferMarginal); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 1 infer", makeCol((Real)0, 1, (Real)0, (Real)0), output)); - - // Row 2 matches only row 2 and 3 - table.inferRow (rows[2], output, CondProbTable::inferMarginal); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 2 infer", makeCol((Real)0, (Real)0, (Real).36, (Real).24), output)); + CondProbTable table; - // Row 3 matches row 0 & row 2 halfway, and row 3 exactly - table.inferRow (rows[3], output, CondProbTable::inferMarginal); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 3 infer", makeCol((Real).24, (Real)0, (Real).24, (Real).52), output)); - - - // -------------------------------------------------------------------- - // Try out inferEvidence inference - - // Row 0 matches row 0 and half row 3, so we get half and half hits on those rows - table.inferRow (rows[0], output, CondProbTable::inferRowEvidence); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 0 inferEvidence", makeCol((Real).4, (Real)0, (Real)0, (Real).24), output)); - - // Row 1 matches only row 1 - table.inferRow (rows[1], output, CondProbTable::inferRowEvidence); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 1 inferEvidence", makeCol((Real)0, 1, (Real)0, (Real)0), output)); + // Add the 2nd row first which has just 1 column + vector row1(1); + row1[0] = rows[1][0]; + table.updateRow(1, row1); - // Row 2 matches only row 2 and half row 3 - table.inferRow (rows[2], output, CondProbTable::inferRowEvidence); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 2 inferEvidence", makeCol((Real)0, (Real)0, (Real).6, (Real).24), output)); + // Add the first row first with just 2 columns + vector row0(2); + row0[0] = rows[0][0]; + row0[1] = rows[0][1]; + table.updateRow(0, row0); - // Row 3 matches row 0 & row 2 halfway, and row 3 exactly - table.inferRow (rows[3], output, CondProbTable::inferRowEvidence); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 3 inferEvidence", makeCol((Real).6, (Real)0, (Real).4, (Real).52), output)); - - - // -------------------------------------------------------------------- - // Try out inferMaxProd inference - - // Row 0 matches row 0 and half row 3, so we get half and half hits on those rows - table.inferRow (rows[0], output, CondProbTable::inferMaxProd); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 0 inferMaxProd", makeCol((Real).16, (Real)0, (Real)0, (Real).24), output)); - - // Row 1 matches only row 1 - table.inferRow (rows[1], output, CondProbTable::inferMaxProd); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 1 inferMaxProd", makeCol((Real)0, 1, (Real)0, (Real)0), output)); + for (Size i = 2; i < numRows(); i++) + table.updateRow((UInt)i, rows[i]); - // Row 2 matches only row 2 and half row 3 - table.inferRow (rows[2], output, CondProbTable::inferMaxProd); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 2 inferMaxProd", makeCol((Real)0, (Real)0, (Real).36, (Real).24), output)); + // Test it + ASSERT_NO_FATAL_FAILURE(testTable("Growing columns:", table, rows)); + } - // Row 3 matches row 0 & row 2 halfway, and row 3 exactly - table.inferRow (rows[3], output, CondProbTable::inferMaxProd); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 3 inferMaxProd", makeCol((Real).24, (Real)0, (Real).24, (Real).36), output)); - - - // -------------------------------------------------------------------- - // Try out inferViterbi inference - - // Row 0 matches row 0 and half row 3, so we get half and half hits on those rows - table.inferRow (rows[0], output, CondProbTable::inferViterbi); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 0 inferViterbi", makeCol((Real)0, (Real)0, (Real)0, (Real).4), output)); - - // Row 1 matches only row 1 - table.inferRow (rows[1], output, CondProbTable::inferViterbi); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 1 inferViterbi", makeCol((Real)0, 1, (Real)0, (Real)0), output)); + // Make a table with 3 columns + { + CondProbTable table((UInt)numCols()); - // Row 2 matches only row 2 and half row 3 - table.inferRow (rows[2], output, CondProbTable::inferViterbi); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 2 inferViterbi", makeCol((Real)0, (Real)0, (Real).6, (Real)0), output)); + // Add the 4 rows + for (Size i = 0; i < numRows(); i++) + table.updateRow((UInt)i, rows[i]); - // Row 3 matches row 0 & row 2 halfway, and row 3 exactly - table.inferRow (rows[3], output, CondProbTable::inferViterbi); - ASSERT_NO_FATAL_FAILURE( - testVectors(testName+"row 3 inferViterbi", makeCol((Real)0, (Real)0, (Real).4, (Real).6), output)); - - - // Add a row a second time, the row should double in value - table.updateRow(0, rows[0]); - expValue = rows[0]; - for (Size i=0; i > rows; - rows.resize(numRows()); - rows[0] = makeRow((Real)0.0, (Real)0.4, (Real)0.0); - rows[1] = makeRow((Real)1.0, (Real)0.0, (Real)0.0); - rows[2] = makeRow((Real)0.0, (Real)0.0, (Real)0.6); - rows[3] = makeRow((Real)0.0, (Real)0.6, (Real)0.4); - - // Test constructing without # of columns - { - CondProbTable table; - - // Add the 4 rows - for (Size i=0; i row1(1); - row1[0] = rows[1][0]; - table.updateRow(1, row1); - - // Add the first row first with just 2 columns - vector row0(2); - row0[0] = rows[0][0]; - row0[1] = rows[0][1]; - table.updateRow(0, row0); - - for (Size i=2; i #include #include -#include "gtest/gtest.h" using namespace std; using namespace nupic; @@ -37,531 +37,480 @@ using namespace nupic::algorithms::connections; namespace { - void setupSampleConnections(Connections &connections) - { - // Cell with 1 segment. - // Segment with: - // - 1 connected synapse: active - // - 2 matching synapses - const Segment segment1_1 = connections.createSegment(10); - connections.createSynapse(segment1_1, 150, 0.85); - connections.createSynapse(segment1_1, 151, 0.15); - - // Cell with 2 segments. - // Segment with: - // - 2 connected synapses: 2 active - // - 3 matching synapses: 3 active - const Segment segment2_1 = connections.createSegment(20); - connections.createSynapse(segment2_1, 80, 0.85); - connections.createSynapse(segment2_1, 81, 0.85); - Synapse synapse = connections.createSynapse(segment2_1, 82, 0.85); - connections.updateSynapsePermanence(synapse, 0.15); - - // Segment with: - // - 2 connected synapses: 1 active, 1 inactive - // - 3 matching synapses: 2 active, 1 inactive - // - 1 non-matching synapse: 1 active - const Segment segment2_2 = connections.createSegment(20); - connections.createSynapse(segment2_2, 50, 0.85); - connections.createSynapse(segment2_2, 51, 0.85); - connections.createSynapse(segment2_2, 52, 0.15); - connections.createSynapse(segment2_2, 53, 0.05); - - // Cell with one segment. - // Segment with: - // - 1 non-matching synapse: 1 active - const Segment segment3_1 = connections.createSegment(30); - connections.createSynapse(segment3_1, 53, 0.05); - } - - void computeSampleActivity(Connections &connections) - { - vector input = {50, 52, 53, - 80, 81, 82, - 150, 151}; - - vector numActiveConnectedSynapsesForSegment( +void setupSampleConnections(Connections &connections) { + // Cell with 1 segment. + // Segment with: + // - 1 connected synapse: active + // - 2 matching synapses + const Segment segment1_1 = connections.createSegment(10); + connections.createSynapse(segment1_1, 150, 0.85); + connections.createSynapse(segment1_1, 151, 0.15); + + // Cell with 2 segments. + // Segment with: + // - 2 connected synapses: 2 active + // - 3 matching synapses: 3 active + const Segment segment2_1 = connections.createSegment(20); + connections.createSynapse(segment2_1, 80, 0.85); + connections.createSynapse(segment2_1, 81, 0.85); + Synapse synapse = connections.createSynapse(segment2_1, 82, 0.85); + connections.updateSynapsePermanence(synapse, 0.15); + + // Segment with: + // - 2 connected synapses: 1 active, 1 inactive + // - 3 matching synapses: 2 active, 1 inactive + // - 1 non-matching synapse: 1 active + const Segment segment2_2 = connections.createSegment(20); + connections.createSynapse(segment2_2, 50, 0.85); + connections.createSynapse(segment2_2, 51, 0.85); + connections.createSynapse(segment2_2, 52, 0.15); + connections.createSynapse(segment2_2, 53, 0.05); + + // Cell with one segment. + // Segment with: + // - 1 non-matching synapse: 1 active + const Segment segment3_1 = connections.createSegment(30); + connections.createSynapse(segment3_1, 53, 0.05); +} + +void computeSampleActivity(Connections &connections) { + vector input = {50, 52, 53, 80, 81, 82, 150, 151}; + + vector numActiveConnectedSynapsesForSegment( connections.segmentFlatListLength(), 0); - vector numActivePotentialSynapsesForSegment( + vector numActivePotentialSynapsesForSegment( connections.segmentFlatListLength(), 0); - connections.computeActivity(numActiveConnectedSynapsesForSegment, - numActivePotentialSynapsesForSegment, - input, - 0.5); - } + connections.computeActivity(numActiveConnectedSynapsesForSegment, + numActivePotentialSynapsesForSegment, input, 0.5); +} - /** - * Creates a segment, and makes sure that it got created on the correct cell. - */ - TEST(ConnectionsTest, testCreateSegment) - { - Connections connections(1024); - UInt32 cell = 10; +/** + * Creates a segment, and makes sure that it got created on the correct cell. + */ +TEST(ConnectionsTest, testCreateSegment) { + Connections connections(1024); + UInt32 cell = 10; - Segment segment1 = connections.createSegment(cell); - ASSERT_EQ(cell, connections.cellForSegment(segment1)); + Segment segment1 = connections.createSegment(cell); + ASSERT_EQ(cell, connections.cellForSegment(segment1)); - Segment segment2 = connections.createSegment(cell); - ASSERT_EQ(cell, connections.cellForSegment(segment2)); + Segment segment2 = connections.createSegment(cell); + ASSERT_EQ(cell, connections.cellForSegment(segment2)); - vector segments = connections.segmentsForCell(cell); - ASSERT_EQ(segments.size(), 2); + vector segments = connections.segmentsForCell(cell); + ASSERT_EQ(segments.size(), 2); - ASSERT_EQ(segment1, segments[0]); - ASSERT_EQ(segment2, segments[1]); - } + ASSERT_EQ(segment1, segments[0]); + ASSERT_EQ(segment2, segments[1]); +} - /** - * Creates a synapse, and makes sure that it got created on the correct - * segment, and that its data was correctly stored. - */ - TEST(ConnectionsTest, testCreateSynapse) - { - Connections connections(1024); - UInt32 cell = 10; - Segment segment = connections.createSegment(cell); +/** + * Creates a synapse, and makes sure that it got created on the correct + * segment, and that its data was correctly stored. + */ +TEST(ConnectionsTest, testCreateSynapse) { + Connections connections(1024); + UInt32 cell = 10; + Segment segment = connections.createSegment(cell); - Synapse synapse1 = connections.createSynapse(segment, 50, 0.34); - ASSERT_EQ(segment, connections.segmentForSynapse(synapse1)); + Synapse synapse1 = connections.createSynapse(segment, 50, 0.34); + ASSERT_EQ(segment, connections.segmentForSynapse(synapse1)); - Synapse synapse2 = connections.createSynapse(segment, 150, 0.48); - ASSERT_EQ(segment, connections.segmentForSynapse(synapse2)); + Synapse synapse2 = connections.createSynapse(segment, 150, 0.48); + ASSERT_EQ(segment, connections.segmentForSynapse(synapse2)); - vector synapses = connections.synapsesForSegment(segment); - ASSERT_EQ(synapses.size(), 2); + vector synapses = connections.synapsesForSegment(segment); + ASSERT_EQ(synapses.size(), 2); - ASSERT_EQ(synapse1, synapses[0]); - ASSERT_EQ(synapse2, synapses[1]); + ASSERT_EQ(synapse1, synapses[0]); + ASSERT_EQ(synapse2, synapses[1]); - SynapseData synapseData1 = connections.dataForSynapse(synapses[0]); - ASSERT_EQ(50, synapseData1.presynapticCell); - ASSERT_NEAR((Permanence)0.34, synapseData1.permanence, EPSILON); + SynapseData synapseData1 = connections.dataForSynapse(synapses[0]); + ASSERT_EQ(50, synapseData1.presynapticCell); + ASSERT_NEAR((Permanence)0.34, synapseData1.permanence, EPSILON); - SynapseData synapseData2 = connections.dataForSynapse(synapses[1]); - ASSERT_EQ(synapseData2.presynapticCell, 150); - ASSERT_NEAR((Permanence)0.48, synapseData2.permanence, EPSILON); - } + SynapseData synapseData2 = connections.dataForSynapse(synapses[1]); + ASSERT_EQ(synapseData2.presynapticCell, 150); + ASSERT_NEAR((Permanence)0.48, synapseData2.permanence, EPSILON); +} - /** - * Creates a segment, destroys it, and makes sure it got destroyed along with - * all of its synapses. - */ - TEST(ConnectionsTest, testDestroySegment) - { - Connections connections(1024); +/** + * Creates a segment, destroys it, and makes sure it got destroyed along with + * all of its synapses. + */ +TEST(ConnectionsTest, testDestroySegment) { + Connections connections(1024); - /* segment1*/ connections.createSegment(10); - Segment segment2 = connections.createSegment(20); - /* segment3*/ connections.createSegment(20); - /* segment4*/ connections.createSegment(30); + /* segment1*/ connections.createSegment(10); + Segment segment2 = connections.createSegment(20); + /* segment3*/ connections.createSegment(20); + /* segment4*/ connections.createSegment(30); - connections.createSynapse(segment2, 80, 0.85); - connections.createSynapse(segment2, 81, 0.85); - connections.createSynapse(segment2, 82, 0.15); + connections.createSynapse(segment2, 80, 0.85); + connections.createSynapse(segment2, 81, 0.85); + connections.createSynapse(segment2, 82, 0.15); - ASSERT_EQ(4, connections.numSegments()); - ASSERT_EQ(3, connections.numSynapses()); + ASSERT_EQ(4, connections.numSegments()); + ASSERT_EQ(3, connections.numSynapses()); - connections.destroySegment(segment2); + connections.destroySegment(segment2); - ASSERT_EQ(3, connections.numSegments()); - ASSERT_EQ(0, connections.numSynapses()); + ASSERT_EQ(3, connections.numSegments()); + ASSERT_EQ(0, connections.numSynapses()); - vector numActiveConnectedSynapsesForSegment( + vector numActiveConnectedSynapsesForSegment( connections.segmentFlatListLength(), 0); - vector numActivePotentialSynapsesForSegment( + vector numActivePotentialSynapsesForSegment( connections.segmentFlatListLength(), 0); - connections.computeActivity(numActiveConnectedSynapsesForSegment, - numActivePotentialSynapsesForSegment, - {80, 81, 82}, - 0.5); + connections.computeActivity(numActiveConnectedSynapsesForSegment, + numActivePotentialSynapsesForSegment, + {80, 81, 82}, 0.5); - ASSERT_EQ(0, numActiveConnectedSynapsesForSegment[segment2]); - ASSERT_EQ(0, numActivePotentialSynapsesForSegment[segment2]); - } + ASSERT_EQ(0, numActiveConnectedSynapsesForSegment[segment2]); + ASSERT_EQ(0, numActivePotentialSynapsesForSegment[segment2]); +} - /** - * Creates a segment, creates a number of synapses on it, destroys a synapse, - * and makes sure it got destroyed. - */ - TEST(ConnectionsTest, testDestroySynapse) - { - Connections connections(1024); +/** + * Creates a segment, creates a number of synapses on it, destroys a synapse, + * and makes sure it got destroyed. + */ +TEST(ConnectionsTest, testDestroySynapse) { + Connections connections(1024); - Segment segment = connections.createSegment(20); - /* synapse1*/ connections.createSynapse(segment, 80, 0.85); - Synapse synapse2 = connections.createSynapse(segment, 81, 0.85); - /* synapse3*/ connections.createSynapse(segment, 82, 0.15); + Segment segment = connections.createSegment(20); + /* synapse1*/ connections.createSynapse(segment, 80, 0.85); + Synapse synapse2 = connections.createSynapse(segment, 81, 0.85); + /* synapse3*/ connections.createSynapse(segment, 82, 0.15); - ASSERT_EQ(3, connections.numSynapses()); + ASSERT_EQ(3, connections.numSynapses()); - connections.destroySynapse(synapse2); + connections.destroySynapse(synapse2); - ASSERT_EQ(2, connections.numSynapses()); - ASSERT_EQ(2, connections.synapsesForSegment(segment).size()); + ASSERT_EQ(2, connections.numSynapses()); + ASSERT_EQ(2, connections.synapsesForSegment(segment).size()); - vector numActiveConnectedSynapsesForSegment( + vector numActiveConnectedSynapsesForSegment( connections.segmentFlatListLength(), 0); - vector numActivePotentialSynapsesForSegment( + vector numActivePotentialSynapsesForSegment( connections.segmentFlatListLength(), 0); - connections.computeActivity(numActiveConnectedSynapsesForSegment, - numActivePotentialSynapsesForSegment, - {80, 81, 82}, - 0.5); - - ASSERT_EQ(1, numActiveConnectedSynapsesForSegment[segment]); - ASSERT_EQ(2, numActivePotentialSynapsesForSegment[segment]); - } - - /** - * Creates segments and synapses, then destroys segments and synapses on - * either side of them and verifies that existing Segment and Synapse - * instances still point to the same segment / synapse as before. - */ - TEST(ConnectionsTest, PathsNotInvalidatedByOtherDestroys) - { - Connections connections(1024); - - Segment segment1 = connections.createSegment(11); - /* segment2*/ connections.createSegment(12); - - Segment segment3 = connections.createSegment(13); - Synapse synapse1 = connections.createSynapse(segment3, 201, 0.85); - /* synapse2*/ connections.createSynapse(segment3, 202, 0.85); - Synapse synapse3 = connections.createSynapse(segment3, 203, 0.85); - /* synapse4*/ connections.createSynapse(segment3, 204, 0.85); - Synapse synapse5 = connections.createSynapse(segment3, 205, 0.85); - - /* segment4*/ connections.createSegment(14); - Segment segment5 = connections.createSegment(15); - - ASSERT_EQ(203, connections.dataForSynapse(synapse3).presynapticCell); - connections.destroySynapse(synapse1); - EXPECT_EQ(203, connections.dataForSynapse(synapse3).presynapticCell); - connections.destroySynapse(synapse5); - EXPECT_EQ(203, connections.dataForSynapse(synapse3).presynapticCell); - - connections.destroySegment(segment1); - EXPECT_EQ(3, connections.synapsesForSegment(segment3).size()); - connections.destroySegment(segment5); - EXPECT_EQ(3, connections.synapsesForSegment(segment3).size()); - EXPECT_EQ(203, connections.dataForSynapse(synapse3).presynapticCell); - } + connections.computeActivity(numActiveConnectedSynapsesForSegment, + numActivePotentialSynapsesForSegment, + {80, 81, 82}, 0.5); + + ASSERT_EQ(1, numActiveConnectedSynapsesForSegment[segment]); + ASSERT_EQ(2, numActivePotentialSynapsesForSegment[segment]); +} + +/** + * Creates segments and synapses, then destroys segments and synapses on + * either side of them and verifies that existing Segment and Synapse + * instances still point to the same segment / synapse as before. + */ +TEST(ConnectionsTest, PathsNotInvalidatedByOtherDestroys) { + Connections connections(1024); + + Segment segment1 = connections.createSegment(11); + /* segment2*/ connections.createSegment(12); + + Segment segment3 = connections.createSegment(13); + Synapse synapse1 = connections.createSynapse(segment3, 201, 0.85); + /* synapse2*/ connections.createSynapse(segment3, 202, 0.85); + Synapse synapse3 = connections.createSynapse(segment3, 203, 0.85); + /* synapse4*/ connections.createSynapse(segment3, 204, 0.85); + Synapse synapse5 = connections.createSynapse(segment3, 205, 0.85); + + /* segment4*/ connections.createSegment(14); + Segment segment5 = connections.createSegment(15); + + ASSERT_EQ(203, connections.dataForSynapse(synapse3).presynapticCell); + connections.destroySynapse(synapse1); + EXPECT_EQ(203, connections.dataForSynapse(synapse3).presynapticCell); + connections.destroySynapse(synapse5); + EXPECT_EQ(203, connections.dataForSynapse(synapse3).presynapticCell); + + connections.destroySegment(segment1); + EXPECT_EQ(3, connections.synapsesForSegment(segment3).size()); + connections.destroySegment(segment5); + EXPECT_EQ(3, connections.synapsesForSegment(segment3).size()); + EXPECT_EQ(203, connections.dataForSynapse(synapse3).presynapticCell); +} + +/** + * Destroy a segment that has a destroyed synapse and a non-destroyed synapse. + * Make sure nothing gets double-destroyed. + */ +TEST(ConnectionsTest, DestroySegmentWithDestroyedSynapses) { + Connections connections(1024); - /** - * Destroy a segment that has a destroyed synapse and a non-destroyed synapse. - * Make sure nothing gets double-destroyed. - */ - TEST(ConnectionsTest, DestroySegmentWithDestroyedSynapses) - { - Connections connections(1024); + Segment segment1 = connections.createSegment(11); + Segment segment2 = connections.createSegment(12); - Segment segment1 = connections.createSegment(11); - Segment segment2 = connections.createSegment(12); + /* synapse1_1*/ connections.createSynapse(segment1, 101, 0.85); + Synapse synapse2_1 = connections.createSynapse(segment2, 201, 0.85); + /* synapse2_2*/ connections.createSynapse(segment2, 202, 0.85); - /* synapse1_1*/ connections.createSynapse(segment1, 101, 0.85); - Synapse synapse2_1 = connections.createSynapse(segment2, 201, 0.85); - /* synapse2_2*/ connections.createSynapse(segment2, 202, 0.85); + ASSERT_EQ(3, connections.numSynapses()); - ASSERT_EQ(3, connections.numSynapses()); + connections.destroySynapse(synapse2_1); - connections.destroySynapse(synapse2_1); + ASSERT_EQ(2, connections.numSegments()); + ASSERT_EQ(2, connections.numSynapses()); - ASSERT_EQ(2, connections.numSegments()); - ASSERT_EQ(2, connections.numSynapses()); + connections.destroySegment(segment2); - connections.destroySegment(segment2); + EXPECT_EQ(1, connections.numSegments()); + EXPECT_EQ(1, connections.numSynapses()); +} - EXPECT_EQ(1, connections.numSegments()); - EXPECT_EQ(1, connections.numSynapses()); - } - - /** - * Destroy a segment that has a destroyed synapse and a non-destroyed synapse. - * Create a new segment in the same place. Make sure its synapse count is - * correct. - */ - TEST(ConnectionsTest, ReuseSegmentWithDestroyedSynapses) - { - Connections connections(1024); +/** + * Destroy a segment that has a destroyed synapse and a non-destroyed synapse. + * Create a new segment in the same place. Make sure its synapse count is + * correct. + */ +TEST(ConnectionsTest, ReuseSegmentWithDestroyedSynapses) { + Connections connections(1024); - Segment segment = connections.createSegment(11); + Segment segment = connections.createSegment(11); - Synapse synapse1 = connections.createSynapse(segment, 201, 0.85); - /* synapse2*/ connections.createSynapse(segment, 202, 0.85); + Synapse synapse1 = connections.createSynapse(segment, 201, 0.85); + /* synapse2*/ connections.createSynapse(segment, 202, 0.85); - connections.destroySynapse(synapse1); + connections.destroySynapse(synapse1); - ASSERT_EQ(1, connections.numSynapses(segment)); + ASSERT_EQ(1, connections.numSynapses(segment)); - connections.destroySegment(segment); - Segment reincarnated = connections.createSegment(11); + connections.destroySegment(segment); + Segment reincarnated = connections.createSegment(11); - EXPECT_EQ(0, connections.numSynapses(reincarnated)); - EXPECT_EQ(0, connections.synapsesForSegment(reincarnated).size()); - } + EXPECT_EQ(0, connections.numSynapses(reincarnated)); + EXPECT_EQ(0, connections.synapsesForSegment(reincarnated).size()); +} - /** - * Creates a synapse and updates its permanence, and makes sure that its - * data was correctly updated. - */ - TEST(ConnectionsTest, testUpdateSynapsePermanence) - { - Connections connections(1024); - Segment segment = connections.createSegment(10); - Synapse synapse = connections.createSynapse(segment, 50, 0.34); +/** + * Creates a synapse and updates its permanence, and makes sure that its + * data was correctly updated. + */ +TEST(ConnectionsTest, testUpdateSynapsePermanence) { + Connections connections(1024); + Segment segment = connections.createSegment(10); + Synapse synapse = connections.createSynapse(segment, 50, 0.34); - connections.updateSynapsePermanence(synapse, 0.21); + connections.updateSynapsePermanence(synapse, 0.21); - SynapseData synapseData = connections.dataForSynapse(synapse); - ASSERT_NEAR(synapseData.permanence, (Real)0.21, EPSILON); - } + SynapseData synapseData = connections.dataForSynapse(synapse); + ASSERT_NEAR(synapseData.permanence, (Real)0.21, EPSILON); +} - /** - * Creates a sample set of connections, and makes sure that computing the - * activity for a collection of cells with no activity returns the right - * activity data. - */ - TEST(ConnectionsTest, testComputeActivity) - { - Connections connections(1024); - - // Cell with 1 segment. - // Segment with: - // - 1 connected synapse: active - // - 2 matching synapses: active - const Segment segment1_1 = connections.createSegment(10); - connections.createSynapse(segment1_1, 150, 0.85); - connections.createSynapse(segment1_1, 151, 0.15); - - // Cell with 1 segments. - // Segment with: - // - 2 connected synapses: 2 active - // - 3 matching synapses: 3 active - const Segment segment2_1 = connections.createSegment(20); - connections.createSynapse(segment2_1, 80, 0.85); - connections.createSynapse(segment2_1, 81, 0.85); - Synapse synapse = connections.createSynapse(segment2_1, 82, 0.85); - connections.updateSynapsePermanence(synapse, 0.15); - - vector input = {50, 52, 53, - 80, 81, 82, - 150, 151}; - - vector numActiveConnectedSynapsesForSegment( +/** + * Creates a sample set of connections, and makes sure that computing the + * activity for a collection of cells with no activity returns the right + * activity data. + */ +TEST(ConnectionsTest, testComputeActivity) { + Connections connections(1024); + + // Cell with 1 segment. + // Segment with: + // - 1 connected synapse: active + // - 2 matching synapses: active + const Segment segment1_1 = connections.createSegment(10); + connections.createSynapse(segment1_1, 150, 0.85); + connections.createSynapse(segment1_1, 151, 0.15); + + // Cell with 1 segments. + // Segment with: + // - 2 connected synapses: 2 active + // - 3 matching synapses: 3 active + const Segment segment2_1 = connections.createSegment(20); + connections.createSynapse(segment2_1, 80, 0.85); + connections.createSynapse(segment2_1, 81, 0.85); + Synapse synapse = connections.createSynapse(segment2_1, 82, 0.85); + connections.updateSynapsePermanence(synapse, 0.15); + + vector input = {50, 52, 53, 80, 81, 82, 150, 151}; + + vector numActiveConnectedSynapsesForSegment( connections.segmentFlatListLength(), 0); - vector numActivePotentialSynapsesForSegment( + vector numActivePotentialSynapsesForSegment( connections.segmentFlatListLength(), 0); - connections.computeActivity(numActiveConnectedSynapsesForSegment, - numActivePotentialSynapsesForSegment, - input, - 0.5); + connections.computeActivity(numActiveConnectedSynapsesForSegment, + numActivePotentialSynapsesForSegment, input, 0.5); - ASSERT_EQ(1, numActiveConnectedSynapsesForSegment[segment1_1]); - ASSERT_EQ(2, numActivePotentialSynapsesForSegment[segment1_1]); + ASSERT_EQ(1, numActiveConnectedSynapsesForSegment[segment1_1]); + ASSERT_EQ(2, numActivePotentialSynapsesForSegment[segment1_1]); - ASSERT_EQ(2, numActiveConnectedSynapsesForSegment[segment2_1]); - ASSERT_EQ(3, numActivePotentialSynapsesForSegment[segment2_1]); - } + ASSERT_EQ(2, numActiveConnectedSynapsesForSegment[segment2_1]); + ASSERT_EQ(3, numActivePotentialSynapsesForSegment[segment2_1]); +} + +/** + * Test the mapSegmentsToCells method. + */ +TEST(ConnectionsTest, testMapSegmentsToCells) { + Connections connections(1024); + const Segment segment1 = connections.createSegment(42); + const Segment segment2 = connections.createSegment(42); + const Segment segment3 = connections.createSegment(43); - /** - * Test the mapSegmentsToCells method. - */ - TEST(ConnectionsTest, testMapSegmentsToCells) - { - Connections connections(1024); + const vector segments = {segment1, segment2, segment3, segment1}; + vector cells(segments.size()); - const Segment segment1 = connections.createSegment(42); - const Segment segment2 = connections.createSegment(42); - const Segment segment3 = connections.createSegment(43); + connections.mapSegmentsToCells( + segments.data(), segments.data() + segments.size(), cells.data()); - const vector segments = {segment1, segment2, segment3, segment1}; - vector cells(segments.size()); + const vector expected = {42, 42, 43, 42}; + ASSERT_EQ(expected, cells); +} - connections.mapSegmentsToCells(segments.data(), - segments.data() + segments.size(), - cells.data()); +bool TEST_EVENT_HANDLER_DESTRUCTED = false; - const vector expected = {42, 42, 43, 42}; - ASSERT_EQ(expected, cells); +class TestConnectionsEventHandler : public ConnectionsEventHandler { +public: + TestConnectionsEventHandler() + : didCreateSegment(false), didDestroySegment(false), + didCreateSynapse(false), didDestroySynapse(false), + didUpdateSynapsePermanence(false) {} + + virtual ~TestConnectionsEventHandler() { + TEST_EVENT_HANDLER_DESTRUCTED = true; } + virtual void onCreateSegment(Segment segment) { didCreateSegment = true; } + virtual void onDestroySegment(Segment segment) { didDestroySegment = true; } - bool TEST_EVENT_HANDLER_DESTRUCTED = false; + virtual void onCreateSynapse(Synapse synapse) { didCreateSynapse = true; } - class TestConnectionsEventHandler : public ConnectionsEventHandler - { - public: - TestConnectionsEventHandler() - :didCreateSegment(false), - didDestroySegment(false), - didCreateSynapse(false), - didDestroySynapse(false), - didUpdateSynapsePermanence(false) - { - } - - virtual ~TestConnectionsEventHandler() - { - TEST_EVENT_HANDLER_DESTRUCTED = true; - } - - virtual void onCreateSegment(Segment segment) - { - didCreateSegment = true; - } - - virtual void onDestroySegment(Segment segment) - { - didDestroySegment = true; - } - - virtual void onCreateSynapse(Synapse synapse) - { - didCreateSynapse = true; - } - - virtual void onDestroySynapse(Synapse synapse) - { - didDestroySynapse = true; - } - - virtual void onUpdateSynapsePermanence(Synapse synapse, - Permanence permanence) - { - didUpdateSynapsePermanence = true; - } - - bool didCreateSegment; - bool didDestroySegment; - bool didCreateSynapse; - bool didDestroySynapse; - bool didUpdateSynapsePermanence; - }; - - /** - * Make sure each event handler gets called. - */ - TEST(ConnectionsTest, subscribe) - { - Connections connections(1024); + virtual void onDestroySynapse(Synapse synapse) { didDestroySynapse = true; } - TestConnectionsEventHandler* handler = new TestConnectionsEventHandler(); - auto token = connections.subscribe(handler); + virtual void onUpdateSynapsePermanence(Synapse synapse, + Permanence permanence) { + didUpdateSynapsePermanence = true; + } - ASSERT_FALSE(handler->didCreateSegment); - Segment segment = connections.createSegment(42); - EXPECT_TRUE(handler->didCreateSegment); + bool didCreateSegment; + bool didDestroySegment; + bool didCreateSynapse; + bool didDestroySynapse; + bool didUpdateSynapsePermanence; +}; - ASSERT_FALSE(handler->didCreateSynapse); - Synapse synapse = connections.createSynapse(segment, 41, 0.50); - EXPECT_TRUE(handler->didCreateSynapse); +/** + * Make sure each event handler gets called. + */ +TEST(ConnectionsTest, subscribe) { + Connections connections(1024); - ASSERT_FALSE(handler->didUpdateSynapsePermanence); - connections.updateSynapsePermanence(synapse, 0.60); - EXPECT_TRUE(handler->didUpdateSynapsePermanence); + TestConnectionsEventHandler *handler = new TestConnectionsEventHandler(); + auto token = connections.subscribe(handler); - ASSERT_FALSE(handler->didDestroySynapse); - connections.destroySynapse(synapse); - EXPECT_TRUE(handler->didDestroySynapse); + ASSERT_FALSE(handler->didCreateSegment); + Segment segment = connections.createSegment(42); + EXPECT_TRUE(handler->didCreateSegment); - ASSERT_FALSE(handler->didDestroySegment); - connections.destroySegment(segment); - EXPECT_TRUE(handler->didDestroySegment); + ASSERT_FALSE(handler->didCreateSynapse); + Synapse synapse = connections.createSynapse(segment, 41, 0.50); + EXPECT_TRUE(handler->didCreateSynapse); - connections.unsubscribe(token); - } + ASSERT_FALSE(handler->didUpdateSynapsePermanence); + connections.updateSynapsePermanence(synapse, 0.60); + EXPECT_TRUE(handler->didUpdateSynapsePermanence); - /** - * Make sure the event handler is destructed on unsubscribe. - */ - TEST(ConnectionsTest, unsubscribe) - { - Connections connections(1024); - TestConnectionsEventHandler* handler = new TestConnectionsEventHandler(); - auto token = connections.subscribe(handler); + ASSERT_FALSE(handler->didDestroySynapse); + connections.destroySynapse(synapse); + EXPECT_TRUE(handler->didDestroySynapse); - TEST_EVENT_HANDLER_DESTRUCTED = false; - connections.unsubscribe(token); - EXPECT_TRUE(TEST_EVENT_HANDLER_DESTRUCTED); - } + ASSERT_FALSE(handler->didDestroySegment); + connections.destroySegment(segment); + EXPECT_TRUE(handler->didDestroySegment); - /** - * Creates a sample set of connections, and makes sure that we can get the - * correct number of segments. - */ - TEST(ConnectionsTest, testNumSegments) - { - Connections connections(1024); - setupSampleConnections(connections); + connections.unsubscribe(token); +} - ASSERT_EQ(4, connections.numSegments()); - } +/** + * Make sure the event handler is destructed on unsubscribe. + */ +TEST(ConnectionsTest, unsubscribe) { + Connections connections(1024); + TestConnectionsEventHandler *handler = new TestConnectionsEventHandler(); + auto token = connections.subscribe(handler); + + TEST_EVENT_HANDLER_DESTRUCTED = false; + connections.unsubscribe(token); + EXPECT_TRUE(TEST_EVENT_HANDLER_DESTRUCTED); +} + +/** + * Creates a sample set of connections, and makes sure that we can get the + * correct number of segments. + */ +TEST(ConnectionsTest, testNumSegments) { + Connections connections(1024); + setupSampleConnections(connections); - /** - * Creates a sample set of connections, and makes sure that we can get the - * correct number of synapses. - */ - TEST(ConnectionsTest, testNumSynapses) - { - Connections connections(1024); - setupSampleConnections(connections); + ASSERT_EQ(4, connections.numSegments()); +} - ASSERT_EQ(10, connections.numSynapses()); - } +/** + * Creates a sample set of connections, and makes sure that we can get the + * correct number of synapses. + */ +TEST(ConnectionsTest, testNumSynapses) { + Connections connections(1024); + setupSampleConnections(connections); - /** - * Creates a sample set of connections with destroyed segments/synapses, - * computes sample activity, and makes sure that we can write to a - * filestream and read it back correctly. - */ - TEST(ConnectionsTest, testWriteRead) - { - const char* filename = "ConnectionsSerialization.tmp"; - Connections c1(1024), c2; - setupSampleConnections(c1); + ASSERT_EQ(10, connections.numSynapses()); +} - Segment segment = c1.createSegment(10); - c1.createSynapse(segment, 400, 0.5); - c1.destroySegment(segment); +/** + * Creates a sample set of connections with destroyed segments/synapses, + * computes sample activity, and makes sure that we can write to a + * filestream and read it back correctly. + */ +TEST(ConnectionsTest, testWriteRead) { + const char *filename = "ConnectionsSerialization.tmp"; + Connections c1(1024), c2; + setupSampleConnections(c1); - computeSampleActivity(c1); + Segment segment = c1.createSegment(10); + c1.createSynapse(segment, 400, 0.5); + c1.destroySegment(segment); - ofstream os(filename, ios::binary); - c1.write(os); - os.close(); + computeSampleActivity(c1); - ifstream is(filename, ios::binary); - c2.read(is); - is.close(); + ofstream os(filename, ios::binary); + c1.write(os); + os.close(); - ASSERT_EQ(c1, c2); + ifstream is(filename, ios::binary); + c2.read(is); + is.close(); - int ret = ::remove(filename); - NTA_CHECK(ret == 0) << "Failed to delete " << filename; - } + ASSERT_EQ(c1, c2); - TEST(ConnectionsTest, testSaveLoad) - { - Connections c1(1024), c2; - setupSampleConnections(c1); + int ret = ::remove(filename); + NTA_CHECK(ret == 0) << "Failed to delete " << filename; +} - auto segment = c1.createSegment(10); +TEST(ConnectionsTest, testSaveLoad) { + Connections c1(1024), c2; + setupSampleConnections(c1); - c1.createSynapse(segment, 400, 0.5); - c1.destroySegment(segment); + auto segment = c1.createSegment(10); - computeSampleActivity(c1); + c1.createSynapse(segment, 400, 0.5); + c1.destroySegment(segment); - { - stringstream ss; - c1.save(ss); - c2.load(ss); - } + computeSampleActivity(c1); - ASSERT_EQ(c1, c2); + { + stringstream ss; + c1.save(ss); + c2.load(ss); } -} // end namespace nupic + ASSERT_EQ(c1, c2); +} + +} // namespace diff --git a/src/test/unit/algorithms/NearestNeighborUnitTest.cpp b/src/test/unit/algorithms/NearestNeighborUnitTest.cpp index 921563dcc0..2e9d10a6fb 100644 --- a/src/test/unit/algorithms/NearestNeighborUnitTest.cpp +++ b/src/test/unit/algorithms/NearestNeighborUnitTest.cpp @@ -22,21 +22,19 @@ /** @file * Implementation of unit tests for NearestNeighbor - */ + */ -#include -#include #include "../math/SparseMatrixUnitTest.hpp" - +#include +#include using namespace std; namespace { -#define TEST_LOOP(M) \ - for (nrows = 1, ncols = M, zr = 15; \ - nrows < M; \ - nrows += M/10, ncols -= M/10, zr = ncols/10) \ +#define TEST_LOOP(M) \ + for (nrows = 1, ncols = M, zr = 15; nrows < M; \ + nrows += M / 10, ncols -= M / 10, zr = ncols / 10) #define M 64 // @@ -44,11 +42,11 @@ namespace { // void NearestNeighborUnitTest::unit_test_rowLpDist() // { // if (0) { // Visual tests, off by default -// -// UInt ncols = 11, nrows = 7, zr = 2; +// +// UInt ncols = 11, nrows = 7, zr = 2; // Dense dense(nrows, ncols, zr); -// NearestNeighbor > sparse(nrows, ncols, dense.begin()); -// std::vector x(ncols, 0); +// NearestNeighbor > sparse(nrows, ncols, +// dense.begin()); std::vector x(ncols, 0); // // for (UInt i = 0; i != ncols; ++i) // x[i] = i % 2; @@ -60,15 +58,15 @@ namespace { // for (UInt i = 0; i != nrows; ++i) { // cout << "L0 " << i << " " // << dense.rowL0Dist(i, x.begin()) << " " -// << sparse.rowL0Dist(i, x.begin()) +// << sparse.rowL0Dist(i, x.begin()) // << endl; // } // -// // L1 +// // L1 // for (UInt i = 0; i != nrows; ++i) { // cout << "L1 " << i << " " -// << dense.rowLpDist(1.0, i, x.begin()) << " " -// << sparse.rowL1Dist(i, x.begin()) +// << dense.rowLpDist(1.0, i, x.begin()) << " " +// << sparse.rowL1Dist(i, x.begin()) // << endl; // } // @@ -76,7 +74,7 @@ namespace { // for (UInt i = 0; i != nrows; ++i) { // cout << "L2 " << i << " " // << dense.rowLpDist(2.0, i, x.begin()) << " " -// << sparse.rowL2Dist(i, x.begin()) +// << sparse.rowL2Dist(i, x.begin()) // << endl; // } // @@ -84,7 +82,7 @@ namespace { // for (UInt i = 0; i != nrows; ++i) { // cout << "Lmax " << i << " " // << dense.rowLMaxDist(i, x.begin()) << " " -// << sparse.rowLMaxDist(i, x.begin()) +// << sparse.rowLMaxDist(i, x.begin()) // << endl; // } // @@ -92,7 +90,7 @@ namespace { // for (UInt i = 0; i != nrows; ++i) { // cout << "Lp " << i << " " // << dense.rowLpDist(.35, i, x.begin()) << " " -// << sparse.rowLpDist(.35, i, x.begin()) +// << sparse.rowLpDist(.35, i, x.begin()) // << endl; // } // } // End visual tests @@ -106,7 +104,8 @@ namespace { // continue; // // Dense dense(nrows, ncols, zr); -// NearestNeighbor > sparse(nrows, ncols, dense.begin()); +// NearestNeighbor > sparse(nrows, ncols, +// dense.begin()); // // std::vector x(ncols, 0); // for (UInt i = 0; i < ncols; ++i) @@ -148,7 +147,7 @@ namespace { // << " - non compact"; // TEST(nupic::nearlyEqual(d1, d2)); // } -// +// // sparse.compact(); // d2 = sparse.rowLMaxDist(row, x.begin()); // { @@ -168,23 +167,23 @@ namespace { // // UInt ncols = 11, nrows = 7, zr = 2; // Dense dense(nrows, ncols, zr); -// NearestNeighbor > sparse(nrows, ncols, dense.begin()); -// std::vector x(ncols, 0), distances(nrows, 0); +// NearestNeighbor > sparse(nrows, ncols, +// dense.begin()); std::vector x(ncols, 0), distances(nrows, 0); // // for (UInt i = 0; i != ncols; ++i) -// x[i] = i % 2; +// x[i] = i % 2; // // cout << dense << endl << x << endl << endl; // -// // L0 +// // L0 // cout << "L0" << endl; // dense.L0Dist(x.begin(), distances.begin()); // cout << distances << endl; -// sparse.L0Dist(x.begin(), distances.begin()); +// sparse.L0Dist(x.begin(), distances.begin()); // cout << distances << endl; // cout << endl; // -// // L1 +// // L1 // cout << "L1" << endl; // dense.LpDist(1.0, x.begin(), distances.begin()); // cout << distances << endl; @@ -218,7 +217,7 @@ namespace { // } // End visual tests // // if (1) { // Automated tests -// +// // UInt ncols = 5, nrows = 7, zr = 2; // // TEST_LOOP(M) { @@ -227,12 +226,13 @@ namespace { // continue; // // Dense dense(nrows, ncols, zr); -// NearestNeighbor > sparse(nrows, ncols, dense.begin()); +// NearestNeighbor > sparse(nrows, ncols, +// dense.begin()); // // std::vector x(ncols, 0), yref(nrows, 0), y(nrows, 0); // for (UInt i = 0; i < ncols; ++i) // x[i] = Real(i); -// +// // for (double p = 0.0; p < 2.5; p += .5) { // // sparse.decompact(); @@ -243,8 +243,8 @@ namespace { // str << "LpDist A " << nrows << "X" << ncols << "/" << zr // << " - non compact"; // CompareVectors(nrows, y.begin(), yref.begin(), str.str().c_str()); -// } -// +// } +// // sparse.compact(); // sparse.LpDist(p, x.begin(), y.begin()); // { @@ -265,7 +265,7 @@ namespace { // << " - non compact"; // CompareVectors(nrows, y.begin(), yref.begin(), str.str().c_str()); // } -// +// // sparse.compact(); // sparse.LMaxDist(x.begin(), y.begin()); // @@ -283,15 +283,15 @@ namespace { // void NearestNeighborUnitTest::unit_test_LpNearest() // { // if (0) { // Visual tests, off by default -// +// // UInt ncols = 11, nrows = 7, zr = 2; // // Dense dense(nrows, ncols, zr); // for (UInt i = 0; i != nrows; ++i) // for (UInt j = 0; j != ncols; ++j) // dense.at(i,j) = rng_->getReal64() * 2.0; -// NearestNeighbor > sparse(nrows, ncols, dense.begin()); -// std::vector x(ncols, 0); +// NearestNeighbor > sparse(nrows, ncols, +// dense.begin()); std::vector x(ncols, 0); // std::vector > nn1(nrows), nn2(nrows); // // for (UInt i = 0; i != ncols; ++i) @@ -309,7 +309,7 @@ namespace { // cout << nn2[i].first << "," << nn2[i].second << " "; // cout << endl; // -// // L1 +// // L1 // cout << "L1" << endl; // dense.LpNearest(1.0, x.begin(), nn1.begin(), nrows); // sparse.L1Nearest(x.begin(), nn2.begin(), nrows); @@ -319,7 +319,7 @@ namespace { // for (UInt i = 0; i != nrows; ++i) // cout << nn2[i].first << "," << nn2[i].second << " "; // cout << endl << endl; -// +// // // L2 // cout << "L2" << endl; // dense.LpNearest(2.0, x.begin(), nn1.begin(), nrows); @@ -331,7 +331,7 @@ namespace { // cout << nn2[i].first << "," << nn2[i].second << " "; // cout << endl << endl; // -// // LMax +// // LMax // cout << "LMax" << endl; // dense.LMaxNearest(x.begin(), nn1.begin(), nrows); // sparse.LMaxNearest(x.begin(), nn2.begin(), nrows); @@ -364,13 +364,14 @@ namespace { // continue; // // Dense dense(nrows, ncols, zr); -// NearestNeighbor > sparse(nrows, ncols, dense.begin()); +// NearestNeighbor > sparse(nrows, ncols, +// dense.begin()); // // std::vector x(ncols, 0); // std::vector > yref(nrows), y(nrows); // for (UInt i = 0; i < ncols; ++i) // x[i] = Real(i); -// +// // for (double p = 0.0; p < 2.5; p += .5) { // // sparse.decompact(); @@ -381,8 +382,8 @@ namespace { // str << "LpNearest A " << nrows << "X" << ncols << "/" << zr // << " - non compact"; // Compare(y, yref, str.str().c_str()); -// } -// +// } +// // sparse.compact(); // sparse.LpNearest(p, x.begin(), y.begin(), nrows); // { @@ -401,7 +402,7 @@ namespace { // << " - non compact"; // Compare(y, yref, str.str().c_str()); // } -// +// // sparse.compact(); // sparse.LpNearest(p, x.begin(), y.begin()); // { @@ -413,7 +414,7 @@ namespace { // } // } // } // End automated tests -// } +// } // // //-------------------------------------------------------------------------------- // void NearestNeighborUnitTest::unit_test_dotNearest() @@ -431,14 +432,14 @@ namespace { // // pair res(0, 0), ref = dense.dotNearest(x.begin()); // -// NearestNeighbor > smc(nrows, ncols, dense.begin()); -// res = smc.dotNearest(x.begin()); -// ComparePair(res, ref, "dotNearest compact 1"); +// NearestNeighbor > smc(nrows, ncols, +// dense.begin()); res = smc.dotNearest(x.begin()); ComparePair(res, ref, +// "dotNearest compact 1"); // // { // nrows *= 10; // ncols *= 10; -// +// // Dense dense2(nrows, ncols); // for (i = 0; i < nrows; ++i) // for (j = 0; j < ncols; ++j) { @@ -446,15 +447,16 @@ namespace { // if (dense2.at(i,j) < .8) // dense2.at(i,j) = 0; // } -// -// NearestNeighbor > sm2(nrows, ncols, dense2.begin()); -// +// +// NearestNeighbor > sm2(nrows, ncols, +// dense2.begin()); +// // std::vector x2(ncols, 0); // for (j = 0; j < ncols; ++j) // x2[j] = rng_->getReal64(); -// +// // ref = dense2.dotNearest(x2.begin()); -// +// // res.first = 0; res.second = 0; // res = sm2.dotNearest(x2.begin()); // ComparePair(res, ref, "dotNearest compact 2"); @@ -464,7 +466,8 @@ namespace { // TEST_LOOP(M) { // // DenseMat dense2(nrows, ncols, zr); -// NearestNeighbor > sm2(nrows, ncols, dense2.begin()); +// NearestNeighbor > sm2(nrows, ncols, +// dense2.begin()); // // std::vector x2(ncols, 0), yref2(nrows, 0), y2(nrows, 0); // for (i = 0; i < ncols; ++i) @@ -494,8 +497,5 @@ namespace { // } // } // - //-------------------------------------------------------------------------------- -} // end namespace nupic - - - +//-------------------------------------------------------------------------------- +} // namespace diff --git a/src/test/unit/algorithms/SDRClassifierTest.cpp b/src/test/unit/algorithms/SDRClassifierTest.cpp index dbd6922eb6..498cfa29e9 100644 --- a/src/test/unit/algorithms/SDRClassifierTest.cpp +++ b/src/test/unit/algorithms/SDRClassifierTest.cpp @@ -40,383 +40,376 @@ using namespace nupic; using namespace nupic::algorithms::cla_classifier; using namespace nupic::algorithms::sdr_classifier; -namespace -{ +namespace { + +TEST(SDRClassifierTest, Basic) { + vector steps; + steps.push_back(1); + SDRClassifier c = SDRClassifier(steps, 0.1, 0.1, 0); + + // Create a vector of input bit indices + vector input1; + input1.push_back(1); + input1.push_back(5); + input1.push_back(9); + vector bucketIdxList1; + bucketIdxList1.push_back(4); + vector actValueList1; + actValueList1.push_back(34.7); + ClassifierResult result1; + c.compute(0, input1, bucketIdxList1, actValueList1, false, true, true, + &result1); + + // Create a vector of input bit indices + vector input2; + input2.push_back(1); + input2.push_back(5); + input2.push_back(9); + vector bucketIdxList2; + bucketIdxList2.push_back(4); + vector actValueList2; + actValueList2.push_back(34.7); + ClassifierResult result2; + c.compute(1, input2, bucketIdxList2, actValueList2, false, true, true, + &result2); - TEST(SDRClassifierTest, Basic) { - vector steps; - steps.push_back(1); - SDRClassifier c = SDRClassifier(steps, 0.1, 0.1, 0); - - // Create a vector of input bit indices - vector input1; - input1.push_back(1); - input1.push_back(5); - input1.push_back(9); - vector bucketIdxList1; - bucketIdxList1.push_back(4); - vector actValueList1; - actValueList1.push_back(34.7); - ClassifierResult result1; - c.compute(0, input1, bucketIdxList1, actValueList1, false, true, true, &result1); - - // Create a vector of input bit indices - vector input2; - input2.push_back(1); - input2.push_back(5); - input2.push_back(9); - vector bucketIdxList2; - bucketIdxList2.push_back(4); - vector actValueList2; - actValueList2.push_back(34.7); - ClassifierResult result2; - c.compute(1, input2, bucketIdxList2, actValueList2, false, true, true, &result2); - - { - bool foundMinus1 = false; - bool found1 = false; - for (auto it = result2.begin(); - it != result2.end(); ++it) - { - if (it->first == -1) - { - // The -1 key is used for the actual values - ASSERT_EQ(false, foundMinus1) + bool foundMinus1 = false; + bool found1 = false; + for (auto it = result2.begin(); it != result2.end(); ++it) { + if (it->first == -1) { + // The -1 key is used for the actual values + ASSERT_EQ(false, foundMinus1) << "Already found key -1 in classifier result"; - foundMinus1 = true; - ASSERT_EQ(5, it->second->size()) + foundMinus1 = true; + ASSERT_EQ(5, it->second->size()) << "Expected five buckets since it has only seen bucket 4 (so it " << "Has buckets 0-4)."; - ASSERT_TRUE(fabs(it->second->at(4) - 34.7) < 0.000001) + ASSERT_TRUE(fabs(it->second->at(4) - 34.7) < 0.000001) << "Incorrect actual value for bucket 4"; - } else if (it->first == 1) { - // Check the one-step prediction - ASSERT_EQ(false, found1) - << "Already found key 1 in classifier result"; - found1 = true; - ASSERT_EQ(5, it->second->size()) - << "Expected five bucket predictions"; - ASSERT_LT(fabs(it->second->at(0) - 0.2), 0.000001) + } else if (it->first == 1) { + // Check the one-step prediction + ASSERT_EQ(false, found1) << "Already found key 1 in classifier result"; + found1 = true; + ASSERT_EQ(5, it->second->size()) << "Expected five bucket predictions"; + ASSERT_LT(fabs(it->second->at(0) - 0.2), 0.000001) << "Incorrect prediction for bucket 0"; - ASSERT_LT(fabs(it->second->at(1) - 0.2), 0.000001) + ASSERT_LT(fabs(it->second->at(1) - 0.2), 0.000001) << "Incorrect prediction for bucket 1"; - ASSERT_LT(fabs(it->second->at(2) - 0.2), 0.000001) + ASSERT_LT(fabs(it->second->at(2) - 0.2), 0.000001) << "Incorrect prediction for bucket 2"; - ASSERT_LT(fabs(it->second->at(3) - 0.2), 0.000001) + ASSERT_LT(fabs(it->second->at(3) - 0.2), 0.000001) << "Incorrect prediction for bucket 3"; - ASSERT_LT(fabs(it->second->at(4) - 0.2), 0.000001) + ASSERT_LT(fabs(it->second->at(4) - 0.2), 0.000001) << "Incorrect prediction for bucket 4"; - } } - ASSERT_TRUE(foundMinus1) << "Key -1 not found in classifier result"; - ASSERT_TRUE(found1) << "key 1 not found in classifier result"; } + ASSERT_TRUE(foundMinus1) << "Key -1 not found in classifier result"; + ASSERT_TRUE(found1) << "key 1 not found in classifier result"; } - - TEST(SDRClassifierTest, SingleValue) - { - // Feed the same input 10 times, the corresponding probability should be - // very high - vector steps; - steps.push_back(1); - SDRClassifier c = SDRClassifier(steps, 0.1, 0.1, 0); - - // Create a vector of input bit indices - vector input1; - input1.push_back(1); - input1.push_back(5); - input1.push_back(9); - vector bucketIdxList; - bucketIdxList.push_back(4); - vector actValueList; - actValueList.push_back(34.7); +} + +TEST(SDRClassifierTest, SingleValue) { + // Feed the same input 10 times, the corresponding probability should be + // very high + vector steps; + steps.push_back(1); + SDRClassifier c = SDRClassifier(steps, 0.1, 0.1, 0); + + // Create a vector of input bit indices + vector input1; + input1.push_back(1); + input1.push_back(5); + input1.push_back(9); + vector bucketIdxList; + bucketIdxList.push_back(4); + vector actValueList; + actValueList.push_back(34.7); + ClassifierResult result1; + for (UInt i = 0; i < 10; ++i) { ClassifierResult result1; - for (UInt i = 0; i < 10; ++i) - { - ClassifierResult result1; - c.compute(i, input1, bucketIdxList, actValueList, false, true, true, &result1); - } + c.compute(i, input1, bucketIdxList, actValueList, false, true, true, + &result1); + } - { - for (auto it = result1.begin(); - it != result1.end(); ++it) - { - if (it->first == -1) - { - ASSERT_TRUE(fabs(it->second->at(4) - 10.0) < 0.000001) + { + for (auto it = result1.begin(); it != result1.end(); ++it) { + if (it->first == -1) { + ASSERT_TRUE(fabs(it->second->at(4) - 10.0) < 0.000001) << "Incorrect actual value for bucket 4"; - } else if (it->first == 1) { - ASSERT_GT(it->second->at(4), 0.9) + } else if (it->first == 1) { + ASSERT_GT(it->second->at(4), 0.9) << "Incorrect prediction for bucket 4"; - } } } - } +} + +TEST(SDRClassifierTest, ComputeComplex) { + // More complex classification + // This test is ported from the Python unit test + vector steps; + steps.push_back(1); + SDRClassifier c = SDRClassifier(steps, 1.0, 0.1, 0); + + // Create a input vector + vector input1; + input1.push_back(1); + input1.push_back(5); + input1.push_back(9); + vector bucketIdxList1; + bucketIdxList1.push_back(4); + vector actValueList1; + actValueList1.push_back(34.7); + + // Create a input vector + vector input2; + input2.push_back(0); + input2.push_back(6); + input2.push_back(9); + input2.push_back(11); + vector bucketIdxList2; + bucketIdxList2.push_back(5); + vector actValueList2; + actValueList2.push_back(41.7); + + // Create input vectors + vector input3; + input3.push_back(6); + input3.push_back(9); + vector bucketIdxList3; + bucketIdxList3.push_back(5); + vector actValueList3; + actValueList3.push_back(44.9); + + vector bucketIdxList4; + bucketIdxList4.push_back(4); + vector actValueList4; + actValueList4.push_back(42.9); + + vector bucketIdxList5; + bucketIdxList5.push_back(4); + vector actValueList5; + actValueList5.push_back(34.7); + + ClassifierResult result1; + c.compute(0, input1, bucketIdxList1, actValueList1, false, true, true, + &result1); + + ClassifierResult result2; + c.compute(1, input2, bucketIdxList2, actValueList2, false, true, true, + &result2); + + ClassifierResult result3; + c.compute(2, input3, bucketIdxList3, actValueList3, false, true, true, + &result3); + + ClassifierResult result4; + c.compute(3, input1, bucketIdxList4, actValueList4, false, true, true, + &result4); + + ClassifierResult result5; + c.compute(4, input1, bucketIdxList5, actValueList5, false, true, true, + &result5); - TEST(SDRClassifierTest, ComputeComplex) { - // More complex classification - // This test is ported from the Python unit test - vector steps; - steps.push_back(1); - SDRClassifier c = SDRClassifier(steps, 1.0, 0.1, 0); - - // Create a input vector - vector input1; - input1.push_back(1); - input1.push_back(5); - input1.push_back(9); - vector bucketIdxList1; - bucketIdxList1.push_back(4); - vector actValueList1; - actValueList1.push_back(34.7); - - // Create a input vector - vector input2; - input2.push_back(0); - input2.push_back(6); - input2.push_back(9); - input2.push_back(11); - vector bucketIdxList2; - bucketIdxList2.push_back(5); - vector actValueList2; - actValueList2.push_back(41.7); - - // Create input vectors - vector input3; - input3.push_back(6); - input3.push_back(9); - vector bucketIdxList3; - bucketIdxList3.push_back(5); - vector actValueList3; - actValueList3.push_back(44.9); - - vector bucketIdxList4; - bucketIdxList4.push_back(4); - vector actValueList4; - actValueList4.push_back(42.9); - - vector bucketIdxList5; - bucketIdxList5.push_back(4); - vector actValueList5; - actValueList5.push_back(34.7); - - ClassifierResult result1; - c.compute(0, input1, bucketIdxList1, actValueList1, false, true, true, &result1); - - ClassifierResult result2; - c.compute(1, input2, bucketIdxList2, actValueList2, false, true, true, &result2); - - ClassifierResult result3; - c.compute(2, input3, bucketIdxList3, actValueList3, false, true, true, &result3); - - ClassifierResult result4; - c.compute(3, input1, bucketIdxList4, actValueList4, false, true, true, &result4); - - ClassifierResult result5; - c.compute(4, input1, bucketIdxList5, actValueList5, false, true, true, &result5); - - { - bool foundMinus1 = false; - bool found1 = false; - for (auto it = result5.begin(); it != result5.end(); ++it) - { - ASSERT_TRUE(it->first == -1 || it->first == 1) + bool foundMinus1 = false; + bool found1 = false; + for (auto it = result5.begin(); it != result5.end(); ++it) { + ASSERT_TRUE(it->first == -1 || it->first == 1) << "Result vector should only have -1 or 1 as key"; - if (it->first == -1) - { - // The -1 key is used for the actual values - ASSERT_EQ(false, foundMinus1) + if (it->first == -1) { + // The -1 key is used for the actual values + ASSERT_EQ(false, foundMinus1) << "Already found key -1 in classifier result"; - foundMinus1 = true; - ASSERT_EQ(6, it->second->size()) + foundMinus1 = true; + ASSERT_EQ(6, it->second->size()) << "Expected six buckets since it has only seen bucket 4-5 (so it " << "has buckets 0-5)."; - ASSERT_TRUE(fabs(it->second->at(4) - 35.520000457763672) < 0.000001) + ASSERT_TRUE(fabs(it->second->at(4) - 35.520000457763672) < 0.000001) << "Incorrect actual value for bucket 4"; - ASSERT_TRUE(fabs(it->second->at(5) - 42.020000457763672) < 0.000001) + ASSERT_TRUE(fabs(it->second->at(5) - 42.020000457763672) < 0.000001) << "Incorrect actual value for bucket 5"; - } else if (it->first == 1) { - // Check the one-step prediction - ASSERT_EQ(false, found1) - << "Already found key 1 in classifier result"; - found1 = true; - - ASSERT_EQ(6, it->second->size()) - << "Expected six bucket predictions"; - ASSERT_LT(fabs(it->second->at(0) - 0.034234), 0.000001) + } else if (it->first == 1) { + // Check the one-step prediction + ASSERT_EQ(false, found1) << "Already found key 1 in classifier result"; + found1 = true; + + ASSERT_EQ(6, it->second->size()) << "Expected six bucket predictions"; + ASSERT_LT(fabs(it->second->at(0) - 0.034234), 0.000001) << "Incorrect prediction for bucket 0"; - ASSERT_LT(fabs(it->second->at(1) - 0.034234), 0.000001) + ASSERT_LT(fabs(it->second->at(1) - 0.034234), 0.000001) << "Incorrect prediction for bucket 1"; - ASSERT_LT(fabs(it->second->at(2) - 0.034234), 0.000001) + ASSERT_LT(fabs(it->second->at(2) - 0.034234), 0.000001) << "Incorrect prediction for bucket 2"; - ASSERT_LT(fabs(it->second->at(3) - 0.034234), 0.000001) + ASSERT_LT(fabs(it->second->at(3) - 0.034234), 0.000001) << "Incorrect prediction for bucket 3"; - ASSERT_LT(fabs(it->second->at(4) - 0.093058), 0.000001) + ASSERT_LT(fabs(it->second->at(4) - 0.093058), 0.000001) << "Incorrect prediction for bucket 4"; - ASSERT_LT(fabs(it->second->at(5) - 0.770004), 0.000001) + ASSERT_LT(fabs(it->second->at(5) - 0.770004), 0.000001) << "Incorrect prediction for bucket 5"; - } } - ASSERT_TRUE(foundMinus1) << "Key -1 not found in classifier result"; - ASSERT_TRUE(found1) << "Key 1 not found in classifier result"; } - + ASSERT_TRUE(foundMinus1) << "Key -1 not found in classifier result"; + ASSERT_TRUE(found1) << "Key 1 not found in classifier result"; } - - TEST(SDRClassifierTest, MultipleCategory) - { - // Test multiple category classification with single compute calls - // This test is ported from the Python unit test - vector steps; - steps.push_back(0); - SDRClassifier c = SDRClassifier(steps, 1.0, 0.1, 0); - - // Create a input vectors - vector input1; - input1.push_back(1); - input1.push_back(3); - input1.push_back(5); - vector bucketIdxList1; - bucketIdxList1.push_back(0); - bucketIdxList1.push_back(1); - vector actValueList1; - actValueList1.push_back(0); - actValueList1.push_back(1); - - // Create a input vectors - vector input2; - input2.push_back(2); - input2.push_back(4); - input2.push_back(6); - vector bucketIdxList2; - bucketIdxList2.push_back(2); - bucketIdxList2.push_back(3); - vector actValueList2; - actValueList2.push_back(2); - actValueList2.push_back(3); - - int recordNum=0; - for(int i=0; i<1000; i++) - { - ClassifierResult result1; - ClassifierResult result2; - c.compute(recordNum, input1, bucketIdxList1, actValueList1, false, true, true, &result1); - recordNum += 1; - c.compute(recordNum, input2, bucketIdxList2, actValueList2, false, true, true, &result2); - recordNum += 1; - } - +} + +TEST(SDRClassifierTest, MultipleCategory) { + // Test multiple category classification with single compute calls + // This test is ported from the Python unit test + vector steps; + steps.push_back(0); + SDRClassifier c = SDRClassifier(steps, 1.0, 0.1, 0); + + // Create a input vectors + vector input1; + input1.push_back(1); + input1.push_back(3); + input1.push_back(5); + vector bucketIdxList1; + bucketIdxList1.push_back(0); + bucketIdxList1.push_back(1); + vector actValueList1; + actValueList1.push_back(0); + actValueList1.push_back(1); + + // Create a input vectors + vector input2; + input2.push_back(2); + input2.push_back(4); + input2.push_back(6); + vector bucketIdxList2; + bucketIdxList2.push_back(2); + bucketIdxList2.push_back(3); + vector actValueList2; + actValueList2.push_back(2); + actValueList2.push_back(3); + + int recordNum = 0; + for (int i = 0; i < 1000; i++) { ClassifierResult result1; ClassifierResult result2; - c.compute(recordNum, input1, bucketIdxList1, actValueList1, false, true, true, &result1); + c.compute(recordNum, input1, bucketIdxList1, actValueList1, false, true, + true, &result1); recordNum += 1; - c.compute(recordNum, input2, bucketIdxList2, actValueList2, false, true, true, &result2); + c.compute(recordNum, input2, bucketIdxList2, actValueList2, false, true, + true, &result2); recordNum += 1; + } - for (auto it = result1.begin(); it != result1.end(); ++it) - { - if (it->first == 0) { - ASSERT_LT(fabs(it->second->at(0) - 0.5), 0.1) - << "Incorrect prediction for bucket 0 (expected=0.5)"; - ASSERT_LT(fabs(it->second->at(1) - 0.5), 0.1) - << "Incorrect prediction for bucket 1 (expected=0.5)"; - } + ClassifierResult result1; + ClassifierResult result2; + c.compute(recordNum, input1, bucketIdxList1, actValueList1, false, true, true, + &result1); + recordNum += 1; + c.compute(recordNum, input2, bucketIdxList2, actValueList2, false, true, true, + &result2); + recordNum += 1; + + for (auto it = result1.begin(); it != result1.end(); ++it) { + if (it->first == 0) { + ASSERT_LT(fabs(it->second->at(0) - 0.5), 0.1) + << "Incorrect prediction for bucket 0 (expected=0.5)"; + ASSERT_LT(fabs(it->second->at(1) - 0.5), 0.1) + << "Incorrect prediction for bucket 1 (expected=0.5)"; } + } - for (auto it = result2.begin(); it != result2.end(); ++it) - { - if (it->first == 0) { - ASSERT_LT(fabs(it->second->at(2) - 0.5), 0.1) - << "Incorrect prediction for bucket 2 (expected=0.5)"; - ASSERT_LT(fabs(it->second->at(3) - 0.5), 0.1) - << "Incorrect prediction for bucket 3 (expected=0.5)"; - } + for (auto it = result2.begin(); it != result2.end(); ++it) { + if (it->first == 0) { + ASSERT_LT(fabs(it->second->at(2) - 0.5), 0.1) + << "Incorrect prediction for bucket 2 (expected=0.5)"; + ASSERT_LT(fabs(it->second->at(3) - 0.5), 0.1) + << "Incorrect prediction for bucket 3 (expected=0.5)"; } - } +} + +TEST(SDRClassifierTest, SaveLoad) { + vector steps; + steps.push_back(1); + SDRClassifier c1 = SDRClassifier(steps, 0.1, 0.1, 0); + SDRClassifier c2 = SDRClassifier(steps, 0.1, 0.1, 0); + + // Create a vector of input bit indices + vector input1; + input1.push_back(1); + input1.push_back(5); + input1.push_back(9); + vector bucketIdxList1; + bucketIdxList1.push_back(4); + vector actValueList1; + actValueList1.push_back(34.7); + ClassifierResult result; + c1.compute(0, input1, bucketIdxList1, actValueList1, false, true, true, + &result); - TEST(SDRClassifierTest, SaveLoad) { - vector steps; - steps.push_back(1); - SDRClassifier c1 = SDRClassifier(steps, 0.1, 0.1, 0); - SDRClassifier c2 = SDRClassifier(steps, 0.1, 0.1, 0); - - // Create a vector of input bit indices - vector input1; - input1.push_back(1); - input1.push_back(5); - input1.push_back(9); - vector bucketIdxList1; - bucketIdxList1.push_back(4); - vector actValueList1; - actValueList1.push_back(34.7); - ClassifierResult result; - c1.compute(0, input1, bucketIdxList1, actValueList1, false, true, true, &result); - - { - stringstream ss; - c1.save(ss); - c2.load(ss); - } - - ASSERT_TRUE(c1 == c2); - - ClassifierResult result1, result2; - c1.compute(1, input1, bucketIdxList1, actValueList1, false, true, true, &result1); - c2.compute(1, input1, bucketIdxList1, actValueList1, false, true, true, &result2); - - ASSERT_TRUE(result1 == result2); + stringstream ss; + c1.save(ss); + c2.load(ss); } - TEST(SDRClassifierTest, WriteRead) + ASSERT_TRUE(c1 == c2); + + ClassifierResult result1, result2; + c1.compute(1, input1, bucketIdxList1, actValueList1, false, true, true, + &result1); + c2.compute(1, input1, bucketIdxList1, actValueList1, false, true, true, + &result2); + + ASSERT_TRUE(result1 == result2); +} + +TEST(SDRClassifierTest, WriteRead) { + vector steps; + steps.push_back(1); + steps.push_back(2); + SDRClassifier c1 = SDRClassifier(steps, 0.1, 0.1, 0); + SDRClassifier c2 = SDRClassifier(steps, 0.1, 0.1, 0); + + // Create a vector of input bit indices + vector input1; + input1.push_back(1); + input1.push_back(5); + input1.push_back(9); + vector bucketIdxList1; + bucketIdxList1.push_back(4); + vector actValueList1; + actValueList1.push_back(34.7); + ClassifierResult trainResult1; + c1.compute(0, input1, bucketIdxList1, actValueList1, false, true, true, + &trainResult1); + + // Create a vector of input bit indices + vector input2; + input2.push_back(0); + input2.push_back(8); + input2.push_back(9); + vector bucketIdxList2; + bucketIdxList2.push_back(2); + vector actValueList2; + actValueList2.push_back(24.7); + ClassifierResult trainResult2; + c1.compute(1, input2, bucketIdxList2, actValueList2, false, true, true, + &trainResult2); + { - vector steps; - steps.push_back(1); - steps.push_back(2); - SDRClassifier c1 = SDRClassifier(steps, 0.1, 0.1, 0); - SDRClassifier c2 = SDRClassifier(steps, 0.1, 0.1, 0); - - // Create a vector of input bit indices - vector input1; - input1.push_back(1); - input1.push_back(5); - input1.push_back(9); - vector bucketIdxList1; - bucketIdxList1.push_back(4); - vector actValueList1; - actValueList1.push_back(34.7); - ClassifierResult trainResult1; - c1.compute(0, input1, bucketIdxList1, actValueList1, false, true, true, &trainResult1); - - // Create a vector of input bit indices - vector input2; - input2.push_back(0); - input2.push_back(8); - input2.push_back(9); - vector bucketIdxList2; - bucketIdxList2.push_back(2); - vector actValueList2; - actValueList2.push_back(24.7); - ClassifierResult trainResult2; - c1.compute(1, input2, bucketIdxList2, actValueList2, false, true, true, &trainResult2); - - { - stringstream ss; - c1.write(ss); - c2.read(ss); - } + stringstream ss; + c1.write(ss); + c2.read(ss); + } - ASSERT_TRUE(c1 == c2); + ASSERT_TRUE(c1 == c2); - ClassifierResult result1, result2; - c1.compute(2, input1, bucketIdxList1, actValueList1, false, true, true, &result1); - c2.compute(2, input1, bucketIdxList1, actValueList1, false, true, true, &result2); + ClassifierResult result1, result2; + c1.compute(2, input1, bucketIdxList1, actValueList1, false, true, true, + &result1); + c2.compute(2, input1, bucketIdxList1, actValueList1, false, true, true, + &result2); - ASSERT_TRUE(result1 == result2); - } + ASSERT_TRUE(result1 == result2); +} } // end namespace diff --git a/src/test/unit/algorithms/SegmentTest.cpp b/src/test/unit/algorithms/SegmentTest.cpp index 8eb498b2e8..9ee67df7fa 100644 --- a/src/test/unit/algorithms/SegmentTest.cpp +++ b/src/test/unit/algorithms/SegmentTest.cpp @@ -24,38 +24,31 @@ * Implementation of unit tests for Segment */ -#include #include +#include #include using namespace nupic::algorithms::Cells4; using namespace std; - -void setUpSegment(Segment &segment, - vector &inactiveSegmentIndices, +void setUpSegment(Segment &segment, vector &inactiveSegmentIndices, vector &activeSegmentIndices, vector &activeSynapseIndices, - vector &inactiveSynapseIndices) -{ - vector permanences = {0.2, 0.9, 0.9, 0.7, 0.4, // active synapses + vector &inactiveSynapseIndices) { + vector permanences = {0.2, 0.9, 0.9, 0.7, 0.4, // active synapses 0.8, 0.1, 0.2, 0.3, 0.2}; // inactive synapses set srcCells; - for (UInt i = 0; i < permanences.size(); i++) - { + for (UInt i = 0; i < permanences.size(); i++) { srcCells.clear(); srcCells.insert(i); segment.addSynapses(srcCells, permanences[i], 0.5); - if (i < 5) - { + if (i < 5) { inactiveSegmentIndices.push_back(i); inactiveSynapseIndices.push_back(0); - } - else - { + } else { activeSegmentIndices.push_back(i); activeSynapseIndices.push_back(0); } @@ -63,11 +56,10 @@ void setUpSegment(Segment &segment, } /* -* Test that synapses are removed from inactive first even when there -* are active synapses with lower permanence. -*/ -TEST(SegmentTest, freeNSynapsesInactiveFirst) -{ + * Test that synapses are removed from inactive first even when there + * are active synapses with lower permanence. + */ +TEST(SegmentTest, freeNSynapsesInactiveFirst) { Segment segment; vector inactiveSegmentIndices; @@ -76,20 +68,13 @@ TEST(SegmentTest, freeNSynapsesInactiveFirst) vector inactiveSynapseIndices; vector removed; - setUpSegment(segment, - inactiveSegmentIndices, - activeSegmentIndices, - activeSynapseIndices, - inactiveSynapseIndices); + setUpSegment(segment, inactiveSegmentIndices, activeSegmentIndices, + activeSynapseIndices, inactiveSynapseIndices); ASSERT_EQ(segment.size(), 10); - segment.freeNSynapses(2, - inactiveSynapseIndices, - inactiveSegmentIndices, - activeSynapseIndices, - activeSegmentIndices, - removed, 0, + segment.freeNSynapses(2, inactiveSynapseIndices, inactiveSegmentIndices, + activeSynapseIndices, activeSegmentIndices, removed, 0, 10, 1.0); ASSERT_EQ(segment.size(), 8); @@ -100,11 +85,10 @@ TEST(SegmentTest, freeNSynapsesInactiveFirst) } /* -* Test that active synapses are removed once all inactive synapses are -* exhausted. -*/ -TEST(SegmentTest, freeNSynapsesActiveFallback) -{ + * Test that active synapses are removed once all inactive synapses are + * exhausted. + */ +TEST(SegmentTest, freeNSynapsesActiveFallback) { Segment segment; vector inactiveSegmentIndices; @@ -114,20 +98,13 @@ TEST(SegmentTest, freeNSynapsesActiveFallback) vector inactiveSynapseIndices; vector removed; - setUpSegment(segment, - inactiveSegmentIndices, - activeSegmentIndices, - activeSynapseIndices, - inactiveSynapseIndices); + setUpSegment(segment, inactiveSegmentIndices, activeSegmentIndices, + activeSynapseIndices, inactiveSynapseIndices); ASSERT_EQ(segment.size(), 10); - segment.freeNSynapses(6, - inactiveSynapseIndices, - inactiveSegmentIndices, - activeSynapseIndices, - activeSegmentIndices, - removed, 0, + segment.freeNSynapses(6, inactiveSynapseIndices, inactiveSegmentIndices, + activeSynapseIndices, activeSegmentIndices, removed, 0, 10, 1.0); vector removed_expected = {0, 1, 2, 3, 4, 6}; @@ -135,12 +112,10 @@ TEST(SegmentTest, freeNSynapsesActiveFallback) ASSERT_EQ(removed, removed_expected); } - /* -* Test that removal respects insertion order (stable sort of permanences). -*/ -TEST(SegmentTest, freeNSynapsesStableSort) -{ + * Test that removal respects insertion order (stable sort of permanences). + */ +TEST(SegmentTest, freeNSynapsesStableSort) { Segment segment; vector inactiveSegmentIndices; @@ -150,20 +125,13 @@ TEST(SegmentTest, freeNSynapsesStableSort) vector inactiveSynapseIndices; vector removed; - setUpSegment(segment, - inactiveSegmentIndices, - activeSegmentIndices, - activeSynapseIndices, - inactiveSynapseIndices); + setUpSegment(segment, inactiveSegmentIndices, activeSegmentIndices, + activeSynapseIndices, inactiveSynapseIndices); ASSERT_EQ(segment.size(), 10); - segment.freeNSynapses(7, - inactiveSynapseIndices, - inactiveSegmentIndices, - activeSynapseIndices, - activeSegmentIndices, - removed, 0, + segment.freeNSynapses(7, inactiveSynapseIndices, inactiveSegmentIndices, + activeSynapseIndices, activeSegmentIndices, removed, 0, 10, 1.0); vector removed_expected = {0, 1, 2, 3, 4, 6, 7}; diff --git a/src/test/unit/algorithms/SpatialPoolerTest.cpp b/src/test/unit/algorithms/SpatialPoolerTest.cpp index dde200ad77..a9eccdd9ee 100644 --- a/src/test/unit/algorithms/SpatialPoolerTest.cpp +++ b/src/test/unit/algorithms/SpatialPoolerTest.cpp @@ -28,424 +28,397 @@ #include #include +#include "gtest/gtest.h" #include #include #include #include -#include "gtest/gtest.h" using namespace std; using namespace nupic; using namespace nupic::algorithms::spatial_pooler; namespace { - UInt countNonzero(const vector& vec) - { - UInt count = 0; +UInt countNonzero(const vector &vec) { + UInt count = 0; - for (UInt x : vec) - { - if (x > 0) - { - count++; - } + for (UInt x : vec) { + if (x > 0) { + count++; } - - return count; } - bool almost_eq(Real a, Real b) - { - Real diff = a - b; - return (diff > -1e-5 && diff < 1e-5); - } + return count; +} - bool check_vector_eq(UInt arr[], vector vec) - { - for (UInt i = 0; i < vec.size(); i++) { - if (arr[i] != vec[i]) { - return false; - } +bool almost_eq(Real a, Real b) { + Real diff = a - b; + return (diff > -1e-5 && diff < 1e-5); +} + +bool check_vector_eq(UInt arr[], vector vec) { + for (UInt i = 0; i < vec.size(); i++) { + if (arr[i] != vec[i]) { + return false; } - return true; } + return true; +} - bool check_vector_eq(Real arr[], vector vec) - { - for (UInt i = 0; i < vec.size(); i++) { - if (!almost_eq(arr[i],vec[i])) { - return false; - } +bool check_vector_eq(Real arr[], vector vec) { + for (UInt i = 0; i < vec.size(); i++) { + if (!almost_eq(arr[i], vec[i])) { + return false; } - return true; } + return true; +} - bool check_vector_eq(UInt arr1[], UInt arr2[], UInt n) - { - for (UInt i = 0; i < n; i++) { - if (arr1[i] != arr2[i]) { - return false; - } +bool check_vector_eq(UInt arr1[], UInt arr2[], UInt n) { + for (UInt i = 0; i < n; i++) { + if (arr1[i] != arr2[i]) { + return false; } - return true; } + return true; +} - bool check_vector_eq(Real arr1[], Real arr2[], UInt n) - { - for (UInt i = 0; i < n; i++) { - if (!almost_eq(arr1[i], arr2[i])) { - return false; - } +bool check_vector_eq(Real arr1[], Real arr2[], UInt n) { + for (UInt i = 0; i < n; i++) { + if (!almost_eq(arr1[i], arr2[i])) { + return false; } - return true; } + return true; +} - bool check_vector_eq(vector vec1, vector vec2) - { - if (vec1.size() != vec2.size()) { +bool check_vector_eq(vector vec1, vector vec2) { + if (vec1.size() != vec2.size()) { + return false; + } + for (UInt i = 0; i < vec1.size(); i++) { + if (vec1[i] != vec2[i]) { return false; } - for (UInt i = 0; i < vec1.size(); i++) { - if (vec1[i] != vec2[i]) { - return false; - } - } - return true; } + return true; +} - bool check_vector_eq(vector vec1, vector vec2) - { - if (vec1.size() != vec2.size()) { +bool check_vector_eq(vector vec1, vector vec2) { + if (vec1.size() != vec2.size()) { + return false; + } + for (UInt i = 0; i < vec1.size(); i++) { + if (!almost_eq(vec1[i], vec2[i])) { return false; } - for (UInt i = 0; i < vec1.size(); i++) { - if (!almost_eq(vec1[i], vec2[i])) { - return false; - } - } - return true; } - - void check_spatial_eq(SpatialPooler sp1, SpatialPooler sp2) - { - UInt numColumns = sp1.getNumColumns(); - UInt numInputs = sp2.getNumInputs(); - - ASSERT_TRUE(sp1.getNumColumns() == sp2.getNumColumns()); - ASSERT_TRUE(sp1.getNumInputs() == sp2.getNumInputs()); - ASSERT_TRUE(sp1.getPotentialRadius() == - sp2.getPotentialRadius()); - ASSERT_TRUE(sp1.getPotentialPct() == sp2.getPotentialPct()); - ASSERT_TRUE(sp1.getGlobalInhibition() == - sp2.getGlobalInhibition()); - ASSERT_TRUE(sp1.getNumActiveColumnsPerInhArea() == + return true; +} + +void check_spatial_eq(SpatialPooler sp1, SpatialPooler sp2) { + UInt numColumns = sp1.getNumColumns(); + UInt numInputs = sp2.getNumInputs(); + + ASSERT_TRUE(sp1.getNumColumns() == sp2.getNumColumns()); + ASSERT_TRUE(sp1.getNumInputs() == sp2.getNumInputs()); + ASSERT_TRUE(sp1.getPotentialRadius() == sp2.getPotentialRadius()); + ASSERT_TRUE(sp1.getPotentialPct() == sp2.getPotentialPct()); + ASSERT_TRUE(sp1.getGlobalInhibition() == sp2.getGlobalInhibition()); + ASSERT_TRUE(sp1.getNumActiveColumnsPerInhArea() == sp2.getNumActiveColumnsPerInhArea()); - ASSERT_TRUE(almost_eq(sp1.getLocalAreaDensity(), - sp2.getLocalAreaDensity())); - ASSERT_TRUE(sp1.getStimulusThreshold() == - sp2.getStimulusThreshold()); - ASSERT_TRUE(sp1.getDutyCyclePeriod() == sp2.getDutyCyclePeriod()); - ASSERT_TRUE(almost_eq(sp1.getBoostStrength(), sp2.getBoostStrength())); - ASSERT_TRUE(sp1.getIterationNum() == sp2.getIterationNum()); - ASSERT_TRUE(sp1.getIterationLearnNum() == - sp2.getIterationLearnNum()); - ASSERT_TRUE(sp1.getSpVerbosity() == sp2.getSpVerbosity()); - ASSERT_TRUE(sp1.getWrapAround() == sp2.getWrapAround()); - ASSERT_TRUE(sp1.getUpdatePeriod() == sp2.getUpdatePeriod()); - ASSERT_TRUE(almost_eq(sp1.getSynPermTrimThreshold(), - sp2.getSynPermTrimThreshold())); - cout << "check: " << sp1.getSynPermActiveInc() << " " << - sp2.getSynPermActiveInc() << endl; - ASSERT_TRUE(almost_eq(sp1.getSynPermActiveInc(), - sp2.getSynPermActiveInc())); - ASSERT_TRUE(almost_eq(sp1.getSynPermInactiveDec(), - sp2.getSynPermInactiveDec())); - ASSERT_TRUE(almost_eq(sp1.getSynPermBelowStimulusInc(), - sp2.getSynPermBelowStimulusInc())); - ASSERT_TRUE(almost_eq(sp1.getSynPermConnected(), - sp2.getSynPermConnected())); - ASSERT_TRUE(almost_eq(sp1.getMinPctOverlapDutyCycles(), - sp2.getMinPctOverlapDutyCycles())); - - - auto boostFactors1 = new Real[numColumns]; - auto boostFactors2 = new Real[numColumns]; - sp1.getBoostFactors(boostFactors1); - sp2.getBoostFactors(boostFactors2); - ASSERT_TRUE(check_vector_eq(boostFactors1, boostFactors2, numColumns)); - delete[] boostFactors1; - delete[] boostFactors2; - - auto overlapDutyCycles1 = new Real[numColumns]; - auto overlapDutyCycles2 = new Real[numColumns]; - sp1.getOverlapDutyCycles(overlapDutyCycles1); - sp2.getOverlapDutyCycles(overlapDutyCycles2); - ASSERT_TRUE(check_vector_eq(overlapDutyCycles1, overlapDutyCycles2, numColumns)); - delete[] overlapDutyCycles1; - delete[] overlapDutyCycles2; - - auto activeDutyCycles1 = new Real[numColumns]; - auto activeDutyCycles2 = new Real[numColumns]; - sp1.getActiveDutyCycles(activeDutyCycles1); - sp2.getActiveDutyCycles(activeDutyCycles2); - ASSERT_TRUE(check_vector_eq(activeDutyCycles1, activeDutyCycles2, numColumns)); - delete[] activeDutyCycles1; - delete[] activeDutyCycles2; - - auto minOverlapDutyCycles1 = new Real[numColumns]; - auto minOverlapDutyCycles2 = new Real[numColumns]; - sp1.getMinOverlapDutyCycles(minOverlapDutyCycles1); - sp2.getMinOverlapDutyCycles(minOverlapDutyCycles2); - ASSERT_TRUE(check_vector_eq(minOverlapDutyCycles1, minOverlapDutyCycles2, numColumns)); - delete[] minOverlapDutyCycles1; - delete[] minOverlapDutyCycles2; - - for (UInt i = 0; i < numColumns; i++) { - auto potential1 = new UInt[numInputs]; - auto potential2 = new UInt[numInputs]; - sp1.getPotential(i, potential1); - sp2.getPotential(i, potential2); - ASSERT_TRUE(check_vector_eq(potential1, potential2, numInputs)); - delete[] potential1; - delete[] potential2; - } - - for (UInt i = 0; i < numColumns; i++) { - auto perm1 = new Real[numInputs]; - auto perm2 = new Real[numInputs]; - sp1.getPermanence(i, perm1); - sp2.getPermanence(i, perm2); - ASSERT_TRUE(check_vector_eq(perm1, perm2, numInputs)); - delete[] perm1; - delete[] perm2; - } - - for (UInt i = 0; i < numColumns; i++) { - auto con1 = new UInt[numInputs]; - auto con2 = new UInt[numInputs]; - sp1.getConnectedSynapses(i, con1); - sp2.getConnectedSynapses(i, con2); - ASSERT_TRUE(check_vector_eq(con1, con2, numInputs)); - delete[] con1; - delete[] con2; - } - - auto conCounts1 = new UInt[numColumns]; - auto conCounts2 = new UInt[numColumns]; - sp1.getConnectedCounts(conCounts1); - sp2.getConnectedCounts(conCounts2); - ASSERT_TRUE(check_vector_eq(conCounts1, conCounts2, numColumns)); - delete[] conCounts1; - delete[] conCounts2; + ASSERT_TRUE(almost_eq(sp1.getLocalAreaDensity(), sp2.getLocalAreaDensity())); + ASSERT_TRUE(sp1.getStimulusThreshold() == sp2.getStimulusThreshold()); + ASSERT_TRUE(sp1.getDutyCyclePeriod() == sp2.getDutyCyclePeriod()); + ASSERT_TRUE(almost_eq(sp1.getBoostStrength(), sp2.getBoostStrength())); + ASSERT_TRUE(sp1.getIterationNum() == sp2.getIterationNum()); + ASSERT_TRUE(sp1.getIterationLearnNum() == sp2.getIterationLearnNum()); + ASSERT_TRUE(sp1.getSpVerbosity() == sp2.getSpVerbosity()); + ASSERT_TRUE(sp1.getWrapAround() == sp2.getWrapAround()); + ASSERT_TRUE(sp1.getUpdatePeriod() == sp2.getUpdatePeriod()); + ASSERT_TRUE( + almost_eq(sp1.getSynPermTrimThreshold(), sp2.getSynPermTrimThreshold())); + cout << "check: " << sp1.getSynPermActiveInc() << " " + << sp2.getSynPermActiveInc() << endl; + ASSERT_TRUE(almost_eq(sp1.getSynPermActiveInc(), sp2.getSynPermActiveInc())); + ASSERT_TRUE( + almost_eq(sp1.getSynPermInactiveDec(), sp2.getSynPermInactiveDec())); + ASSERT_TRUE(almost_eq(sp1.getSynPermBelowStimulusInc(), + sp2.getSynPermBelowStimulusInc())); + ASSERT_TRUE(almost_eq(sp1.getSynPermConnected(), sp2.getSynPermConnected())); + ASSERT_TRUE(almost_eq(sp1.getMinPctOverlapDutyCycles(), + sp2.getMinPctOverlapDutyCycles())); + + auto boostFactors1 = new Real[numColumns]; + auto boostFactors2 = new Real[numColumns]; + sp1.getBoostFactors(boostFactors1); + sp2.getBoostFactors(boostFactors2); + ASSERT_TRUE(check_vector_eq(boostFactors1, boostFactors2, numColumns)); + delete[] boostFactors1; + delete[] boostFactors2; + + auto overlapDutyCycles1 = new Real[numColumns]; + auto overlapDutyCycles2 = new Real[numColumns]; + sp1.getOverlapDutyCycles(overlapDutyCycles1); + sp2.getOverlapDutyCycles(overlapDutyCycles2); + ASSERT_TRUE( + check_vector_eq(overlapDutyCycles1, overlapDutyCycles2, numColumns)); + delete[] overlapDutyCycles1; + delete[] overlapDutyCycles2; + + auto activeDutyCycles1 = new Real[numColumns]; + auto activeDutyCycles2 = new Real[numColumns]; + sp1.getActiveDutyCycles(activeDutyCycles1); + sp2.getActiveDutyCycles(activeDutyCycles2); + ASSERT_TRUE( + check_vector_eq(activeDutyCycles1, activeDutyCycles2, numColumns)); + delete[] activeDutyCycles1; + delete[] activeDutyCycles2; + + auto minOverlapDutyCycles1 = new Real[numColumns]; + auto minOverlapDutyCycles2 = new Real[numColumns]; + sp1.getMinOverlapDutyCycles(minOverlapDutyCycles1); + sp2.getMinOverlapDutyCycles(minOverlapDutyCycles2); + ASSERT_TRUE(check_vector_eq(minOverlapDutyCycles1, minOverlapDutyCycles2, + numColumns)); + delete[] minOverlapDutyCycles1; + delete[] minOverlapDutyCycles2; + + for (UInt i = 0; i < numColumns; i++) { + auto potential1 = new UInt[numInputs]; + auto potential2 = new UInt[numInputs]; + sp1.getPotential(i, potential1); + sp2.getPotential(i, potential2); + ASSERT_TRUE(check_vector_eq(potential1, potential2, numInputs)); + delete[] potential1; + delete[] potential2; } - void setup(SpatialPooler& sp, UInt numInputs, - UInt numColumns) - { - vector inputDim; - vector columnDim; - inputDim.push_back(numInputs); - columnDim.push_back(numColumns); - sp.initialize(inputDim,columnDim); + for (UInt i = 0; i < numColumns; i++) { + auto perm1 = new Real[numInputs]; + auto perm2 = new Real[numInputs]; + sp1.getPermanence(i, perm1); + sp2.getPermanence(i, perm2); + ASSERT_TRUE(check_vector_eq(perm1, perm2, numInputs)); + delete[] perm1; + delete[] perm2; } - TEST(SpatialPoolerTest, testUpdateInhibitionRadius) - { - SpatialPooler sp; - vector colDim, inputDim; - colDim.push_back(57); - colDim.push_back(31); - colDim.push_back(2); - inputDim.push_back(1); - inputDim.push_back(1); - inputDim.push_back(1); - - sp.initialize(inputDim, colDim); - sp.setGlobalInhibition(true); - ASSERT_TRUE(sp.getInhibitionRadius() == 57); - - colDim.clear(); - inputDim.clear(); - // avgColumnsPerInput = 4 - // avgConnectedSpanForColumn = 3 - UInt numInputs = 3; - inputDim.push_back(numInputs); - UInt numCols = 12; - colDim.push_back(numCols); - sp.initialize(inputDim, colDim); - sp.setGlobalInhibition(false); - - for (UInt i = 0; i < numCols; i++) { - Real permArr[] = {1, 1, 1}; - sp.setPermanence(i, permArr); - } - UInt trueInhibitionRadius = 6; - // ((3 * 4) - 1)/2 => round up - sp.updateInhibitionRadius_(); - ASSERT_TRUE(trueInhibitionRadius == sp.getInhibitionRadius()); - - colDim.clear(); - inputDim.clear(); - // avgColumnsPerInput = 1.2 - // avgConnectedSpanForColumn = 0.5 - numInputs = 5; - inputDim.push_back(numInputs); - numCols = 6; - colDim.push_back(numCols); - sp.initialize(inputDim, colDim); - sp.setGlobalInhibition(false); - - for (UInt i = 0; i < numCols; i++) { - Real permArr[] = {1, 0, 0, 0, 0}; - if (i % 2 == 0) { - permArr[0] = 0; - } - sp.setPermanence(i, permArr); - } - trueInhibitionRadius = 1; - sp.updateInhibitionRadius_(); - ASSERT_TRUE(trueInhibitionRadius == sp.getInhibitionRadius()); - - colDim.clear(); - inputDim.clear(); - // avgColumnsPerInput = 2.4 - // avgConnectedSpanForColumn = 2 - numInputs = 5; - inputDim.push_back(numInputs); - numCols = 12; - colDim.push_back(numCols); - sp.initialize(inputDim, colDim); - sp.setGlobalInhibition(false); - - for (UInt i = 0; i < numCols; i++) { - Real permArr[] = {1, 1, 0, 0, 0}; - sp.setPermanence(i,permArr); - } - trueInhibitionRadius = 2; - // ((2.4 * 2) - 1)/2 => round up - sp.updateInhibitionRadius_(); - ASSERT_TRUE(trueInhibitionRadius == sp.getInhibitionRadius()); + for (UInt i = 0; i < numColumns; i++) { + auto con1 = new UInt[numInputs]; + auto con2 = new UInt[numInputs]; + sp1.getConnectedSynapses(i, con1); + sp2.getConnectedSynapses(i, con2); + ASSERT_TRUE(check_vector_eq(con1, con2, numInputs)); + delete[] con1; + delete[] con2; } - TEST(SpatialPoolerTest, testUpdateMinDutyCycles) - { - SpatialPooler sp; - UInt numColumns = 10; - UInt numInputs = 5; - setup(sp, numInputs, numColumns); - sp.setMinPctOverlapDutyCycles(0.01); - - Real initOverlapDuty[10] = {0.01, 0.001, 0.02, 0.3, 0.012, 0.0512, - 0.054, 0.221, 0.0873, 0.309}; - - Real initActiveDuty[10] = {0.01, 0.045, 0.812, 0.091, 0.001, 0.0003, - 0.433, 0.136, 0.211, 0.129}; - - sp.setOverlapDutyCycles(initOverlapDuty); - sp.setActiveDutyCycles(initActiveDuty); - sp.setGlobalInhibition(true); - sp.setInhibitionRadius(2); - sp.updateMinDutyCycles_(); - Real resultMinOverlap[10]; - sp.getMinOverlapDutyCycles(resultMinOverlap); - - - sp.updateMinDutyCyclesGlobal_(); - Real resultMinOverlapGlobal[10]; - sp.getMinOverlapDutyCycles(resultMinOverlapGlobal); - - sp.updateMinDutyCyclesLocal_(); - Real resultMinOverlapLocal[10]; - sp.getMinOverlapDutyCycles(resultMinOverlapLocal); - - ASSERT_TRUE(check_vector_eq(resultMinOverlap, resultMinOverlapGlobal, - numColumns)); - - sp.setGlobalInhibition(false); - sp.updateMinDutyCycles_(); - sp.getMinOverlapDutyCycles(resultMinOverlap); - - ASSERT_TRUE(!check_vector_eq(resultMinOverlap, resultMinOverlapGlobal, - numColumns)); - + auto conCounts1 = new UInt[numColumns]; + auto conCounts2 = new UInt[numColumns]; + sp1.getConnectedCounts(conCounts1); + sp2.getConnectedCounts(conCounts2); + ASSERT_TRUE(check_vector_eq(conCounts1, conCounts2, numColumns)); + delete[] conCounts1; + delete[] conCounts2; +} + +void setup(SpatialPooler &sp, UInt numInputs, UInt numColumns) { + vector inputDim; + vector columnDim; + inputDim.push_back(numInputs); + columnDim.push_back(numColumns); + sp.initialize(inputDim, columnDim); +} + +TEST(SpatialPoolerTest, testUpdateInhibitionRadius) { + SpatialPooler sp; + vector colDim, inputDim; + colDim.push_back(57); + colDim.push_back(31); + colDim.push_back(2); + inputDim.push_back(1); + inputDim.push_back(1); + inputDim.push_back(1); + + sp.initialize(inputDim, colDim); + sp.setGlobalInhibition(true); + ASSERT_TRUE(sp.getInhibitionRadius() == 57); + + colDim.clear(); + inputDim.clear(); + // avgColumnsPerInput = 4 + // avgConnectedSpanForColumn = 3 + UInt numInputs = 3; + inputDim.push_back(numInputs); + UInt numCols = 12; + colDim.push_back(numCols); + sp.initialize(inputDim, colDim); + sp.setGlobalInhibition(false); + + for (UInt i = 0; i < numCols; i++) { + Real permArr[] = {1, 1, 1}; + sp.setPermanence(i, permArr); } - - TEST(SpatialPoolerTest, testUpdateMinDutyCyclesGlobal) { - SpatialPooler sp; - UInt numColumns = 5; - UInt numInputs = 5; - setup(sp, numInputs, numColumns); - Real minPctOverlap; - - minPctOverlap = 0.01; - - sp.setMinPctOverlapDutyCycles(minPctOverlap); - - Real overlapArr1[] = - {0.06, 1, 3, 6, 0.5}; - Real activeArr1[] = - {0.6, 0.07, 0.5, 0.4, 0.3}; - - sp.setOverlapDutyCycles(overlapArr1); - sp.setActiveDutyCycles(activeArr1); - - Real trueMinOverlap1 = 0.01 * 6; - - sp.updateMinDutyCyclesGlobal_(); - Real resultOverlap1[5]; - sp.getMinOverlapDutyCycles(resultOverlap1); - for (UInt i = 0; i < numColumns; i++) { - ASSERT_TRUE(resultOverlap1[i] == trueMinOverlap1); + UInt trueInhibitionRadius = 6; + // ((3 * 4) - 1)/2 => round up + sp.updateInhibitionRadius_(); + ASSERT_TRUE(trueInhibitionRadius == sp.getInhibitionRadius()); + + colDim.clear(); + inputDim.clear(); + // avgColumnsPerInput = 1.2 + // avgConnectedSpanForColumn = 0.5 + numInputs = 5; + inputDim.push_back(numInputs); + numCols = 6; + colDim.push_back(numCols); + sp.initialize(inputDim, colDim); + sp.setGlobalInhibition(false); + + for (UInt i = 0; i < numCols; i++) { + Real permArr[] = {1, 0, 0, 0, 0}; + if (i % 2 == 0) { + permArr[0] = 0; } + sp.setPermanence(i, permArr); + } + trueInhibitionRadius = 1; + sp.updateInhibitionRadius_(); + ASSERT_TRUE(trueInhibitionRadius == sp.getInhibitionRadius()); + + colDim.clear(); + inputDim.clear(); + // avgColumnsPerInput = 2.4 + // avgConnectedSpanForColumn = 2 + numInputs = 5; + inputDim.push_back(numInputs); + numCols = 12; + colDim.push_back(numCols); + sp.initialize(inputDim, colDim); + sp.setGlobalInhibition(false); + + for (UInt i = 0; i < numCols; i++) { + Real permArr[] = {1, 1, 0, 0, 0}; + sp.setPermanence(i, permArr); + } + trueInhibitionRadius = 2; + // ((2.4 * 2) - 1)/2 => round up + sp.updateInhibitionRadius_(); + ASSERT_TRUE(trueInhibitionRadius == sp.getInhibitionRadius()); +} + +TEST(SpatialPoolerTest, testUpdateMinDutyCycles) { + SpatialPooler sp; + UInt numColumns = 10; + UInt numInputs = 5; + setup(sp, numInputs, numColumns); + sp.setMinPctOverlapDutyCycles(0.01); + + Real initOverlapDuty[10] = {0.01, 0.001, 0.02, 0.3, 0.012, + 0.0512, 0.054, 0.221, 0.0873, 0.309}; + + Real initActiveDuty[10] = {0.01, 0.045, 0.812, 0.091, 0.001, + 0.0003, 0.433, 0.136, 0.211, 0.129}; + + sp.setOverlapDutyCycles(initOverlapDuty); + sp.setActiveDutyCycles(initActiveDuty); + sp.setGlobalInhibition(true); + sp.setInhibitionRadius(2); + sp.updateMinDutyCycles_(); + Real resultMinOverlap[10]; + sp.getMinOverlapDutyCycles(resultMinOverlap); + + sp.updateMinDutyCyclesGlobal_(); + Real resultMinOverlapGlobal[10]; + sp.getMinOverlapDutyCycles(resultMinOverlapGlobal); + + sp.updateMinDutyCyclesLocal_(); + Real resultMinOverlapLocal[10]; + sp.getMinOverlapDutyCycles(resultMinOverlapLocal); + + ASSERT_TRUE( + check_vector_eq(resultMinOverlap, resultMinOverlapGlobal, numColumns)); + + sp.setGlobalInhibition(false); + sp.updateMinDutyCycles_(); + sp.getMinOverlapDutyCycles(resultMinOverlap); + + ASSERT_TRUE( + !check_vector_eq(resultMinOverlap, resultMinOverlapGlobal, numColumns)); +} + +TEST(SpatialPoolerTest, testUpdateMinDutyCyclesGlobal) { + SpatialPooler sp; + UInt numColumns = 5; + UInt numInputs = 5; + setup(sp, numInputs, numColumns); + Real minPctOverlap; + + minPctOverlap = 0.01; + + sp.setMinPctOverlapDutyCycles(minPctOverlap); + + Real overlapArr1[] = {0.06, 1, 3, 6, 0.5}; + Real activeArr1[] = {0.6, 0.07, 0.5, 0.4, 0.3}; + + sp.setOverlapDutyCycles(overlapArr1); + sp.setActiveDutyCycles(activeArr1); + + Real trueMinOverlap1 = 0.01 * 6; + + sp.updateMinDutyCyclesGlobal_(); + Real resultOverlap1[5]; + sp.getMinOverlapDutyCycles(resultOverlap1); + for (UInt i = 0; i < numColumns; i++) { + ASSERT_TRUE(resultOverlap1[i] == trueMinOverlap1); + } + minPctOverlap = 0.015; - minPctOverlap = 0.015; - - sp.setMinPctOverlapDutyCycles(minPctOverlap); - - Real overlapArr2[] = {0.86, 2.4, 0.03, 1.6, 1.5}; - Real activeArr2[] = {0.16, 0.007, 0.15, 0.54, 0.13}; + sp.setMinPctOverlapDutyCycles(minPctOverlap); - sp.setOverlapDutyCycles(overlapArr2); - sp.setActiveDutyCycles(activeArr2); + Real overlapArr2[] = {0.86, 2.4, 0.03, 1.6, 1.5}; + Real activeArr2[] = {0.16, 0.007, 0.15, 0.54, 0.13}; - Real trueMinOverlap2 = 0.015 * 2.4; + sp.setOverlapDutyCycles(overlapArr2); + sp.setActiveDutyCycles(activeArr2); - sp.updateMinDutyCyclesGlobal_(); - Real resultOverlap2[5]; - sp.getMinOverlapDutyCycles(resultOverlap2); - for (UInt i = 0; i < numColumns; i++) { - ASSERT_TRUE(almost_eq(resultOverlap2[i],trueMinOverlap2)); - } + Real trueMinOverlap2 = 0.015 * 2.4; + sp.updateMinDutyCyclesGlobal_(); + Real resultOverlap2[5]; + sp.getMinOverlapDutyCycles(resultOverlap2); + for (UInt i = 0; i < numColumns; i++) { + ASSERT_TRUE(almost_eq(resultOverlap2[i], trueMinOverlap2)); + } - minPctOverlap = 0.015; + minPctOverlap = 0.015; - sp.setMinPctOverlapDutyCycles(minPctOverlap); + sp.setMinPctOverlapDutyCycles(minPctOverlap); - Real overlapArr3[] = {0, 0, 0, 0, 0}; - Real activeArr3[] = {0, 0, 0, 0, 0}; + Real overlapArr3[] = {0, 0, 0, 0, 0}; + Real activeArr3[] = {0, 0, 0, 0, 0}; - sp.setOverlapDutyCycles(overlapArr3); - sp.setActiveDutyCycles(activeArr3); + sp.setOverlapDutyCycles(overlapArr3); + sp.setActiveDutyCycles(activeArr3); - Real trueMinOverlap3 = 0; + Real trueMinOverlap3 = 0; - sp.updateMinDutyCyclesGlobal_(); - Real resultOverlap3[5]; - sp.getMinOverlapDutyCycles(resultOverlap3); - for (UInt i = 0; i < numColumns; i++) { - ASSERT_TRUE(almost_eq(resultOverlap3[i],trueMinOverlap3)); - } + sp.updateMinDutyCyclesGlobal_(); + Real resultOverlap3[5]; + sp.getMinOverlapDutyCycles(resultOverlap3); + for (UInt i = 0; i < numColumns; i++) { + ASSERT_TRUE(almost_eq(resultOverlap3[i], trueMinOverlap3)); } +} - TEST(SpatialPoolerTest, testUpdateMinDutyCyclesLocal) +TEST(SpatialPoolerTest, testUpdateMinDutyCyclesLocal) { + // wrapAround=false { - // wrapAround=false - { - UInt numColumns = 8; - SpatialPooler sp( - /*inputDimensions*/{5}, + UInt numColumns = 8; + SpatialPooler sp( + /*inputDimensions*/ {5}, /*columnDimensions*/ {numColumns}, /*potentialRadius*/ 16, /*potentialPct*/ 0.5, @@ -463,37 +436,31 @@ namespace { /*spVerbosity*/ 0, /*wrapAround*/ false); - sp.setInhibitionRadius(1); + sp.setInhibitionRadius(1); - Real activeDutyArr[] = {0.9, 0.3, 0.5, 0.7, 0.1, 0.01, 0.08, 0.12}; - sp.setActiveDutyCycles(activeDutyArr); + Real activeDutyArr[] = {0.9, 0.3, 0.5, 0.7, 0.1, 0.01, 0.08, 0.12}; + sp.setActiveDutyCycles(activeDutyArr); - Real overlapDutyArr[] = {0.7, 0.1, 0.5, 0.01, 0.78, 0.55, 0.1, 0.001}; - sp.setOverlapDutyCycles(overlapDutyArr); + Real overlapDutyArr[] = {0.7, 0.1, 0.5, 0.01, 0.78, 0.55, 0.1, 0.001}; + sp.setOverlapDutyCycles(overlapDutyArr); - sp.setMinPctOverlapDutyCycles(0.2); + sp.setMinPctOverlapDutyCycles(0.2); - sp.updateMinDutyCyclesLocal_(); + sp.updateMinDutyCyclesLocal_(); - Real trueOverlapArr[] = {0.2*0.7, - 0.2*0.7, - 0.2*0.5, - 0.2*0.78, - 0.2*0.78, - 0.2*0.78, - 0.2*0.55, - 0.2*0.1}; - Real resultMinOverlapArr[8]; - sp.getMinOverlapDutyCycles(resultMinOverlapArr); - ASSERT_TRUE(check_vector_eq(resultMinOverlapArr, trueOverlapArr, - numColumns)); - } + Real trueOverlapArr[] = {0.2 * 0.7, 0.2 * 0.7, 0.2 * 0.5, 0.2 * 0.78, + 0.2 * 0.78, 0.2 * 0.78, 0.2 * 0.55, 0.2 * 0.1}; + Real resultMinOverlapArr[8]; + sp.getMinOverlapDutyCycles(resultMinOverlapArr); + ASSERT_TRUE( + check_vector_eq(resultMinOverlapArr, trueOverlapArr, numColumns)); + } - // wrapAround=true - { - UInt numColumns = 8; - SpatialPooler sp( - /*inputDimensions*/{5}, + // wrapAround=true + { + UInt numColumns = 8; + SpatialPooler sp( + /*inputDimensions*/ {5}, /*columnDimensions*/ {numColumns}, /*potentialRadius*/ 16, /*potentialPct*/ 0.5, @@ -511,976 +478,899 @@ namespace { /*spVerbosity*/ 0, /*wrapAround*/ true); - sp.setInhibitionRadius(1); + sp.setInhibitionRadius(1); - Real activeDutyArr[] = {0.9, 0.3, 0.5, 0.7, 0.1, 0.01, 0.08, 0.12}; - sp.setActiveDutyCycles(activeDutyArr); + Real activeDutyArr[] = {0.9, 0.3, 0.5, 0.7, 0.1, 0.01, 0.08, 0.12}; + sp.setActiveDutyCycles(activeDutyArr); - Real overlapDutyArr[] = {0.7, 0.1, 0.5, 0.01, 0.78, 0.55, 0.1, 0.001}; - sp.setOverlapDutyCycles(overlapDutyArr); + Real overlapDutyArr[] = {0.7, 0.1, 0.5, 0.01, 0.78, 0.55, 0.1, 0.001}; + sp.setOverlapDutyCycles(overlapDutyArr); - sp.setMinPctOverlapDutyCycles(0.2); + sp.setMinPctOverlapDutyCycles(0.2); - sp.updateMinDutyCyclesLocal_(); + sp.updateMinDutyCyclesLocal_(); - Real trueOverlapArr[] = {0.2*0.7, - 0.2*0.7, - 0.2*0.5, - 0.2*0.78, - 0.2*0.78, - 0.2*0.78, - 0.2*0.55, - 0.2*0.7}; - Real resultMinOverlapArr[8]; - sp.getMinOverlapDutyCycles(resultMinOverlapArr); - ASSERT_TRUE(check_vector_eq(resultMinOverlapArr, trueOverlapArr, - numColumns)); - } + Real trueOverlapArr[] = {0.2 * 0.7, 0.2 * 0.7, 0.2 * 0.5, 0.2 * 0.78, + 0.2 * 0.78, 0.2 * 0.78, 0.2 * 0.55, 0.2 * 0.7}; + Real resultMinOverlapArr[8]; + sp.getMinOverlapDutyCycles(resultMinOverlapArr); + ASSERT_TRUE( + check_vector_eq(resultMinOverlapArr, trueOverlapArr, numColumns)); } +} + +TEST(SpatialPoolerTest, testUpdateDutyCycles) { + SpatialPooler sp; + UInt numInputs = 5; + UInt numColumns = 5; + setup(sp, numInputs, numColumns); + vector overlaps; + + Real initOverlapArr1[] = {1, 1, 1, 1, 1}; + sp.setOverlapDutyCycles(initOverlapArr1); + Real overlapNewVal1[] = {1, 5, 7, 0, 0}; + overlaps.assign(overlapNewVal1, overlapNewVal1 + numColumns); + UInt active[] = {0, 0, 0, 0, 0}; + + sp.setIterationNum(2); + sp.updateDutyCycles_(overlaps, active); + + Real resultOverlapArr1[5]; + sp.getOverlapDutyCycles(resultOverlapArr1); + + Real trueOverlapArr1[] = {1, 1, 1, 0.5, 0.5}; + ASSERT_TRUE(check_vector_eq(resultOverlapArr1, trueOverlapArr1, numColumns)); + + sp.setOverlapDutyCycles(initOverlapArr1); + sp.setIterationNum(2000); + sp.setUpdatePeriod(1000); + sp.updateDutyCycles_(overlaps, active); + + Real resultOverlapArr2[5]; + sp.getOverlapDutyCycles(resultOverlapArr2); + Real trueOverlapArr2[] = {1, 1, 1, 0.999, 0.999}; + + ASSERT_TRUE(check_vector_eq(resultOverlapArr2, trueOverlapArr2, numColumns)); +} + +TEST(SpatialPoolerTest, testAvgColumnsPerInput) { + SpatialPooler sp; + vector inputDim, colDim; + inputDim.clear(); + colDim.clear(); + + UInt colDim1[4] = {2, 2, 2, 2}; + UInt inputDim1[4] = {4, 4, 4, 4}; + Real trueAvgColumnPerInput1 = 0.5; + + inputDim.assign(inputDim1, inputDim1 + 4); + colDim.assign(colDim1, colDim1 + 4); + sp.initialize(inputDim, colDim); + Real result = sp.avgColumnsPerInput_(); + ASSERT_FLOAT_EQ(result, trueAvgColumnPerInput1); + + UInt colDim2[4] = {2, 2, 2, 2}; + UInt inputDim2[4] = {7, 5, 1, 3}; + Real trueAvgColumnPerInput2 = (2.0 / 7 + 2.0 / 5 + 2.0 / 1 + 2 / 3.0) / 4; + + inputDim.assign(inputDim2, inputDim2 + 4); + colDim.assign(colDim2, colDim2 + 4); + sp.initialize(inputDim, colDim); + result = sp.avgColumnsPerInput_(); + ASSERT_FLOAT_EQ(result, trueAvgColumnPerInput2); + + UInt colDim3[2] = {3, 3}; + UInt inputDim3[2] = {3, 3}; + Real trueAvgColumnPerInput3 = 1; + + inputDim.assign(inputDim3, inputDim3 + 2); + colDim.assign(colDim3, colDim3 + 2); + sp.initialize(inputDim, colDim); + result = sp.avgColumnsPerInput_(); + ASSERT_FLOAT_EQ(result, trueAvgColumnPerInput3); + + UInt colDim4[1] = {25}; + UInt inputDim4[1] = {5}; + Real trueAvgColumnPerInput4 = 5; + + inputDim.assign(inputDim4, inputDim4 + 1); + colDim.assign(colDim4, colDim4 + 1); + sp.initialize(inputDim, colDim); + result = sp.avgColumnsPerInput_(); + ASSERT_FLOAT_EQ(result, trueAvgColumnPerInput4); + + UInt colDim5[7] = {3, 5, 6}; + UInt inputDim5[7] = {3, 5, 6}; + Real trueAvgColumnPerInput5 = 1; + + inputDim.assign(inputDim5, inputDim5 + 3); + colDim.assign(colDim5, colDim5 + 3); + sp.initialize(inputDim, colDim); + result = sp.avgColumnsPerInput_(); + ASSERT_FLOAT_EQ(result, trueAvgColumnPerInput5); + + UInt colDim6[4] = {2, 4, 6, 8}; + UInt inputDim6[4] = {2, 2, 2, 2}; + // 1 2 3 4 + Real trueAvgColumnPerInput6 = 2.5; + + inputDim.assign(inputDim6, inputDim6 + 4); + colDim.assign(colDim6, colDim6 + 4); + sp.initialize(inputDim, colDim); + result = sp.avgColumnsPerInput_(); + ASSERT_FLOAT_EQ(result, trueAvgColumnPerInput6); +} + +TEST(SpatialPoolerTest, testAvgConnectedSpanForColumn1D) { + + SpatialPooler sp; + UInt numColumns = 9; + UInt numInputs = 8; + setup(sp, numInputs, numColumns); + + Real permArr[9][8] = {{0, 1, 0, 1, 0, 1, 0, 1}, {0, 0, 0, 1, 0, 0, 0, 1}, + {0, 0, 0, 0, 0, 0, 1, 0}, {0, 0, 1, 0, 0, 0, 1, 0}, + {0, 0, 0, 0, 0, 0, 0, 0}, {0, 1, 1, 0, 0, 0, 0, 0}, + {0, 0, 1, 1, 1, 0, 0, 0}, {0, 0, 1, 0, 1, 0, 0, 0}, + {1, 1, 1, 1, 1, 1, 1, 1}}; + + UInt trueAvgConnectedSpan[9] = {7, 5, 1, 5, 0, 2, 3, 3, 8}; + + for (UInt i = 0; i < numColumns; i++) { + sp.setPermanence(i, permArr[i]); + UInt result = sp.avgConnectedSpanForColumn1D_(i); + ASSERT_TRUE(result == trueAvgConnectedSpan[i]); + } +} - TEST(SpatialPoolerTest, testUpdateDutyCycles) - { - SpatialPooler sp; - UInt numInputs = 5; - UInt numColumns = 5; - setup(sp, numInputs, numColumns); - vector overlaps; +TEST(SpatialPoolerTest, testAvgConnectedSpanForColumn2D) { + SpatialPooler sp; - Real initOverlapArr1[] = {1, 1, 1, 1, 1}; - sp.setOverlapDutyCycles(initOverlapArr1); - Real overlapNewVal1[] = {1, 5, 7, 0, 0}; - overlaps.assign(overlapNewVal1, overlapNewVal1+numColumns); - UInt active[] = {0, 0, 0, 0, 0}; + UInt numColumns = 7; + UInt numInputs = 20; - sp.setIterationNum(2); - sp.updateDutyCycles_(overlaps, active); + vector colDim, inputDim; + Real permArr1[7][20] = { + {0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, + // rowspan = 3, colspan = 3, avg = 3 - Real resultOverlapArr1[5]; - sp.getOverlapDutyCycles(resultOverlapArr1); + {1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + // rowspan = 2 colspan = 4, avg = 3 - Real trueOverlapArr1[] = {1, 1, 1, 0.5, 0.5}; - ASSERT_TRUE(check_vector_eq(resultOverlapArr1, trueOverlapArr1, numColumns)); + {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, + // row span = 5, colspan = 4, avg = 4.5 - sp.setOverlapDutyCycles(initOverlapArr1); - sp.setIterationNum(2000); - sp.setUpdatePeriod(1000); - sp.updateDutyCycles_(overlaps, active); + {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0}, + // rowspan = 5, colspan = 1, avg = 3 - Real resultOverlapArr2[5]; - sp.getOverlapDutyCycles(resultOverlapArr2); - Real trueOverlapArr2[] = {1, 1, 1, 0.999, 0.999}; + {0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + // rowspan = 1, colspan = 4, avg = 2.5 - ASSERT_TRUE(check_vector_eq(resultOverlapArr2, trueOverlapArr2, numColumns)); - } + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}, + // rowspan = 2, colspan = 2, avg = 2 - TEST(SpatialPoolerTest, testAvgColumnsPerInput) - { - SpatialPooler sp; - vector inputDim, colDim; - inputDim.clear(); - colDim.clear(); - - UInt colDim1[4] = {2, 2, 2, 2}; - UInt inputDim1[4] = {4, 4, 4, 4}; - Real trueAvgColumnPerInput1 = 0.5; - - inputDim.assign(inputDim1, inputDim1+4); - colDim.assign(colDim1, colDim1+4); - sp.initialize(inputDim, colDim); - Real result = sp.avgColumnsPerInput_(); - ASSERT_FLOAT_EQ(result, trueAvgColumnPerInput1); - - UInt colDim2[4] = {2, 2, 2, 2}; - UInt inputDim2[4] = {7, 5, 1, 3}; - Real trueAvgColumnPerInput2 = (2.0/7 + 2.0/5 + 2.0/1 + 2/3.0) / 4; - - inputDim.assign(inputDim2, inputDim2+4); - colDim.assign(colDim2, colDim2+4); - sp.initialize(inputDim, colDim); - result = sp.avgColumnsPerInput_(); - ASSERT_FLOAT_EQ(result, trueAvgColumnPerInput2); - - UInt colDim3[2] = {3, 3}; - UInt inputDim3[2] = {3, 3}; - Real trueAvgColumnPerInput3 = 1; - - inputDim.assign(inputDim3, inputDim3+2); - colDim.assign(colDim3, colDim3+2); - sp.initialize(inputDim, colDim); - result = sp.avgColumnsPerInput_(); - ASSERT_FLOAT_EQ(result, trueAvgColumnPerInput3); - - - UInt colDim4[1] = {25}; - UInt inputDim4[1] = {5}; - Real trueAvgColumnPerInput4 = 5; - - inputDim.assign(inputDim4, inputDim4+1); - colDim.assign(colDim4, colDim4+1); - sp.initialize(inputDim, colDim); - result = sp.avgColumnsPerInput_(); - ASSERT_FLOAT_EQ(result, trueAvgColumnPerInput4); - - UInt colDim5[7] = {3, 5, 6}; - UInt inputDim5[7] = {3, 5, 6}; - Real trueAvgColumnPerInput5 = 1; - - inputDim.assign(inputDim5, inputDim5+3); - colDim.assign(colDim5, colDim5+3); - sp.initialize(inputDim, colDim); - result = sp.avgColumnsPerInput_(); - ASSERT_FLOAT_EQ(result, trueAvgColumnPerInput5); - - UInt colDim6[4] = {2, 4, 6, 8}; - UInt inputDim6[4] = {2, 2, 2, 2}; - // 1 2 3 4 - Real trueAvgColumnPerInput6 = 2.5; - - inputDim.assign(inputDim6, inputDim6+4); - colDim.assign(colDim6, colDim6+4); - sp.initialize(inputDim, colDim); - result = sp.avgColumnsPerInput_(); - ASSERT_FLOAT_EQ(result, trueAvgColumnPerInput6); - } + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + // rowspan = 0, colspan = 0, avg = 0 + }; - TEST(SpatialPoolerTest, testAvgConnectedSpanForColumn1D) - { + inputDim.push_back(5); + inputDim.push_back(4); + colDim.push_back(10); + colDim.push_back(1); + sp.initialize(inputDim, colDim); - SpatialPooler sp; - UInt numColumns = 9; - UInt numInputs = 8; - setup(sp, numInputs, numColumns); - - Real permArr[9][8] = - {{0, 1, 0, 1, 0, 1, 0, 1}, - {0, 0, 0, 1, 0, 0, 0, 1}, - {0, 0, 0, 0, 0, 0, 1, 0}, - {0, 0, 1, 0, 0, 0, 1, 0}, - {0, 0, 0, 0, 0, 0, 0, 0}, - {0, 1, 1, 0, 0, 0, 0, 0}, - {0, 0, 1, 1, 1, 0, 0, 0}, - {0, 0, 1, 0, 1, 0, 0, 0}, - {1, 1, 1, 1, 1, 1, 1, 1}}; - - UInt trueAvgConnectedSpan[9] = - {7, 5, 1, 5, 0, 2, 3, 3, 8}; - - for (UInt i = 0; i < numColumns; i++) { - sp.setPermanence(i, permArr[i]); - UInt result = sp.avgConnectedSpanForColumn1D_(i); - ASSERT_TRUE(result == trueAvgConnectedSpan[i]); - } + UInt trueAvgConnectedSpan1[7] = {3, 3, 4, 3, 2, 2, 0}; + + for (UInt i = 0; i < numColumns; i++) { + sp.setPermanence(i, permArr1[i]); + UInt result = sp.avgConnectedSpanForColumn2D_(i); + ASSERT_TRUE(result == (trueAvgConnectedSpan1[i])); } - TEST(SpatialPoolerTest, testAvgConnectedSpanForColumn2D) - { - SpatialPooler sp; - - UInt numColumns = 7; - UInt numInputs = 20; - - vector colDim, inputDim; - Real permArr1[7][20] = - {{0, 1, 1, 1, - 0, 1, 1, 1, - 0, 1, 1, 1, - 0, 0, 0, 0, - 0, 0, 0, 0}, - // rowspan = 3, colspan = 3, avg = 3 - - {1, 1, 1, 1, - 0, 0, 1, 1, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0}, - // rowspan = 2 colspan = 4, avg = 3 - - {1, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 1}, - // row span = 5, colspan = 4, avg = 4.5 - - {0, 1, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 1, 0, 0, - 0, 1, 0, 0}, - // rowspan = 5, colspan = 1, avg = 3 - - {0, 0, 0, 0, - 1, 0, 0, 1, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0}, - // rowspan = 1, colspan = 4, avg = 2.5 - - {0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 1, 0, - 0, 0, 0, 1}, - // rowspan = 2, colspan = 2, avg = 2 - - {0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0} - // rowspan = 0, colspan = 0, avg = 0 - }; - - inputDim.push_back(5); - inputDim.push_back(4); - colDim.push_back(10); - colDim.push_back(1); - sp.initialize(inputDim, colDim); - - UInt trueAvgConnectedSpan1[7] = {3, 3, 4, 3, 2, 2, 0}; - - for (UInt i = 0; i < numColumns; i++) { - sp.setPermanence(i, permArr1[i]); - UInt result = sp.avgConnectedSpanForColumn2D_(i); - ASSERT_TRUE(result == (trueAvgConnectedSpan1[i])); - } + // 1D tests repeated + numColumns = 9; + numInputs = 8; - //1D tests repeated - numColumns = 9; - numInputs = 8; - - colDim.clear(); - inputDim.clear(); - inputDim.push_back(numInputs); - inputDim.push_back(1); - colDim.push_back(numColumns); - colDim.push_back(1); - - sp.initialize(inputDim, colDim); - - Real permArr2[9][8] = - {{0, 1, 0, 1, 0, 1, 0, 1}, - {0, 0, 0, 1, 0, 0, 0, 1}, - {0, 0, 0, 0, 0, 0, 1, 0}, - {0, 0, 1, 0, 0, 0, 1, 0}, - {0, 0, 0, 0, 0, 0, 0, 0}, - {0, 1, 1, 0, 0, 0, 0, 0}, - {0, 0, 1, 1, 1, 0, 0, 0}, - {0, 0, 1, 0, 1, 0, 0, 0}, - {1, 1, 1, 1, 1, 1, 1, 1}}; - - UInt trueAvgConnectedSpan2[9] = - {8, 5, 1, 5, 0, 2, 3, 3, 8}; - - for (UInt i = 0; i < numColumns; i++) { - sp.setPermanence(i, permArr2[i]); - UInt result = sp.avgConnectedSpanForColumn2D_(i); - ASSERT_TRUE(result == (trueAvgConnectedSpan2[i] + 1)/2); - } - } + colDim.clear(); + inputDim.clear(); + inputDim.push_back(numInputs); + inputDim.push_back(1); + colDim.push_back(numColumns); + colDim.push_back(1); - TEST(SpatialPoolerTest, testAvgConnectedSpanForColumnND) - { - SpatialPooler sp; - vector inputDim, colDim; - inputDim.push_back(4); - inputDim.push_back(4); - inputDim.push_back(2); - inputDim.push_back(5); - colDim.push_back(5); - colDim.push_back(1); - colDim.push_back(1); - colDim.push_back(1); - - sp.initialize(inputDim, colDim); - - UInt numInputs = 160; - UInt numColumns = 5; - - Real permArr0[4][4][2][5]; - Real permArr1[4][4][2][5]; - Real permArr2[4][4][2][5]; - Real permArr3[4][4][2][5]; - Real permArr4[4][4][2][5]; - - for (UInt i = 0; i < numInputs; i++) { - ((Real*)permArr0)[i] = 0; - ((Real*)permArr1)[i] = 0; - ((Real*)permArr2)[i] = 0; - ((Real*)permArr3)[i] = 0; - ((Real*)permArr4)[i] = 0; - } + sp.initialize(inputDim, colDim); - permArr0[1][0][1][0] = 1; - permArr0[1][0][1][1] = 1; - permArr0[3][2][1][0] = 1; - permArr0[3][0][1][0] = 1; - permArr0[1][0][1][3] = 1; - permArr0[2][2][1][0] = 1; - - permArr1[2][0][1][0] = 1; - permArr1[2][0][0][0] = 1; - permArr1[3][0][0][0] = 1; - permArr1[3][0][1][0] = 1; - - permArr2[0][0][1][4] = 1; - permArr2[0][0][0][3] = 1; - permArr2[0][0][0][1] = 1; - permArr2[1][0][0][2] = 1; - permArr2[0][0][1][1] = 1; - permArr2[3][3][1][1] = 1; - - permArr3[3][3][1][4] = 1; - permArr3[0][0][0][0] = 1; - - sp.setPermanence(0, (Real *) permArr0); - sp.setPermanence(1, (Real *) permArr1); - sp.setPermanence(2, (Real *) permArr2); - sp.setPermanence(3, (Real *) permArr3); - sp.setPermanence(4, (Real *) permArr4); - - Real trueAvgConnectedSpan[5] = - {11.0/4, 6.0/4, 14.0/4, 15.0/4, 0}; - - for (UInt i = 0; i < numColumns; i++) { - Real result = sp.avgConnectedSpanForColumnND_(i); - ASSERT_TRUE(result == trueAvgConnectedSpan[i]); - } + Real permArr2[9][8] = {{0, 1, 0, 1, 0, 1, 0, 1}, {0, 0, 0, 1, 0, 0, 0, 1}, + {0, 0, 0, 0, 0, 0, 1, 0}, {0, 0, 1, 0, 0, 0, 1, 0}, + {0, 0, 0, 0, 0, 0, 0, 0}, {0, 1, 1, 0, 0, 0, 0, 0}, + {0, 0, 1, 1, 1, 0, 0, 0}, {0, 0, 1, 0, 1, 0, 0, 0}, + {1, 1, 1, 1, 1, 1, 1, 1}}; + + UInt trueAvgConnectedSpan2[9] = {8, 5, 1, 5, 0, 2, 3, 3, 8}; + + for (UInt i = 0; i < numColumns; i++) { + sp.setPermanence(i, permArr2[i]); + UInt result = sp.avgConnectedSpanForColumn2D_(i); + ASSERT_TRUE(result == (trueAvgConnectedSpan2[i] + 1) / 2); + } +} + +TEST(SpatialPoolerTest, testAvgConnectedSpanForColumnND) { + SpatialPooler sp; + vector inputDim, colDim; + inputDim.push_back(4); + inputDim.push_back(4); + inputDim.push_back(2); + inputDim.push_back(5); + colDim.push_back(5); + colDim.push_back(1); + colDim.push_back(1); + colDim.push_back(1); + + sp.initialize(inputDim, colDim); + + UInt numInputs = 160; + UInt numColumns = 5; + + Real permArr0[4][4][2][5]; + Real permArr1[4][4][2][5]; + Real permArr2[4][4][2][5]; + Real permArr3[4][4][2][5]; + Real permArr4[4][4][2][5]; + + for (UInt i = 0; i < numInputs; i++) { + ((Real *)permArr0)[i] = 0; + ((Real *)permArr1)[i] = 0; + ((Real *)permArr2)[i] = 0; + ((Real *)permArr3)[i] = 0; + ((Real *)permArr4)[i] = 0; } - TEST(SpatialPoolerTest, testAdaptSynapses) - { - SpatialPooler sp; - UInt numColumns = 4; - UInt numInputs = 8; - setup(sp, numInputs, numColumns); - - vector activeColumns; - vector inputVector; - - UInt potentialArr1[4][8] = - {{1, 1, 1, 1, 0, 0, 0, 0}, - {1, 0, 0, 0, 1, 1, 0, 1}, - {0, 0, 1, 0, 0, 0, 1, 0}, - {1, 0, 0, 0, 0, 0, 1, 0}}; - - Real permanencesArr1[5][8] = - {{0.200, 0.120, 0.090, 0.060, 0.000, 0.000, 0.000, 0.000}, - {0.150, 0.000, 0.000, 0.000, 0.180, 0.120, 0.000, 0.450}, - {0.000, 0.000, 0.014, 0.000, 0.000, 0.000, 0.110, 0.000}, - {0.070, 0.000, 0.000, 0.000, 0.000, 0.000, 0.178, 0.000}}; - - Real truePermanences1[5][8] = - {{ 0.300, 0.110, 0.080, 0.160, 0.000, 0.000, 0.000, 0.000}, + permArr0[1][0][1][0] = 1; + permArr0[1][0][1][1] = 1; + permArr0[3][2][1][0] = 1; + permArr0[3][0][1][0] = 1; + permArr0[1][0][1][3] = 1; + permArr0[2][2][1][0] = 1; + + permArr1[2][0][1][0] = 1; + permArr1[2][0][0][0] = 1; + permArr1[3][0][0][0] = 1; + permArr1[3][0][1][0] = 1; + + permArr2[0][0][1][4] = 1; + permArr2[0][0][0][3] = 1; + permArr2[0][0][0][1] = 1; + permArr2[1][0][0][2] = 1; + permArr2[0][0][1][1] = 1; + permArr2[3][3][1][1] = 1; + + permArr3[3][3][1][4] = 1; + permArr3[0][0][0][0] = 1; + + sp.setPermanence(0, (Real *)permArr0); + sp.setPermanence(1, (Real *)permArr1); + sp.setPermanence(2, (Real *)permArr2); + sp.setPermanence(3, (Real *)permArr3); + sp.setPermanence(4, (Real *)permArr4); + + Real trueAvgConnectedSpan[5] = {11.0 / 4, 6.0 / 4, 14.0 / 4, 15.0 / 4, 0}; + + for (UInt i = 0; i < numColumns; i++) { + Real result = sp.avgConnectedSpanForColumnND_(i); + ASSERT_TRUE(result == trueAvgConnectedSpan[i]); + } +} + +TEST(SpatialPoolerTest, testAdaptSynapses) { + SpatialPooler sp; + UInt numColumns = 4; + UInt numInputs = 8; + setup(sp, numInputs, numColumns); + + vector activeColumns; + vector inputVector; + + UInt potentialArr1[4][8] = {{1, 1, 1, 1, 0, 0, 0, 0}, + {1, 0, 0, 0, 1, 1, 0, 1}, + {0, 0, 1, 0, 0, 0, 1, 0}, + {1, 0, 0, 0, 0, 0, 1, 0}}; + + Real permanencesArr1[5][8] = { + {0.200, 0.120, 0.090, 0.060, 0.000, 0.000, 0.000, 0.000}, + {0.150, 0.000, 0.000, 0.000, 0.180, 0.120, 0.000, 0.450}, + {0.000, 0.000, 0.014, 0.000, 0.000, 0.000, 0.110, 0.000}, + {0.070, 0.000, 0.000, 0.000, 0.000, 0.000, 0.178, 0.000}}; + + Real truePermanences1[5][8] = { + {0.300, 0.110, 0.080, 0.160, 0.000, 0.000, 0.000, 0.000}, // Inc Dec Dec Inc - - - - - {0.250, 0.000, 0.000, 0.000, 0.280, 0.110, 0.000, 0.440}, + {0.250, 0.000, 0.000, 0.000, 0.280, 0.110, 0.000, 0.440}, // Inc - - - Inc Dec - Dec - {0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.210, 0.000}, + {0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.210, 0.000}, // - - Trim - - - Inc - - {0.070, 0.000, 0.000, 0.000, 0.000, 0.000, 0.178, 0.000}}; - // - - - - - - - - + {0.070, 0.000, 0.000, 0.000, 0.000, 0.000, 0.178, 0.000}}; + // - - - - - - - - - UInt inputArr1[8] = {1, 0, 0, 1, 1, 0, 1, 0}; - UInt activeColumnsArr1[3] = {0, 1, 2}; - - for (UInt column = 0; column < numColumns; column++) { - sp.setPotential(column, potentialArr1[column]); - sp.setPermanence(column, permanencesArr1[column]); - } - - activeColumns.assign(&activeColumnsArr1[0], &activeColumnsArr1[3]); - - sp.adaptSynapses_(inputArr1, activeColumns); - cout << endl; - for (UInt column = 0; column < numColumns; column++) { - auto permArr = new Real[numInputs]; - sp.getPermanence(column, permArr); - ASSERT_TRUE(check_vector_eq(truePermanences1[column], - permArr, - numInputs)); - delete[] permArr; - } - - - UInt potentialArr2[4][8] = - {{1, 1, 1, 0, 0, 0, 0, 0}, - {0, 1, 1, 1, 0, 0, 0, 0}, - {0, 0, 1, 1, 1, 0, 0, 0}, - {1, 0, 0, 0, 0, 0, 1, 0}}; - - Real permanencesArr2[4][8] = - {{0.200, 0.120, 0.090, 0.000, 0.000, 0.000, 0.000, 0.000}, - {0.000, 0.017, 0.232, 0.400, 0.000, 0.000, 0.000, 0.000}, - {0.000, 0.000, 0.014, 0.051, 0.730, 0.000, 0.000, 0.000}, - {0.170, 0.000, 0.000, 0.000, 0.000, 0.000, 0.380, 0.000}}; - - Real truePermanences2[4][8] = - {{0.30, 0.110, 0.080, 0.000, 0.000, 0.000, 0.000, 0.000}, - // # Inc Dec Dec - - - - - - {0.000, 0.000, 0.222, 0.500, 0.000, 0.000, 0.000, 0.000}, - // # - Trim Dec Inc - - - - - {0.000, 0.000, 0.000, 0.151, 0.830, 0.000, 0.000, 0.000}, - // # - - Trim Inc Inc - - - - {0.170, 0.000, 0.000, 0.000, 0.000, 0.000, 0.380, 0.000}}; - // # - - - - - - - - - - UInt inputArr2[8] = { 1, 0, 0, 1, 1, 0, 1, 0 }; - UInt activeColumnsArr2[3] = {0, 1, 2}; - - for (UInt column = 0; column < numColumns; column++) { - sp.setPotential(column, potentialArr2[column]); - sp.setPermanence(column, permanencesArr2[column]); - } - - activeColumns.assign(&activeColumnsArr2[0], &activeColumnsArr2[3]); - - sp.adaptSynapses_(inputArr2, activeColumns); - cout << endl; - for (UInt column = 0; column < numColumns; column++) { - auto permArr = new Real[numInputs]; - sp.getPermanence(column, permArr); - ASSERT_TRUE(check_vector_eq(truePermanences2[column], permArr, numInputs)); - delete[] permArr; - } + UInt inputArr1[8] = {1, 0, 0, 1, 1, 0, 1, 0}; + UInt activeColumnsArr1[3] = {0, 1, 2}; + for (UInt column = 0; column < numColumns; column++) { + sp.setPotential(column, potentialArr1[column]); + sp.setPermanence(column, permanencesArr1[column]); } - TEST(SpatialPoolerTest, testBumpUpWeakColumns) - { - SpatialPooler sp; - UInt numInputs = 8; - UInt numColumns = 5; - setup(sp,numInputs,numColumns); - sp.setSynPermBelowStimulusInc(0.01); - sp.setSynPermTrimThreshold(0.05); - Real overlapDutyCyclesArr[] = {0, 0.009, 0.1, 0.001, 0.002}; - sp.setOverlapDutyCycles(overlapDutyCyclesArr); - Real minOverlapDutyCyclesArr[] = {0.01, 0.01, 0.01, 0.01, 0.01}; - sp.setMinOverlapDutyCycles(minOverlapDutyCyclesArr); - - UInt potentialArr[5][8] = - {{1, 1, 1, 1, 0, 0, 0, 0}, - {1, 0, 0, 0, 1, 1, 0, 1}, - {0, 0, 1, 0, 1, 1, 1, 0}, - {1, 1, 1, 0, 0, 0, 1, 0}, - {1, 1, 1, 1, 1, 1, 1, 1}}; - - Real permArr[5][8] = - {{0.200, 0.120, 0.090, 0.040, 0.000, 0.000, 0.000, 0.000}, - {0.150, 0.000, 0.000, 0.000, 0.180, 0.120, 0.000, 0.450}, - {0.000, 0.000, 0.074, 0.000, 0.062, 0.054, 0.110, 0.000}, - {0.051, 0.000, 0.000, 0.000, 0.000, 0.000, 0.178, 0.000}, - {0.100, 0.738, 0.085, 0.002, 0.052, 0.008, 0.208, 0.034}}; - - Real truePermArr[5][8] = - {{0.210, 0.130, 0.100, 0.000, 0.000, 0.000, 0.000, 0.000}, - // Inc Inc Inc Trim - - - - - {0.160, 0.000, 0.000, 0.000, 0.190, 0.130, 0.000, 0.460}, - // Inc - - - Inc Inc - Inc - {0.000, 0.000, 0.074, 0.000, 0.062, 0.054, 0.110, 0.000}, // unchanged - // - - - - - - - - - {0.061, 0.000, 0.000, 0.000, 0.000, 0.000, 0.188, 0.000}, - // Inc Trim Trim - - - Inc - - {0.110, 0.748, 0.095, 0.000, 0.062, 0.000, 0.218, 0.000}}; - - for (UInt i = 0; i < numColumns; i++) { - sp.setPotential(i, potentialArr[i]); - sp.setPermanence(i, permArr[i]); - Real perm[8]; - sp.getPermanence(i, perm); - } - - sp.bumpUpWeakColumns_(); - - for (UInt i = 0; i < numColumns; i++) { - Real perm[8]; - sp.getPermanence(i, perm); - ASSERT_TRUE(check_vector_eq(truePermArr[i], perm, numInputs)); - } + activeColumns.assign(&activeColumnsArr1[0], &activeColumnsArr1[3]); + sp.adaptSynapses_(inputArr1, activeColumns); + cout << endl; + for (UInt column = 0; column < numColumns; column++) { + auto permArr = new Real[numInputs]; + sp.getPermanence(column, permArr); + ASSERT_TRUE(check_vector_eq(truePermanences1[column], permArr, numInputs)); + delete[] permArr; } - TEST(SpatialPoolerTest, testUpdateDutyCyclesHelper) - { - SpatialPooler sp; - vector dutyCycles; - vector newValues; - UInt period; - - dutyCycles.clear(); - newValues.clear(); - Real dutyCyclesArr1[] = {1000.0, 1000.0, 1000.0, 1000.0, 1000.0}; - Real newValues1[] = {0, 0, 0, 0, 0}; - period = 1000; - Real trueDutyCycles1[] = {999.0, 999.0, 999.0, 999.0, 999.0}; - dutyCycles.assign(dutyCyclesArr1, dutyCyclesArr1+5); - newValues.assign(newValues1, newValues1+5); - sp.updateDutyCyclesHelper_(dutyCycles, newValues, period); - ASSERT_TRUE(check_vector_eq(trueDutyCycles1, dutyCycles)); - - dutyCycles.clear(); - newValues.clear(); - Real dutyCyclesArr2[] = {1000.0, 1000.0, 1000.0, 1000.0, 1000.0}; - Real newValues2[] = {1000, 1000, 1000, 1000, 1000}; - period = 1000; - Real trueDutyCycles2[] = {1000.0, 1000.0, 1000.0, 1000.0, 1000.0}; - dutyCycles.assign(dutyCyclesArr2, dutyCyclesArr2+5); - newValues.assign(newValues2, newValues2+5); - sp.updateDutyCyclesHelper_(dutyCycles, newValues, period); - ASSERT_TRUE(check_vector_eq(trueDutyCycles2, dutyCycles)); - - dutyCycles.clear(); - newValues.clear(); - Real dutyCyclesArr3[] = {1000.0, 1000.0, 1000.0, 1000.0, 1000.0}; - Real newValues3[] = {2000, 4000, 5000, 6000, 7000}; - period = 1000; - Real trueDutyCycles3[] = {1001.0, 1003.0, 1004.0, 1005.0, 1006.0}; - dutyCycles.assign(dutyCyclesArr3, dutyCyclesArr3+5); - newValues.assign(newValues3, newValues3+5); - sp.updateDutyCyclesHelper_(dutyCycles, newValues, period); - ASSERT_TRUE(check_vector_eq(trueDutyCycles3, dutyCycles)); - - dutyCycles.clear(); - newValues.clear(); - Real dutyCyclesArr4[] = {1000.0, 800.0, 600.0, 400.0, 2000.0}; - Real newValues4[] = {0, 0, 0, 0, 0}; - period = 2; - Real trueDutyCycles4[] = {500.0, 400.0, 300.0, 200.0, 1000.0}; - dutyCycles.assign(dutyCyclesArr4, dutyCyclesArr4+5); - newValues.assign(newValues4, newValues4+5); - sp.updateDutyCyclesHelper_(dutyCycles, newValues, period); - ASSERT_TRUE(check_vector_eq(trueDutyCycles4, dutyCycles)); - + UInt potentialArr2[4][8] = {{1, 1, 1, 0, 0, 0, 0, 0}, + {0, 1, 1, 1, 0, 0, 0, 0}, + {0, 0, 1, 1, 1, 0, 0, 0}, + {1, 0, 0, 0, 0, 0, 1, 0}}; + + Real permanencesArr2[4][8] = { + {0.200, 0.120, 0.090, 0.000, 0.000, 0.000, 0.000, 0.000}, + {0.000, 0.017, 0.232, 0.400, 0.000, 0.000, 0.000, 0.000}, + {0.000, 0.000, 0.014, 0.051, 0.730, 0.000, 0.000, 0.000}, + {0.170, 0.000, 0.000, 0.000, 0.000, 0.000, 0.380, 0.000}}; + + Real truePermanences2[4][8] = { + {0.30, 0.110, 0.080, 0.000, 0.000, 0.000, 0.000, 0.000}, + // # Inc Dec Dec - - - - - + {0.000, 0.000, 0.222, 0.500, 0.000, 0.000, 0.000, 0.000}, + // # - Trim Dec Inc - - - - + {0.000, 0.000, 0.000, 0.151, 0.830, 0.000, 0.000, 0.000}, + // # - - Trim Inc Inc - - - + {0.170, 0.000, 0.000, 0.000, 0.000, 0.000, 0.380, 0.000}}; + // # - - - - - - - - + + UInt inputArr2[8] = {1, 0, 0, 1, 1, 0, 1, 0}; + UInt activeColumnsArr2[3] = {0, 1, 2}; + + for (UInt column = 0; column < numColumns; column++) { + sp.setPotential(column, potentialArr2[column]); + sp.setPermanence(column, permanencesArr2[column]); } - TEST(SpatialPoolerTest, testUpdateBoostFactors) - { - SpatialPooler sp; - setup(sp, 5, 6); - - Real initActiveDutyCycles1[] = {0.1, 0.1, 0.1, 0.1, 0.1, 0.1}; - Real initBoostFactors1[] = {0, 0, 0, 0, 0, 0}; - vector trueBoostFactors1 = - {1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; - vector resultBoostFactors1(6, 0); - sp.setGlobalInhibition(false); - sp.setBoostStrength(10); - sp.setBoostFactors(initBoostFactors1); - sp.setActiveDutyCycles(initActiveDutyCycles1); - sp.updateBoostFactors_(); - sp.getBoostFactors(resultBoostFactors1.data()); - ASSERT_TRUE(check_vector_eq(trueBoostFactors1, resultBoostFactors1)); - - Real initActiveDutyCycles2[] = - {0.1, 0.3, 0.02, 0.04, 0.7, 0.12}; - Real initBoostFactors2[] = - {0, 0, 0, 0, 0, 0}; - vector trueBoostFactors2 = - {3.10599, 0.42035, 6.91251, 5.65949, 0.00769898, 2.54297}; - vector resultBoostFactors2(6, 0); - sp.setGlobalInhibition(false); - sp.setBoostStrength(10); - sp.setBoostFactors(initBoostFactors2); - sp.setActiveDutyCycles(initActiveDutyCycles2); - sp.updateBoostFactors_(); - sp.getBoostFactors(resultBoostFactors2.data()); - - ASSERT_TRUE(check_vector_eq(trueBoostFactors2, resultBoostFactors2)); - - Real initActiveDutyCycles3[] = - {0.1, 0.3, 0.02, 0.04, 0.7, 0.12}; - Real initBoostFactors3[] = - {0, 0, 0, 0, 0, 0}; - vector trueBoostFactors3 = - { 1.25441, 0.840857, 1.47207, 1.41435, 0.377822, 1.20523 }; - vector resultBoostFactors3(6, 0); - sp.setWrapAround(true); - sp.setGlobalInhibition(false); - sp.setBoostStrength(2.0); - sp.setInhibitionRadius(5); - sp.setNumActiveColumnsPerInhArea(1); - sp.setBoostFactors(initBoostFactors3); - sp.setActiveDutyCycles(initActiveDutyCycles3); - sp.updateBoostFactors_(); - sp.getBoostFactors(resultBoostFactors3.data()); - - ASSERT_TRUE(check_vector_eq(trueBoostFactors3, resultBoostFactors3)); - - Real initActiveDutyCycles4[] = - {0.1, 0.3, 0.02, 0.04, 0.7, 0.12}; - Real initBoostFactors4[] = - {0, 0, 0, 0, 0, 0}; - vector trueBoostFactors4 = - { 1.94773, 0.263597, 4.33476, 3.549, 0.00482795, 1.59467 }; - vector resultBoostFactors4(6, 0); - sp.setGlobalInhibition(true); - sp.setBoostStrength(10); - sp.setNumActiveColumnsPerInhArea(1); - sp.setInhibitionRadius(3); - sp.setBoostFactors(initBoostFactors4); - sp.setActiveDutyCycles(initActiveDutyCycles4); - sp.updateBoostFactors_(); - sp.getBoostFactors(resultBoostFactors4.data()); - - ASSERT_TRUE(check_vector_eq(trueBoostFactors3, resultBoostFactors3)); - } + activeColumns.assign(&activeColumnsArr2[0], &activeColumnsArr2[3]); - TEST(SpatialPoolerTest, testUpdateBookeepingVars) - { - SpatialPooler sp; - sp.setIterationNum(5); - sp.setIterationLearnNum(3); - sp.updateBookeepingVars_(true); - ASSERT_TRUE(6 == sp.getIterationNum()); - ASSERT_TRUE(4 == sp.getIterationLearnNum()); - - sp.updateBookeepingVars_(false); - ASSERT_TRUE(7 == sp.getIterationNum()); - ASSERT_TRUE(4 == sp.getIterationLearnNum()); + sp.adaptSynapses_(inputArr2, activeColumns); + cout << endl; + for (UInt column = 0; column < numColumns; column++) { + auto permArr = new Real[numInputs]; + sp.getPermanence(column, permArr); + ASSERT_TRUE(check_vector_eq(truePermanences2[column], permArr, numInputs)); + delete[] permArr; } - - TEST(SpatialPoolerTest, testCalculateOverlap) - { - SpatialPooler sp; - UInt numInputs = 10; - UInt numColumns = 5; - UInt numTrials = 5; - setup(sp,numInputs,numColumns); - sp.setStimulusThreshold(0); - - Real permArr[5][10] = - {{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - {0, 0, 1, 1, 1, 1, 1, 1, 1, 1}, - {0, 0, 0, 0, 1, 1, 1, 1, 1, 1}, - {0, 0, 0, 0, 0, 0, 1, 1, 1, 1}, - {0, 0, 0, 0, 0, 0, 0, 0, 1, 1}}; - - - UInt inputs[5][10] = - {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - {0, 1, 0, 1, 0, 1, 0, 1, 0, 1}, - {1, 1, 1, 1, 1, 0, 0, 0, 0, 0}, - {0, 0, 0, 0, 0, 0, 0, 0, 0, 1}}; - - UInt trueOverlaps[5][5] = - {{ 0, 0, 0, 0, 0}, - {10, 8, 6, 4, 2}, - { 5, 4, 3, 2, 1}, - { 5, 3, 1, 0, 0}, - { 1, 1, 1, 1, 1}}; - - for (UInt i = 0; i < numColumns; i++) - { - sp.setPermanence(i, permArr[i]); - } - - for (UInt i = 0; i < numTrials; i++) - { - vector overlaps; - sp.calculateOverlap_(inputs[i], overlaps); - ASSERT_TRUE(check_vector_eq(trueOverlaps[i],overlaps)); - } +} + +TEST(SpatialPoolerTest, testBumpUpWeakColumns) { + SpatialPooler sp; + UInt numInputs = 8; + UInt numColumns = 5; + setup(sp, numInputs, numColumns); + sp.setSynPermBelowStimulusInc(0.01); + sp.setSynPermTrimThreshold(0.05); + Real overlapDutyCyclesArr[] = {0, 0.009, 0.1, 0.001, 0.002}; + sp.setOverlapDutyCycles(overlapDutyCyclesArr); + Real minOverlapDutyCyclesArr[] = {0.01, 0.01, 0.01, 0.01, 0.01}; + sp.setMinOverlapDutyCycles(minOverlapDutyCyclesArr); + + UInt potentialArr[5][8] = {{1, 1, 1, 1, 0, 0, 0, 0}, + {1, 0, 0, 0, 1, 1, 0, 1}, + {0, 0, 1, 0, 1, 1, 1, 0}, + {1, 1, 1, 0, 0, 0, 1, 0}, + {1, 1, 1, 1, 1, 1, 1, 1}}; + + Real permArr[5][8] = { + {0.200, 0.120, 0.090, 0.040, 0.000, 0.000, 0.000, 0.000}, + {0.150, 0.000, 0.000, 0.000, 0.180, 0.120, 0.000, 0.450}, + {0.000, 0.000, 0.074, 0.000, 0.062, 0.054, 0.110, 0.000}, + {0.051, 0.000, 0.000, 0.000, 0.000, 0.000, 0.178, 0.000}, + {0.100, 0.738, 0.085, 0.002, 0.052, 0.008, 0.208, 0.034}}; + + Real truePermArr[5][8] = { + {0.210, 0.130, 0.100, 0.000, 0.000, 0.000, 0.000, 0.000}, + // Inc Inc Inc Trim - - - - + {0.160, 0.000, 0.000, 0.000, 0.190, 0.130, 0.000, 0.460}, + // Inc - - - Inc Inc - Inc + {0.000, 0.000, 0.074, 0.000, 0.062, 0.054, 0.110, 0.000}, // unchanged + // - - - - - - - - + {0.061, 0.000, 0.000, 0.000, 0.000, 0.000, 0.188, 0.000}, + // Inc Trim Trim - - - Inc - + {0.110, 0.748, 0.095, 0.000, 0.062, 0.000, 0.218, 0.000}}; + + for (UInt i = 0; i < numColumns; i++) { + sp.setPotential(i, potentialArr[i]); + sp.setPermanence(i, permArr[i]); + Real perm[8]; + sp.getPermanence(i, perm); } - TEST(SpatialPoolerTest, testCalculateOverlapPct) - { - SpatialPooler sp; - UInt numInputs = 10; - UInt numColumns = 5; - UInt numTrials = 5; - setup(sp,numInputs,numColumns); - sp.setStimulusThreshold(0); - - Real permArr[5][10] = - {{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, - {0, 0, 1, 1, 1, 1, 1, 1, 1, 1}, - {0, 0, 0, 0, 1, 1, 1, 1, 1, 1}, - {0, 0, 0, 0, 0, 0, 1, 1, 1, 1}, - {0, 0, 0, 0, 0, 0, 0, 0, 1, 1}}; - - UInt overlapsArr[5][10] = - {{ 0, 0, 0, 0, 0}, - {10, 8, 6, 4, 2}, - { 5, 4, 3, 2, 1}, - { 5, 3, 1, 0, 0}, - { 1, 1, 1, 1, 1}}; - - Real trueOverlapsPct[5][5] = - {{0.0, 0.0, 0.0, 0.0, 0.0}, - {1.0, 1.0, 1.0, 1.0, 1.0}, - {0.5, 0.5, 0.5, 0.5, 0.5}, - {0.5, 3.0/8, 1.0/6, 0, 0}, - { 1.0/10, 1.0/8, 1.0/6, 1.0/4, 1.0/2}}; - - for (UInt i = 0; i < numColumns; i++) - { - sp.setPermanence(i,permArr[i]); - } - - for (UInt i = 0; i < numTrials; i++) - { - vector overlapsPct; - vector overlaps; - overlaps.assign(&overlapsArr[i][0],&overlapsArr[i][numColumns]); - sp.calculateOverlapPct_(overlaps,overlapsPct); - ASSERT_TRUE(check_vector_eq(trueOverlapsPct[i],overlapsPct)); - } - + sp.bumpUpWeakColumns_(); + for (UInt i = 0; i < numColumns; i++) { + Real perm[8]; + sp.getPermanence(i, perm); + ASSERT_TRUE(check_vector_eq(truePermArr[i], perm, numInputs)); } - - TEST(SpatialPoolerTest, testIsWinner) - { - UInt numInputs = 10; - UInt numColumns = 5; - SpatialPooler sp({numInputs}, {numColumns}); - - vector > winners; - - UInt numWinners = 3; - Real score = -5; - ASSERT_FALSE(sp.isWinner_(score,winners,numWinners)); - score = 0; - ASSERT_TRUE(sp.isWinner_(score,winners,numWinners)); - - pair sc1; sc1.first = 1; sc1.second = 32; - pair sc2; sc2.first = 2; sc2.second = 27; - pair sc3; sc3.first = 17; sc3.second = 19.5; - winners.push_back(sc1); - winners.push_back(sc2); - winners.push_back(sc3); - - numWinners = 3; - score = -5; - ASSERT_TRUE(!sp.isWinner_(score,winners,numWinners)); - score = 18; - ASSERT_TRUE(!sp.isWinner_(score,winners,numWinners)); - score = 18; - numWinners = 4; - ASSERT_TRUE(sp.isWinner_(score,winners,numWinners)); - numWinners = 3; - score = 20; - ASSERT_TRUE(sp.isWinner_(score,winners,numWinners)); - score = 30; - ASSERT_TRUE(sp.isWinner_(score,winners,numWinners)); - score = 40; - ASSERT_TRUE(sp.isWinner_(score,winners,numWinners)); - score = 40; - numWinners = 6; - ASSERT_TRUE(sp.isWinner_(score,winners,numWinners)); - - pair sc4; sc4.first = 34; sc4.second = 17.1; - pair sc5; sc5.first = 51; sc5.second = 1.2; - pair sc6; sc6.first = 19; sc6.second = 0.3; - winners.push_back(sc4); - winners.push_back(sc5); - winners.push_back(sc6); - - score = 40; - numWinners = 6; - ASSERT_TRUE(sp.isWinner_(score,winners,numWinners)); - score = 12; - numWinners = 6; - ASSERT_TRUE(sp.isWinner_(score,winners,numWinners)); - score = 0.1; - numWinners = 6; - ASSERT_TRUE(!sp.isWinner_(score,winners,numWinners)); - score = 0.1; - numWinners = 7; - ASSERT_TRUE(sp.isWinner_(score,winners,numWinners)); +} + +TEST(SpatialPoolerTest, testUpdateDutyCyclesHelper) { + SpatialPooler sp; + vector dutyCycles; + vector newValues; + UInt period; + + dutyCycles.clear(); + newValues.clear(); + Real dutyCyclesArr1[] = {1000.0, 1000.0, 1000.0, 1000.0, 1000.0}; + Real newValues1[] = {0, 0, 0, 0, 0}; + period = 1000; + Real trueDutyCycles1[] = {999.0, 999.0, 999.0, 999.0, 999.0}; + dutyCycles.assign(dutyCyclesArr1, dutyCyclesArr1 + 5); + newValues.assign(newValues1, newValues1 + 5); + sp.updateDutyCyclesHelper_(dutyCycles, newValues, period); + ASSERT_TRUE(check_vector_eq(trueDutyCycles1, dutyCycles)); + + dutyCycles.clear(); + newValues.clear(); + Real dutyCyclesArr2[] = {1000.0, 1000.0, 1000.0, 1000.0, 1000.0}; + Real newValues2[] = {1000, 1000, 1000, 1000, 1000}; + period = 1000; + Real trueDutyCycles2[] = {1000.0, 1000.0, 1000.0, 1000.0, 1000.0}; + dutyCycles.assign(dutyCyclesArr2, dutyCyclesArr2 + 5); + newValues.assign(newValues2, newValues2 + 5); + sp.updateDutyCyclesHelper_(dutyCycles, newValues, period); + ASSERT_TRUE(check_vector_eq(trueDutyCycles2, dutyCycles)); + + dutyCycles.clear(); + newValues.clear(); + Real dutyCyclesArr3[] = {1000.0, 1000.0, 1000.0, 1000.0, 1000.0}; + Real newValues3[] = {2000, 4000, 5000, 6000, 7000}; + period = 1000; + Real trueDutyCycles3[] = {1001.0, 1003.0, 1004.0, 1005.0, 1006.0}; + dutyCycles.assign(dutyCyclesArr3, dutyCyclesArr3 + 5); + newValues.assign(newValues3, newValues3 + 5); + sp.updateDutyCyclesHelper_(dutyCycles, newValues, period); + ASSERT_TRUE(check_vector_eq(trueDutyCycles3, dutyCycles)); + + dutyCycles.clear(); + newValues.clear(); + Real dutyCyclesArr4[] = {1000.0, 800.0, 600.0, 400.0, 2000.0}; + Real newValues4[] = {0, 0, 0, 0, 0}; + period = 2; + Real trueDutyCycles4[] = {500.0, 400.0, 300.0, 200.0, 1000.0}; + dutyCycles.assign(dutyCyclesArr4, dutyCyclesArr4 + 5); + newValues.assign(newValues4, newValues4 + 5); + sp.updateDutyCyclesHelper_(dutyCycles, newValues, period); + ASSERT_TRUE(check_vector_eq(trueDutyCycles4, dutyCycles)); +} + +TEST(SpatialPoolerTest, testUpdateBoostFactors) { + SpatialPooler sp; + setup(sp, 5, 6); + + Real initActiveDutyCycles1[] = {0.1, 0.1, 0.1, 0.1, 0.1, 0.1}; + Real initBoostFactors1[] = {0, 0, 0, 0, 0, 0}; + vector trueBoostFactors1 = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + vector resultBoostFactors1(6, 0); + sp.setGlobalInhibition(false); + sp.setBoostStrength(10); + sp.setBoostFactors(initBoostFactors1); + sp.setActiveDutyCycles(initActiveDutyCycles1); + sp.updateBoostFactors_(); + sp.getBoostFactors(resultBoostFactors1.data()); + ASSERT_TRUE(check_vector_eq(trueBoostFactors1, resultBoostFactors1)); + + Real initActiveDutyCycles2[] = {0.1, 0.3, 0.02, 0.04, 0.7, 0.12}; + Real initBoostFactors2[] = {0, 0, 0, 0, 0, 0}; + vector trueBoostFactors2 = {3.10599, 0.42035, 6.91251, + 5.65949, 0.00769898, 2.54297}; + vector resultBoostFactors2(6, 0); + sp.setGlobalInhibition(false); + sp.setBoostStrength(10); + sp.setBoostFactors(initBoostFactors2); + sp.setActiveDutyCycles(initActiveDutyCycles2); + sp.updateBoostFactors_(); + sp.getBoostFactors(resultBoostFactors2.data()); + + ASSERT_TRUE(check_vector_eq(trueBoostFactors2, resultBoostFactors2)); + + Real initActiveDutyCycles3[] = {0.1, 0.3, 0.02, 0.04, 0.7, 0.12}; + Real initBoostFactors3[] = {0, 0, 0, 0, 0, 0}; + vector trueBoostFactors3 = {1.25441, 0.840857, 1.47207, + 1.41435, 0.377822, 1.20523}; + vector resultBoostFactors3(6, 0); + sp.setWrapAround(true); + sp.setGlobalInhibition(false); + sp.setBoostStrength(2.0); + sp.setInhibitionRadius(5); + sp.setNumActiveColumnsPerInhArea(1); + sp.setBoostFactors(initBoostFactors3); + sp.setActiveDutyCycles(initActiveDutyCycles3); + sp.updateBoostFactors_(); + sp.getBoostFactors(resultBoostFactors3.data()); + + ASSERT_TRUE(check_vector_eq(trueBoostFactors3, resultBoostFactors3)); + + Real initActiveDutyCycles4[] = {0.1, 0.3, 0.02, 0.04, 0.7, 0.12}; + Real initBoostFactors4[] = {0, 0, 0, 0, 0, 0}; + vector trueBoostFactors4 = {1.94773, 0.263597, 4.33476, + 3.549, 0.00482795, 1.59467}; + vector resultBoostFactors4(6, 0); + sp.setGlobalInhibition(true); + sp.setBoostStrength(10); + sp.setNumActiveColumnsPerInhArea(1); + sp.setInhibitionRadius(3); + sp.setBoostFactors(initBoostFactors4); + sp.setActiveDutyCycles(initActiveDutyCycles4); + sp.updateBoostFactors_(); + sp.getBoostFactors(resultBoostFactors4.data()); + + ASSERT_TRUE(check_vector_eq(trueBoostFactors3, resultBoostFactors3)); +} + +TEST(SpatialPoolerTest, testUpdateBookeepingVars) { + SpatialPooler sp; + sp.setIterationNum(5); + sp.setIterationLearnNum(3); + sp.updateBookeepingVars_(true); + ASSERT_TRUE(6 == sp.getIterationNum()); + ASSERT_TRUE(4 == sp.getIterationLearnNum()); + + sp.updateBookeepingVars_(false); + ASSERT_TRUE(7 == sp.getIterationNum()); + ASSERT_TRUE(4 == sp.getIterationLearnNum()); +} + +TEST(SpatialPoolerTest, testCalculateOverlap) { + SpatialPooler sp; + UInt numInputs = 10; + UInt numColumns = 5; + UInt numTrials = 5; + setup(sp, numInputs, numColumns); + sp.setStimulusThreshold(0); + + Real permArr[5][10] = {{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {0, 0, 1, 1, 1, 1, 1, 1, 1, 1}, + {0, 0, 0, 0, 1, 1, 1, 1, 1, 1}, + {0, 0, 0, 0, 0, 0, 1, 1, 1, 1}, + {0, 0, 0, 0, 0, 0, 0, 0, 1, 1}}; + + UInt inputs[5][10] = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {0, 1, 0, 1, 0, 1, 0, 1, 0, 1}, + {1, 1, 1, 1, 1, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 1}}; + + UInt trueOverlaps[5][5] = {{0, 0, 0, 0, 0}, + {10, 8, 6, 4, 2}, + {5, 4, 3, 2, 1}, + {5, 3, 1, 0, 0}, + {1, 1, 1, 1, 1}}; + + for (UInt i = 0; i < numColumns; i++) { + sp.setPermanence(i, permArr[i]); } - TEST(SpatialPoolerTest, testAddToWinners) - { - SpatialPooler sp; - vector > winners; - - UInt index; - Real score; - - index = 17; score = 19.5; - sp.addToWinners_(index,score,winners); - index = 1; score = 32; - sp.addToWinners_(index,score,winners); - index = 2; score = 27; - sp.addToWinners_(index,score,winners); - - ASSERT_TRUE(winners[0].first == 1); - ASSERT_TRUE(almost_eq(winners[0].second,32)); - ASSERT_TRUE(winners[1].first == 2); - ASSERT_TRUE(almost_eq(winners[1].second,27)); - ASSERT_TRUE(winners[2].first == 17); - ASSERT_TRUE(almost_eq(winners[2].second,19.5)); - - index = 15; score = 20.5; - sp.addToWinners_(index,score,winners); - ASSERT_TRUE(winners[0].first == 1); - ASSERT_TRUE(almost_eq(winners[0].second,32)); - ASSERT_TRUE(winners[1].first == 2); - ASSERT_TRUE(almost_eq(winners[1].second,27)); - ASSERT_TRUE(winners[2].first == 15); - ASSERT_TRUE(almost_eq(winners[2].second,20.5)); - ASSERT_TRUE(winners[3].first == 17); - ASSERT_TRUE(almost_eq(winners[3].second,19.5)); - - index = 7; score = 100; - sp.addToWinners_(index,score,winners); - ASSERT_TRUE(winners[0].first == 7); - ASSERT_TRUE(almost_eq(winners[0].second,100)); - ASSERT_TRUE(winners[1].first == 1); - ASSERT_TRUE(almost_eq(winners[1].second,32)); - ASSERT_TRUE(winners[2].first == 2); - ASSERT_TRUE(almost_eq(winners[2].second,27)); - ASSERT_TRUE(winners[3].first == 15); - ASSERT_TRUE(almost_eq(winners[3].second,20.5)); - ASSERT_TRUE(winners[4].first == 17); - ASSERT_TRUE(almost_eq(winners[4].second,19.5)); - - index = 22; score = 1; - sp.addToWinners_(index,score,winners); - ASSERT_TRUE(winners[0].first == 7); - ASSERT_TRUE(almost_eq(winners[0].second,100)); - ASSERT_TRUE(winners[1].first == 1); - ASSERT_TRUE(almost_eq(winners[1].second,32)); - ASSERT_TRUE(winners[2].first == 2); - ASSERT_TRUE(almost_eq(winners[2].second,27)); - ASSERT_TRUE(winners[3].first == 15); - ASSERT_TRUE(almost_eq(winners[3].second,20.5)); - ASSERT_TRUE(winners[4].first == 17); - ASSERT_TRUE(almost_eq(winners[4].second,19.5)); - ASSERT_TRUE(winners[5].first == 22); - ASSERT_TRUE(almost_eq(winners[5].second,1)); - + for (UInt i = 0; i < numTrials; i++) { + vector overlaps; + sp.calculateOverlap_(inputs[i], overlaps); + ASSERT_TRUE(check_vector_eq(trueOverlaps[i], overlaps)); } - - TEST(SpatialPoolerTest, testInhibitColumns) - { - SpatialPooler sp; - setup(sp, 10,10); - - vector overlapsReal; - vector overlaps; - vector activeColumns; - vector activeColumnsGlobal; - vector activeColumnsLocal; - Real density; - UInt inhibitionRadius; - UInt numColumns; - - density = 0.3; - numColumns = 10; - Real overlapsArray[10] = {10,21,34,4,18,3,12,5,7,1}; - - overlapsReal.assign(&overlapsArray[0],&overlapsArray[numColumns]); - sp.inhibitColumnsGlobal_(overlapsReal, density,activeColumnsGlobal); - overlapsReal.assign(&overlapsArray[0],&overlapsArray[numColumns]); - sp.inhibitColumnsLocal_(overlapsReal, density, activeColumnsLocal); - - sp.setInhibitionRadius(5); - sp.setGlobalInhibition(true); - sp.setLocalAreaDensity(density); - - overlaps.assign(&overlapsArray[0],&overlapsArray[numColumns]); - sp.inhibitColumns_(overlaps, activeColumns); - - ASSERT_TRUE(check_vector_eq(activeColumns, activeColumnsGlobal)); - ASSERT_TRUE(!check_vector_eq(activeColumns, activeColumnsLocal)); - - sp.setGlobalInhibition(false); - sp.setInhibitionRadius(numColumns + 1); - - overlaps.assign(&overlapsArray[0],&overlapsArray[numColumns]); - sp.inhibitColumns_(overlaps, activeColumns); - - ASSERT_TRUE(check_vector_eq(activeColumns, activeColumnsGlobal)); - ASSERT_TRUE(!check_vector_eq(activeColumns, activeColumnsLocal)); - - inhibitionRadius = 2; - density = 2.0 / 5; - - sp.setInhibitionRadius(inhibitionRadius); - sp.setNumActiveColumnsPerInhArea(2); - - overlapsReal.assign(&overlapsArray[0], &overlapsArray[numColumns]); - sp.inhibitColumnsGlobal_(overlapsReal, density,activeColumnsGlobal); - overlapsReal.assign(&overlapsArray[0], &overlapsArray[numColumns]); - sp.inhibitColumnsLocal_(overlapsReal, density, activeColumnsLocal); - - overlaps.assign(&overlapsArray[0],&overlapsArray[numColumns]); - sp.inhibitColumns_(overlaps, activeColumns); - - ASSERT_TRUE(!check_vector_eq(activeColumns, activeColumnsGlobal)); - ASSERT_TRUE(check_vector_eq(activeColumns, activeColumnsLocal)); +} + +TEST(SpatialPoolerTest, testCalculateOverlapPct) { + SpatialPooler sp; + UInt numInputs = 10; + UInt numColumns = 5; + UInt numTrials = 5; + setup(sp, numInputs, numColumns); + sp.setStimulusThreshold(0); + + Real permArr[5][10] = {{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {0, 0, 1, 1, 1, 1, 1, 1, 1, 1}, + {0, 0, 0, 0, 1, 1, 1, 1, 1, 1}, + {0, 0, 0, 0, 0, 0, 1, 1, 1, 1}, + {0, 0, 0, 0, 0, 0, 0, 0, 1, 1}}; + + UInt overlapsArr[5][10] = {{0, 0, 0, 0, 0}, + {10, 8, 6, 4, 2}, + {5, 4, 3, 2, 1}, + {5, 3, 1, 0, 0}, + {1, 1, 1, 1, 1}}; + + Real trueOverlapsPct[5][5] = {{0.0, 0.0, 0.0, 0.0, 0.0}, + {1.0, 1.0, 1.0, 1.0, 1.0}, + {0.5, 0.5, 0.5, 0.5, 0.5}, + {0.5, 3.0 / 8, 1.0 / 6, 0, 0}, + {1.0 / 10, 1.0 / 8, 1.0 / 6, 1.0 / 4, 1.0 / 2}}; + + for (UInt i = 0; i < numColumns; i++) { + sp.setPermanence(i, permArr[i]); } - TEST(SpatialPoolerTest, testInhibitColumnsGlobal) - { - SpatialPooler sp; - UInt numInputs = 10; - UInt numColumns = 10; - setup(sp,numInputs,numColumns); - vector overlaps; - vector activeColumns; - vector trueActive; - vector active; - Real density; - - density = 0.3; - Real overlapsArray[10] = {1,2,1,4,8,3,12,5,4,1}; - overlaps.assign(&overlapsArray[0],&overlapsArray[numColumns]); - sp.inhibitColumnsGlobal_(overlaps,density,activeColumns); - UInt trueActiveArray1[3] = {4,6,7}; - - trueActive.assign(numColumns, 0); - active.assign(numColumns, 0); - - for (auto & elem : trueActiveArray1) { - trueActive[elem] = 1; - } - - for (auto & activeColumn : activeColumns) { - active[activeColumn] = 1; - } - - ASSERT_TRUE(check_vector_eq(trueActive,active)); - + for (UInt i = 0; i < numTrials; i++) { + vector overlapsPct; + vector overlaps; + overlaps.assign(&overlapsArr[i][0], &overlapsArr[i][numColumns]); + sp.calculateOverlapPct_(overlaps, overlapsPct); + ASSERT_TRUE(check_vector_eq(trueOverlapsPct[i], overlapsPct)); + } +} + +TEST(SpatialPoolerTest, testIsWinner) { + UInt numInputs = 10; + UInt numColumns = 5; + SpatialPooler sp({numInputs}, {numColumns}); + + vector> winners; + + UInt numWinners = 3; + Real score = -5; + ASSERT_FALSE(sp.isWinner_(score, winners, numWinners)); + score = 0; + ASSERT_TRUE(sp.isWinner_(score, winners, numWinners)); + + pair sc1; + sc1.first = 1; + sc1.second = 32; + pair sc2; + sc2.first = 2; + sc2.second = 27; + pair sc3; + sc3.first = 17; + sc3.second = 19.5; + winners.push_back(sc1); + winners.push_back(sc2); + winners.push_back(sc3); + + numWinners = 3; + score = -5; + ASSERT_TRUE(!sp.isWinner_(score, winners, numWinners)); + score = 18; + ASSERT_TRUE(!sp.isWinner_(score, winners, numWinners)); + score = 18; + numWinners = 4; + ASSERT_TRUE(sp.isWinner_(score, winners, numWinners)); + numWinners = 3; + score = 20; + ASSERT_TRUE(sp.isWinner_(score, winners, numWinners)); + score = 30; + ASSERT_TRUE(sp.isWinner_(score, winners, numWinners)); + score = 40; + ASSERT_TRUE(sp.isWinner_(score, winners, numWinners)); + score = 40; + numWinners = 6; + ASSERT_TRUE(sp.isWinner_(score, winners, numWinners)); + + pair sc4; + sc4.first = 34; + sc4.second = 17.1; + pair sc5; + sc5.first = 51; + sc5.second = 1.2; + pair sc6; + sc6.first = 19; + sc6.second = 0.3; + winners.push_back(sc4); + winners.push_back(sc5); + winners.push_back(sc6); + + score = 40; + numWinners = 6; + ASSERT_TRUE(sp.isWinner_(score, winners, numWinners)); + score = 12; + numWinners = 6; + ASSERT_TRUE(sp.isWinner_(score, winners, numWinners)); + score = 0.1; + numWinners = 6; + ASSERT_TRUE(!sp.isWinner_(score, winners, numWinners)); + score = 0.1; + numWinners = 7; + ASSERT_TRUE(sp.isWinner_(score, winners, numWinners)); +} + +TEST(SpatialPoolerTest, testAddToWinners) { + SpatialPooler sp; + vector> winners; + + UInt index; + Real score; + + index = 17; + score = 19.5; + sp.addToWinners_(index, score, winners); + index = 1; + score = 32; + sp.addToWinners_(index, score, winners); + index = 2; + score = 27; + sp.addToWinners_(index, score, winners); + + ASSERT_TRUE(winners[0].first == 1); + ASSERT_TRUE(almost_eq(winners[0].second, 32)); + ASSERT_TRUE(winners[1].first == 2); + ASSERT_TRUE(almost_eq(winners[1].second, 27)); + ASSERT_TRUE(winners[2].first == 17); + ASSERT_TRUE(almost_eq(winners[2].second, 19.5)); + + index = 15; + score = 20.5; + sp.addToWinners_(index, score, winners); + ASSERT_TRUE(winners[0].first == 1); + ASSERT_TRUE(almost_eq(winners[0].second, 32)); + ASSERT_TRUE(winners[1].first == 2); + ASSERT_TRUE(almost_eq(winners[1].second, 27)); + ASSERT_TRUE(winners[2].first == 15); + ASSERT_TRUE(almost_eq(winners[2].second, 20.5)); + ASSERT_TRUE(winners[3].first == 17); + ASSERT_TRUE(almost_eq(winners[3].second, 19.5)); + + index = 7; + score = 100; + sp.addToWinners_(index, score, winners); + ASSERT_TRUE(winners[0].first == 7); + ASSERT_TRUE(almost_eq(winners[0].second, 100)); + ASSERT_TRUE(winners[1].first == 1); + ASSERT_TRUE(almost_eq(winners[1].second, 32)); + ASSERT_TRUE(winners[2].first == 2); + ASSERT_TRUE(almost_eq(winners[2].second, 27)); + ASSERT_TRUE(winners[3].first == 15); + ASSERT_TRUE(almost_eq(winners[3].second, 20.5)); + ASSERT_TRUE(winners[4].first == 17); + ASSERT_TRUE(almost_eq(winners[4].second, 19.5)); + + index = 22; + score = 1; + sp.addToWinners_(index, score, winners); + ASSERT_TRUE(winners[0].first == 7); + ASSERT_TRUE(almost_eq(winners[0].second, 100)); + ASSERT_TRUE(winners[1].first == 1); + ASSERT_TRUE(almost_eq(winners[1].second, 32)); + ASSERT_TRUE(winners[2].first == 2); + ASSERT_TRUE(almost_eq(winners[2].second, 27)); + ASSERT_TRUE(winners[3].first == 15); + ASSERT_TRUE(almost_eq(winners[3].second, 20.5)); + ASSERT_TRUE(winners[4].first == 17); + ASSERT_TRUE(almost_eq(winners[4].second, 19.5)); + ASSERT_TRUE(winners[5].first == 22); + ASSERT_TRUE(almost_eq(winners[5].second, 1)); +} + +TEST(SpatialPoolerTest, testInhibitColumns) { + SpatialPooler sp; + setup(sp, 10, 10); + + vector overlapsReal; + vector overlaps; + vector activeColumns; + vector activeColumnsGlobal; + vector activeColumnsLocal; + Real density; + UInt inhibitionRadius; + UInt numColumns; + + density = 0.3; + numColumns = 10; + Real overlapsArray[10] = {10, 21, 34, 4, 18, 3, 12, 5, 7, 1}; + + overlapsReal.assign(&overlapsArray[0], &overlapsArray[numColumns]); + sp.inhibitColumnsGlobal_(overlapsReal, density, activeColumnsGlobal); + overlapsReal.assign(&overlapsArray[0], &overlapsArray[numColumns]); + sp.inhibitColumnsLocal_(overlapsReal, density, activeColumnsLocal); + + sp.setInhibitionRadius(5); + sp.setGlobalInhibition(true); + sp.setLocalAreaDensity(density); + + overlaps.assign(&overlapsArray[0], &overlapsArray[numColumns]); + sp.inhibitColumns_(overlaps, activeColumns); + + ASSERT_TRUE(check_vector_eq(activeColumns, activeColumnsGlobal)); + ASSERT_TRUE(!check_vector_eq(activeColumns, activeColumnsLocal)); + + sp.setGlobalInhibition(false); + sp.setInhibitionRadius(numColumns + 1); + + overlaps.assign(&overlapsArray[0], &overlapsArray[numColumns]); + sp.inhibitColumns_(overlaps, activeColumns); + + ASSERT_TRUE(check_vector_eq(activeColumns, activeColumnsGlobal)); + ASSERT_TRUE(!check_vector_eq(activeColumns, activeColumnsLocal)); + + inhibitionRadius = 2; + density = 2.0 / 5; + + sp.setInhibitionRadius(inhibitionRadius); + sp.setNumActiveColumnsPerInhArea(2); + + overlapsReal.assign(&overlapsArray[0], &overlapsArray[numColumns]); + sp.inhibitColumnsGlobal_(overlapsReal, density, activeColumnsGlobal); + overlapsReal.assign(&overlapsArray[0], &overlapsArray[numColumns]); + sp.inhibitColumnsLocal_(overlapsReal, density, activeColumnsLocal); + + overlaps.assign(&overlapsArray[0], &overlapsArray[numColumns]); + sp.inhibitColumns_(overlaps, activeColumns); + + ASSERT_TRUE(!check_vector_eq(activeColumns, activeColumnsGlobal)); + ASSERT_TRUE(check_vector_eq(activeColumns, activeColumnsLocal)); +} + +TEST(SpatialPoolerTest, testInhibitColumnsGlobal) { + SpatialPooler sp; + UInt numInputs = 10; + UInt numColumns = 10; + setup(sp, numInputs, numColumns); + vector overlaps; + vector activeColumns; + vector trueActive; + vector active; + Real density; + + density = 0.3; + Real overlapsArray[10] = {1, 2, 1, 4, 8, 3, 12, 5, 4, 1}; + overlaps.assign(&overlapsArray[0], &overlapsArray[numColumns]); + sp.inhibitColumnsGlobal_(overlaps, density, activeColumns); + UInt trueActiveArray1[3] = {4, 6, 7}; + + trueActive.assign(numColumns, 0); + active.assign(numColumns, 0); + + for (auto &elem : trueActiveArray1) { + trueActive[elem] = 1; + } - density = 0.5; - UInt overlapsArray2[10] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; - overlaps.assign(&overlapsArray2[0],&overlapsArray2[numColumns]); - sp.inhibitColumnsGlobal_(overlaps, density, activeColumns); - UInt trueActiveArray2[5] = {5,6,7,8,9}; + for (auto &activeColumn : activeColumns) { + active[activeColumn] = 1; + } - for (auto & elem : trueActiveArray2) { - trueActive[elem] = 1; - } + ASSERT_TRUE(check_vector_eq(trueActive, active)); - for (auto & activeColumn : activeColumns) { - active[activeColumn] = 1; - } + density = 0.5; + UInt overlapsArray2[10] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + overlaps.assign(&overlapsArray2[0], &overlapsArray2[numColumns]); + sp.inhibitColumnsGlobal_(overlaps, density, activeColumns); + UInt trueActiveArray2[5] = {5, 6, 7, 8, 9}; - ASSERT_TRUE(check_vector_eq(trueActive,active)); + for (auto &elem : trueActiveArray2) { + trueActive[elem] = 1; } - TEST(SpatialPoolerTest, testValidateGlobalInhibitionParameters) { - // With 10 columns the minimum sparsity for global inhibition is 10% - // Setting sparsity to 2% should throw an exception - SpatialPooler sp; - setup(sp, 10, 10); - sp.setGlobalInhibition(true); - sp.setLocalAreaDensity(0.02); - vector input(sp.getNumInputs(), 1); - vector out1(sp.getNumColumns(), 0); - EXPECT_THROW(sp.compute(input.data(), false, out1.data()), nupic::LoggingException); + for (auto &activeColumn : activeColumns) { + active[activeColumn] = 1; } - TEST(SpatialPoolerTest, testInhibitColumnsLocal) + ASSERT_TRUE(check_vector_eq(trueActive, active)); +} + +TEST(SpatialPoolerTest, testValidateGlobalInhibitionParameters) { + // With 10 columns the minimum sparsity for global inhibition is 10% + // Setting sparsity to 2% should throw an exception + SpatialPooler sp; + setup(sp, 10, 10); + sp.setGlobalInhibition(true); + sp.setLocalAreaDensity(0.02); + vector input(sp.getNumInputs(), 1); + vector out1(sp.getNumColumns(), 0); + EXPECT_THROW(sp.compute(input.data(), false, out1.data()), + nupic::LoggingException); +} + +TEST(SpatialPoolerTest, testInhibitColumnsLocal) { + // wrapAround = false { - // wrapAround = false - { - SpatialPooler sp( - /*inputDimensions*/{10}, + SpatialPooler sp( + /*inputDimensions*/ {10}, /*columnDimensions*/ {10}, /*potentialRadius*/ 16, /*potentialPct*/ 0.5, @@ -1498,54 +1388,54 @@ namespace { /*spVerbosity*/ 0, /*wrapAround*/ false); - Real density; - UInt inhibitionRadius; - - vector overlaps; - vector active; - - Real overlapsArray1[10] = { 1, 2, 7, 0, 3, 4, 16, 1, 1.5, 1.7}; - // L W W L L W W L L W - - inhibitionRadius = 2; - density = 0.5; - overlaps.assign(&overlapsArray1[0], &overlapsArray1[10]); - UInt trueActive[5] = {1, 2, 5, 6, 9}; - sp.setInhibitionRadius(inhibitionRadius); - sp.inhibitColumnsLocal_(overlaps, density, active); - ASSERT_EQ(5, active.size()); - ASSERT_TRUE(check_vector_eq(trueActive, active)); - - Real overlapsArray2[10] = {1, 2, 7, 0, 3, 4, 16, 1, 1.5, 1.7}; - // L W W L L W W L L W - overlaps.assign(&overlapsArray2[0], &overlapsArray2[10]); - UInt trueActive2[6] = {1, 2, 4, 5, 6, 9}; - inhibitionRadius = 3; - density = 0.5; - sp.setInhibitionRadius(inhibitionRadius); - sp.inhibitColumnsLocal_(overlaps, density, active); - ASSERT_TRUE(active.size() == 6); - ASSERT_TRUE(check_vector_eq(trueActive2, active)); - - // Test arbitration - - Real overlapsArray3[10] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - // W L W L W L W L L L - overlaps.assign(&overlapsArray3[0], &overlapsArray3[10]); - UInt trueActive3[4] = {0, 2, 4, 6}; - inhibitionRadius = 3; - density = 0.25; - sp.setInhibitionRadius(inhibitionRadius); - sp.inhibitColumnsLocal_(overlaps, density, active); - - ASSERT_TRUE(active.size() == 4); - ASSERT_TRUE(check_vector_eq(trueActive3, active)); - } + Real density; + UInt inhibitionRadius; + + vector overlaps; + vector active; + + Real overlapsArray1[10] = {1, 2, 7, 0, 3, 4, 16, 1, 1.5, 1.7}; + // L W W L L W W L L W - // wrapAround = true - { - SpatialPooler sp( - /*inputDimensions*/{10}, + inhibitionRadius = 2; + density = 0.5; + overlaps.assign(&overlapsArray1[0], &overlapsArray1[10]); + UInt trueActive[5] = {1, 2, 5, 6, 9}; + sp.setInhibitionRadius(inhibitionRadius); + sp.inhibitColumnsLocal_(overlaps, density, active); + ASSERT_EQ(5, active.size()); + ASSERT_TRUE(check_vector_eq(trueActive, active)); + + Real overlapsArray2[10] = {1, 2, 7, 0, 3, 4, 16, 1, 1.5, 1.7}; + // L W W L L W W L L W + overlaps.assign(&overlapsArray2[0], &overlapsArray2[10]); + UInt trueActive2[6] = {1, 2, 4, 5, 6, 9}; + inhibitionRadius = 3; + density = 0.5; + sp.setInhibitionRadius(inhibitionRadius); + sp.inhibitColumnsLocal_(overlaps, density, active); + ASSERT_TRUE(active.size() == 6); + ASSERT_TRUE(check_vector_eq(trueActive2, active)); + + // Test arbitration + + Real overlapsArray3[10] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + // W L W L W L W L L L + overlaps.assign(&overlapsArray3[0], &overlapsArray3[10]); + UInt trueActive3[4] = {0, 2, 4, 6}; + inhibitionRadius = 3; + density = 0.25; + sp.setInhibitionRadius(inhibitionRadius); + sp.inhibitColumnsLocal_(overlaps, density, active); + + ASSERT_TRUE(active.size() == 4); + ASSERT_TRUE(check_vector_eq(trueActive3, active)); + } + + // wrapAround = true + { + SpatialPooler sp( + /*inputDimensions*/ {10}, /*columnDimensions*/ {10}, /*potentialRadius*/ 16, /*potentialPct*/ 0.5, @@ -1563,730 +1453,671 @@ namespace { /*spVerbosity*/ 0, /*wrapAround*/ true); - Real density; - UInt inhibitionRadius; - - vector overlaps; - vector active; - - Real overlapsArray1[10] = { 1, 2, 7, 0, 3, 4, 16, 1, 1.5, 1.7}; - // L W W L L W W L W W - - inhibitionRadius = 2; - density = 0.5; - overlaps.assign(&overlapsArray1[0], &overlapsArray1[10]); - UInt trueActive[6] = {1, 2, 5, 6, 8, 9}; - sp.setInhibitionRadius(inhibitionRadius); - sp.inhibitColumnsLocal_(overlaps, density, active); - ASSERT_EQ(6, active.size()); - ASSERT_TRUE(check_vector_eq(trueActive, active)); - - Real overlapsArray2[10] = {1, 2, 7, 0, 3, 4, 16, 1, 1.5, 1.7}; - // L W W L W W W L L W - overlaps.assign(&overlapsArray2[0], &overlapsArray2[10]); - UInt trueActive2[6] = {1, 2, 4, 5, 6, 9}; - inhibitionRadius = 3; - density = 0.5; - sp.setInhibitionRadius(inhibitionRadius); - sp.inhibitColumnsLocal_(overlaps, density, active); - ASSERT_TRUE(active.size() == 6); - ASSERT_TRUE(check_vector_eq(trueActive2, active)); - - // Test arbitration - - Real overlapsArray3[10] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - // W W L L W W L L L W - overlaps.assign(&overlapsArray3[0], &overlapsArray3[10]); - UInt trueActive3[4] = {0, 1, 4, 5}; - inhibitionRadius = 3; - density = 0.25; - sp.setInhibitionRadius(inhibitionRadius); - sp.inhibitColumnsLocal_(overlaps, density, active); - - ASSERT_TRUE(active.size() == 4); - ASSERT_TRUE(check_vector_eq(trueActive3, active)); - } - } + Real density; + UInt inhibitionRadius; - TEST(SpatialPoolerTest, testIsUpdateRound) - { - SpatialPooler sp; - sp.setUpdatePeriod(50); - sp.setIterationNum(1); - ASSERT_TRUE(!sp.isUpdateRound_()); - sp.setIterationNum(39); - ASSERT_TRUE(!sp.isUpdateRound_()); - sp.setIterationNum(50); - ASSERT_TRUE(sp.isUpdateRound_()); - sp.setIterationNum(1009); - ASSERT_TRUE(!sp.isUpdateRound_()); - sp.setIterationNum(1250); - ASSERT_TRUE(sp.isUpdateRound_()); - - sp.setUpdatePeriod(125); - sp.setIterationNum(0); - ASSERT_TRUE(sp.isUpdateRound_()); - sp.setIterationNum(200); - ASSERT_TRUE(!sp.isUpdateRound_()); - sp.setIterationNum(249); - ASSERT_TRUE(!sp.isUpdateRound_()); - sp.setIterationNum(1330); - ASSERT_TRUE(!sp.isUpdateRound_()); - sp.setIterationNum(1249); - ASSERT_TRUE(!sp.isUpdateRound_()); - sp.setIterationNum(1375); - ASSERT_TRUE(sp.isUpdateRound_()); + vector overlaps; + vector active; - } + Real overlapsArray1[10] = {1, 2, 7, 0, 3, 4, 16, 1, 1.5, 1.7}; + // L W W L L W W L W W - TEST(SpatialPoolerTest, testRaisePermanencesToThreshold) - { - SpatialPooler sp; - UInt stimulusThreshold = 3; - Real synPermConnected = 0.1; - Real synPermBelowStimulusInc = 0.01; - UInt numInputs = 5; - UInt numColumns = 7; - setup(sp,numInputs,numColumns); - sp.setStimulusThreshold(stimulusThreshold); - sp.setSynPermConnected(synPermConnected); - sp.setSynPermBelowStimulusInc(synPermBelowStimulusInc); - - UInt potentialArr[7][5] = - {{ 1, 1, 1, 1, 1 }, - { 1, 1, 1, 1, 1 }, - { 1, 1, 1, 1, 1 }, - { 1, 1, 1, 1, 1 }, - { 1, 1, 1, 1, 1 }, - { 1, 1, 0, 0, 1 }, - { 0, 1, 1, 1, 0 }}; - - - Real permArr[7][5] = - {{ 0.0, 0.11, 0.095, 0.092, 0.01 }, - { 0.12, 0.15, 0.02, 0.12, 0.09 }, - { 0.51, 0.081, 0.025, 0.089, 0.31 }, - { 0.18, 0.0601, 0.11, 0.011, 0.03 }, - { 0.011, 0.011, 0.011, 0.011, 0.011 }, - { 0.12, 0.056, 0, 0, 0.078 }, - { 0, 0.061, 0.07, 0.14, 0 }}; - - Real truePerm[7][5] = - {{ 0.01, 0.12, 0.105, 0.102, 0.02 }, // incremented once - { 0.12, 0.15, 0.02, 0.12, 0.09 }, // no change - { 0.53, 0.101, 0.045, 0.109, 0.33 }, // increment twice - { 0.22, 0.1001, 0.15, 0.051, 0.07 }, // increment four times - { 0.101, 0.101, 0.101, 0.101, 0.101 }, // increment 9 times - { 0.17, 0.106, 0, 0, 0.128 }, // increment 5 times - { 0, 0.101, 0.11, 0.18, 0 }}; // increment 4 times - - - UInt trueConnectedCount[7] = - {3, 3, 4, 3, 5, 3, 3}; - - for (UInt i = 0; i < numColumns; i++) - { - vector perm; - vector potential; - perm.assign(&permArr[i][0],&permArr[i][numInputs]); - for (UInt j = 0; j < numInputs; j++) { - if (potentialArr[i][j] > 0) { - potential.push_back(j); - } - } - UInt connected = - sp.raisePermanencesToThreshold_(perm, potential); - ASSERT_TRUE(check_vector_eq(truePerm[i],perm)); - ASSERT_TRUE(connected == trueConnectedCount[i]); - } + inhibitionRadius = 2; + density = 0.5; + overlaps.assign(&overlapsArray1[0], &overlapsArray1[10]); + UInt trueActive[6] = {1, 2, 5, 6, 8, 9}; + sp.setInhibitionRadius(inhibitionRadius); + sp.inhibitColumnsLocal_(overlaps, density, active); + ASSERT_EQ(6, active.size()); + ASSERT_TRUE(check_vector_eq(trueActive, active)); + + Real overlapsArray2[10] = {1, 2, 7, 0, 3, 4, 16, 1, 1.5, 1.7}; + // L W W L W W W L L W + overlaps.assign(&overlapsArray2[0], &overlapsArray2[10]); + UInt trueActive2[6] = {1, 2, 4, 5, 6, 9}; + inhibitionRadius = 3; + density = 0.5; + sp.setInhibitionRadius(inhibitionRadius); + sp.inhibitColumnsLocal_(overlaps, density, active); + ASSERT_TRUE(active.size() == 6); + ASSERT_TRUE(check_vector_eq(trueActive2, active)); + + // Test arbitration + + Real overlapsArray3[10] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + // W W L L W W L L L W + overlaps.assign(&overlapsArray3[0], &overlapsArray3[10]); + UInt trueActive3[4] = {0, 1, 4, 5}; + inhibitionRadius = 3; + density = 0.25; + sp.setInhibitionRadius(inhibitionRadius); + sp.inhibitColumnsLocal_(overlaps, density, active); + ASSERT_TRUE(active.size() == 4); + ASSERT_TRUE(check_vector_eq(trueActive3, active)); } - - TEST(SpatialPoolerTest, testUpdatePermanencesForColumn) - { - vector inputDim; - vector columnDim; - - UInt numInputs = 5; - UInt numColumns = 5; - SpatialPooler sp; - setup(sp,numInputs,numColumns); - Real synPermTrimThreshold = 0.05; - sp.setSynPermTrimThreshold(synPermTrimThreshold); - - Real permArr[5][5] = - {{ -0.10, 0.500, 0.400, 0.010, 0.020 }, - { 0.300, 0.010, 0.020, 0.120, 0.090 }, - { 0.070, 0.050, 1.030, 0.190, 0.060 }, - { 0.180, 0.090, 0.110, 0.010, 0.030 }, - { 0.200, 0.101, 0.050, -0.09, 1.100 }}; - - Real truePerm[5][5] = - {{ 0.000, 0.500, 0.400, 0.000, 0.000}, - // Clip - - Trim Trim - {0.300, 0.000, 0.000, 0.120, 0.090}, - // - Trim Trim - - - {0.070, 0.050, 1.000, 0.190, 0.060}, - // - - Clip - - - {0.180, 0.090, 0.110, 0.000, 0.000}, - // - - - Trim Trim - {0.200, 0.101, 0.050, 0.000, 1.000}}; - // - - - Clip Clip - - UInt trueConnectedSynapses[5][5] = - {{0, 1, 1, 0, 0}, - {1, 0, 0, 1, 0}, - {0, 0, 1, 1, 0}, - {1, 0, 1, 0, 0}, - {1, 1, 0, 0, 1 }}; - - UInt trueConnectedCount[5] = {2, 2, 2, 2, 3}; - - for (UInt i = 0; i < 5; i ++) - { - vector perm(&permArr[i][0], &permArr[i][5]); - sp.updatePermanencesForColumn_(perm, i, false); - auto permArr = new Real[numInputs]; - auto connectedArr = new UInt[numInputs]; - auto connectedCountsArr = new UInt[numColumns]; - sp.getPermanence(i, permArr); - sp.getConnectedSynapses(i, connectedArr); - sp.getConnectedCounts(connectedCountsArr); - ASSERT_TRUE(check_vector_eq(truePerm[i], permArr, numInputs)); - ASSERT_TRUE(check_vector_eq(trueConnectedSynapses[i],connectedArr, - numInputs)); - ASSERT_TRUE(trueConnectedCount[i] == connectedCountsArr[i]); - delete[] permArr; - delete[] connectedArr; - delete[] connectedCountsArr; +} + +TEST(SpatialPoolerTest, testIsUpdateRound) { + SpatialPooler sp; + sp.setUpdatePeriod(50); + sp.setIterationNum(1); + ASSERT_TRUE(!sp.isUpdateRound_()); + sp.setIterationNum(39); + ASSERT_TRUE(!sp.isUpdateRound_()); + sp.setIterationNum(50); + ASSERT_TRUE(sp.isUpdateRound_()); + sp.setIterationNum(1009); + ASSERT_TRUE(!sp.isUpdateRound_()); + sp.setIterationNum(1250); + ASSERT_TRUE(sp.isUpdateRound_()); + + sp.setUpdatePeriod(125); + sp.setIterationNum(0); + ASSERT_TRUE(sp.isUpdateRound_()); + sp.setIterationNum(200); + ASSERT_TRUE(!sp.isUpdateRound_()); + sp.setIterationNum(249); + ASSERT_TRUE(!sp.isUpdateRound_()); + sp.setIterationNum(1330); + ASSERT_TRUE(!sp.isUpdateRound_()); + sp.setIterationNum(1249); + ASSERT_TRUE(!sp.isUpdateRound_()); + sp.setIterationNum(1375); + ASSERT_TRUE(sp.isUpdateRound_()); +} + +TEST(SpatialPoolerTest, testRaisePermanencesToThreshold) { + SpatialPooler sp; + UInt stimulusThreshold = 3; + Real synPermConnected = 0.1; + Real synPermBelowStimulusInc = 0.01; + UInt numInputs = 5; + UInt numColumns = 7; + setup(sp, numInputs, numColumns); + sp.setStimulusThreshold(stimulusThreshold); + sp.setSynPermConnected(synPermConnected); + sp.setSynPermBelowStimulusInc(synPermBelowStimulusInc); + + UInt potentialArr[7][5] = {{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, + {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 0, 0, 1}, + {0, 1, 1, 1, 0}}; + + Real permArr[7][5] = {{0.0, 0.11, 0.095, 0.092, 0.01}, + {0.12, 0.15, 0.02, 0.12, 0.09}, + {0.51, 0.081, 0.025, 0.089, 0.31}, + {0.18, 0.0601, 0.11, 0.011, 0.03}, + {0.011, 0.011, 0.011, 0.011, 0.011}, + {0.12, 0.056, 0, 0, 0.078}, + {0, 0.061, 0.07, 0.14, 0}}; + + Real truePerm[7][5] = { + {0.01, 0.12, 0.105, 0.102, 0.02}, // incremented once + {0.12, 0.15, 0.02, 0.12, 0.09}, // no change + {0.53, 0.101, 0.045, 0.109, 0.33}, // increment twice + {0.22, 0.1001, 0.15, 0.051, 0.07}, // increment four times + {0.101, 0.101, 0.101, 0.101, 0.101}, // increment 9 times + {0.17, 0.106, 0, 0, 0.128}, // increment 5 times + {0, 0.101, 0.11, 0.18, 0}}; // increment 4 times + + UInt trueConnectedCount[7] = {3, 3, 4, 3, 5, 3, 3}; + + for (UInt i = 0; i < numColumns; i++) { + vector perm; + vector potential; + perm.assign(&permArr[i][0], &permArr[i][numInputs]); + for (UInt j = 0; j < numInputs; j++) { + if (potentialArr[i][j] > 0) { + potential.push_back(j); + } } - + UInt connected = sp.raisePermanencesToThreshold_(perm, potential); + ASSERT_TRUE(check_vector_eq(truePerm[i], perm)); + ASSERT_TRUE(connected == trueConnectedCount[i]); } - - TEST(SpatialPoolerTest, testInitPermanence) - { - vector inputDim; - vector columnDim; - inputDim.push_back(8); - columnDim.push_back(2); - - SpatialPooler sp; - Real synPermConnected = 0.2; - Real synPermTrimThreshold = 0.1; - Real synPermActiveInc = 0.05; - sp.initialize(inputDim,columnDim); - sp.setSynPermConnected(synPermConnected); - sp.setSynPermTrimThreshold(synPermTrimThreshold); - sp.setSynPermActiveInc(synPermActiveInc); - - UInt arr[8] = { 0, 1, 1 , 0, 0, 1, 0, 1 }; - vector potential(&arr[0], &arr[8]); - vector perm = sp.initPermanence_(potential, 1.0); - for (UInt i = 0; i < 8; i++) - if (potential[i]) - ASSERT_TRUE(perm[i] >= synPermConnected); - else - ASSERT_TRUE(perm[i] < 1e-5); - - perm = sp.initPermanence_(potential, 0); - for (UInt i = 0; i < 8; i++) - if (potential[i]) - ASSERT_LE(perm[i], synPermConnected); - else - ASSERT_LT(perm[i], 1e-5); - - inputDim[0] = 100; - sp.initialize(inputDim,columnDim); - sp.setSynPermConnected(synPermConnected); - sp.setSynPermTrimThreshold(synPermTrimThreshold); - sp.setSynPermActiveInc(synPermActiveInc); - potential.clear(); - - for(UInt i = 0; i < 100; i++) - potential.push_back(1); - - perm = sp.initPermanence_(potential, 0.5); - int count = 0; - for (UInt i = 0; i < 100; i++) - { - ASSERT_TRUE(perm[i] < 1e-5 || perm[i] >= synPermTrimThreshold); - if (perm[i] >= synPermConnected) - count++; - } - ASSERT_TRUE(count > 5 && count < 95); +} + +TEST(SpatialPoolerTest, testUpdatePermanencesForColumn) { + vector inputDim; + vector columnDim; + + UInt numInputs = 5; + UInt numColumns = 5; + SpatialPooler sp; + setup(sp, numInputs, numColumns); + Real synPermTrimThreshold = 0.05; + sp.setSynPermTrimThreshold(synPermTrimThreshold); + + Real permArr[5][5] = {{-0.10, 0.500, 0.400, 0.010, 0.020}, + {0.300, 0.010, 0.020, 0.120, 0.090}, + {0.070, 0.050, 1.030, 0.190, 0.060}, + {0.180, 0.090, 0.110, 0.010, 0.030}, + {0.200, 0.101, 0.050, -0.09, 1.100}}; + + Real truePerm[5][5] = {{0.000, 0.500, 0.400, 0.000, 0.000}, + // Clip - - Trim Trim + {0.300, 0.000, 0.000, 0.120, 0.090}, + // - Trim Trim - - + {0.070, 0.050, 1.000, 0.190, 0.060}, + // - - Clip - - + {0.180, 0.090, 0.110, 0.000, 0.000}, + // - - - Trim Trim + {0.200, 0.101, 0.050, 0.000, 1.000}}; + // - - - Clip Clip + + UInt trueConnectedSynapses[5][5] = {{0, 1, 1, 0, 0}, + {1, 0, 0, 1, 0}, + {0, 0, 1, 1, 0}, + {1, 0, 1, 0, 0}, + {1, 1, 0, 0, 1}}; + + UInt trueConnectedCount[5] = {2, 2, 2, 2, 3}; + + for (UInt i = 0; i < 5; i++) { + vector perm(&permArr[i][0], &permArr[i][5]); + sp.updatePermanencesForColumn_(perm, i, false); + auto permArr = new Real[numInputs]; + auto connectedArr = new UInt[numInputs]; + auto connectedCountsArr = new UInt[numColumns]; + sp.getPermanence(i, permArr); + sp.getConnectedSynapses(i, connectedArr); + sp.getConnectedCounts(connectedCountsArr); + ASSERT_TRUE(check_vector_eq(truePerm[i], permArr, numInputs)); + ASSERT_TRUE( + check_vector_eq(trueConnectedSynapses[i], connectedArr, numInputs)); + ASSERT_TRUE(trueConnectedCount[i] == connectedCountsArr[i]); + delete[] permArr; + delete[] connectedArr; + delete[] connectedCountsArr; + } +} + +TEST(SpatialPoolerTest, testInitPermanence) { + vector inputDim; + vector columnDim; + inputDim.push_back(8); + columnDim.push_back(2); + + SpatialPooler sp; + Real synPermConnected = 0.2; + Real synPermTrimThreshold = 0.1; + Real synPermActiveInc = 0.05; + sp.initialize(inputDim, columnDim); + sp.setSynPermConnected(synPermConnected); + sp.setSynPermTrimThreshold(synPermTrimThreshold); + sp.setSynPermActiveInc(synPermActiveInc); + + UInt arr[8] = {0, 1, 1, 0, 0, 1, 0, 1}; + vector potential(&arr[0], &arr[8]); + vector perm = sp.initPermanence_(potential, 1.0); + for (UInt i = 0; i < 8; i++) + if (potential[i]) + ASSERT_TRUE(perm[i] >= synPermConnected); + else + ASSERT_TRUE(perm[i] < 1e-5); + + perm = sp.initPermanence_(potential, 0); + for (UInt i = 0; i < 8; i++) + if (potential[i]) + ASSERT_LE(perm[i], synPermConnected); + else + ASSERT_LT(perm[i], 1e-5); + + inputDim[0] = 100; + sp.initialize(inputDim, columnDim); + sp.setSynPermConnected(synPermConnected); + sp.setSynPermTrimThreshold(synPermTrimThreshold); + sp.setSynPermActiveInc(synPermActiveInc); + potential.clear(); + + for (UInt i = 0; i < 100; i++) + potential.push_back(1); + + perm = sp.initPermanence_(potential, 0.5); + int count = 0; + for (UInt i = 0; i < 100; i++) { + ASSERT_TRUE(perm[i] < 1e-5 || perm[i] >= synPermTrimThreshold); + if (perm[i] >= synPermConnected) + count++; } + ASSERT_TRUE(count > 5 && count < 95); +} - TEST(SpatialPoolerTest, testInitPermConnected) - { - SpatialPooler sp; - Real synPermConnected = 0.2; - Real synPermMax = 1.0; +TEST(SpatialPoolerTest, testInitPermConnected) { + SpatialPooler sp; + Real synPermConnected = 0.2; + Real synPermMax = 1.0; - sp.setSynPermConnected(synPermConnected); - sp.setSynPermMax(synPermMax); + sp.setSynPermConnected(synPermConnected); + sp.setSynPermMax(synPermMax); - for (UInt i = 0; i < 100; i++) { - Real permVal = sp.initPermConnected_(); - ASSERT_GE(permVal, synPermConnected); - ASSERT_LE(permVal, synPermMax); - } + for (UInt i = 0; i < 100; i++) { + Real permVal = sp.initPermConnected_(); + ASSERT_GE(permVal, synPermConnected); + ASSERT_LE(permVal, synPermMax); } +} + +TEST(SpatialPoolerTest, testInitPermNonConnected) { + SpatialPooler sp; + Real synPermConnected = 0.2; + sp.setSynPermConnected(synPermConnected); + for (UInt i = 0; i < 100; i++) { + Real permVal = sp.initPermNonConnected_(); + ASSERT_GE(permVal, 0); + ASSERT_LE(permVal, synPermConnected); + } +} - TEST(SpatialPoolerTest, testInitPermNonConnected) +TEST(SpatialPoolerTest, testMapColumn) { { - SpatialPooler sp; - Real synPermConnected = 0.2; - sp.setSynPermConnected(synPermConnected); - for (UInt i = 0; i < 100; i++) { - Real permVal = sp.initPermNonConnected_(); - ASSERT_GE(permVal, 0); - ASSERT_LE(permVal, synPermConnected); - } + // Test 1D. + SpatialPooler sp( + /*inputDimensions*/ {12}, + /*columnDimensions*/ {4}); + + EXPECT_EQ(1, sp.mapColumn_(0)); + EXPECT_EQ(4, sp.mapColumn_(1)); + EXPECT_EQ(7, sp.mapColumn_(2)); + EXPECT_EQ(10, sp.mapColumn_(3)); } - TEST(SpatialPoolerTest, testMapColumn) { - { - // Test 1D. - SpatialPooler sp( - /*inputDimensions*/{12}, - /*columnDimensions*/{4}); - - EXPECT_EQ(1, sp.mapColumn_(0)); - EXPECT_EQ(4, sp.mapColumn_(1)); - EXPECT_EQ(7, sp.mapColumn_(2)); - EXPECT_EQ(10, sp.mapColumn_(3)); - } - - { - // Test 1D with same dimensions of columns and inputs. - SpatialPooler sp( - /*inputDimensions*/{4}, - /*columnDimensions*/{4}); - - EXPECT_EQ(0, sp.mapColumn_(0)); - EXPECT_EQ(1, sp.mapColumn_(1)); - EXPECT_EQ(2, sp.mapColumn_(2)); - EXPECT_EQ(3, sp.mapColumn_(3)); - } - - { - // Test 1D with dimensions of length 1. - SpatialPooler sp( - /*inputDimensions*/{1}, - /*columnDimensions*/{1}); - - EXPECT_EQ(0, sp.mapColumn_(0)); - } - - { - // Test 2D. - SpatialPooler sp( - /*inputDimensions*/{36, 12}, - /*columnDimensions*/{12, 4}); - - EXPECT_EQ(13, sp.mapColumn_(0)); - EXPECT_EQ(49, sp.mapColumn_(4)); - EXPECT_EQ(52, sp.mapColumn_(5)); - EXPECT_EQ(58, sp.mapColumn_(7)); - EXPECT_EQ(418, sp.mapColumn_(47)); - } + // Test 1D with same dimensions of columns and inputs. + SpatialPooler sp( + /*inputDimensions*/ {4}, + /*columnDimensions*/ {4}); + + EXPECT_EQ(0, sp.mapColumn_(0)); + EXPECT_EQ(1, sp.mapColumn_(1)); + EXPECT_EQ(2, sp.mapColumn_(2)); + EXPECT_EQ(3, sp.mapColumn_(3)); + } - { - // Test 2D, some input dimensions smaller than column dimensions. - SpatialPooler sp( - /*inputDimensions*/{3, 5}, - /*columnDimensions*/{4, 4}); + { + // Test 1D with dimensions of length 1. + SpatialPooler sp( + /*inputDimensions*/ {1}, + /*columnDimensions*/ {1}); - EXPECT_EQ(0, sp.mapColumn_(0)); - EXPECT_EQ(4, sp.mapColumn_(3)); - EXPECT_EQ(14, sp.mapColumn_(15)); - } + EXPECT_EQ(0, sp.mapColumn_(0)); } - TEST(SpatialPoolerTest, testMapPotential1D) { - vector inputDim, columnDim; - inputDim.push_back(12); - columnDim.push_back(4); - UInt potentialRadius = 2; + // Test 2D. + SpatialPooler sp( + /*inputDimensions*/ {36, 12}, + /*columnDimensions*/ {12, 4}); + + EXPECT_EQ(13, sp.mapColumn_(0)); + EXPECT_EQ(49, sp.mapColumn_(4)); + EXPECT_EQ(52, sp.mapColumn_(5)); + EXPECT_EQ(58, sp.mapColumn_(7)); + EXPECT_EQ(418, sp.mapColumn_(47)); + } - SpatialPooler sp; - sp.initialize(inputDim, columnDim); - sp.setPotentialRadius(potentialRadius); + { + // Test 2D, some input dimensions smaller than column dimensions. + SpatialPooler sp( + /*inputDimensions*/ {3, 5}, + /*columnDimensions*/ {4, 4}); + + EXPECT_EQ(0, sp.mapColumn_(0)); + EXPECT_EQ(4, sp.mapColumn_(3)); + EXPECT_EQ(14, sp.mapColumn_(15)); + } +} - vector mask; +TEST(SpatialPoolerTest, testMapPotential1D) { + vector inputDim, columnDim; + inputDim.push_back(12); + columnDim.push_back(4); + UInt potentialRadius = 2; - // Test without wrapAround and potentialPct = 1 - sp.setPotentialPct(1.0); + SpatialPooler sp; + sp.initialize(inputDim, columnDim); + sp.setPotentialRadius(potentialRadius); - UInt expectedMask1[12] = {1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}; - mask = sp.mapPotential_(0, false); - ASSERT_TRUE(check_vector_eq(expectedMask1, mask)); + vector mask; - UInt expectedMask2[12] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0}; - mask = sp.mapPotential_(2, false); - ASSERT_TRUE(check_vector_eq(expectedMask2, mask)); + // Test without wrapAround and potentialPct = 1 + sp.setPotentialPct(1.0); - // Test with wrapAround and potentialPct = 1 - sp.setPotentialPct(1.0); + UInt expectedMask1[12] = {1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}; + mask = sp.mapPotential_(0, false); + ASSERT_TRUE(check_vector_eq(expectedMask1, mask)); - UInt expectedMask3[12] = {1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1}; - mask = sp.mapPotential_(0, true); - ASSERT_TRUE(check_vector_eq(expectedMask3, mask)); + UInt expectedMask2[12] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0}; + mask = sp.mapPotential_(2, false); + ASSERT_TRUE(check_vector_eq(expectedMask2, mask)); - UInt expectedMask4[12] = {1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1}; - mask = sp.mapPotential_(3, true); - ASSERT_TRUE(check_vector_eq(expectedMask4, mask)); + // Test with wrapAround and potentialPct = 1 + sp.setPotentialPct(1.0); - // Test with potentialPct < 1 - sp.setPotentialPct(0.5); - UInt supersetMask1[12] = {1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1}; - mask = sp.mapPotential_(0, true); - ASSERT_TRUE(sum(mask) == 3); + UInt expectedMask3[12] = {1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1}; + mask = sp.mapPotential_(0, true); + ASSERT_TRUE(check_vector_eq(expectedMask3, mask)); - UInt unionMask1[12] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - for (UInt i = 0; i < 12; i++) { - unionMask1[i] = supersetMask1[i] | mask.at(i); - } + UInt expectedMask4[12] = {1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1}; + mask = sp.mapPotential_(3, true); + ASSERT_TRUE(check_vector_eq(expectedMask4, mask)); - ASSERT_TRUE(check_vector_eq(unionMask1, supersetMask1, 12)); - } + // Test with potentialPct < 1 + sp.setPotentialPct(0.5); + UInt supersetMask1[12] = {1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1}; + mask = sp.mapPotential_(0, true); + ASSERT_TRUE(sum(mask) == 3); - TEST(SpatialPoolerTest, testMapPotential2D) - { - vector inputDim, columnDim; - inputDim.push_back(6); - inputDim.push_back(12); - columnDim.push_back(2); - columnDim.push_back(4); - UInt potentialRadius = 1; - Real potentialPct = 1.0; - - SpatialPooler sp; - sp.initialize(inputDim, columnDim); - sp.setPotentialRadius(potentialRadius); - sp.setPotentialPct(potentialPct); - - vector mask; - - // Test without wrapAround - UInt expectedMask1[72] = { - 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - }; - mask = sp.mapPotential_(0, false); - ASSERT_TRUE(check_vector_eq(expectedMask1, mask)); - - UInt expectedMask2[72] = { - 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - }; - mask = sp.mapPotential_(2, false); - ASSERT_TRUE(check_vector_eq(expectedMask2, mask)); - - // Test with wrapAround - potentialRadius = 2; - sp.setPotentialRadius(potentialRadius); - UInt expectedMask3[72] = { - 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1 - }; - mask = sp.mapPotential_(0, true); - ASSERT_TRUE(check_vector_eq(expectedMask3, mask)); - - UInt expectedMask4[72] = { - 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, - 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, - 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, - 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1 - }; - mask = sp.mapPotential_(3, true); - ASSERT_TRUE(check_vector_eq(expectedMask4, mask)); + UInt unionMask1[12] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + for (UInt i = 0; i < 12; i++) { + unionMask1[i] = supersetMask1[i] | mask.at(i); } - TEST(SpatialPoolerTest, testStripUnlearnedColumns) + ASSERT_TRUE(check_vector_eq(unionMask1, supersetMask1, 12)); +} + +TEST(SpatialPoolerTest, testMapPotential2D) { + vector inputDim, columnDim; + inputDim.push_back(6); + inputDim.push_back(12); + columnDim.push_back(2); + columnDim.push_back(4); + UInt potentialRadius = 1; + Real potentialPct = 1.0; + + SpatialPooler sp; + sp.initialize(inputDim, columnDim); + sp.setPotentialRadius(potentialRadius); + sp.setPotentialPct(potentialPct); + + vector mask; + + // Test without wrapAround + UInt expectedMask1[72] = { + 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + mask = sp.mapPotential_(0, false); + ASSERT_TRUE(check_vector_eq(expectedMask1, mask)); + + UInt expectedMask2[72] = { + 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + mask = sp.mapPotential_(2, false); + ASSERT_TRUE(check_vector_eq(expectedMask2, mask)); + + // Test with wrapAround + potentialRadius = 2; + sp.setPotentialRadius(potentialRadius); + UInt expectedMask3[72] = { + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1}; + mask = sp.mapPotential_(0, true); + ASSERT_TRUE(check_vector_eq(expectedMask3, mask)); + + UInt expectedMask4[72] = { + 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, + 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1}; + mask = sp.mapPotential_(3, true); + ASSERT_TRUE(check_vector_eq(expectedMask4, mask)); +} + +TEST(SpatialPoolerTest, testStripUnlearnedColumns) { + SpatialPooler sp; + vector inputDim, columnDim; + inputDim.push_back(5); + columnDim.push_back(3); + sp.initialize(inputDim, columnDim); + + // None learned, none active { - SpatialPooler sp; - vector inputDim, columnDim; - inputDim.push_back(5); - columnDim.push_back(3); - sp.initialize(inputDim, columnDim); - - // None learned, none active - { - Real activeDutyCycles[3] = {0, 0, 0}; - UInt activeArray[3] = {0, 0, 0}; - UInt expected[3] = {0, 0, 0}; - - sp.setActiveDutyCycles(activeDutyCycles); - sp.stripUnlearnedColumns(activeArray); - - ASSERT_TRUE(check_vector_eq(activeArray, expected, 3)); - } - - // None learned, some active - { - Real activeDutyCycles[3] = {0, 0, 0}; - UInt activeArray[3] = {1, 0, 1}; - UInt expected[3] = {0, 0, 0}; - - sp.setActiveDutyCycles(activeDutyCycles); - sp.stripUnlearnedColumns(activeArray); - - ASSERT_TRUE(check_vector_eq(activeArray, expected, 3)); - } - - // Some learned, none active - { - Real activeDutyCycles[3] = {1, 1, 0}; - UInt activeArray[3] = {0, 0, 0}; - UInt expected[3] = {0, 0, 0}; - - sp.setActiveDutyCycles(activeDutyCycles); - sp.stripUnlearnedColumns(activeArray); - - ASSERT_TRUE(check_vector_eq(activeArray, expected, 3)); - } - - // Some learned, some active - { - Real activeDutyCycles[3] = {1, 1, 0}; - UInt activeArray[3] = {1, 0, 1}; - UInt expected[3] = {1, 0, 0}; + Real activeDutyCycles[3] = {0, 0, 0}; + UInt activeArray[3] = {0, 0, 0}; + UInt expected[3] = {0, 0, 0}; - sp.setActiveDutyCycles(activeDutyCycles); - sp.stripUnlearnedColumns(activeArray); + sp.setActiveDutyCycles(activeDutyCycles); + sp.stripUnlearnedColumns(activeArray); - ASSERT_TRUE(check_vector_eq(activeArray, expected, 3)); - } + ASSERT_TRUE(check_vector_eq(activeArray, expected, 3)); } - TEST(SpatialPoolerTest, getOverlaps) + // None learned, some active { - SpatialPooler sp; - const vector inputDim = {5}; - const vector columnDim = {3}; - sp.initialize(inputDim, columnDim); - - UInt potential[5] = {1, 1, 1, 1, 1}; - sp.setPotential(0, potential); - sp.setPotential(1, potential); - sp.setPotential(2, potential); - - Real permanence0[5] = {0.0, 0.0, 0.0, 0.0, 0.0}; - sp.setPermanence(0, permanence0); - Real permanence1[5] = {1.0, 1.0, 1.0, 0.0, 0.0}; - sp.setPermanence(1, permanence1); - Real permanence2[5] = {1.0, 1.0, 1.0, 1.0, 1.0}; - sp.setPermanence(2, permanence2); - - vector boostFactors = {1.0, 2.0, 3.0}; - sp.setBoostFactors(boostFactors.data()); - - vector input = {1, 1, 1, 1, 1}; - vector activeColumns = {0, 0, 0}; - sp.compute(input.data(), true, activeColumns.data()); - - const vector& overlaps = sp.getOverlaps(); - const vector expectedOverlaps = {0, 3, 5}; - EXPECT_EQ(expectedOverlaps, overlaps); - - const vector& boostedOverlaps = sp.getBoostedOverlaps(); - const vector expectedBoostedOverlaps = {0.0, 6.0, 15.0}; - EXPECT_EQ(expectedBoostedOverlaps, boostedOverlaps); - } + Real activeDutyCycles[3] = {0, 0, 0}; + UInt activeArray[3] = {1, 0, 1}; + UInt expected[3] = {0, 0, 0}; - TEST(SpatialPoolerTest, ZeroOverlap_NoStimulusThreshold_GlobalInhibition) - { - const UInt inputSize = 10; - const UInt nColumns = 20; - - SpatialPooler sp({inputSize}, - {nColumns}, - /*potentialRadius*/ 10, - /*potentialPct*/ 0.5, - /*globalInhibition*/ true, - /*localAreaDensity*/ -1.0, - /*numActiveColumnsPerInhArea*/ 3, - /*stimulusThreshold*/ 0, - /*synPermInactiveDec*/ 0.008, - /*synPermActiveInc*/ 0.05, - /*synPermConnected*/ 0.1, - /*minPctOverlapDutyCycles*/ 0.001, - /*dutyCyclePeriod*/ 1000, - /*boostStrength*/ 10.0, - /*seed*/ 1, - /*spVerbosity*/ 0, - /*wrapAround*/ true); - - vector input(inputSize, 0); - vector activeColumns(nColumns, 0); - sp.compute(input.data(), true, activeColumns.data()); - - EXPECT_EQ(3, countNonzero(activeColumns)); - } + sp.setActiveDutyCycles(activeDutyCycles); + sp.stripUnlearnedColumns(activeArray); - TEST(SpatialPoolerTest, ZeroOverlap_StimulusThreshold_GlobalInhibition) - { - const UInt inputSize = 10; - const UInt nColumns = 20; - - SpatialPooler sp({inputSize}, - {nColumns}, - /*potentialRadius*/ 5, - /*potentialPct*/ 0.5, - /*globalInhibition*/ true, - /*localAreaDensity*/ -1.0, - /*numActiveColumnsPerInhArea*/ 1, - /*stimulusThreshold*/ 1, - /*synPermInactiveDec*/ 0.008, - /*synPermActiveInc*/ 0.05, - /*synPermConnected*/ 0.1, - /*minPctOverlapDutyCycles*/ 0.001, - /*dutyCyclePeriod*/ 1000, - /*boostStrength*/ 10.0, - /*seed*/ 1, - /*spVerbosity*/ 0, - /*wrapAround*/ true); - - vector input(inputSize, 0); - vector activeColumns(nColumns, 0); - sp.compute(input.data(), true, activeColumns.data()); - - EXPECT_EQ(0, countNonzero(activeColumns)); + ASSERT_TRUE(check_vector_eq(activeArray, expected, 3)); } - TEST(SpatialPoolerTest, ZeroOverlap_NoStimulusThreshold_LocalInhibition) + // Some learned, none active { - const UInt inputSize = 10; - const UInt nColumns = 20; - - SpatialPooler sp({inputSize}, - {nColumns}, - /*potentialRadius*/ 5, - /*potentialPct*/ 0.5, - /*globalInhibition*/ false, - /*localAreaDensity*/ -1.0, - /*numActiveColumnsPerInhArea*/ 1, - /*stimulusThreshold*/ 0, - /*synPermInactiveDec*/ 0.008, - /*synPermActiveInc*/ 0.05, - /*synPermConnected*/ 0.1, - /*minPctOverlapDutyCycles*/ 0.001, - /*dutyCyclePeriod*/ 1000, - /*boostStrength*/ 10.0, - /*seed*/ 1, - /*spVerbosity*/ 0, - /*wrapAround*/ true); - - vector input(inputSize, 0); - vector activeColumns(nColumns, 0); - sp.compute(input.data(), true, activeColumns.data()); - - // This exact number of active columns is determined by the inhibition - // radius, which changes based on the random synapses (i.e. weird math). - EXPECT_GT(countNonzero(activeColumns), 2); - EXPECT_LT(countNonzero(activeColumns), 10); - } + Real activeDutyCycles[3] = {1, 1, 0}; + UInt activeArray[3] = {0, 0, 0}; + UInt expected[3] = {0, 0, 0}; - TEST(SpatialPoolerTest, ZeroOverlap_StimulusThreshold_LocalInhibition) - { - const UInt inputSize = 10; - const UInt nColumns = 20; - - SpatialPooler sp({inputSize}, - {nColumns}, - /*potentialRadius*/ 10, - /*potentialPct*/ 0.5, - /*globalInhibition*/ false, - /*localAreaDensity*/ -1.0, - /*numActiveColumnsPerInhArea*/ 3, - /*stimulusThreshold*/ 1, - /*synPermInactiveDec*/ 0.008, - /*synPermActiveInc*/ 0.05, - /*synPermConnected*/ 0.1, - /*minPctOverlapDutyCycles*/ 0.001, - /*dutyCyclePeriod*/ 1000, - /*boostStrength*/ 10.0, - /*seed*/ 1, - /*spVerbosity*/ 0, - /*wrapAround*/ true); - - vector input(inputSize, 0); - vector activeColumns(nColumns, 0); - sp.compute(input.data(), true, activeColumns.data()); - - EXPECT_EQ(0, countNonzero(activeColumns)); - } + sp.setActiveDutyCycles(activeDutyCycles); + sp.stripUnlearnedColumns(activeArray); - TEST(SpatialPoolerTest, testSaveLoad) - { - const char* filename = "SpatialPoolerSerialization.tmp"; - SpatialPooler sp1, sp2; - UInt numInputs = 6; - UInt numColumns = 12; - setup(sp1, numInputs, numColumns); - - ofstream outfile; - outfile.open(filename); - sp1.save(outfile); - outfile.close(); - - ifstream infile (filename); - sp2.load(infile); - infile.close(); - - ASSERT_NO_FATAL_FAILURE( - check_spatial_eq(sp1, sp2)); - - int ret = ::remove(filename); - ASSERT_TRUE(ret == 0) << "Failed to delete " << filename; + ASSERT_TRUE(check_vector_eq(activeArray, expected, 3)); } - TEST(SpatialPoolerTest, testWriteRead) + // Some learned, some active { - const char* filename = "SpatialPoolerSerialization.tmp"; - SpatialPooler sp1, sp2; - UInt numInputs = 6; - UInt numColumns = 12; - setup(sp1, numInputs, numColumns); - - ofstream os(filename, ios::binary); - sp1.write(os); - os.close(); + Real activeDutyCycles[3] = {1, 1, 0}; + UInt activeArray[3] = {1, 0, 1}; + UInt expected[3] = {1, 0, 0}; - ifstream is(filename, ios::binary); - sp2.read(is); - is.close(); + sp.setActiveDutyCycles(activeDutyCycles); + sp.stripUnlearnedColumns(activeArray); - ASSERT_NO_FATAL_FAILURE( - check_spatial_eq(sp1, sp2)); - - int ret = ::remove(filename); - ASSERT_TRUE(ret == 0) << "Failed to delete " << filename; + ASSERT_TRUE(check_vector_eq(activeArray, expected, 3)); } - - TEST(SpatialPoolerTest, testConstructorVsInitialize) - { - // Initialize SP using the constructor - SpatialPooler sp1( - /*inputDimensions*/{100}, - /*columnDimensions*/{100}, +} + +TEST(SpatialPoolerTest, getOverlaps) { + SpatialPooler sp; + const vector inputDim = {5}; + const vector columnDim = {3}; + sp.initialize(inputDim, columnDim); + + UInt potential[5] = {1, 1, 1, 1, 1}; + sp.setPotential(0, potential); + sp.setPotential(1, potential); + sp.setPotential(2, potential); + + Real permanence0[5] = {0.0, 0.0, 0.0, 0.0, 0.0}; + sp.setPermanence(0, permanence0); + Real permanence1[5] = {1.0, 1.0, 1.0, 0.0, 0.0}; + sp.setPermanence(1, permanence1); + Real permanence2[5] = {1.0, 1.0, 1.0, 1.0, 1.0}; + sp.setPermanence(2, permanence2); + + vector boostFactors = {1.0, 2.0, 3.0}; + sp.setBoostFactors(boostFactors.data()); + + vector input = {1, 1, 1, 1, 1}; + vector activeColumns = {0, 0, 0}; + sp.compute(input.data(), true, activeColumns.data()); + + const vector &overlaps = sp.getOverlaps(); + const vector expectedOverlaps = {0, 3, 5}; + EXPECT_EQ(expectedOverlaps, overlaps); + + const vector &boostedOverlaps = sp.getBoostedOverlaps(); + const vector expectedBoostedOverlaps = {0.0, 6.0, 15.0}; + EXPECT_EQ(expectedBoostedOverlaps, boostedOverlaps); +} + +TEST(SpatialPoolerTest, ZeroOverlap_NoStimulusThreshold_GlobalInhibition) { + const UInt inputSize = 10; + const UInt nColumns = 20; + + SpatialPooler sp({inputSize}, {nColumns}, + /*potentialRadius*/ 10, + /*potentialPct*/ 0.5, + /*globalInhibition*/ true, + /*localAreaDensity*/ -1.0, + /*numActiveColumnsPerInhArea*/ 3, + /*stimulusThreshold*/ 0, + /*synPermInactiveDec*/ 0.008, + /*synPermActiveInc*/ 0.05, + /*synPermConnected*/ 0.1, + /*minPctOverlapDutyCycles*/ 0.001, + /*dutyCyclePeriod*/ 1000, + /*boostStrength*/ 10.0, + /*seed*/ 1, + /*spVerbosity*/ 0, + /*wrapAround*/ true); + + vector input(inputSize, 0); + vector activeColumns(nColumns, 0); + sp.compute(input.data(), true, activeColumns.data()); + + EXPECT_EQ(3, countNonzero(activeColumns)); +} + +TEST(SpatialPoolerTest, ZeroOverlap_StimulusThreshold_GlobalInhibition) { + const UInt inputSize = 10; + const UInt nColumns = 20; + + SpatialPooler sp({inputSize}, {nColumns}, + /*potentialRadius*/ 5, + /*potentialPct*/ 0.5, + /*globalInhibition*/ true, + /*localAreaDensity*/ -1.0, + /*numActiveColumnsPerInhArea*/ 1, + /*stimulusThreshold*/ 1, + /*synPermInactiveDec*/ 0.008, + /*synPermActiveInc*/ 0.05, + /*synPermConnected*/ 0.1, + /*minPctOverlapDutyCycles*/ 0.001, + /*dutyCyclePeriod*/ 1000, + /*boostStrength*/ 10.0, + /*seed*/ 1, + /*spVerbosity*/ 0, + /*wrapAround*/ true); + + vector input(inputSize, 0); + vector activeColumns(nColumns, 0); + sp.compute(input.data(), true, activeColumns.data()); + + EXPECT_EQ(0, countNonzero(activeColumns)); +} + +TEST(SpatialPoolerTest, ZeroOverlap_NoStimulusThreshold_LocalInhibition) { + const UInt inputSize = 10; + const UInt nColumns = 20; + + SpatialPooler sp({inputSize}, {nColumns}, + /*potentialRadius*/ 5, + /*potentialPct*/ 0.5, + /*globalInhibition*/ false, + /*localAreaDensity*/ -1.0, + /*numActiveColumnsPerInhArea*/ 1, + /*stimulusThreshold*/ 0, + /*synPermInactiveDec*/ 0.008, + /*synPermActiveInc*/ 0.05, + /*synPermConnected*/ 0.1, + /*minPctOverlapDutyCycles*/ 0.001, + /*dutyCyclePeriod*/ 1000, + /*boostStrength*/ 10.0, + /*seed*/ 1, + /*spVerbosity*/ 0, + /*wrapAround*/ true); + + vector input(inputSize, 0); + vector activeColumns(nColumns, 0); + sp.compute(input.data(), true, activeColumns.data()); + + // This exact number of active columns is determined by the inhibition + // radius, which changes based on the random synapses (i.e. weird math). + EXPECT_GT(countNonzero(activeColumns), 2); + EXPECT_LT(countNonzero(activeColumns), 10); +} + +TEST(SpatialPoolerTest, ZeroOverlap_StimulusThreshold_LocalInhibition) { + const UInt inputSize = 10; + const UInt nColumns = 20; + + SpatialPooler sp({inputSize}, {nColumns}, + /*potentialRadius*/ 10, + /*potentialPct*/ 0.5, + /*globalInhibition*/ false, + /*localAreaDensity*/ -1.0, + /*numActiveColumnsPerInhArea*/ 3, + /*stimulusThreshold*/ 1, + /*synPermInactiveDec*/ 0.008, + /*synPermActiveInc*/ 0.05, + /*synPermConnected*/ 0.1, + /*minPctOverlapDutyCycles*/ 0.001, + /*dutyCyclePeriod*/ 1000, + /*boostStrength*/ 10.0, + /*seed*/ 1, + /*spVerbosity*/ 0, + /*wrapAround*/ true); + + vector input(inputSize, 0); + vector activeColumns(nColumns, 0); + sp.compute(input.data(), true, activeColumns.data()); + + EXPECT_EQ(0, countNonzero(activeColumns)); +} + +TEST(SpatialPoolerTest, testSaveLoad) { + const char *filename = "SpatialPoolerSerialization.tmp"; + SpatialPooler sp1, sp2; + UInt numInputs = 6; + UInt numColumns = 12; + setup(sp1, numInputs, numColumns); + + ofstream outfile; + outfile.open(filename); + sp1.save(outfile); + outfile.close(); + + ifstream infile(filename); + sp2.load(infile); + infile.close(); + + ASSERT_NO_FATAL_FAILURE(check_spatial_eq(sp1, sp2)); + + int ret = ::remove(filename); + ASSERT_TRUE(ret == 0) << "Failed to delete " << filename; +} + +TEST(SpatialPoolerTest, testWriteRead) { + const char *filename = "SpatialPoolerSerialization.tmp"; + SpatialPooler sp1, sp2; + UInt numInputs = 6; + UInt numColumns = 12; + setup(sp1, numInputs, numColumns); + + ofstream os(filename, ios::binary); + sp1.write(os); + os.close(); + + ifstream is(filename, ios::binary); + sp2.read(is); + is.close(); + + ASSERT_NO_FATAL_FAILURE(check_spatial_eq(sp1, sp2)); + + int ret = ::remove(filename); + ASSERT_TRUE(ret == 0) << "Failed to delete " << filename; +} + +TEST(SpatialPoolerTest, testConstructorVsInitialize) { + // Initialize SP using the constructor + SpatialPooler sp1( + /*inputDimensions*/ {100}, + /*columnDimensions*/ {100}, /*potentialRadius*/ 16, /*potentialPct*/ 0.5, /*globalInhibition*/ true, @@ -2303,11 +2134,11 @@ namespace { /*spVerbosity*/ 0, /*wrapAround*/ true); - // Initialize SP using the "initialize" method - SpatialPooler sp2; - sp2.initialize( - /*inputDimensions*/{100}, - /*columnDimensions*/{100}, + // Initialize SP using the "initialize" method + SpatialPooler sp2; + sp2.initialize( + /*inputDimensions*/ {100}, + /*columnDimensions*/ {100}, /*potentialRadius*/ 16, /*potentialPct*/ 0.5, /*globalInhibition*/ true, @@ -2324,8 +2155,8 @@ namespace { /*spVerbosity*/ 0, /*wrapAround*/ true); - // The two SP should be the same - check_spatial_eq(sp1, sp2); - } + // The two SP should be the same + check_spatial_eq(sp1, sp2); +} } // end anonymous namespace diff --git a/src/test/unit/algorithms/TemporalMemoryTest.cpp b/src/test/unit/algorithms/TemporalMemoryTest.cpp index af70bd66be..e0750626c8 100644 --- a/src/test/unit/algorithms/TemporalMemoryTest.cpp +++ b/src/test/unit/algorithms/TemporalMemoryTest.cpp @@ -26,13 +26,13 @@ #include #include -#include #include #include #include +#include -#include #include "gtest/gtest.h" +#include using namespace nupic::algorithms::temporal_memory; using namespace std; @@ -41,24 +41,22 @@ using namespace std; namespace { - TEST(TemporalMemoryTest, testInitInvalidParams) - { - // Invalid columnDimensions - vector columnDim = {}; - TemporalMemory tm1; - EXPECT_THROW(tm1.initialize(columnDim, 32), exception); - - // Invalid cellsPerColumn - columnDim.push_back(2048); - EXPECT_THROW(tm1.initialize(columnDim, 0), exception); - } +TEST(TemporalMemoryTest, testInitInvalidParams) { + // Invalid columnDimensions + vector columnDim = {}; + TemporalMemory tm1; + EXPECT_THROW(tm1.initialize(columnDim, 32), exception); - /** - * If you call compute with unsorted input, it should throw an exception. - */ - TEST(TemporalMemoryTest, testCheckInputs_UnsortedColumns) - { - TemporalMemory tm( + // Invalid cellsPerColumn + columnDim.push_back(2048); + EXPECT_THROW(tm1.initialize(columnDim, 0), exception); +} + +/** + * If you call compute with unsorted input, it should throw an exception. + */ +TEST(TemporalMemoryTest, testCheckInputs_UnsortedColumns) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -69,21 +67,19 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); + /*seed*/ 42); - const UInt activeColumns[4] = {1, 3, 2, 4}; + const UInt activeColumns[4] = {1, 3, 2, 4}; - EXPECT_THROW(tm.compute(4, activeColumns), exception); - } + EXPECT_THROW(tm.compute(4, activeColumns), exception); +} - /** - * If you call compute with a binary vector rather than a list of indices, it - * should throw an exception. - */ - TEST(TemporalMemoryTest, testCheckInputs_BinaryArray) - { - TemporalMemory tm( +/** + * If you call compute with a binary vector rather than a list of indices, it + * should throw an exception. + */ +TEST(TemporalMemoryTest, testCheckInputs_BinaryArray) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -94,22 +90,20 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); + /*seed*/ 42); - // Use an input that will pass an `is_sorted` check. - const UInt activeColumns[5] = {0, 0, 0, 1, 1}; + // Use an input that will pass an `is_sorted` check. + const UInt activeColumns[5] = {0, 0, 0, 1, 1}; - EXPECT_THROW(tm.compute(5, activeColumns), exception); - } + EXPECT_THROW(tm.compute(5, activeColumns), exception); +} - /** - * When a predicted column is activated, only the predicted cells in the - * columns should be activated. - */ - TEST(TemporalMemoryTest, ActivateCorrectlyPredictiveCells) - { - TemporalMemory tm( +/** + * When a predicted column is activated, only the predicted cells in the + * columns should be activated. + */ +TEST(TemporalMemoryTest, ActivateCorrectlyPredictiveCells) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -120,36 +114,33 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); - - const UInt numActiveColumns = 1; - const UInt previousActiveColumns[1] = {0}; - const UInt activeColumns[1] = {1}; - const vector previousActiveCells = {0, 1, 2, 3}; - const vector expectedActiveCells = {4}; - - Segment activeSegment = - tm.createSegment(expectedActiveCells[0]); - tm.connections.createSynapse(activeSegment, previousActiveCells[0], 0.5); - tm.connections.createSynapse(activeSegment, previousActiveCells[1], 0.5); - tm.connections.createSynapse(activeSegment, previousActiveCells[2], 0.5); - tm.connections.createSynapse(activeSegment, previousActiveCells[3], 0.5); - - tm.compute(numActiveColumns, previousActiveColumns, true); - ASSERT_EQ(expectedActiveCells, tm.getPredictiveCells()); - tm.compute(numActiveColumns, activeColumns, true); - - EXPECT_EQ(expectedActiveCells, tm.getActiveCells()); - } - - /** - * When an unpredicted column is activated, every cell in the column should - * become active. - */ - TEST(TemporalMemoryTest, BurstUnpredictedColumns) - { - TemporalMemory tm( + /*seed*/ 42); + + const UInt numActiveColumns = 1; + const UInt previousActiveColumns[1] = {0}; + const UInt activeColumns[1] = {1}; + const vector previousActiveCells = {0, 1, 2, 3}; + const vector expectedActiveCells = {4}; + + Segment activeSegment = tm.createSegment(expectedActiveCells[0]); + tm.connections.createSynapse(activeSegment, previousActiveCells[0], 0.5); + tm.connections.createSynapse(activeSegment, previousActiveCells[1], 0.5); + tm.connections.createSynapse(activeSegment, previousActiveCells[2], 0.5); + tm.connections.createSynapse(activeSegment, previousActiveCells[3], 0.5); + + tm.compute(numActiveColumns, previousActiveColumns, true); + ASSERT_EQ(expectedActiveCells, tm.getPredictiveCells()); + tm.compute(numActiveColumns, activeColumns, true); + + EXPECT_EQ(expectedActiveCells, tm.getActiveCells()); +} + +/** + * When an unpredicted column is activated, every cell in the column should + * become active. + */ +TEST(TemporalMemoryTest, BurstUnpredictedColumns) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -160,25 +151,23 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); + /*seed*/ 42); - const UInt activeColumns[1] = {0}; - const vector burstingCells = {0, 1, 2, 3}; + const UInt activeColumns[1] = {0}; + const vector burstingCells = {0, 1, 2, 3}; - tm.compute(1, activeColumns, true); + tm.compute(1, activeColumns, true); - EXPECT_EQ(burstingCells, tm.getActiveCells()); - } + EXPECT_EQ(burstingCells, tm.getActiveCells()); +} - /** - * When the TemporalMemory receives zero active columns, it should still - * compute the active cells, winner cells, and predictive cells. All should be - * empty. - */ - TEST(TemporalMemoryTest, ZeroActiveColumns) - { - TemporalMemory tm( +/** + * When the TemporalMemory receives zero active columns, it should still + * compute the active cells, winner cells, and predictive cells. All should be + * empty. + */ +TEST(TemporalMemoryTest, ZeroActiveColumns) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -189,40 +178,38 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.02, - /*seed*/ 42 - ); - - // Make some cells predictive. - const UInt previousActiveColumns[1] = {0}; - const vector previousActiveCells = {0, 1, 2, 3}; - const vector expectedActiveCells = {4}; - - Segment segment = tm.createSegment(expectedActiveCells[0]); - tm.connections.createSynapse(segment, previousActiveCells[0], 0.5); - tm.connections.createSynapse(segment, previousActiveCells[1], 0.5); - tm.connections.createSynapse(segment, previousActiveCells[2], 0.5); - tm.connections.createSynapse(segment, previousActiveCells[3], 0.5); - - tm.compute(1, previousActiveColumns, true); - ASSERT_FALSE(tm.getActiveCells().empty()); - ASSERT_FALSE(tm.getWinnerCells().empty()); - ASSERT_FALSE(tm.getPredictiveCells().empty()); - - const UInt zeroColumns[0] = {}; - tm.compute(0, zeroColumns, true); - - EXPECT_TRUE(tm.getActiveCells().empty()); - EXPECT_TRUE(tm.getWinnerCells().empty()); - EXPECT_TRUE(tm.getPredictiveCells().empty()); - } - - /** - * All predicted active cells are winner cells, even when learning is - * disabled. - */ - TEST(TemporalMemoryTest, PredictedActiveCellsAreAlwaysWinners) - { - TemporalMemory tm( + /*seed*/ 42); + + // Make some cells predictive. + const UInt previousActiveColumns[1] = {0}; + const vector previousActiveCells = {0, 1, 2, 3}; + const vector expectedActiveCells = {4}; + + Segment segment = tm.createSegment(expectedActiveCells[0]); + tm.connections.createSynapse(segment, previousActiveCells[0], 0.5); + tm.connections.createSynapse(segment, previousActiveCells[1], 0.5); + tm.connections.createSynapse(segment, previousActiveCells[2], 0.5); + tm.connections.createSynapse(segment, previousActiveCells[3], 0.5); + + tm.compute(1, previousActiveColumns, true); + ASSERT_FALSE(tm.getActiveCells().empty()); + ASSERT_FALSE(tm.getWinnerCells().empty()); + ASSERT_FALSE(tm.getPredictiveCells().empty()); + + const UInt zeroColumns[0] = {}; + tm.compute(0, zeroColumns, true); + + EXPECT_TRUE(tm.getActiveCells().empty()); + EXPECT_TRUE(tm.getWinnerCells().empty()); + EXPECT_TRUE(tm.getPredictiveCells().empty()); +} + +/** + * All predicted active cells are winner cells, even when learning is + * disabled. + */ +TEST(TemporalMemoryTest, PredictedActiveCellsAreAlwaysWinners) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -233,40 +220,36 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); - - const UInt numActiveColumns = 1; - const UInt previousActiveColumns[1] = {0}; - const UInt activeColumns[1] = {1}; - const vector previousActiveCells = {0, 1, 2, 3}; - const vector expectedWinnerCells = {4, 6}; - - Segment activeSegment1 = - tm.createSegment(expectedWinnerCells[0]); - tm.connections.createSynapse(activeSegment1, previousActiveCells[0], 0.5); - tm.connections.createSynapse(activeSegment1, previousActiveCells[1], 0.5); - tm.connections.createSynapse(activeSegment1, previousActiveCells[2], 0.5); - - Segment activeSegment2 = - tm.createSegment(expectedWinnerCells[1]); - tm.connections.createSynapse(activeSegment2, previousActiveCells[0], 0.5); - tm.connections.createSynapse(activeSegment2, previousActiveCells[1], 0.5); - tm.connections.createSynapse(activeSegment2, previousActiveCells[2], 0.5); - - tm.compute(numActiveColumns, previousActiveColumns, false); - tm.compute(numActiveColumns, activeColumns, false); - - EXPECT_EQ(expectedWinnerCells, tm.getWinnerCells()); - } + /*seed*/ 42); - /** - * One cell in each bursting column is a winner cell, even when learning is - * disabled. - */ - TEST(TemporalMemoryTest, ChooseOneWinnerCellInBurstingColumn) - { - TemporalMemory tm( + const UInt numActiveColumns = 1; + const UInt previousActiveColumns[1] = {0}; + const UInt activeColumns[1] = {1}; + const vector previousActiveCells = {0, 1, 2, 3}; + const vector expectedWinnerCells = {4, 6}; + + Segment activeSegment1 = tm.createSegment(expectedWinnerCells[0]); + tm.connections.createSynapse(activeSegment1, previousActiveCells[0], 0.5); + tm.connections.createSynapse(activeSegment1, previousActiveCells[1], 0.5); + tm.connections.createSynapse(activeSegment1, previousActiveCells[2], 0.5); + + Segment activeSegment2 = tm.createSegment(expectedWinnerCells[1]); + tm.connections.createSynapse(activeSegment2, previousActiveCells[0], 0.5); + tm.connections.createSynapse(activeSegment2, previousActiveCells[1], 0.5); + tm.connections.createSynapse(activeSegment2, previousActiveCells[2], 0.5); + + tm.compute(numActiveColumns, previousActiveColumns, false); + tm.compute(numActiveColumns, activeColumns, false); + + EXPECT_EQ(expectedWinnerCells, tm.getWinnerCells()); +} + +/** + * One cell in each bursting column is a winner cell, even when learning is + * disabled. + */ +TEST(TemporalMemoryTest, ChooseOneWinnerCellInBurstingColumn) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -277,26 +260,24 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); + /*seed*/ 42); - const UInt activeColumns[1] = {0}; - const set burstingCells = {0, 1, 2, 3}; + const UInt activeColumns[1] = {0}; + const set burstingCells = {0, 1, 2, 3}; - tm.compute(1, activeColumns, false); + tm.compute(1, activeColumns, false); - vector winnerCells = tm.getWinnerCells(); - ASSERT_EQ(1, winnerCells.size()); - EXPECT_TRUE(burstingCells.find(winnerCells[0]) != burstingCells.end()); - } + vector winnerCells = tm.getWinnerCells(); + ASSERT_EQ(1, winnerCells.size()); + EXPECT_TRUE(burstingCells.find(winnerCells[0]) != burstingCells.end()); +} - /** - * Active segments on predicted active cells should be reinforced. Active - * synapses should be reinforced, inactive synapses should be punished. - */ - TEST(TemporalMemoryTest, ReinforceCorrectlyActiveSegments) - { - TemporalMemory tm( +/** + * Active segments on predicted active cells should be reinforced. Active + * synapses should be reinforced, inactive synapses should be punished. + */ +TEST(TemporalMemoryTest, ReinforceCorrectlyActiveSegments) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -307,46 +288,44 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.08, /*predictedSegmentDecrement*/ 0.02, - /*seed*/ 42 - ); - - const UInt numActiveColumns = 1; - const UInt previousActiveColumns[1] = {0}; - const vector previousActiveCells = {0, 1, 2, 3}; - const UInt activeColumns[1] = {1}; - const vector activeCells = {5}; - const CellIdx activeCell = 5; - - Segment activeSegment = tm.createSegment(activeCell); - Synapse activeSynapse1 = + /*seed*/ 42); + + const UInt numActiveColumns = 1; + const UInt previousActiveColumns[1] = {0}; + const vector previousActiveCells = {0, 1, 2, 3}; + const UInt activeColumns[1] = {1}; + const vector activeCells = {5}; + const CellIdx activeCell = 5; + + Segment activeSegment = tm.createSegment(activeCell); + Synapse activeSynapse1 = tm.connections.createSynapse(activeSegment, previousActiveCells[0], 0.5); - Synapse activeSynapse2 = + Synapse activeSynapse2 = tm.connections.createSynapse(activeSegment, previousActiveCells[1], 0.5); - Synapse activeSynapse3 = + Synapse activeSynapse3 = tm.connections.createSynapse(activeSegment, previousActiveCells[2], 0.5); - Synapse inactiveSynapse = + Synapse inactiveSynapse = tm.connections.createSynapse(activeSegment, 81, 0.5); - tm.compute(numActiveColumns, previousActiveColumns, true); - tm.compute(numActiveColumns, activeColumns, true); - - EXPECT_NEAR(0.6, tm.connections.dataForSynapse(activeSynapse1).permanence, - EPSILON); - EXPECT_NEAR(0.6, tm.connections.dataForSynapse(activeSynapse2).permanence, - EPSILON); - EXPECT_NEAR(0.6, tm.connections.dataForSynapse(activeSynapse3).permanence, - EPSILON); - EXPECT_NEAR(0.42, tm.connections.dataForSynapse(inactiveSynapse).permanence, - EPSILON); - } - - /** - * The best matching segment in a bursting column should be reinforced. Active - * synapses should be strengthened, and inactive synapses should be weakened. - */ - TEST(TemporalMemoryTest, ReinforceSelectedMatchingSegmentInBurstingColumn) - { - TemporalMemory tm( + tm.compute(numActiveColumns, previousActiveColumns, true); + tm.compute(numActiveColumns, activeColumns, true); + + EXPECT_NEAR(0.6, tm.connections.dataForSynapse(activeSynapse1).permanence, + EPSILON); + EXPECT_NEAR(0.6, tm.connections.dataForSynapse(activeSynapse2).permanence, + EPSILON); + EXPECT_NEAR(0.6, tm.connections.dataForSynapse(activeSynapse3).permanence, + EPSILON); + EXPECT_NEAR(0.42, tm.connections.dataForSynapse(inactiveSynapse).permanence, + EPSILON); +} + +/** + * The best matching segment in a bursting column should be reinforced. Active + * synapses should be strengthened, and inactive synapses should be weakened. + */ +TEST(TemporalMemoryTest, ReinforceSelectedMatchingSegmentInBurstingColumn) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -357,60 +336,52 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.08, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); - - const UInt numActiveColumns = 1; - const UInt previousActiveColumns[1] = {0}; - const UInt activeColumns[1] = {1}; - const vector previousActiveCells = {0, 1, 2, 3}; - const vector burstingCells = {4, 5, 6, 7}; - - Segment selectedMatchingSegment = - tm.createSegment(burstingCells[0]); - Synapse activeSynapse1 = - tm.connections.createSynapse(selectedMatchingSegment, - previousActiveCells[0], 0.3); - Synapse activeSynapse2 = - tm.connections.createSynapse(selectedMatchingSegment, - previousActiveCells[1], 0.3); - Synapse activeSynapse3 = - tm.connections.createSynapse(selectedMatchingSegment, - previousActiveCells[2], 0.3); - Synapse inactiveSynapse = - tm.connections.createSynapse(selectedMatchingSegment, - 81, 0.3); - - // Add some competition. - Segment otherMatchingSegment = - tm.createSegment(burstingCells[1]); - tm.connections.createSynapse(otherMatchingSegment, - previousActiveCells[0], 0.3); - tm.connections.createSynapse(otherMatchingSegment, - previousActiveCells[1], 0.3); - tm.connections.createSynapse(otherMatchingSegment, - 81, 0.3); - - tm.compute(numActiveColumns, previousActiveColumns, true); - tm.compute(numActiveColumns, activeColumns, true); - - EXPECT_NEAR(0.4, tm.connections.dataForSynapse(activeSynapse1).permanence, - EPSILON); - EXPECT_NEAR(0.4, tm.connections.dataForSynapse(activeSynapse2).permanence, - EPSILON); - EXPECT_NEAR(0.4, tm.connections.dataForSynapse(activeSynapse3).permanence, - EPSILON); - EXPECT_NEAR(0.22, tm.connections.dataForSynapse(inactiveSynapse).permanence, - EPSILON); - } - - /** - * When a column bursts, don't reward or punish matching-but-not-selected - * segments. - */ - TEST(TemporalMemoryTest, NoChangeToNonselectedMatchingSegmentsInBurstingColumn) - { - TemporalMemory tm( + /*seed*/ 42); + + const UInt numActiveColumns = 1; + const UInt previousActiveColumns[1] = {0}; + const UInt activeColumns[1] = {1}; + const vector previousActiveCells = {0, 1, 2, 3}; + const vector burstingCells = {4, 5, 6, 7}; + + Segment selectedMatchingSegment = tm.createSegment(burstingCells[0]); + Synapse activeSynapse1 = tm.connections.createSynapse( + selectedMatchingSegment, previousActiveCells[0], 0.3); + Synapse activeSynapse2 = tm.connections.createSynapse( + selectedMatchingSegment, previousActiveCells[1], 0.3); + Synapse activeSynapse3 = tm.connections.createSynapse( + selectedMatchingSegment, previousActiveCells[2], 0.3); + Synapse inactiveSynapse = + tm.connections.createSynapse(selectedMatchingSegment, 81, 0.3); + + // Add some competition. + Segment otherMatchingSegment = tm.createSegment(burstingCells[1]); + tm.connections.createSynapse(otherMatchingSegment, previousActiveCells[0], + 0.3); + tm.connections.createSynapse(otherMatchingSegment, previousActiveCells[1], + 0.3); + tm.connections.createSynapse(otherMatchingSegment, 81, 0.3); + + tm.compute(numActiveColumns, previousActiveColumns, true); + tm.compute(numActiveColumns, activeColumns, true); + + EXPECT_NEAR(0.4, tm.connections.dataForSynapse(activeSynapse1).permanence, + EPSILON); + EXPECT_NEAR(0.4, tm.connections.dataForSynapse(activeSynapse2).permanence, + EPSILON); + EXPECT_NEAR(0.4, tm.connections.dataForSynapse(activeSynapse3).permanence, + EPSILON); + EXPECT_NEAR(0.22, tm.connections.dataForSynapse(inactiveSynapse).permanence, + EPSILON); +} + +/** + * When a column bursts, don't reward or punish matching-but-not-selected + * segments. + */ +TEST(TemporalMemoryTest, + NoChangeToNonselectedMatchingSegmentsInBurstingColumn) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -421,55 +392,47 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.08, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); - - const UInt previousActiveColumns[1] = {0}; - const UInt activeColumns[1] = {1}; - const vector previousActiveCells = {0, 1, 2, 3}; - const vector burstingCells = {4, 5, 6, 7}; - - Segment selectedMatchingSegment = - tm.createSegment(burstingCells[0]); - tm.connections.createSynapse(selectedMatchingSegment, - previousActiveCells[0], 0.3); - tm.connections.createSynapse(selectedMatchingSegment, - previousActiveCells[1], 0.3); - tm.connections.createSynapse(selectedMatchingSegment, - previousActiveCells[2], 0.3); - tm.connections.createSynapse(selectedMatchingSegment, - 81, 0.3); - - Segment otherMatchingSegment = - tm.createSegment(burstingCells[1]); - Synapse activeSynapse1 = - tm.connections.createSynapse(otherMatchingSegment, - previousActiveCells[0], 0.3); - Synapse activeSynapse2 = - tm.connections.createSynapse(otherMatchingSegment, - previousActiveCells[1], 0.3); - Synapse inactiveSynapse = - tm.connections.createSynapse(otherMatchingSegment, - 81, 0.3); - - tm.compute(1, previousActiveColumns, true); - tm.compute(1, activeColumns, true); - - EXPECT_NEAR(0.3, tm.connections.dataForSynapse(activeSynapse1).permanence, - EPSILON); - EXPECT_NEAR(0.3, tm.connections.dataForSynapse(activeSynapse2).permanence, - EPSILON); - EXPECT_NEAR(0.3, tm.connections.dataForSynapse(inactiveSynapse).permanence, - EPSILON); - } - - /** - * When a predicted column is activated, don't reward or punish - * matching-but-not-active segments anywhere in the column. - */ - TEST(TemporalMemoryTest, NoChangeToMatchingSegmentsInPredictedActiveColumn) - { - TemporalMemory tm( + /*seed*/ 42); + + const UInt previousActiveColumns[1] = {0}; + const UInt activeColumns[1] = {1}; + const vector previousActiveCells = {0, 1, 2, 3}; + const vector burstingCells = {4, 5, 6, 7}; + + Segment selectedMatchingSegment = tm.createSegment(burstingCells[0]); + tm.connections.createSynapse(selectedMatchingSegment, previousActiveCells[0], + 0.3); + tm.connections.createSynapse(selectedMatchingSegment, previousActiveCells[1], + 0.3); + tm.connections.createSynapse(selectedMatchingSegment, previousActiveCells[2], + 0.3); + tm.connections.createSynapse(selectedMatchingSegment, 81, 0.3); + + Segment otherMatchingSegment = tm.createSegment(burstingCells[1]); + Synapse activeSynapse1 = tm.connections.createSynapse( + otherMatchingSegment, previousActiveCells[0], 0.3); + Synapse activeSynapse2 = tm.connections.createSynapse( + otherMatchingSegment, previousActiveCells[1], 0.3); + Synapse inactiveSynapse = + tm.connections.createSynapse(otherMatchingSegment, 81, 0.3); + + tm.compute(1, previousActiveColumns, true); + tm.compute(1, activeColumns, true); + + EXPECT_NEAR(0.3, tm.connections.dataForSynapse(activeSynapse1).permanence, + EPSILON); + EXPECT_NEAR(0.3, tm.connections.dataForSynapse(activeSynapse2).permanence, + EPSILON); + EXPECT_NEAR(0.3, tm.connections.dataForSynapse(inactiveSynapse).permanence, + EPSILON); +} + +/** + * When a predicted column is activated, don't reward or punish + * matching-but-not-active segments anywhere in the column. + */ +TEST(TemporalMemoryTest, NoChangeToMatchingSegmentsInPredictedActiveColumn) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -480,61 +443,48 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); - - const UInt previousActiveColumns[1] = {0}; - const UInt activeColumns[1] = {1}; - const vector previousActiveCells = {0, 1, 2, 3}; - const vector expectedActiveCells = {4}; - const vector otherBurstingCells = {5, 6, 7}; - - Segment activeSegment = - tm.createSegment(expectedActiveCells[0]); - tm.connections.createSynapse(activeSegment, previousActiveCells[0], 0.5); - tm.connections.createSynapse(activeSegment, previousActiveCells[1], 0.5); - tm.connections.createSynapse(activeSegment, previousActiveCells[2], 0.5); - tm.connections.createSynapse(activeSegment, previousActiveCells[3], 0.5); - - Segment matchingSegmentOnSameCell = - tm.createSegment(expectedActiveCells[0]); - Synapse synapse1 = - tm.connections.createSynapse(matchingSegmentOnSameCell, - previousActiveCells[0], 0.3); - Synapse synapse2 = - tm.connections.createSynapse(matchingSegmentOnSameCell, - previousActiveCells[1], 0.3); - - Segment matchingSegmentOnOtherCell = - tm.createSegment(otherBurstingCells[0]); - Synapse synapse3 = - tm.connections.createSynapse(matchingSegmentOnOtherCell, - previousActiveCells[0], 0.3); - Synapse synapse4 = - tm.connections.createSynapse(matchingSegmentOnOtherCell, - previousActiveCells[1], 0.3); - - tm.compute(1, previousActiveColumns, true); - ASSERT_EQ(expectedActiveCells, tm.getPredictiveCells()); - tm.compute(1, activeColumns, true); - - EXPECT_NEAR(0.3, tm.connections.dataForSynapse(synapse1).permanence, - EPSILON); - EXPECT_NEAR(0.3, tm.connections.dataForSynapse(synapse2).permanence, - EPSILON); - EXPECT_NEAR(0.3, tm.connections.dataForSynapse(synapse3).permanence, - EPSILON); - EXPECT_NEAR(0.3, tm.connections.dataForSynapse(synapse4).permanence, - EPSILON); - } - - /** - * When growing a new segment, if there are no previous winner cells, don't - * even grow the segment. It will never match. - */ - TEST(TemporalMemoryTest, NoNewSegmentIfNotEnoughWinnerCells) - { - TemporalMemory tm( + /*seed*/ 42); + + const UInt previousActiveColumns[1] = {0}; + const UInt activeColumns[1] = {1}; + const vector previousActiveCells = {0, 1, 2, 3}; + const vector expectedActiveCells = {4}; + const vector otherBurstingCells = {5, 6, 7}; + + Segment activeSegment = tm.createSegment(expectedActiveCells[0]); + tm.connections.createSynapse(activeSegment, previousActiveCells[0], 0.5); + tm.connections.createSynapse(activeSegment, previousActiveCells[1], 0.5); + tm.connections.createSynapse(activeSegment, previousActiveCells[2], 0.5); + tm.connections.createSynapse(activeSegment, previousActiveCells[3], 0.5); + + Segment matchingSegmentOnSameCell = tm.createSegment(expectedActiveCells[0]); + Synapse synapse1 = tm.connections.createSynapse(matchingSegmentOnSameCell, + previousActiveCells[0], 0.3); + Synapse synapse2 = tm.connections.createSynapse(matchingSegmentOnSameCell, + previousActiveCells[1], 0.3); + + Segment matchingSegmentOnOtherCell = tm.createSegment(otherBurstingCells[0]); + Synapse synapse3 = tm.connections.createSynapse(matchingSegmentOnOtherCell, + previousActiveCells[0], 0.3); + Synapse synapse4 = tm.connections.createSynapse(matchingSegmentOnOtherCell, + previousActiveCells[1], 0.3); + + tm.compute(1, previousActiveColumns, true); + ASSERT_EQ(expectedActiveCells, tm.getPredictiveCells()); + tm.compute(1, activeColumns, true); + + EXPECT_NEAR(0.3, tm.connections.dataForSynapse(synapse1).permanence, EPSILON); + EXPECT_NEAR(0.3, tm.connections.dataForSynapse(synapse2).permanence, EPSILON); + EXPECT_NEAR(0.3, tm.connections.dataForSynapse(synapse3).permanence, EPSILON); + EXPECT_NEAR(0.3, tm.connections.dataForSynapse(synapse4).permanence, EPSILON); +} + +/** + * When growing a new segment, if there are no previous winner cells, don't + * even grow the segment. It will never match. + */ +TEST(TemporalMemoryTest, NoNewSegmentIfNotEnoughWinnerCells) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -545,25 +495,23 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); + /*seed*/ 42); - const UInt zeroColumns[0] = {}; - const UInt activeColumns[1] = {0}; + const UInt zeroColumns[0] = {}; + const UInt activeColumns[1] = {0}; - tm.compute(0, zeroColumns); - tm.compute(1, activeColumns); + tm.compute(0, zeroColumns); + tm.compute(1, activeColumns); - EXPECT_EQ(0, tm.connections.numSegments()); - } + EXPECT_EQ(0, tm.connections.numSegments()); +} - /** - * When growing a new segment, if the number of previous winner cells is above - * maxNewSynapseCount, grow maxNewSynapseCount synapses. - */ - TEST(TemporalMemoryTest, NewSegmentAddSynapsesToSubsetOfWinnerCells) - { - TemporalMemory tm( +/** + * When growing a new segment, if the number of previous winner cells is above + * maxNewSynapseCount, grow maxNewSynapseCount synapses. + */ +TEST(TemporalMemoryTest, NewSegmentAddSynapsesToSubsetOfWinnerCells) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -574,43 +522,39 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); + /*seed*/ 42); - const UInt previousActiveColumns[3] = {0, 1, 2}; - const UInt activeColumns[1] = {4}; + const UInt previousActiveColumns[3] = {0, 1, 2}; + const UInt activeColumns[1] = {4}; - tm.compute(3, previousActiveColumns); + tm.compute(3, previousActiveColumns); - vector prevWinnerCells = tm.getWinnerCells(); - ASSERT_EQ(3, prevWinnerCells.size()); + vector prevWinnerCells = tm.getWinnerCells(); + ASSERT_EQ(3, prevWinnerCells.size()); - tm.compute(1, activeColumns); - - vector winnerCells = tm.getWinnerCells(); - ASSERT_EQ(1, winnerCells.size()); - vector segments = tm.connections.segmentsForCell(winnerCells[0]); - ASSERT_EQ(1, segments.size()); - vector synapses = tm.connections.synapsesForSegment(segments[0]); - ASSERT_EQ(2, synapses.size()); - for (Synapse synapse : synapses) - { - SynapseData synapseData = tm.connections.dataForSynapse(synapse); - EXPECT_NEAR(0.21, synapseData.permanence, EPSILON); - EXPECT_TRUE(synapseData.presynapticCell == prevWinnerCells[0] || - synapseData.presynapticCell == prevWinnerCells[1] || - synapseData.presynapticCell == prevWinnerCells[2]); - } + tm.compute(1, activeColumns); + vector winnerCells = tm.getWinnerCells(); + ASSERT_EQ(1, winnerCells.size()); + vector segments = tm.connections.segmentsForCell(winnerCells[0]); + ASSERT_EQ(1, segments.size()); + vector synapses = tm.connections.synapsesForSegment(segments[0]); + ASSERT_EQ(2, synapses.size()); + for (Synapse synapse : synapses) { + SynapseData synapseData = tm.connections.dataForSynapse(synapse); + EXPECT_NEAR(0.21, synapseData.permanence, EPSILON); + EXPECT_TRUE(synapseData.presynapticCell == prevWinnerCells[0] || + synapseData.presynapticCell == prevWinnerCells[1] || + synapseData.presynapticCell == prevWinnerCells[2]); } +} - /** - * When growing a new segment, if the number of previous winner cells is below - * maxNewSynapseCount, grow synapses to all of the previous winner cells. - */ - TEST(TemporalMemoryTest, NewSegmentAddSynapsesToAllWinnerCells) - { - TemporalMemory tm( +/** + * When growing a new segment, if the number of previous winner cells is below + * maxNewSynapseCount, grow synapses to all of the previous winner cells. + */ +TEST(TemporalMemoryTest, NewSegmentAddSynapsesToAllWinnerCells) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -621,45 +565,42 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); + /*seed*/ 42); - const UInt previousActiveColumns[3] = {0, 1, 2}; - const UInt activeColumns[1] = {4}; + const UInt previousActiveColumns[3] = {0, 1, 2}; + const UInt activeColumns[1] = {4}; - tm.compute(3, previousActiveColumns); + tm.compute(3, previousActiveColumns); - vector prevWinnerCells = tm.getWinnerCells(); - ASSERT_EQ(3, prevWinnerCells.size()); + vector prevWinnerCells = tm.getWinnerCells(); + ASSERT_EQ(3, prevWinnerCells.size()); - tm.compute(1, activeColumns); + tm.compute(1, activeColumns); - vector winnerCells = tm.getWinnerCells(); - ASSERT_EQ(1, winnerCells.size()); - vector segments = tm.connections.segmentsForCell(winnerCells[0]); - ASSERT_EQ(1, segments.size()); - vector synapses = tm.connections.synapsesForSegment(segments[0]); - ASSERT_EQ(3, synapses.size()); + vector winnerCells = tm.getWinnerCells(); + ASSERT_EQ(1, winnerCells.size()); + vector segments = tm.connections.segmentsForCell(winnerCells[0]); + ASSERT_EQ(1, segments.size()); + vector synapses = tm.connections.synapsesForSegment(segments[0]); + ASSERT_EQ(3, synapses.size()); - vector presynapticCells; - for (Synapse synapse : synapses) - { - SynapseData synapseData = tm.connections.dataForSynapse(synapse); - EXPECT_NEAR(0.21, synapseData.permanence, EPSILON); - presynapticCells.push_back(synapseData.presynapticCell); - } - std::sort(presynapticCells.begin(), presynapticCells.end()); - EXPECT_EQ(prevWinnerCells, presynapticCells); + vector presynapticCells; + for (Synapse synapse : synapses) { + SynapseData synapseData = tm.connections.dataForSynapse(synapse); + EXPECT_NEAR(0.21, synapseData.permanence, EPSILON); + presynapticCells.push_back(synapseData.presynapticCell); } - - /** - * When adding synapses to a matching segment, the final number of active - * synapses on the segment should be maxNewSynapseCount, assuming there are - * enough previous winner cells available to connect to. - */ - TEST(TemporalMemoryTest, MatchingSegmentAddSynapsesToSubsetOfWinnerCells) - { - TemporalMemory tm( + std::sort(presynapticCells.begin(), presynapticCells.end()); + EXPECT_EQ(prevWinnerCells, presynapticCells); +} + +/** + * When adding synapses to a matching segment, the final number of active + * synapses on the segment should be maxNewSynapseCount, assuming there are + * enough previous winner cells available to connect to. + */ +TEST(TemporalMemoryTest, MatchingSegmentAddSynapsesToSubsetOfWinnerCells) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 1, /*activationThreshold*/ 3, @@ -670,43 +611,40 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); + /*seed*/ 42); - // Use 1 cell per column so that we have easy control over the winner cells. - const UInt previousActiveColumns[4] = {0, 1, 2, 3}; - const vector prevWinnerCells = {0, 1, 2, 3}; - const UInt activeColumns[1] = {4}; + // Use 1 cell per column so that we have easy control over the winner cells. + const UInt previousActiveColumns[4] = {0, 1, 2, 3}; + const vector prevWinnerCells = {0, 1, 2, 3}; + const UInt activeColumns[1] = {4}; - Segment matchingSegment = tm.createSegment(4); - tm.connections.createSynapse(matchingSegment, 0, 0.5); + Segment matchingSegment = tm.createSegment(4); + tm.connections.createSynapse(matchingSegment, 0, 0.5); - tm.compute(4, previousActiveColumns); + tm.compute(4, previousActiveColumns); - ASSERT_EQ(prevWinnerCells, tm.getWinnerCells()); + ASSERT_EQ(prevWinnerCells, tm.getWinnerCells()); - tm.compute(1, activeColumns); + tm.compute(1, activeColumns); - vector synapses = tm.connections.synapsesForSegment(matchingSegment); - ASSERT_EQ(3, synapses.size()); - for (SynapseIdx i = 1; i < synapses.size(); i++) - { - SynapseData synapseData = tm.connections.dataForSynapse(synapses[i]); - EXPECT_NEAR(0.21, synapseData.permanence, EPSILON); - EXPECT_TRUE(synapseData.presynapticCell == prevWinnerCells[1] || - synapseData.presynapticCell == prevWinnerCells[2] || - synapseData.presynapticCell == prevWinnerCells[3]); - } + vector synapses = tm.connections.synapsesForSegment(matchingSegment); + ASSERT_EQ(3, synapses.size()); + for (SynapseIdx i = 1; i < synapses.size(); i++) { + SynapseData synapseData = tm.connections.dataForSynapse(synapses[i]); + EXPECT_NEAR(0.21, synapseData.permanence, EPSILON); + EXPECT_TRUE(synapseData.presynapticCell == prevWinnerCells[1] || + synapseData.presynapticCell == prevWinnerCells[2] || + synapseData.presynapticCell == prevWinnerCells[3]); } +} - /** - * When adding synapses to a matching segment, if the number of previous - * winner cells is lower than (maxNewSynapseCount - nActiveSynapsesOnSegment), - * grow synapses to all the previous winner cells. - */ - TEST(TemporalMemoryTest, MatchingSegmentAddSynapsesToAllWinnerCells) - { - TemporalMemory tm( +/** + * When adding synapses to a matching segment, if the number of previous + * winner cells is lower than (maxNewSynapseCount - nActiveSynapsesOnSegment), + * grow synapses to all the previous winner cells. + */ +TEST(TemporalMemoryTest, MatchingSegmentAddSynapsesToAllWinnerCells) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 1, /*activationThreshold*/ 3, @@ -717,40 +655,38 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); + /*seed*/ 42); - // Use 1 cell per column so that we have easy control over the winner cells. - const UInt previousActiveColumns[2] = {0, 1}; - const vector prevWinnerCells = {0, 1}; - const UInt activeColumns[1] = {4}; + // Use 1 cell per column so that we have easy control over the winner cells. + const UInt previousActiveColumns[2] = {0, 1}; + const vector prevWinnerCells = {0, 1}; + const UInt activeColumns[1] = {4}; - Segment matchingSegment = tm.createSegment(4); - tm.connections.createSynapse(matchingSegment, 0, 0.5); + Segment matchingSegment = tm.createSegment(4); + tm.connections.createSynapse(matchingSegment, 0, 0.5); - tm.compute(2, previousActiveColumns); + tm.compute(2, previousActiveColumns); - ASSERT_EQ(prevWinnerCells, tm.getWinnerCells()); + ASSERT_EQ(prevWinnerCells, tm.getWinnerCells()); - tm.compute(1, activeColumns); + tm.compute(1, activeColumns); - vector synapses = tm.connections.synapsesForSegment(matchingSegment); - ASSERT_EQ(2, synapses.size()); + vector synapses = tm.connections.synapsesForSegment(matchingSegment); + ASSERT_EQ(2, synapses.size()); - SynapseData synapseData = tm.connections.dataForSynapse(synapses[1]); - EXPECT_NEAR(0.21, synapseData.permanence, EPSILON); - EXPECT_EQ(prevWinnerCells[1], synapseData.presynapticCell); - } + SynapseData synapseData = tm.connections.dataForSynapse(synapses[1]); + EXPECT_NEAR(0.21, synapseData.permanence, EPSILON); + EXPECT_EQ(prevWinnerCells[1], synapseData.presynapticCell); +} - /** - * When a segment becomes active, grow synapses to previous winner cells. - * - * The number of grown synapses is calculated from the "matching segment" - * overlap, not the "active segment" overlap. - */ - TEST(TemporalMemoryTest, ActiveSegmentGrowSynapsesAccordingToPotentialOverlap) - { - TemporalMemory tm( +/** + * When a segment becomes active, grow synapses to previous winner cells. + * + * The number of grown synapses is calculated from the "matching segment" + * overlap, not the "active segment" overlap. + */ +TEST(TemporalMemoryTest, ActiveSegmentGrowSynapsesAccordingToPotentialOverlap) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 1, /*activationThreshold*/ 2, @@ -761,42 +697,40 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); + /*seed*/ 42); - // Use 1 cell per column so that we have easy control over the winner cells. - const UInt previousActiveColumns[5] = {0, 1, 2, 3, 4}; - const vector prevWinnerCells = {0, 1, 2, 3, 4}; - const UInt activeColumns[1] = {5}; + // Use 1 cell per column so that we have easy control over the winner cells. + const UInt previousActiveColumns[5] = {0, 1, 2, 3, 4}; + const vector prevWinnerCells = {0, 1, 2, 3, 4}; + const UInt activeColumns[1] = {5}; - Segment activeSegment = tm.createSegment(5); - tm.connections.createSynapse(activeSegment, 0, 0.5); - tm.connections.createSynapse(activeSegment, 1, 0.5); - tm.connections.createSynapse(activeSegment, 2, 0.2); + Segment activeSegment = tm.createSegment(5); + tm.connections.createSynapse(activeSegment, 0, 0.5); + tm.connections.createSynapse(activeSegment, 1, 0.5); + tm.connections.createSynapse(activeSegment, 2, 0.2); - tm.compute(5, previousActiveColumns); + tm.compute(5, previousActiveColumns); - ASSERT_EQ(prevWinnerCells, tm.getWinnerCells()); + ASSERT_EQ(prevWinnerCells, tm.getWinnerCells()); - tm.compute(1, activeColumns); + tm.compute(1, activeColumns); - vector synapses = tm.connections.synapsesForSegment(activeSegment); + vector synapses = tm.connections.synapsesForSegment(activeSegment); - ASSERT_EQ(4, synapses.size()); + ASSERT_EQ(4, synapses.size()); - SynapseData synapseData = tm.connections.dataForSynapse(synapses[3]); - EXPECT_NEAR(0.21, synapseData.permanence, EPSILON); - EXPECT_TRUE(synapseData.presynapticCell == prevWinnerCells[3] || - synapseData.presynapticCell == prevWinnerCells[4]); - } + SynapseData synapseData = tm.connections.dataForSynapse(synapses[3]); + EXPECT_NEAR(0.21, synapseData.permanence, EPSILON); + EXPECT_TRUE(synapseData.presynapticCell == prevWinnerCells[3] || + synapseData.presynapticCell == prevWinnerCells[4]); +} - /** - * When a synapse is punished for contributing to a wrong prediction, if its - * permanence falls to 0 it should be destroyed. - */ - TEST(TemporalMemoryTest, DestroyWeakSynapseOnWrongPrediction) - { - TemporalMemory tm( +/** + * When a synapse is punished for contributing to a wrong prediction, if its + * permanence falls to 0 it should be destroyed. + */ +TEST(TemporalMemoryTest, DestroyWeakSynapseOnWrongPrediction) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -807,36 +741,34 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.02, - /*seed*/ 42 - ); + /*seed*/ 42); - const UInt numActiveColumns = 1; - const UInt previousActiveColumns[1] = {0}; - const vector previousActiveCells = {0, 1, 2, 3}; - const UInt activeColumns[1] = {2}; - const CellIdx expectedActiveCell = 5; + const UInt numActiveColumns = 1; + const UInt previousActiveColumns[1] = {0}; + const vector previousActiveCells = {0, 1, 2, 3}; + const UInt activeColumns[1] = {2}; + const CellIdx expectedActiveCell = 5; - Segment activeSegment = tm.createSegment(expectedActiveCell); - tm.connections.createSynapse(activeSegment, previousActiveCells[0], 0.5); - tm.connections.createSynapse(activeSegment, previousActiveCells[1], 0.5); - tm.connections.createSynapse(activeSegment, previousActiveCells[2], 0.5); + Segment activeSegment = tm.createSegment(expectedActiveCell); + tm.connections.createSynapse(activeSegment, previousActiveCells[0], 0.5); + tm.connections.createSynapse(activeSegment, previousActiveCells[1], 0.5); + tm.connections.createSynapse(activeSegment, previousActiveCells[2], 0.5); - // Weak synapse. - tm.connections.createSynapse(activeSegment, previousActiveCells[3], 0.015); + // Weak synapse. + tm.connections.createSynapse(activeSegment, previousActiveCells[3], 0.015); - tm.compute(numActiveColumns, previousActiveColumns, true); - tm.compute(numActiveColumns, activeColumns, true); + tm.compute(numActiveColumns, previousActiveColumns, true); + tm.compute(numActiveColumns, activeColumns, true); - EXPECT_EQ(3, tm.connections.numSynapses(activeSegment)); - } + EXPECT_EQ(3, tm.connections.numSynapses(activeSegment)); +} - /** - * When a synapse is punished for not contributing to a right prediction, if - * its permanence falls to 0 it should be destroyed. - */ - TEST(TemporalMemoryTest, DestroyWeakSynapseOnActiveReinforce) - { - TemporalMemory tm( +/** + * When a synapse is punished for not contributing to a right prediction, if + * its permanence falls to 0 it should be destroyed. + */ +TEST(TemporalMemoryTest, DestroyWeakSynapseOnActiveReinforce) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -847,36 +779,34 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.02, - /*seed*/ 42 - ); + /*seed*/ 42); - const UInt numActiveColumns = 1; - const UInt previousActiveColumns[1] = {0}; - const vector previousActiveCells = {0, 1, 2, 3}; - const UInt activeColumns[1] = {1}; - const CellIdx activeCell = 5; + const UInt numActiveColumns = 1; + const UInt previousActiveColumns[1] = {0}; + const vector previousActiveCells = {0, 1, 2, 3}; + const UInt activeColumns[1] = {1}; + const CellIdx activeCell = 5; - Segment activeSegment = tm.createSegment(activeCell); - tm.connections.createSynapse(activeSegment, previousActiveCells[0], 0.5); - tm.connections.createSynapse(activeSegment, previousActiveCells[1], 0.5); - tm.connections.createSynapse(activeSegment, previousActiveCells[2], 0.5); + Segment activeSegment = tm.createSegment(activeCell); + tm.connections.createSynapse(activeSegment, previousActiveCells[0], 0.5); + tm.connections.createSynapse(activeSegment, previousActiveCells[1], 0.5); + tm.connections.createSynapse(activeSegment, previousActiveCells[2], 0.5); - // Weak inactive synapse. - tm.connections.createSynapse(activeSegment, 81, 0.09); + // Weak inactive synapse. + tm.connections.createSynapse(activeSegment, 81, 0.09); - tm.compute(numActiveColumns, previousActiveColumns, true); - tm.compute(numActiveColumns, activeColumns, true); + tm.compute(numActiveColumns, previousActiveColumns, true); + tm.compute(numActiveColumns, activeColumns, true); - EXPECT_EQ(3, tm.connections.numSynapses(activeSegment)); - } + EXPECT_EQ(3, tm.connections.numSynapses(activeSegment)); +} - /** - * When a segment adds synapses and it runs over maxSynapsesPerSegment, it - * should make room by destroying synapses with the lowest permanence. - */ - TEST(TemporalMemoryTest, RecycleWeakestSynapseToMakeRoomForNewSynapse) - { - TemporalMemory tm( +/** + * When a segment adds synapses and it runs over maxSynapsesPerSegment, it + * should make room by destroying synapses with the lowest permanence. + */ +TEST(TemporalMemoryTest, RecycleWeakestSynapseToMakeRoomForNewSynapse) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 1, /*activationThreshold*/ 3, @@ -889,57 +819,55 @@ namespace { /*predictedSegmentDecrement*/ 0.0, /*seed*/ 42, /*maxSegmentsPerCell*/ 255, - /*maxSynapsesPerSegment*/ 4 - ); + /*maxSynapsesPerSegment*/ 4); - // Use 1 cell per column so that we have easy control over the winner cells. - const UInt previousActiveColumns[3] = {1, 2, 3}; - const vector prevWinnerCells = {1, 2, 3}; - const UInt activeColumns[1] = {4}; + // Use 1 cell per column so that we have easy control over the winner cells. + const UInt previousActiveColumns[3] = {1, 2, 3}; + const vector prevWinnerCells = {1, 2, 3}; + const UInt activeColumns[1] = {4}; - Segment matchingSegment = tm.createSegment(4); + Segment matchingSegment = tm.createSegment(4); - // Create a weak synapse. Make sure it's not so weak that - // permanenceDecrement destroys it. - tm.connections.createSynapse(matchingSegment, 0, 0.11); + // Create a weak synapse. Make sure it's not so weak that + // permanenceDecrement destroys it. + tm.connections.createSynapse(matchingSegment, 0, 0.11); - // Create a synapse that will match. - tm.connections.createSynapse(matchingSegment, 1, 0.20); + // Create a synapse that will match. + tm.connections.createSynapse(matchingSegment, 1, 0.20); - // Create a synapse with a high permanence. - tm.connections.createSynapse(matchingSegment, 31, 0.6); + // Create a synapse with a high permanence. + tm.connections.createSynapse(matchingSegment, 31, 0.6); - // Activate a synapse on the segment, making it "matching". - tm.compute(3, previousActiveColumns); + // Activate a synapse on the segment, making it "matching". + tm.compute(3, previousActiveColumns); - ASSERT_EQ(prevWinnerCells, tm.getWinnerCells()); + ASSERT_EQ(prevWinnerCells, tm.getWinnerCells()); - // Now mark the segment as "correct" by activating its cell. - tm.compute(1, activeColumns); + // Now mark the segment as "correct" by activating its cell. + tm.compute(1, activeColumns); - // There should now be 3 synapses, and none of them should be to cell 0. - const vector& synapses = + // There should now be 3 synapses, and none of them should be to cell 0. + const vector &synapses = tm.connections.synapsesForSegment(matchingSegment); - ASSERT_EQ(4, synapses.size()); + ASSERT_EQ(4, synapses.size()); - std::set presynapticCells; - for (Synapse synapse : synapses) - { - presynapticCells.insert( + std::set presynapticCells; + for (Synapse synapse : synapses) { + presynapticCells.insert( tm.connections.dataForSynapse(synapse).presynapticCell); - } - - std::set expected = {1, 2, 3, 31}; - EXPECT_EQ(expected, presynapticCells); } - /** - * When a cell adds a segment and it runs over maxSegmentsPerCell, it should - * make room by destroying the least recently active segment. - */ - TEST(TemporalMemoryTest, RecycleLeastRecentlyActiveSegmentToMakeRoomForNewSegment) - { - TemporalMemory tm( + std::set expected = {1, 2, 3, 31}; + EXPECT_EQ(expected, presynapticCells); +} + +/** + * When a cell adds a segment and it runs over maxSegmentsPerCell, it should + * make room by destroying the least recently active segment. + */ +TEST(TemporalMemoryTest, + RecycleLeastRecentlyActiveSegmentToMakeRoomForNewSegment) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 1, /*activationThreshold*/ 3, @@ -951,68 +879,63 @@ namespace { /*permanenceDecrement*/ 0.02, /*predictedSegmentDecrement*/ 0.0, /*seed*/ 42, - /*maxSegmentsPerCell*/ 2 - ); + /*maxSegmentsPerCell*/ 2); - const UInt previousActiveColumns1[3] = {0, 1, 2}; - const UInt previousActiveColumns2[3] = {3, 4, 5}; - const UInt previousActiveColumns3[3] = {6, 7, 8}; - const UInt activeColumns[1] = {9}; + const UInt previousActiveColumns1[3] = {0, 1, 2}; + const UInt previousActiveColumns2[3] = {3, 4, 5}; + const UInt previousActiveColumns3[3] = {6, 7, 8}; + const UInt activeColumns[1] = {9}; - tm.compute(3, previousActiveColumns1); - tm.compute(1, activeColumns); + tm.compute(3, previousActiveColumns1); + tm.compute(1, activeColumns); - ASSERT_EQ(1, tm.connections.numSegments(9)); - Segment oldestSegment = tm.connections.segmentsForCell(9)[0]; + ASSERT_EQ(1, tm.connections.numSegments(9)); + Segment oldestSegment = tm.connections.segmentsForCell(9)[0]; - tm.reset(); - tm.compute(3, previousActiveColumns2); - tm.compute(1, activeColumns); + tm.reset(); + tm.compute(3, previousActiveColumns2); + tm.compute(1, activeColumns); - ASSERT_EQ(2, tm.connections.numSegments(9)); + ASSERT_EQ(2, tm.connections.numSegments(9)); - set oldPresynaptic; - for (Synapse synapse : tm.connections.synapsesForSegment(oldestSegment)) - { - oldPresynaptic.insert( + set oldPresynaptic; + for (Synapse synapse : tm.connections.synapsesForSegment(oldestSegment)) { + oldPresynaptic.insert( tm.connections.dataForSynapse(synapse).presynapticCell); - } + } - tm.reset(); - tm.compute(3, previousActiveColumns3); - tm.compute(1, activeColumns); + tm.reset(); + tm.compute(3, previousActiveColumns3); + tm.compute(1, activeColumns); - ASSERT_EQ(2, tm.connections.numSegments(9)); + ASSERT_EQ(2, tm.connections.numSegments(9)); - // Verify none of the segments are connected to the cells the old segment - // was connected to. + // Verify none of the segments are connected to the cells the old segment + // was connected to. - for (Segment segment : tm.connections.segmentsForCell(9)) - { - set newPresynaptic; - for (Synapse synapse : tm.connections.synapsesForSegment(segment)) - { - newPresynaptic.insert( + for (Segment segment : tm.connections.segmentsForCell(9)) { + set newPresynaptic; + for (Synapse synapse : tm.connections.synapsesForSegment(segment)) { + newPresynaptic.insert( tm.connections.dataForSynapse(synapse).presynapticCell); - } + } - vector intersection; - std::set_intersection(oldPresynaptic.begin(), oldPresynaptic.end(), - newPresynaptic.begin(), newPresynaptic.end(), - std::back_inserter(intersection)); + vector intersection; + std::set_intersection(oldPresynaptic.begin(), oldPresynaptic.end(), + newPresynaptic.begin(), newPresynaptic.end(), + std::back_inserter(intersection)); - vector expected = {}; - EXPECT_EQ(expected, intersection); - } + vector expected = {}; + EXPECT_EQ(expected, intersection); } +} - /** - * When a segment's number of synapses falls to 0, the segment should be - * destroyed. - */ - TEST(TemporalMemoryTest, DestroySegmentsWithTooFewSynapsesToBeMatching) - { - TemporalMemory tm( +/** + * When a segment's number of synapses falls to 0, the segment should be + * destroyed. + */ +TEST(TemporalMemoryTest, DestroySegmentsWithTooFewSynapsesToBeMatching) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -1023,39 +946,37 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.02, - /*seed*/ 42 - ); + /*seed*/ 42); - const UInt numActiveColumns = 1; - const UInt previousActiveColumns[1] = {0}; - const vector previousActiveCells = {0, 1, 2, 3}; - const UInt activeColumns[1] = {2}; - const CellIdx expectedActiveCell = 5; + const UInt numActiveColumns = 1; + const UInt previousActiveColumns[1] = {0}; + const vector previousActiveCells = {0, 1, 2, 3}; + const UInt activeColumns[1] = {2}; + const CellIdx expectedActiveCell = 5; - Segment matchingSegment = tm.createSegment(expectedActiveCell); - tm.connections.createSynapse(matchingSegment, previousActiveCells[0], 0.015); - tm.connections.createSynapse(matchingSegment, previousActiveCells[1], 0.015); - tm.connections.createSynapse(matchingSegment, previousActiveCells[2], 0.015); - tm.connections.createSynapse(matchingSegment, previousActiveCells[3], 0.015); + Segment matchingSegment = tm.createSegment(expectedActiveCell); + tm.connections.createSynapse(matchingSegment, previousActiveCells[0], 0.015); + tm.connections.createSynapse(matchingSegment, previousActiveCells[1], 0.015); + tm.connections.createSynapse(matchingSegment, previousActiveCells[2], 0.015); + tm.connections.createSynapse(matchingSegment, previousActiveCells[3], 0.015); - tm.compute(numActiveColumns, previousActiveColumns, true); - tm.compute(numActiveColumns, activeColumns, true); + tm.compute(numActiveColumns, previousActiveColumns, true); + tm.compute(numActiveColumns, activeColumns, true); - EXPECT_EQ(0, tm.connections.numSegments(expectedActiveCell)); - } + EXPECT_EQ(0, tm.connections.numSegments(expectedActiveCell)); +} - /** - * When a column with a matching segment isn't activated, punish the matching - * segment. - * - * To exercise the implementation: - * - * - Use cells before, between, and after the active columns. - * - Use segments that are matching-but-not-active and matching-and-active. - */ - TEST(TemporalMemoryTest, PunishMatchingSegmentsInInactiveColumns) - { - TemporalMemory tm( +/** + * When a column with a matching segment isn't activated, punish the matching + * segment. + * + * To exercise the implementation: + * + * - Use cells before, between, and after the active columns. + * - Use segments that are matching-but-not-active and matching-and-active. + */ +TEST(TemporalMemoryTest, PunishMatchingSegmentsInInactiveColumns) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -1066,63 +987,60 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.02, - /*seed*/ 42 - ); + /*seed*/ 42); - const UInt numActiveColumns = 1; - const UInt previousActiveColumns[1] = {0}; - const vector previousActiveCells = {0, 1, 2, 3}; - const UInt activeColumns[1] = {1}; - const CellIdx previousInactiveCell = 81; + const UInt numActiveColumns = 1; + const UInt previousActiveColumns[1] = {0}; + const vector previousActiveCells = {0, 1, 2, 3}; + const UInt activeColumns[1] = {1}; + const CellIdx previousInactiveCell = 81; - Segment activeSegment = tm.createSegment(42); - Synapse activeSynapse1 = + Segment activeSegment = tm.createSegment(42); + Synapse activeSynapse1 = tm.connections.createSynapse(activeSegment, previousActiveCells[0], 0.5); - Synapse activeSynapse2 = + Synapse activeSynapse2 = tm.connections.createSynapse(activeSegment, previousActiveCells[1], 0.5); - Synapse activeSynapse3 = + Synapse activeSynapse3 = tm.connections.createSynapse(activeSegment, previousActiveCells[2], 0.5); - Synapse inactiveSynapse1 = + Synapse inactiveSynapse1 = tm.connections.createSynapse(activeSegment, previousInactiveCell, 0.5); - Segment matchingSegment = tm.createSegment(43); - Synapse activeSynapse4 = - tm.connections.createSynapse(matchingSegment, previousActiveCells[0], 0.5); - Synapse activeSynapse5 = - tm.connections.createSynapse(matchingSegment, previousActiveCells[1], 0.5); - Synapse inactiveSynapse2 = + Segment matchingSegment = tm.createSegment(43); + Synapse activeSynapse4 = tm.connections.createSynapse( + matchingSegment, previousActiveCells[0], 0.5); + Synapse activeSynapse5 = tm.connections.createSynapse( + matchingSegment, previousActiveCells[1], 0.5); + Synapse inactiveSynapse2 = tm.connections.createSynapse(matchingSegment, previousInactiveCell, 0.5); - tm.compute(numActiveColumns, previousActiveColumns, true); - tm.compute(numActiveColumns, activeColumns, true); - - EXPECT_NEAR(0.48, tm.connections.dataForSynapse(activeSynapse1).permanence, - EPSILON); - EXPECT_NEAR(0.48, tm.connections.dataForSynapse(activeSynapse2).permanence, - EPSILON); - EXPECT_NEAR(0.48, tm.connections.dataForSynapse(activeSynapse3).permanence, - EPSILON); - EXPECT_NEAR(0.48, tm.connections.dataForSynapse(activeSynapse4).permanence, - EPSILON); - EXPECT_NEAR(0.48, tm.connections.dataForSynapse(activeSynapse5).permanence, - EPSILON); - EXPECT_NEAR(0.50, tm.connections.dataForSynapse(inactiveSynapse1).permanence, - EPSILON); - EXPECT_NEAR(0.50, tm.connections.dataForSynapse(inactiveSynapse2).permanence, - EPSILON); - } - - /** - * In a bursting column with no matching segments, a segment should be added - * to the cell with the fewest segments. When there's a tie, choose randomly. - */ - TEST(TemporalMemoryTest, AddSegmentToCellWithFewestSegments) - { - bool grewOnCell1 = false; - bool grewOnCell2 = false; - for (UInt seed = 0; seed < 100; seed++) - { - TemporalMemory tm( + tm.compute(numActiveColumns, previousActiveColumns, true); + tm.compute(numActiveColumns, activeColumns, true); + + EXPECT_NEAR(0.48, tm.connections.dataForSynapse(activeSynapse1).permanence, + EPSILON); + EXPECT_NEAR(0.48, tm.connections.dataForSynapse(activeSynapse2).permanence, + EPSILON); + EXPECT_NEAR(0.48, tm.connections.dataForSynapse(activeSynapse3).permanence, + EPSILON); + EXPECT_NEAR(0.48, tm.connections.dataForSynapse(activeSynapse4).permanence, + EPSILON); + EXPECT_NEAR(0.48, tm.connections.dataForSynapse(activeSynapse5).permanence, + EPSILON); + EXPECT_NEAR(0.50, tm.connections.dataForSynapse(inactiveSynapse1).permanence, + EPSILON); + EXPECT_NEAR(0.50, tm.connections.dataForSynapse(inactiveSynapse2).permanence, + EPSILON); +} + +/** + * In a bursting column with no matching segments, a segment should be added + * to the cell with the fewest segments. When there's a tie, choose randomly. + */ +TEST(TemporalMemoryTest, AddSegmentToCellWithFewestSegments) { + bool grewOnCell1 = false; + bool grewOnCell2 = false; + for (UInt seed = 0; seed < 100; seed++) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -1133,77 +1051,72 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.02, - /*seed*/ seed - ); - - // enough for 4 winner cells - const UInt previousActiveColumns[4] ={1, 2, 3, 4}; - const UInt activeColumns[1] = {0}; - const vector previousActiveCells = - {4, 5, 6, 7}; // (there are more) - vector nonmatchingCells = {0, 3}; - vector activeCells = {0, 1, 2, 3}; - - Segment segment1 = tm.createSegment(nonmatchingCells[0]); - tm.connections.createSynapse(segment1, previousActiveCells[0], 0.5); - Segment segment2 = tm.createSegment(nonmatchingCells[1]); - tm.connections.createSynapse(segment2, previousActiveCells[1], 0.5); - - tm.compute(4, previousActiveColumns, true); - tm.compute(1, activeColumns, true); - - ASSERT_EQ(activeCells, tm.getActiveCells()); - - EXPECT_EQ(3, tm.connections.numSegments()); - EXPECT_EQ(1, tm.connections.segmentsForCell(0).size()); - EXPECT_EQ(1, tm.connections.segmentsForCell(3).size()); - EXPECT_EQ(1, tm.connections.numSynapses(segment1)); - EXPECT_EQ(1, tm.connections.numSynapses(segment2)); - - vector segments = tm.connections.segmentsForCell(1); - if (segments.empty()) - { - vector segments2 = tm.connections.segmentsForCell(2); - EXPECT_FALSE(segments2.empty()); - grewOnCell2 = true; - segments.insert(segments.end(), segments2.begin(), segments2.end()); - } - else - { - grewOnCell1 = true; - } - - ASSERT_EQ(1, segments.size()); - vector synapses = tm.connections.synapsesForSegment(segments[0]); - EXPECT_EQ(4, synapses.size()); - - set columnChecklist(previousActiveColumns, previousActiveColumns+4); - - for (Synapse synapse : synapses) - { - SynapseData synapseData = tm.connections.dataForSynapse(synapse); - EXPECT_NEAR(0.2, synapseData.permanence, EPSILON); - - UInt32 column = (UInt)tm.columnForCell(synapseData.presynapticCell); - auto position = columnChecklist.find(column); - EXPECT_NE(columnChecklist.end(), position); - columnChecklist.erase(position); - } - EXPECT_TRUE(columnChecklist.empty()); + /*seed*/ seed); + + // enough for 4 winner cells + const UInt previousActiveColumns[4] = {1, 2, 3, 4}; + const UInt activeColumns[1] = {0}; + const vector previousActiveCells = {4, 5, 6, + 7}; // (there are more) + vector nonmatchingCells = {0, 3}; + vector activeCells = {0, 1, 2, 3}; + + Segment segment1 = tm.createSegment(nonmatchingCells[0]); + tm.connections.createSynapse(segment1, previousActiveCells[0], 0.5); + Segment segment2 = tm.createSegment(nonmatchingCells[1]); + tm.connections.createSynapse(segment2, previousActiveCells[1], 0.5); + + tm.compute(4, previousActiveColumns, true); + tm.compute(1, activeColumns, true); + + ASSERT_EQ(activeCells, tm.getActiveCells()); + + EXPECT_EQ(3, tm.connections.numSegments()); + EXPECT_EQ(1, tm.connections.segmentsForCell(0).size()); + EXPECT_EQ(1, tm.connections.segmentsForCell(3).size()); + EXPECT_EQ(1, tm.connections.numSynapses(segment1)); + EXPECT_EQ(1, tm.connections.numSynapses(segment2)); + + vector segments = tm.connections.segmentsForCell(1); + if (segments.empty()) { + vector segments2 = tm.connections.segmentsForCell(2); + EXPECT_FALSE(segments2.empty()); + grewOnCell2 = true; + segments.insert(segments.end(), segments2.begin(), segments2.end()); + } else { + grewOnCell1 = true; } - EXPECT_TRUE(grewOnCell1); - EXPECT_TRUE(grewOnCell2); + ASSERT_EQ(1, segments.size()); + vector synapses = tm.connections.synapsesForSegment(segments[0]); + EXPECT_EQ(4, synapses.size()); + + set columnChecklist(previousActiveColumns, + previousActiveColumns + 4); + + for (Synapse synapse : synapses) { + SynapseData synapseData = tm.connections.dataForSynapse(synapse); + EXPECT_NEAR(0.2, synapseData.permanence, EPSILON); + + UInt32 column = (UInt)tm.columnForCell(synapseData.presynapticCell); + auto position = columnChecklist.find(column); + EXPECT_NE(columnChecklist.end(), position); + columnChecklist.erase(position); + } + EXPECT_TRUE(columnChecklist.empty()); } - /** - * When the best matching segment has more than maxNewSynapseCount matching - * synapses, don't grow new synapses. This test is specifically aimed at - * unexpected behavior with negative numbers and unsigned integers. - */ - TEST(TemporalMemoryTest, MaxNewSynapseCountOverflow) - { - TemporalMemory tm( + EXPECT_TRUE(grewOnCell1); + EXPECT_TRUE(grewOnCell2); +} + +/** + * When the best matching segment has more than maxNewSynapseCount matching + * synapses, don't grow new synapses. This test is specifically aimed at + * unexpected behavior with negative numbers and unsigned integers. + */ +TEST(TemporalMemoryTest, MaxNewSynapseCountOverflow) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -1214,42 +1127,40 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.02, - /*seed*/ 42 - ); + /*seed*/ 42); - Segment segment = tm.createSegment(8); - tm.connections.createSynapse(segment, 0, 0.2); - tm.connections.createSynapse(segment, 1, 0.2); - tm.connections.createSynapse(segment, 2, 0.2); - tm.connections.createSynapse(segment, 3, 0.2); - tm.connections.createSynapse(segment, 4, 0.2); - Synapse sampleSynapse = tm.connections.createSynapse(segment, 5, 0.2); - tm.connections.createSynapse(segment, 6, 0.2); - tm.connections.createSynapse(segment, 7, 0.2); + Segment segment = tm.createSegment(8); + tm.connections.createSynapse(segment, 0, 0.2); + tm.connections.createSynapse(segment, 1, 0.2); + tm.connections.createSynapse(segment, 2, 0.2); + tm.connections.createSynapse(segment, 3, 0.2); + tm.connections.createSynapse(segment, 4, 0.2); + Synapse sampleSynapse = tm.connections.createSynapse(segment, 5, 0.2); + tm.connections.createSynapse(segment, 6, 0.2); + tm.connections.createSynapse(segment, 7, 0.2); - const UInt previousActiveColumns[4] = {0, 1, 3, 4}; - tm.compute(4, previousActiveColumns); + const UInt previousActiveColumns[4] = {0, 1, 3, 4}; + tm.compute(4, previousActiveColumns); - ASSERT_EQ(1, tm.getMatchingSegments().size()); + ASSERT_EQ(1, tm.getMatchingSegments().size()); - const UInt activeColumns[1] = {2}; - tm.compute(1, activeColumns); + const UInt activeColumns[1] = {2}; + tm.compute(1, activeColumns); - // Make sure the segment has learned. - ASSERT_NEAR(0.3, tm.connections.dataForSynapse(sampleSynapse).permanence, - EPSILON); + // Make sure the segment has learned. + ASSERT_NEAR(0.3, tm.connections.dataForSynapse(sampleSynapse).permanence, + EPSILON); - EXPECT_EQ(8, tm.connections.numSynapses(segment)); - } + EXPECT_EQ(8, tm.connections.numSynapses(segment)); +} - /** - * With learning disabled, generate some predicted active columns, predicted - * inactive columns, and nonpredicted active columns. The connections should - * not change. - */ - TEST(TemporalMemoryTest, ConnectionsNeverChangeWhenLearningDisabled) - { - TemporalMemory tm( +/** + * With learning disabled, generate some predicted active columns, predicted + * inactive columns, and nonpredicted active columns. The connections should + * not change. + */ +TEST(TemporalMemoryTest, ConnectionsNeverChangeWhenLearningDisabled) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -1260,50 +1171,46 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.02, - /*seed*/ 42 - ); + /*seed*/ 42); - const UInt previousActiveColumns[1] = {0}; - const vector previousActiveCells = {0, 1, 2, 3}; - const UInt activeColumns[2] = { + const UInt previousActiveColumns[1] = {0}; + const vector previousActiveCells = {0, 1, 2, 3}; + const UInt activeColumns[2] = { 1, // predicted 2 // bursting - }; - const CellIdx previousInactiveCell = 81; - const vector expectedActiveCells = {4}; - - Segment correctActiveSegment = - tm.createSegment(expectedActiveCells[0]); - tm.connections.createSynapse(correctActiveSegment, - previousActiveCells[0], 0.5); - tm.connections.createSynapse(correctActiveSegment, - previousActiveCells[1], 0.5); - tm.connections.createSynapse(correctActiveSegment, - previousActiveCells[2], 0.5); - - Segment wrongMatchingSegment = tm.createSegment(43); - tm.connections.createSynapse(wrongMatchingSegment, - previousActiveCells[0], 0.5); - tm.connections.createSynapse(wrongMatchingSegment, - previousActiveCells[1], 0.5); - tm.connections.createSynapse(wrongMatchingSegment, - previousInactiveCell, 0.5); - - Connections before = tm.connections; - - tm.compute(1, previousActiveColumns, false); - tm.compute(2, activeColumns, false); - - EXPECT_EQ(before, tm.connections); - } - - /** - * Destroy some segments then verify that the maxSegmentsPerCell is still - * correctly applied. - */ - TEST(TemporalMemoryTest, DestroySegmentsThenReachLimit) - { - TemporalMemory tm( + }; + const CellIdx previousInactiveCell = 81; + const vector expectedActiveCells = {4}; + + Segment correctActiveSegment = tm.createSegment(expectedActiveCells[0]); + tm.connections.createSynapse(correctActiveSegment, previousActiveCells[0], + 0.5); + tm.connections.createSynapse(correctActiveSegment, previousActiveCells[1], + 0.5); + tm.connections.createSynapse(correctActiveSegment, previousActiveCells[2], + 0.5); + + Segment wrongMatchingSegment = tm.createSegment(43); + tm.connections.createSynapse(wrongMatchingSegment, previousActiveCells[0], + 0.5); + tm.connections.createSynapse(wrongMatchingSegment, previousActiveCells[1], + 0.5); + tm.connections.createSynapse(wrongMatchingSegment, previousInactiveCell, 0.5); + + Connections before = tm.connections; + + tm.compute(1, previousActiveColumns, false); + tm.compute(2, activeColumns, false); + + EXPECT_EQ(before, tm.connections); +} + +/** + * Destroy some segments then verify that the maxSegmentsPerCell is still + * correctly applied. + */ +TEST(TemporalMemoryTest, DestroySegmentsThenReachLimit) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 1, /*activationThreshold*/ 3, @@ -1315,37 +1222,35 @@ namespace { /*permanenceDecrement*/ 0.02, /*predictedSegmentDecrement*/ 0.0, /*seed*/ 42, - /*maxSegmentsPerCell*/ 2 - ); - - { - Segment segment1 = tm.createSegment(11); - Segment segment2 = tm.createSegment(11); - ASSERT_EQ(2, tm.connections.numSegments()); - tm.connections.destroySegment(segment1); - tm.connections.destroySegment(segment2); - ASSERT_EQ(0, tm.connections.numSegments()); - } + /*maxSegmentsPerCell*/ 2); - { - tm.createSegment(11); - EXPECT_EQ(1, tm.connections.numSegments()); - tm.createSegment(11); - EXPECT_EQ(2, tm.connections.numSegments()); - tm.createSegment(11); - EXPECT_EQ(2, tm.connections.numSegments()); - EXPECT_EQ(2, tm.connections.numSegments(11)); - } + { + Segment segment1 = tm.createSegment(11); + Segment segment2 = tm.createSegment(11); + ASSERT_EQ(2, tm.connections.numSegments()); + tm.connections.destroySegment(segment1); + tm.connections.destroySegment(segment2); + ASSERT_EQ(0, tm.connections.numSegments()); } - /** - * Creates many segments on a cell, until hits segment limit. Then creates - * another segment, and checks that it destroyed the least recently used - * segment and created a new one in its place. - */ - TEST(TemporalMemoryTest, CreateSegmentDestroyOld) { - TemporalMemory tm( + tm.createSegment(11); + EXPECT_EQ(1, tm.connections.numSegments()); + tm.createSegment(11); + EXPECT_EQ(2, tm.connections.numSegments()); + tm.createSegment(11); + EXPECT_EQ(2, tm.connections.numSegments()); + EXPECT_EQ(2, tm.connections.numSegments(11)); + } +} + +/** + * Creates many segments on a cell, until hits segment limit. Then creates + * another segment, and checks that it destroyed the least recently used + * segment and created a new one in its place. + */ +TEST(TemporalMemoryTest, CreateSegmentDestroyOld) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 1, /*activationThreshold*/ 3, @@ -1357,54 +1262,52 @@ namespace { /*permanenceDecrement*/ 0.02, /*predictedSegmentDecrement*/ 0.0, /*seed*/ 42, - /*maxSegmentsPerCell*/ 2 - ); + /*maxSegmentsPerCell*/ 2); - Segment segment1 = tm.createSegment(12); + Segment segment1 = tm.createSegment(12); - tm.connections.createSynapse(segment1, 1, 0.5); - tm.connections.createSynapse(segment1, 2, 0.5); - tm.connections.createSynapse(segment1, 3, 0.5); + tm.connections.createSynapse(segment1, 1, 0.5); + tm.connections.createSynapse(segment1, 2, 0.5); + tm.connections.createSynapse(segment1, 3, 0.5); - // Let some time pass. - tm.compute(0, nullptr); - tm.compute(0, nullptr); - tm.compute(0, nullptr); + // Let some time pass. + tm.compute(0, nullptr); + tm.compute(0, nullptr); + tm.compute(0, nullptr); - // Create a segment with 1 synapse. - Segment segment2 = tm.createSegment(12); - tm.connections.createSynapse(segment2, 1, 0.5); + // Create a segment with 1 synapse. + Segment segment2 = tm.createSegment(12); + tm.connections.createSynapse(segment2, 1, 0.5); - tm.compute(0, nullptr); + tm.compute(0, nullptr); - // Give the first segment some activity. - const UInt activeColumns[3] = {1, 2, 3}; - tm.compute(3, activeColumns); + // Give the first segment some activity. + const UInt activeColumns[3] = {1, 2, 3}; + tm.compute(3, activeColumns); - // Create a new segment with no synapses. - tm.createSegment(12); + // Create a new segment with no synapses. + tm.createSegment(12); - vector segments = tm.connections.segmentsForCell(12); - ASSERT_EQ(2, segments.size()); + vector segments = tm.connections.segmentsForCell(12); + ASSERT_EQ(2, segments.size()); - // Verify first segment is still there with the same synapses. - vector synapses1 = tm.connections.synapsesForSegment(segments[0]); - ASSERT_EQ(3, synapses1.size()); - ASSERT_EQ(1, tm.connections.dataForSynapse(synapses1[0]).presynapticCell); - ASSERT_EQ(2, tm.connections.dataForSynapse(synapses1[1]).presynapticCell); - ASSERT_EQ(3, tm.connections.dataForSynapse(synapses1[2]).presynapticCell); + // Verify first segment is still there with the same synapses. + vector synapses1 = tm.connections.synapsesForSegment(segments[0]); + ASSERT_EQ(3, synapses1.size()); + ASSERT_EQ(1, tm.connections.dataForSynapse(synapses1[0]).presynapticCell); + ASSERT_EQ(2, tm.connections.dataForSynapse(synapses1[1]).presynapticCell); + ASSERT_EQ(3, tm.connections.dataForSynapse(synapses1[2]).presynapticCell); - // Verify second segment has been replaced. - ASSERT_EQ(0, tm.connections.numSynapses(segments[1])); - } + // Verify second segment has been replaced. + ASSERT_EQ(0, tm.connections.numSynapses(segments[1])); +} - /** - * Hit the maxSegmentsPerCell threshold multiple times. Make sure it works - * more than once. - */ - TEST(ConnectionsTest, ReachSegmentLimitMultipleTimes) - { - TemporalMemory tm( +/** + * Hit the maxSegmentsPerCell threshold multiple times. Make sure it works + * more than once. + */ +TEST(ConnectionsTest, ReachSegmentLimitMultipleTimes) { + TemporalMemory tm( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 1, /*activationThreshold*/ 3, @@ -1416,221 +1319,177 @@ namespace { /*permanenceDecrement*/ 0.02, /*predictedSegmentDecrement*/ 0.0, /*seed*/ 42, - /*maxSegmentsPerCell*/ 2 - ); - - tm.createSegment(10); - ASSERT_EQ(1, tm.connections.numSegments()); - tm.createSegment(10); - ASSERT_EQ(2, tm.connections.numSegments()); - tm.createSegment(10); - ASSERT_EQ(2, tm.connections.numSegments()); - tm.createSegment(10); - EXPECT_EQ(2, tm.connections.numSegments()); - } - - TEST(TemporalMemoryTest, testColumnForCell1D) - { - TemporalMemory tm; - tm.initialize(vector{2048}, 5); - - ASSERT_EQ(0, tm.columnForCell(0)); - ASSERT_EQ(0, tm.columnForCell(4)); - ASSERT_EQ(1, tm.columnForCell(5)); - ASSERT_EQ(2047, tm.columnForCell(10239)); - } - - TEST(TemporalMemoryTest, testColumnForCell2D) - { - TemporalMemory tm; - tm.initialize(vector{64, 64}, 4); - - ASSERT_EQ(0, tm.columnForCell(0)); - ASSERT_EQ(0, tm.columnForCell(3)); - ASSERT_EQ(1, tm.columnForCell(4)); - ASSERT_EQ(4095, tm.columnForCell(16383)); - } - - TEST(TemporalMemoryTest, testColumnForCellInvalidCell) - { - TemporalMemory tm; - tm.initialize(vector{64, 64}, 4); - - EXPECT_NO_THROW(tm.columnForCell(16383)); - EXPECT_THROW(tm.columnForCell(16384), std::exception); - EXPECT_THROW(tm.columnForCell(-1), std::exception); - } - - TEST(TemporalMemoryTest, testNumberOfColumns) - { - TemporalMemory tm; - tm.initialize(vector{64, 64}, 32); - - int numOfColumns = tm.numberOfColumns(); - ASSERT_EQ(numOfColumns, 64 * 64); - } - - TEST(TemporalMemoryTest, testNumberOfCells) - { - TemporalMemory tm; - tm.initialize(vector{64, 64}, 32); - - Int numberOfCells = tm.numberOfCells(); - ASSERT_EQ(numberOfCells, 64 * 64 * 32); - } - - void serializationTestPrepare(TemporalMemory& tm) - { - // Create an active segment and a two matching segments. - // Destroy a few to exercise the code. - Segment destroyMe1 = tm.createSegment(4); - tm.connections.destroySegment(destroyMe1); - - Segment activeSegment = tm.createSegment(4); - tm.connections.createSynapse(activeSegment, 0, 0.5); - tm.connections.createSynapse(activeSegment, 1, 0.5); - Synapse destroyMe2 = tm.connections.createSynapse(activeSegment, 42, 0.5); - tm.connections.destroySynapse(destroyMe2); - tm.connections.createSynapse(activeSegment, 2, 0.5); - tm.connections.createSynapse(activeSegment, 3, 0.5); - - Segment matchingSegment1 = tm.createSegment(8); - tm.connections.createSynapse(matchingSegment1, 0, 0.4); - tm.connections.createSynapse(matchingSegment1, 1, 0.4); - tm.connections.createSynapse(matchingSegment1, 2, 0.4); - - Segment matchingSegment2 = tm.createSegment(9); - tm.connections.createSynapse(matchingSegment2, 0, 0.4); - tm.connections.createSynapse(matchingSegment2, 1, 0.4); - tm.connections.createSynapse(matchingSegment2, 2, 0.4); - tm.connections.createSynapse(matchingSegment2, 3, 0.4); - - UInt activeColumns[] = {0}; - tm.compute(1, activeColumns); - - ASSERT_EQ(1, tm.getActiveSegments().size()); - ASSERT_EQ(3, tm.getMatchingSegments().size()); - } - - void serializationTestVerify(TemporalMemory& tm) - { - // Activate 3 columns. One has an active segment, one has two - // matching segments, and one has none. One column should be - // predicted, the others should burst, there should be four - // segments total, and they should have the correct permanences - // and synapse counts. - - const vector prevWinnerCells = tm.getWinnerCells(); - ASSERT_EQ(1, prevWinnerCells.size()); - - UInt activeColumns[] = {1, 2, 3}; - tm.compute(3, activeColumns); - - // Verify the correct cells were activated. - EXPECT_EQ((vector{4, 8, 9, 10, 11, 12, 13, 14, 15}), - tm.getActiveCells()); - const vector winnerCells = tm.getWinnerCells(); - ASSERT_EQ(3, winnerCells.size()); - EXPECT_EQ(4, winnerCells[0]); - EXPECT_EQ(9, winnerCells[1]); - - EXPECT_EQ(4, tm.connections.numSegments()); - - // Verify the active segment learned. - ASSERT_EQ(1, tm.connections.numSegments(4)); - Segment activeSegment = tm.connections.segmentsForCell(4)[0]; - const vector syns1 = + /*maxSegmentsPerCell*/ 2); + + tm.createSegment(10); + ASSERT_EQ(1, tm.connections.numSegments()); + tm.createSegment(10); + ASSERT_EQ(2, tm.connections.numSegments()); + tm.createSegment(10); + ASSERT_EQ(2, tm.connections.numSegments()); + tm.createSegment(10); + EXPECT_EQ(2, tm.connections.numSegments()); +} + +TEST(TemporalMemoryTest, testColumnForCell1D) { + TemporalMemory tm; + tm.initialize(vector{2048}, 5); + + ASSERT_EQ(0, tm.columnForCell(0)); + ASSERT_EQ(0, tm.columnForCell(4)); + ASSERT_EQ(1, tm.columnForCell(5)); + ASSERT_EQ(2047, tm.columnForCell(10239)); +} + +TEST(TemporalMemoryTest, testColumnForCell2D) { + TemporalMemory tm; + tm.initialize(vector{64, 64}, 4); + + ASSERT_EQ(0, tm.columnForCell(0)); + ASSERT_EQ(0, tm.columnForCell(3)); + ASSERT_EQ(1, tm.columnForCell(4)); + ASSERT_EQ(4095, tm.columnForCell(16383)); +} + +TEST(TemporalMemoryTest, testColumnForCellInvalidCell) { + TemporalMemory tm; + tm.initialize(vector{64, 64}, 4); + + EXPECT_NO_THROW(tm.columnForCell(16383)); + EXPECT_THROW(tm.columnForCell(16384), std::exception); + EXPECT_THROW(tm.columnForCell(-1), std::exception); +} + +TEST(TemporalMemoryTest, testNumberOfColumns) { + TemporalMemory tm; + tm.initialize(vector{64, 64}, 32); + + int numOfColumns = tm.numberOfColumns(); + ASSERT_EQ(numOfColumns, 64 * 64); +} + +TEST(TemporalMemoryTest, testNumberOfCells) { + TemporalMemory tm; + tm.initialize(vector{64, 64}, 32); + + Int numberOfCells = tm.numberOfCells(); + ASSERT_EQ(numberOfCells, 64 * 64 * 32); +} + +void serializationTestPrepare(TemporalMemory &tm) { + // Create an active segment and a two matching segments. + // Destroy a few to exercise the code. + Segment destroyMe1 = tm.createSegment(4); + tm.connections.destroySegment(destroyMe1); + + Segment activeSegment = tm.createSegment(4); + tm.connections.createSynapse(activeSegment, 0, 0.5); + tm.connections.createSynapse(activeSegment, 1, 0.5); + Synapse destroyMe2 = tm.connections.createSynapse(activeSegment, 42, 0.5); + tm.connections.destroySynapse(destroyMe2); + tm.connections.createSynapse(activeSegment, 2, 0.5); + tm.connections.createSynapse(activeSegment, 3, 0.5); + + Segment matchingSegment1 = tm.createSegment(8); + tm.connections.createSynapse(matchingSegment1, 0, 0.4); + tm.connections.createSynapse(matchingSegment1, 1, 0.4); + tm.connections.createSynapse(matchingSegment1, 2, 0.4); + + Segment matchingSegment2 = tm.createSegment(9); + tm.connections.createSynapse(matchingSegment2, 0, 0.4); + tm.connections.createSynapse(matchingSegment2, 1, 0.4); + tm.connections.createSynapse(matchingSegment2, 2, 0.4); + tm.connections.createSynapse(matchingSegment2, 3, 0.4); + + UInt activeColumns[] = {0}; + tm.compute(1, activeColumns); + + ASSERT_EQ(1, tm.getActiveSegments().size()); + ASSERT_EQ(3, tm.getMatchingSegments().size()); +} + +void serializationTestVerify(TemporalMemory &tm) { + // Activate 3 columns. One has an active segment, one has two + // matching segments, and one has none. One column should be + // predicted, the others should burst, there should be four + // segments total, and they should have the correct permanences + // and synapse counts. + + const vector prevWinnerCells = tm.getWinnerCells(); + ASSERT_EQ(1, prevWinnerCells.size()); + + UInt activeColumns[] = {1, 2, 3}; + tm.compute(3, activeColumns); + + // Verify the correct cells were activated. + EXPECT_EQ((vector{4, 8, 9, 10, 11, 12, 13, 14, 15}), + tm.getActiveCells()); + const vector winnerCells = tm.getWinnerCells(); + ASSERT_EQ(3, winnerCells.size()); + EXPECT_EQ(4, winnerCells[0]); + EXPECT_EQ(9, winnerCells[1]); + + EXPECT_EQ(4, tm.connections.numSegments()); + + // Verify the active segment learned. + ASSERT_EQ(1, tm.connections.numSegments(4)); + Segment activeSegment = tm.connections.segmentsForCell(4)[0]; + const vector syns1 = tm.connections.synapsesForSegment(activeSegment); - ASSERT_EQ(4, syns1.size()); - EXPECT_EQ(0, - tm.connections.dataForSynapse(syns1[0]).presynapticCell); - EXPECT_NEAR(0.6, - tm.connections.dataForSynapse(syns1[0]).permanence, - EPSILON); - EXPECT_EQ(1, - tm.connections.dataForSynapse(syns1[1]).presynapticCell); - EXPECT_NEAR(0.6, - tm.connections.dataForSynapse(syns1[1]).permanence, - EPSILON); - EXPECT_EQ(2, - tm.connections.dataForSynapse(syns1[2]).presynapticCell); - EXPECT_NEAR(0.6, - tm.connections.dataForSynapse(syns1[2]).permanence, - EPSILON); - EXPECT_EQ(3, - tm.connections.dataForSynapse(syns1[3]).presynapticCell); - EXPECT_NEAR(0.6, - tm.connections.dataForSynapse(syns1[3]).permanence, - EPSILON); - - // Verify the non-best matching segment is unchanged. - ASSERT_EQ(1, tm.connections.numSegments(8)); - Segment matchingSegment1 = tm.connections.segmentsForCell(8)[0]; - const vector syns2 = + ASSERT_EQ(4, syns1.size()); + EXPECT_EQ(0, tm.connections.dataForSynapse(syns1[0]).presynapticCell); + EXPECT_NEAR(0.6, tm.connections.dataForSynapse(syns1[0]).permanence, EPSILON); + EXPECT_EQ(1, tm.connections.dataForSynapse(syns1[1]).presynapticCell); + EXPECT_NEAR(0.6, tm.connections.dataForSynapse(syns1[1]).permanence, EPSILON); + EXPECT_EQ(2, tm.connections.dataForSynapse(syns1[2]).presynapticCell); + EXPECT_NEAR(0.6, tm.connections.dataForSynapse(syns1[2]).permanence, EPSILON); + EXPECT_EQ(3, tm.connections.dataForSynapse(syns1[3]).presynapticCell); + EXPECT_NEAR(0.6, tm.connections.dataForSynapse(syns1[3]).permanence, EPSILON); + + // Verify the non-best matching segment is unchanged. + ASSERT_EQ(1, tm.connections.numSegments(8)); + Segment matchingSegment1 = tm.connections.segmentsForCell(8)[0]; + const vector syns2 = tm.connections.synapsesForSegment(matchingSegment1); - ASSERT_EQ(3, syns2.size()); - EXPECT_EQ(0, - tm.connections.dataForSynapse(syns2[0]).presynapticCell); - EXPECT_NEAR(0.4, - tm.connections.dataForSynapse(syns2[0]).permanence, - EPSILON); - EXPECT_EQ(1, - tm.connections.dataForSynapse(syns2[1]).presynapticCell); - EXPECT_NEAR(0.4, - tm.connections.dataForSynapse(syns2[1]).permanence, - EPSILON); - EXPECT_EQ(2, - tm.connections.dataForSynapse(syns2[2]).presynapticCell); - EXPECT_NEAR(0.4, - tm.connections.dataForSynapse(syns2[2]).permanence, - EPSILON); - - // Verify the best matching segment learned. - ASSERT_EQ(1, tm.connections.numSegments(9)); - Segment matchingSegment2 = tm.connections.segmentsForCell(9)[0]; - const vector syns3 = + ASSERT_EQ(3, syns2.size()); + EXPECT_EQ(0, tm.connections.dataForSynapse(syns2[0]).presynapticCell); + EXPECT_NEAR(0.4, tm.connections.dataForSynapse(syns2[0]).permanence, EPSILON); + EXPECT_EQ(1, tm.connections.dataForSynapse(syns2[1]).presynapticCell); + EXPECT_NEAR(0.4, tm.connections.dataForSynapse(syns2[1]).permanence, EPSILON); + EXPECT_EQ(2, tm.connections.dataForSynapse(syns2[2]).presynapticCell); + EXPECT_NEAR(0.4, tm.connections.dataForSynapse(syns2[2]).permanence, EPSILON); + + // Verify the best matching segment learned. + ASSERT_EQ(1, tm.connections.numSegments(9)); + Segment matchingSegment2 = tm.connections.segmentsForCell(9)[0]; + const vector syns3 = tm.connections.synapsesForSegment(matchingSegment2); - ASSERT_EQ(4, syns3.size()); - EXPECT_EQ(0, - tm.connections.dataForSynapse(syns3[0]).presynapticCell); - EXPECT_NEAR(0.5, - tm.connections.dataForSynapse(syns3[0]).permanence, - EPSILON); - EXPECT_EQ(1, - tm.connections.dataForSynapse(syns3[1]).presynapticCell); - EXPECT_NEAR(0.5, - tm.connections.dataForSynapse(syns3[1]).permanence, - EPSILON); - EXPECT_EQ(2, - tm.connections.dataForSynapse(syns3[2]).presynapticCell); - EXPECT_NEAR(0.5, - tm.connections.dataForSynapse(syns3[2]).permanence, - EPSILON); - EXPECT_EQ(3, - tm.connections.dataForSynapse(syns3[3]).presynapticCell); - EXPECT_NEAR(0.5, - tm.connections.dataForSynapse(syns3[3]).permanence, - EPSILON); - - // Verify the winner cell in the last column grew a segment. - const UInt winnerCell = winnerCells[2]; - EXPECT_GE(winnerCell, 12); - EXPECT_LT(winnerCell, 16); - ASSERT_EQ(1, tm.connections.numSegments(winnerCell)); - Segment newSegment = tm.connections.segmentsForCell(winnerCell)[0]; - const vector syns4 = - tm.connections.synapsesForSegment(newSegment); - ASSERT_EQ(1, syns4.size()); - EXPECT_EQ(prevWinnerCells[0], - tm.connections.dataForSynapse(syns4[0]).presynapticCell); - EXPECT_NEAR(0.21, - tm.connections.dataForSynapse(syns4[0]).permanence, - EPSILON); - } - - TEST(TemporalMemoryTest, testSaveLoad) - { - TemporalMemory tm1( + ASSERT_EQ(4, syns3.size()); + EXPECT_EQ(0, tm.connections.dataForSynapse(syns3[0]).presynapticCell); + EXPECT_NEAR(0.5, tm.connections.dataForSynapse(syns3[0]).permanence, EPSILON); + EXPECT_EQ(1, tm.connections.dataForSynapse(syns3[1]).presynapticCell); + EXPECT_NEAR(0.5, tm.connections.dataForSynapse(syns3[1]).permanence, EPSILON); + EXPECT_EQ(2, tm.connections.dataForSynapse(syns3[2]).presynapticCell); + EXPECT_NEAR(0.5, tm.connections.dataForSynapse(syns3[2]).permanence, EPSILON); + EXPECT_EQ(3, tm.connections.dataForSynapse(syns3[3]).presynapticCell); + EXPECT_NEAR(0.5, tm.connections.dataForSynapse(syns3[3]).permanence, EPSILON); + + // Verify the winner cell in the last column grew a segment. + const UInt winnerCell = winnerCells[2]; + EXPECT_GE(winnerCell, 12); + EXPECT_LT(winnerCell, 16); + ASSERT_EQ(1, tm.connections.numSegments(winnerCell)); + Segment newSegment = tm.connections.segmentsForCell(winnerCell)[0]; + const vector syns4 = tm.connections.synapsesForSegment(newSegment); + ASSERT_EQ(1, syns4.size()); + EXPECT_EQ(prevWinnerCells[0], + tm.connections.dataForSynapse(syns4[0]).presynapticCell); + EXPECT_NEAR(0.21, tm.connections.dataForSynapse(syns4[0]).permanence, + EPSILON); +} + +TEST(TemporalMemoryTest, testSaveLoad) { + TemporalMemory tm1( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -1641,25 +1500,23 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); + /*seed*/ 42); - serializationTestPrepare(tm1); + serializationTestPrepare(tm1); - stringstream ss; - tm1.save(ss); + stringstream ss; + tm1.save(ss); - TemporalMemory tm2; - tm2.load(ss); + TemporalMemory tm2; + tm2.load(ss); - ASSERT_TRUE(tm1 == tm2); + ASSERT_TRUE(tm1 == tm2); - serializationTestVerify(tm2); - } + serializationTestVerify(tm2); +} - TEST(TemporalMemoryTest, testWrite) - { - TemporalMemory tm1( +TEST(TemporalMemoryTest, testWrite) { + TemporalMemory tm1( /*columnDimensions*/ {32}, /*cellsPerColumn*/ 4, /*activationThreshold*/ 3, @@ -1670,95 +1527,94 @@ namespace { /*permanenceIncrement*/ 0.10, /*permanenceDecrement*/ 0.10, /*predictedSegmentDecrement*/ 0.0, - /*seed*/ 42 - ); - - serializationTestPrepare(tm1); - - // Write and read back the proto - stringstream ss; - tm1.write(ss); - - TemporalMemory tm2; - tm2.read(ss); - - ASSERT_TRUE(tm1 == tm2); - - serializationTestVerify(tm2); - } - - // Uncomment these tests individually to save/load from a file. - // This is useful for ad-hoc testing of backwards-compatibility. - - // TEST(TemporalMemoryTest, saveTestFile) - // { - // TemporalMemory tm( - // /*columnDimensions*/ {32}, - // /*cellsPerColumn*/ 4, - // /*activationThreshold*/ 3, - // /*initialPermanence*/ 0.21, - // /*connectedPermanence*/ 0.50, - // /*minThreshold*/ 2, - // /*maxNewSynapseCount*/ 3, - // /*permanenceIncrement*/ 0.10, - // /*permanenceDecrement*/ 0.10, - // /*predictedSegmentDecrement*/ 0.0, - // /*seed*/ 42 - // ); - // - // serializationTestPrepare(tm); - // - // const char* filename = "TemporalMemorySerializationSave.tmp"; - // ofstream outfile; - // outfile.open(filename, ios::binary); - // tm.save(outfile); - // outfile.close(); - // } - - // TEST(TemporalMemoryTest, loadTestFile) - // { - // TemporalMemory tm; - // const char* filename = "TemporalMemorySerializationSave.tmp"; - // ifstream infile(filename, ios::binary); - // tm.load(infile); - // infile.close(); - // - // serializationTestVerify(tm); - // } - - // TEST(TemporalMemoryTest, writeTestFile) - // { - // TemporalMemory tm( - // /*columnDimensions*/ {32}, - // /*cellsPerColumn*/ 4, - // /*activationThreshold*/ 3, - // /*initialPermanence*/ 0.21, - // /*connectedPermanence*/ 0.50, - // /*minThreshold*/ 2, - // /*maxNewSynapseCount*/ 3, - // /*permanenceIncrement*/ 0.10, - // /*permanenceDecrement*/ 0.10, - // /*predictedSegmentDecrement*/ 0.0, - // /*seed*/ 42 - // ); - - // serializationTestPrepare(tm); - - // const char* filename = "TemporalMemorySerializationWrite.tmp"; - // ofstream outfile; - // outfile.open(filename, ios::binary); - // tm.write(outfile); - // outfile.close(); - // } - - // TEST(TemporalMemoryTest, readTestFile) - // { - // TemporalMemory tm; - // const char* filename = "TemporalMemorySerializationWrite.tmp"; - // ifstream infile(filename, ios::binary); - // tm.read(infile); - // infile.close(); - - // serializationTestVerify(tm); - // } -} // end namespace nupic + /*seed*/ 42); + + serializationTestPrepare(tm1); + + // Write and read back the proto + stringstream ss; + tm1.write(ss); + + TemporalMemory tm2; + tm2.read(ss); + + ASSERT_TRUE(tm1 == tm2); + + serializationTestVerify(tm2); +} + +// Uncomment these tests individually to save/load from a file. +// This is useful for ad-hoc testing of backwards-compatibility. + +// TEST(TemporalMemoryTest, saveTestFile) +// { +// TemporalMemory tm( +// /*columnDimensions*/ {32}, +// /*cellsPerColumn*/ 4, +// /*activationThreshold*/ 3, +// /*initialPermanence*/ 0.21, +// /*connectedPermanence*/ 0.50, +// /*minThreshold*/ 2, +// /*maxNewSynapseCount*/ 3, +// /*permanenceIncrement*/ 0.10, +// /*permanenceDecrement*/ 0.10, +// /*predictedSegmentDecrement*/ 0.0, +// /*seed*/ 42 +// ); +// +// serializationTestPrepare(tm); +// +// const char* filename = "TemporalMemorySerializationSave.tmp"; +// ofstream outfile; +// outfile.open(filename, ios::binary); +// tm.save(outfile); +// outfile.close(); +// } + +// TEST(TemporalMemoryTest, loadTestFile) +// { +// TemporalMemory tm; +// const char* filename = "TemporalMemorySerializationSave.tmp"; +// ifstream infile(filename, ios::binary); +// tm.load(infile); +// infile.close(); +// +// serializationTestVerify(tm); +// } + +// TEST(TemporalMemoryTest, writeTestFile) +// { +// TemporalMemory tm( +// /*columnDimensions*/ {32}, +// /*cellsPerColumn*/ 4, +// /*activationThreshold*/ 3, +// /*initialPermanence*/ 0.21, +// /*connectedPermanence*/ 0.50, +// /*minThreshold*/ 2, +// /*maxNewSynapseCount*/ 3, +// /*permanenceIncrement*/ 0.10, +// /*permanenceDecrement*/ 0.10, +// /*predictedSegmentDecrement*/ 0.0, +// /*seed*/ 42 +// ); + +// serializationTestPrepare(tm); + +// const char* filename = "TemporalMemorySerializationWrite.tmp"; +// ofstream outfile; +// outfile.open(filename, ios::binary); +// tm.write(outfile); +// outfile.close(); +// } + +// TEST(TemporalMemoryTest, readTestFile) +// { +// TemporalMemory tm; +// const char* filename = "TemporalMemorySerializationWrite.tmp"; +// ifstream infile(filename, ios::binary); +// tm.read(infile); +// infile.close(); + +// serializationTestVerify(tm); +// } +} // namespace diff --git a/src/test/unit/encoders/ScalarEncoderTest.cpp b/src/test/unit/encoders/ScalarEncoderTest.cpp index bcc8142193..bb4289a095 100644 --- a/src/test/unit/encoders/ScalarEncoderTest.cpp +++ b/src/test/unit/encoders/ScalarEncoderTest.cpp @@ -24,64 +24,55 @@ * Unit tests for the ScalarEncoder and PeriodicScalarEncoder */ +#include "gtest/gtest.h" +#include #include #include -#include -#include "gtest/gtest.h" using namespace nupic; -template -std::string vec2str(std::vector vec) -{ +template std::string vec2str(std::vector vec) { std::ostringstream oss(""); for (size_t i = 0; i < vec.size(); i++) oss << vec[i]; return oss.str(); } -std::vector getEncoding(ScalarEncoderBase& e, Real64 input) -{ +std::vector getEncoding(ScalarEncoderBase &e, Real64 input) { auto actualOutput = std::vector(e.getOutputWidth()); e.encodeIntoArray(input, &actualOutput[0]); return actualOutput; } -struct ScalarValueCase -{ +struct ScalarValueCase { Real64 input; std::vector expectedOutput; }; -std::vector patternFromNZ(int n, std::vector patternNZ) -{ +std::vector patternFromNZ(int n, std::vector patternNZ) { auto v = std::vector(n, 0); - for (auto it = patternNZ.begin(); it != patternNZ.end(); it++) - { - v[*it] = 1; - } + for (auto it = patternNZ.begin(); it != patternNZ.end(); it++) { + v[*it] = 1; + } return v; } -void doScalarValueCases(ScalarEncoderBase& e, std::vector cases) -{ - for (auto c = cases.begin(); c != cases.end(); c++) - { - auto actualOutput = getEncoding(e, c->input); - for (int i = 0; i < e.getOutputWidth(); i++) - { - EXPECT_EQ(c->expectedOutput[i], actualOutput[i]) - << "For input " << c->input << " and index " << i << std::endl - << "EXPECTED:" << std::endl - << vec2str(c->expectedOutput) << std::endl - << "ACTUAL:" << std::endl - << vec2str(actualOutput); - } +void doScalarValueCases(ScalarEncoderBase &e, + std::vector cases) { + for (auto c = cases.begin(); c != cases.end(); c++) { + auto actualOutput = getEncoding(e, c->input); + for (int i = 0; i < e.getOutputWidth(); i++) { + EXPECT_EQ(c->expectedOutput[i], actualOutput[i]) + << "For input " << c->input << " and index " << i << std::endl + << "EXPECTED:" << std::endl + << vec2str(c->expectedOutput) << std::endl + << "ACTUAL:" << std::endl + << vec2str(actualOutput); } + } } -TEST(ScalarEncoder, ValidScalarInputs) -{ +TEST(ScalarEncoder, ValidScalarInputs) { const int n = 10; const int w = 2; const double minValue = 10; @@ -91,7 +82,8 @@ TEST(ScalarEncoder, ValidScalarInputs) { const bool clipInput = false; - ScalarEncoder encoder(w, minValue, maxValue, n, radius, resolution, clipInput); + ScalarEncoder encoder(w, minValue, maxValue, n, radius, resolution, + clipInput); EXPECT_THROW(getEncoding(encoder, 9.9), std::exception); EXPECT_NO_THROW(getEncoding(encoder, 10.0)); @@ -101,15 +93,15 @@ TEST(ScalarEncoder, ValidScalarInputs) { const bool clipInput = true; - ScalarEncoder encoder(w, minValue, maxValue, n, radius, resolution, clipInput); + ScalarEncoder encoder(w, minValue, maxValue, n, radius, resolution, + clipInput); EXPECT_NO_THROW(getEncoding(encoder, 9.9)); EXPECT_NO_THROW(getEncoding(encoder, 20.1)); } } -TEST(PeriodicScalarEncoder, ValidScalarInputs) -{ +TEST(PeriodicScalarEncoder, ValidScalarInputs) { const int n = 10; const int w = 2; const double minValue = 10; @@ -124,8 +116,7 @@ TEST(PeriodicScalarEncoder, ValidScalarInputs) EXPECT_THROW(getEncoding(encoder, 20.0), std::exception); } -TEST(ScalarEncoder, NonIntegerBucketWidth) -{ +TEST(ScalarEncoder, NonIntegerBucketWidth) { const int n = 7; const int w = 3; const double minValue = 10; @@ -133,17 +124,16 @@ TEST(ScalarEncoder, NonIntegerBucketWidth) const double radius = 0; const double resolution = 0; const bool clipInput = false; - ScalarEncoder encoder(w, minValue, maxValue, n, radius, resolution, clipInput); + ScalarEncoder encoder(w, minValue, maxValue, n, radius, resolution, + clipInput); - std::vector cases = - {{10.0, patternFromNZ(n, {0, 1, 2})}, - {20.0, patternFromNZ(n, {4, 5, 6})}}; + std::vector cases = {{10.0, patternFromNZ(n, {0, 1, 2})}, + {20.0, patternFromNZ(n, {4, 5, 6})}}; doScalarValueCases(encoder, cases); } -TEST(PeriodicScalarEncoder, NonIntegerBucketWidth) -{ +TEST(PeriodicScalarEncoder, NonIntegerBucketWidth) { const int n = 7; const int w = 3; const double minValue = 10; @@ -152,15 +142,13 @@ TEST(PeriodicScalarEncoder, NonIntegerBucketWidth) const double resolution = 0; PeriodicScalarEncoder encoder(w, minValue, maxValue, n, radius, resolution); - std::vector cases = - {{10.0, patternFromNZ(n, {6, 0, 1})}, - {19.9, patternFromNZ(n, {5, 6, 0})}}; + std::vector cases = {{10.0, patternFromNZ(n, {6, 0, 1})}, + {19.9, patternFromNZ(n, {5, 6, 0})}}; doScalarValueCases(encoder, cases); } -TEST(ScalarEncoder, RoundToNearestMultipleOfResolution) -{ +TEST(ScalarEncoder, RoundToNearestMultipleOfResolution) { const int n_in = 0; const int w = 3; const double minValue = 10; @@ -168,53 +156,53 @@ TEST(ScalarEncoder, RoundToNearestMultipleOfResolution) const double radius = 0; const double resolution = 1; const bool clipInput = false; - ScalarEncoder encoder(w, minValue, maxValue, n_in, radius, resolution, clipInput); + ScalarEncoder encoder(w, minValue, maxValue, n_in, radius, resolution, + clipInput); const int n = 13; ASSERT_EQ(n, encoder.getOutputWidth()); - std::vector cases = - {{10.00, patternFromNZ(n, {0, 1, 2})}, - {10.49, patternFromNZ(n, {0, 1, 2})}, - {10.50, patternFromNZ(n, {1, 2, 3})}, - {11.49, patternFromNZ(n, {1, 2, 3})}, - {11.50, patternFromNZ(n, {2, 3, 4})}, - {14.49, patternFromNZ(n, {4, 5, 6})}, - {14.50, patternFromNZ(n, {5, 6, 7})}, - {15.49, patternFromNZ(n, {5, 6, 7})}, - {15.50, patternFromNZ(n, {6, 7, 8})}, - {19.49, patternFromNZ(n, {9, 10, 11})}, - {19.50, patternFromNZ(n, {10, 11, 12})}, - {20.00, patternFromNZ(n, {10, 11, 12})}}; + std::vector cases = { + {10.00, patternFromNZ(n, {0, 1, 2})}, + {10.49, patternFromNZ(n, {0, 1, 2})}, + {10.50, patternFromNZ(n, {1, 2, 3})}, + {11.49, patternFromNZ(n, {1, 2, 3})}, + {11.50, patternFromNZ(n, {2, 3, 4})}, + {14.49, patternFromNZ(n, {4, 5, 6})}, + {14.50, patternFromNZ(n, {5, 6, 7})}, + {15.49, patternFromNZ(n, {5, 6, 7})}, + {15.50, patternFromNZ(n, {6, 7, 8})}, + {19.49, patternFromNZ(n, {9, 10, 11})}, + {19.50, patternFromNZ(n, {10, 11, 12})}, + {20.00, patternFromNZ(n, {10, 11, 12})}}; doScalarValueCases(encoder, cases); } -TEST(PeriodicScalarEncoder, FloorToNearestMultipleOfResolution) -{ +TEST(PeriodicScalarEncoder, FloorToNearestMultipleOfResolution) { const int n_in = 0; const int w = 3; const double minValue = 10; const double maxValue = 20; const double radius = 0; const double resolution = 1; - PeriodicScalarEncoder encoder(w, minValue, maxValue, n_in, radius, resolution); + PeriodicScalarEncoder encoder(w, minValue, maxValue, n_in, radius, + resolution); const int n = 10; ASSERT_EQ(n, encoder.getOutputWidth()); - std::vector cases = - {{10.00, patternFromNZ(n, {9, 0, 1})}, - {10.99, patternFromNZ(n, {9, 0, 1})}, - {11.00, patternFromNZ(n, {0, 1, 2})}, - {11.99, patternFromNZ(n, {0, 1, 2})}, - {12.00, patternFromNZ(n, {1, 2, 3})}, - {14.00, patternFromNZ(n, {3, 4, 5})}, - {14.99, patternFromNZ(n, {3, 4, 5})}, - {15.00, patternFromNZ(n, {4, 5, 6})}, - {15.99, patternFromNZ(n, {4, 5, 6})}, - {19.00, patternFromNZ(n, {8, 9, 0})}, - {19.99, patternFromNZ(n, {8, 9, 0})}}; + std::vector cases = {{10.00, patternFromNZ(n, {9, 0, 1})}, + {10.99, patternFromNZ(n, {9, 0, 1})}, + {11.00, patternFromNZ(n, {0, 1, 2})}, + {11.99, patternFromNZ(n, {0, 1, 2})}, + {12.00, patternFromNZ(n, {1, 2, 3})}, + {14.00, patternFromNZ(n, {3, 4, 5})}, + {14.99, patternFromNZ(n, {3, 4, 5})}, + {15.00, patternFromNZ(n, {4, 5, 6})}, + {15.99, patternFromNZ(n, {4, 5, 6})}, + {19.00, patternFromNZ(n, {8, 9, 0})}, + {19.99, patternFromNZ(n, {8, 9, 0})}}; doScalarValueCases(encoder, cases); } diff --git a/src/test/unit/engine/InputTest.cpp b/src/test/unit/engine/InputTest.cpp index ba1c851a10..e171164299 100644 --- a/src/test/unit/engine/InputTest.cpp +++ b/src/test/unit/engine/InputTest.cpp @@ -24,41 +24,40 @@ * Implementation of Input test */ +#include "gtest/gtest.h" #include #include #include #include #include #include -#include "gtest/gtest.h" using namespace nupic; -TEST(InputTest, BasicNetworkConstruction) -{ +TEST(InputTest, BasicNetworkConstruction) { Network net; - Region * r1 = net.addRegion("r1", "TestNode", ""); - Region * r2 = net.addRegion("r2", "TestNode", ""); + Region *r1 = net.addRegion("r1", "TestNode", ""); + Region *r2 = net.addRegion("r2", "TestNode", ""); - //Test constructor + // Test constructor Input x(*r1, NTA_BasicType_Int32, true); Input y(*r2, NTA_BasicType_Byte, false); EXPECT_THROW(Input i(*r1, (NTA_BasicType)(NTA_BasicType_Last + 1), true), std::exception); - //test getRegion() + // test getRegion() ASSERT_EQ(r1, &(x.getRegion())); ASSERT_EQ(r2, &(y.getRegion())); - //test isRegionLevel() + // test isRegionLevel() ASSERT_TRUE(x.isRegionLevel()); - ASSERT_TRUE(! y.isRegionLevel()); + ASSERT_TRUE(!y.isRegionLevel()); - //test isInitialized() - ASSERT_TRUE(! x.isInitialized()); - ASSERT_TRUE(! y.isInitialized()); + // test isInitialized() + ASSERT_TRUE(!x.isInitialized()); + ASSERT_TRUE(!y.isInitialized()); - //test one case of initialize() + // test one case of initialize() EXPECT_THROW(x.initialize(), std::exception); EXPECT_THROW(y.initialize(), std::exception); @@ -75,85 +74,83 @@ TEST(InputTest, BasicNetworkConstruction) x.initialize(); y.initialize(); - //test evaluateLinks() - //should return 0 because x is initialized + // test evaluateLinks() + // should return 0 because x is initialized ASSERT_EQ(0u, x.evaluateLinks()); - //should return 0 because there are no links + // should return 0 because there are no links ASSERT_EQ(0u, y.evaluateLinks()); - //test getData() - const ArrayBase * pa = &(y.getData()); + // test getData() + const ArrayBase *pa = &(y.getData()); ASSERT_EQ(0u, pa->getCount()); - Real64* buf = (Real64*)(pa->getBuffer()); + Real64 *buf = (Real64 *)(pa->getBuffer()); ASSERT_TRUE(buf != nullptr); } - -TEST(InputTest, SplitterMap) -{ +TEST(InputTest, SplitterMap) { Network net; - Region * region1 = net.addRegion("region1", "TestNode", ""); - Region * region2 = net.addRegion("region2", "TestNode", ""); + Region *region1 = net.addRegion("region1", "TestNode", ""); + Region *region2 = net.addRegion("region2", "TestNode", ""); Dimensions d1; d1.push_back(8); d1.push_back(4); region1->setDimensions(d1); - //test addLink() indirectly - it is called by Network::link() + // test addLink() indirectly - it is called by Network::link() net.link("region1", "region2", "TestFanIn2", ""); - //test initialize(), which is called by net.initialize() + // test initialize(), which is called by net.initialize() net.initialize(); Dimensions d2 = region2->getDimensions(); - Input * in1 = region1->getInput("bottomUpIn"); - Input * in2 = region2->getInput("bottomUpIn"); - Output * out1 = region1->getOutput("bottomUpOut"); + Input *in1 = region1->getInput("bottomUpIn"); + Input *in2 = region2->getInput("bottomUpIn"); + Output *out1 = region1->getOutput("bottomUpOut"); - //test isInitialized() + // test isInitialized() ASSERT_TRUE(in1->isInitialized()); ASSERT_TRUE(in2->isInitialized()); - //test evaluateLinks(), in1 already initialized + // test evaluateLinks(), in1 already initialized ASSERT_EQ(0u, in1->evaluateLinks()); ASSERT_EQ(0u, in2->evaluateLinks()); - //test prepare + // test prepare { - //set in2 to all zeroes - const ArrayBase * ai2 = &(in2->getData()); - Real64* idata = (Real64*)(ai2->getBuffer()); + // set in2 to all zeroes + const ArrayBase *ai2 = &(in2->getData()); + Real64 *idata = (Real64 *)(ai2->getBuffer()); for (UInt i = 0; i < 64; i++) idata[i] = 0; - //set out1 to all 10's - const ArrayBase * ao1 = &(out1->getData()); - idata = (Real64*)(ao1->getBuffer()); + // set out1 to all 10's + const ArrayBase *ao1 = &(out1->getData()); + idata = (Real64 *)(ao1->getBuffer()); for (UInt i = 0; i < 64; i++) idata[i] = 10; - //confirm that in2 is still all zeroes + // confirm that in2 is still all zeroes ai2 = &(in2->getData()); - idata = (Real64*)(ai2->getBuffer()); - //only test 4 instead of 64 to cut down on number of tests + idata = (Real64 *)(ai2->getBuffer()); + // only test 4 instead of 64 to cut down on number of tests for (UInt i = 0; i < 4; i++) ASSERT_EQ(0, idata[i]); in2->prepare(); - //confirm that in2 is now all 10's + // confirm that in2 is now all 10's ai2 = &(in2->getData()); - idata = (Real64*)(ai2->getBuffer()); - //only test 4 instead of 64 to cut down on number of tests + idata = (Real64 *)(ai2->getBuffer()); + // only test 4 instead of 64 to cut down on number of tests for (UInt i = 0; i < 4; i++) ASSERT_EQ(10, idata[i]); } net.run(2); - //test getSplitterMap() - std::vector< std::vector > sm; + // test getSplitterMap() + std::vector> sm; sm = in2->getSplitterMap(); ASSERT_EQ(8u, sm.size()); ASSERT_EQ(8u, sm[0].size()); @@ -161,7 +158,7 @@ TEST(InputTest, SplitterMap) ASSERT_EQ(12u, sm[3][0]); ASSERT_EQ(31u, sm[3][7]); - //test getInputForNode() + // test getInputForNode() std::vector input; in2->getInputForNode(0, input); ASSERT_EQ(1, input[0]); @@ -173,10 +170,10 @@ TEST(InputTest, SplitterMap) ASSERT_EQ(6, input[1]); ASSERT_EQ(15, input[7]); - //test getData() - const ArrayBase * pa = &(in2->getData()); + // test getData() + const ArrayBase *pa = &(in2->getData()); ASSERT_EQ(64u, pa->getCount()); - Real64* data = (Real64*)(pa->getBuffer()); + Real64 *data = (Real64 *)(pa->getBuffer()); ASSERT_EQ(1, data[0]); ASSERT_EQ(0, data[1]); ASSERT_EQ(1, data[30]); @@ -184,12 +181,11 @@ TEST(InputTest, SplitterMap) ASSERT_EQ(31, data[63]); } -TEST(InputTest, LinkTwoRegionsOneInput) -{ +TEST(InputTest, LinkTwoRegionsOneInput) { Network net; - Region * region1 = net.addRegion("region1", "TestNode", ""); - Region * region2 = net.addRegion("region2", "TestNode", ""); - Region * region3 = net.addRegion("region3", "TestNode", ""); + Region *region1 = net.addRegion("region1", "TestNode", ""); + Region *region2 = net.addRegion("region2", "TestNode", ""); + Region *region3 = net.addRegion("region3", "TestNode", ""); Dimensions d1; d1.push_back(8); @@ -203,7 +199,7 @@ TEST(InputTest, LinkTwoRegionsOneInput) net.initialize(); Dimensions d3 = region3->getDimensions(); - Input * in3 = region3->getInput("bottomUpIn"); + Input *in3 = region3->getInput("bottomUpIn"); ASSERT_EQ(2u, d3.size()); ASSERT_EQ(4u, d3[0]); @@ -211,8 +207,8 @@ TEST(InputTest, LinkTwoRegionsOneInput) net.run(2); - //test getSplitterMap() - std::vector< std::vector > sm; + // test getSplitterMap() + std::vector> sm; sm = in3->getSplitterMap(); ASSERT_EQ(8u, sm.size()); ASSERT_EQ(16u, sm[0].size()); @@ -220,7 +216,7 @@ TEST(InputTest, LinkTwoRegionsOneInput) ASSERT_EQ(12u, sm[3][0]); ASSERT_EQ(31u, sm[3][7]); - //test getInputForNode() + // test getInputForNode() std::vector input; in3->getInputForNode(0, input); ASSERT_EQ(1, input[0]); @@ -232,10 +228,10 @@ TEST(InputTest, LinkTwoRegionsOneInput) ASSERT_EQ(6, input[1]); ASSERT_EQ(15, input[7]); - //test getData() - const ArrayBase * pa = &(in3->getData()); + // test getData() + const ArrayBase *pa = &(in3->getData()); ASSERT_EQ(128u, pa->getCount()); - Real64* data = (Real64*)(pa->getBuffer()); + Real64 *data = (Real64 *)(pa->getBuffer()); ASSERT_EQ(1, data[0]); ASSERT_EQ(0, data[1]); ASSERT_EQ(1, data[30]); @@ -246,5 +242,4 @@ TEST(InputTest, LinkTwoRegionsOneInput) ASSERT_EQ(1, data[94]); ASSERT_EQ(15, data[95]); ASSERT_EQ(31, data[127]); - } diff --git a/src/test/unit/engine/LinkTest.cpp b/src/test/unit/engine/LinkTest.cpp index 8c53b1f984..980de98da5 100644 --- a/src/test/unit/engine/LinkTest.cpp +++ b/src/test/unit/engine/LinkTest.cpp @@ -26,6 +26,7 @@ #include +#include "gtest/gtest.h" #include #include #include @@ -38,16 +39,13 @@ #include #include #include -#include "gtest/gtest.h" using namespace nupic; - -TEST(LinkTest, Links) -{ +TEST(LinkTest, Links) { Network net; - Region * region1 = net.addRegion("region1", "TestNode", ""); - Region * region2 = net.addRegion("region2", "TestNode", ""); + Region *region1 = net.addRegion("region1", "TestNode", ""); + Region *region2 = net.addRegion("region2", "TestNode", ""); Dimensions d1; d1.push_back(8); @@ -56,20 +54,20 @@ TEST(LinkTest, Links) net.link("region1", "region2", "TestFanIn2", ""); - //test initialize(), which is called by net.initialize() - //also test evaluateLinks() which is called here + // test initialize(), which is called by net.initialize() + // also test evaluateLinks() which is called here net.initialize(); net.run(1); - //test that region has correct induced dimensions + // test that region has correct induced dimensions Dimensions d2 = region2->getDimensions(); ASSERT_EQ(2u, d2.size()); ASSERT_EQ(4u, d2[0]); ASSERT_EQ(2u, d2[1]); - //test getName() and setName() - Input * in1 = region1->getInput("bottomUpIn"); - Input * in2 = region2->getInput("bottomUpIn"); + // test getName() and setName() + Input *in1 = region1->getInput("bottomUpIn"); + Input *in2 = region2->getInput("bottomUpIn"); EXPECT_STREQ("bottomUpIn", in1->getName().c_str()); EXPECT_STREQ("bottomUpIn", in2->getName().c_str()); @@ -77,69 +75,57 @@ TEST(LinkTest, Links) EXPECT_STREQ("uselessName", in1->getName().c_str()); in1->setName("bottomUpIn"); - //test isInitialized() + // test isInitialized() ASSERT_TRUE(in1->isInitialized()); ASSERT_TRUE(in2->isInitialized()); - //test getLinks() - std::vector links = in2->getLinks(); + // test getLinks() + std::vector links = in2->getLinks(); ASSERT_EQ(1u, links.size()); - for(auto & link : links) { - //do something to make sure l[i] is a valid Link* + for (auto &link : links) { + // do something to make sure l[i] is a valid Link* ASSERT_TRUE(link != nullptr); - //should fail because regions are initialized + // should fail because regions are initialized EXPECT_THROW(in2->removeLink(link), std::exception); } - //test findLink() - Link * l1 = in1->findLink("region1", "bottomUpOut"); + // test findLink() + Link *l1 = in1->findLink("region1", "bottomUpOut"); ASSERT_TRUE(l1 == nullptr); - Link * l2 = in2->findLink("region1", "bottomUpOut"); + Link *l2 = in2->findLink("region1", "bottomUpOut"); ASSERT_TRUE(l2 != nullptr); - - //test removeLink(), uninitialize() - //uninitialize() is called internally from removeLink() + // test removeLink(), uninitialize() + // uninitialize() is called internally from removeLink() { - //can't remove link b/c region1 initialized + // can't remove link b/c region1 initialized EXPECT_THROW(in2->removeLink(l2), std::exception); - //can't remove region b/c region1 has links + // can't remove region b/c region1 has links EXPECT_THROW(net.removeRegion("region1"), std::exception); region1->uninitialize(); region2->uninitialize(); EXPECT_THROW(in1->removeLink(l2), std::exception); in2->removeLink(l2); EXPECT_THROW(in2->removeLink(l2), std::exception); - //l1 == NULL + // l1 == NULL EXPECT_THROW(in1->removeLink(l1), std::exception); } } - -TEST(LinkTest, DelayedLink) -{ - class MyTestNode : public TestNode - { +TEST(LinkTest, DelayedLink) { + class MyTestNode : public TestNode { public: - MyTestNode(const ValueMap& params, Region *region) - : TestNode(params, region) - {} + MyTestNode(const ValueMap ¶ms, Region *region) + : TestNode(params, region) {} - MyTestNode(BundleIO& bundle, Region* region) - : TestNode(bundle, region) - {} + MyTestNode(BundleIO &bundle, Region *region) : TestNode(bundle, region) {} - MyTestNode(capnp::AnyPointer::Reader& proto, Region* region) - : TestNode(proto, region) - {} + MyTestNode(capnp::AnyPointer::Reader &proto, Region *region) + : TestNode(proto, region) {} - std::string getNodeType() - { - return "MyTestNode"; - } + std::string getNodeType() { return "MyTestNode"; } - void compute() override - { + void compute() override { // Replace with no-op to preserve output } }; @@ -148,8 +134,8 @@ TEST(LinkTest, DelayedLink) new RegisteredRegionImpl()); Network net; - Region * region1 = net.addRegion("region1", "MyTestNode", ""); - Region * region2 = net.addRegion("region2", "TestNode", ""); + Region *region1 = net.addRegion("region1", "MyTestNode", ""); + Region *region2 = net.addRegion("region2", "TestNode", ""); RegionImplFactory::unregisterCPPRegion("MyTestNode"); @@ -160,35 +146,35 @@ TEST(LinkTest, DelayedLink) // NOTE: initial delayed values are set to all 0's net.link("region1", "region2", "TestFanIn2", "", "", "", - 2/*propagationDelay*/); + 2 /*propagationDelay*/); - //test initialize(), which is called by net.initialize() + // test initialize(), which is called by net.initialize() net.initialize(); - Input * in1 = region1->getInput("bottomUpIn"); - Input * in2 = region2->getInput("bottomUpIn"); - Output * out1 = region1->getOutput("bottomUpOut"); + Input *in1 = region1->getInput("bottomUpIn"); + Input *in2 = region2->getInput("bottomUpIn"); + Output *out1 = region1->getOutput("bottomUpOut"); - //test isInitialized() + // test isInitialized() ASSERT_TRUE(in1->isInitialized()); ASSERT_TRUE(in2->isInitialized()); - //test evaluateLinks(), in1 already initialized + // test evaluateLinks(), in1 already initialized ASSERT_EQ(0u, in1->evaluateLinks()); ASSERT_EQ(0u, in2->evaluateLinks()); - //set in2 to all 1's, to detect if net.run fails to update the input. + // set in2 to all 1's, to detect if net.run fails to update the input. { - const ArrayBase * ai2 = &(in2->getData()); - Real64* idata = (Real64*)(ai2->getBuffer()); + const ArrayBase *ai2 = &(in2->getData()); + Real64 *idata = (Real64 *)(ai2->getBuffer()); for (UInt i = 0; i < 64; i++) idata[i] = 1; } - //set out1 to all 10's + // set out1 to all 10's { - const ArrayBase * ao1 = &(out1->getData()); - Real64* idata = (Real64*)(ao1->getBuffer()); + const ArrayBase *ao1 = &(out1->getData()); + Real64 *idata = (Real64 *)(ao1->getBuffer()); for (UInt i = 0; i < 64; i++) idata[i] = 10; } @@ -198,32 +184,30 @@ TEST(LinkTest, DelayedLink) // This run should also pick up the 10s net.run(1); - //confirm that in2 is all zeroes - const ArrayBase * ai2 = &(in2->getData()); - Real64* idata = (Real64*)(ai2->getBuffer()); - //only test 4 instead of 64 to cut down on number of tests + // confirm that in2 is all zeroes + const ArrayBase *ai2 = &(in2->getData()); + Real64 *idata = (Real64 *)(ai2->getBuffer()); + // only test 4 instead of 64 to cut down on number of tests for (UInt i = 0; i < 4; i++) ASSERT_EQ(0, idata[i]); } - - //set out1 to all 100's + // set out1 to all 100's { - const ArrayBase * ao1 = &(out1->getData()); - Real64* idata = (Real64*)(ao1->getBuffer()); + const ArrayBase *ao1 = &(out1->getData()); + Real64 *idata = (Real64 *)(ao1->getBuffer()); for (UInt i = 0; i < 64; i++) idata[i] = 100; } - // Check extraction of second delayed value { net.run(1); - //confirm that in2 is all zeroes - const ArrayBase * ai2 = &(in2->getData()); - Real64* idata = (Real64*)(ai2->getBuffer()); - //only test 4 instead of 64 to cut down on number of tests + // confirm that in2 is all zeroes + const ArrayBase *ai2 = &(in2->getData()); + Real64 *idata = (Real64 *)(ai2->getBuffer()); + // only test 4 instead of 64 to cut down on number of tests for (UInt i = 0; i < 4; i++) ASSERT_EQ(0, idata[i]); } @@ -232,10 +216,10 @@ TEST(LinkTest, DelayedLink) { net.run(1); - //confirm that in2 is now all 10's - const ArrayBase * ai2 = &(in2->getData()); - Real64* idata = (Real64*)(ai2->getBuffer()); - //only test 4 instead of 64 to cut down on number of tests + // confirm that in2 is now all 10's + const ArrayBase *ai2 = &(in2->getData()); + Real64 *idata = (Real64 *)(ai2->getBuffer()); + // only test 4 instead of 64 to cut down on number of tests for (UInt i = 0; i < 4; i++) ASSERT_EQ(10, idata[i]); } @@ -244,43 +228,31 @@ TEST(LinkTest, DelayedLink) { net.run(1); - //confirm that in2 is now all 100's - const ArrayBase * ai2 = &(in2->getData()); - Real64* idata = (Real64*)(ai2->getBuffer()); - //only test 4 instead of 64 to cut down on number of tests + // confirm that in2 is now all 100's + const ArrayBase *ai2 = &(in2->getData()); + Real64 *idata = (Real64 *)(ai2->getBuffer()); + // only test 4 instead of 64 to cut down on number of tests for (UInt i = 0; i < 4; i++) ASSERT_EQ(100, idata[i]); } - } - -TEST(LinkTest, DelayedLinkCapnpSerialization) -{ +TEST(LinkTest, DelayedLinkCapnpSerialization) { // Cap'n Proto serialization test of delayed link. - class MyTestNode : public TestNode - { + class MyTestNode : public TestNode { public: - MyTestNode(const ValueMap& params, Region *region) - : TestNode(params, region) - {} + MyTestNode(const ValueMap ¶ms, Region *region) + : TestNode(params, region) {} - MyTestNode(BundleIO& bundle, Region* region) - : TestNode(bundle, region) - {} + MyTestNode(BundleIO &bundle, Region *region) : TestNode(bundle, region) {} - MyTestNode(capnp::AnyPointer::Reader& proto, Region* region) - : TestNode(proto, region) - {} + MyTestNode(capnp::AnyPointer::Reader &proto, Region *region) + : TestNode(proto, region) {} - std::string getNodeType() - { - return "MyTestNode"; - }; + std::string getNodeType() { return "MyTestNode"; }; - void compute() override - { + void compute() override { // Replace with no-op to preserve output } }; @@ -289,8 +261,8 @@ TEST(LinkTest, DelayedLinkCapnpSerialization) new RegisteredRegionImpl()); Network net; - Region * region1 = net.addRegion("region1", "MyTestNode", ""); - Region * region2 = net.addRegion("region2", "TestNode", ""); + Region *region1 = net.addRegion("region1", "MyTestNode", ""); + Region *region2 = net.addRegion("region2", "TestNode", ""); Dimensions d1; d1.push_back(8); @@ -299,34 +271,34 @@ TEST(LinkTest, DelayedLinkCapnpSerialization) // NOTE: initial delayed values are set to all 0's net.link("region1", "region2", "TestFanIn2", "", "", "", - 2/*propagationDelay*/); + 2 /*propagationDelay*/); net.initialize(); - Input * in1 = region1->getInput("bottomUpIn"); - Input * in2 = region2->getInput("bottomUpIn"); - Output * out1 = region1->getOutput("bottomUpOut"); + Input *in1 = region1->getInput("bottomUpIn"); + Input *in2 = region2->getInput("bottomUpIn"); + Output *out1 = region1->getOutput("bottomUpOut"); - //test isInitialized() + // test isInitialized() ASSERT_TRUE(in1->isInitialized()); ASSERT_TRUE(in2->isInitialized()); - //test evaluateLinks(), in1 already initialized + // test evaluateLinks(), in1 already initialized ASSERT_EQ(0u, in1->evaluateLinks()); ASSERT_EQ(0u, in2->evaluateLinks()); - //set in2 to all 1's, to detect if net.run fails to update the input. + // set in2 to all 1's, to detect if net.run fails to update the input. { - const ArrayBase * ai2 = &(in2->getData()); - Real64* idata = (Real64*)(ai2->getBuffer()); + const ArrayBase *ai2 = &(in2->getData()); + Real64 *idata = (Real64 *)(ai2->getBuffer()); for (UInt i = 0; i < 64; i++) idata[i] = 1; } - //set out1 to all 10's + // set out1 to all 10's { - const ArrayBase * ao1 = &(out1->getData()); - Real64* idata = (Real64*)(ao1->getBuffer()); + const ArrayBase *ao1 = &(out1->getData()); + Real64 *idata = (Real64 *)(ao1->getBuffer()); for (UInt i = 0; i < 64; i++) idata[i] = 10; } @@ -336,37 +308,34 @@ TEST(LinkTest, DelayedLinkCapnpSerialization) // This run should also pick up the 10s net.run(1); - //confirm that in2 is all zeroes - const ArrayBase * ai2 = &(in2->getData()); - Real64* idata = (Real64*)(ai2->getBuffer()); - //only test 4 instead of 64 to cut down on number of tests + // confirm that in2 is all zeroes + const ArrayBase *ai2 = &(in2->getData()); + Real64 *idata = (Real64 *)(ai2->getBuffer()); + // only test 4 instead of 64 to cut down on number of tests for (UInt i = 0; i < 4; i++) ASSERT_EQ(0, idata[i]); } - - //set out1 to all 100's + // set out1 to all 100's { - const ArrayBase * ao1 = &(out1->getData()); - Real64* idata = (Real64*)(ao1->getBuffer()); + const ArrayBase *ao1 = &(out1->getData()); + Real64 *idata = (Real64 *)(ao1->getBuffer()); for (UInt i = 0; i < 64; i++) idata[i] = 100; } - // Check extraction of second delayed value { net.run(1); - //confirm that in2 is all zeroes - const ArrayBase * ai2 = &(in2->getData()); - Real64* idata = (Real64*)(ai2->getBuffer()); - //only test 4 instead of 64 to cut down on number of tests + // confirm that in2 is all zeroes + const ArrayBase *ai2 = &(in2->getData()); + Real64 *idata = (Real64 *)(ai2->getBuffer()); + // only test 4 instead of 64 to cut down on number of tests for (UInt i = 0; i < 4; i++) ASSERT_EQ(0, idata[i]); } - // We should now have two delayed array values: 10's and 100's // Serialize the current net @@ -384,15 +353,14 @@ TEST(LinkTest, DelayedLinkCapnpSerialization) in2 = region2->getInput("bottomUpIn"); - // Check extraction of first "generated" value { net2.run(1); - //confirm that in2 is now all 10's - const ArrayBase * ai2 = &(in2->getData()); - Real64* idata = (Real64*)(ai2->getBuffer()); - //only test 4 instead of 64 to cut down on number of tests + // confirm that in2 is now all 10's + const ArrayBase *ai2 = &(in2->getData()); + Real64 *idata = (Real64 *)(ai2->getBuffer()); + // only test 4 instead of 64 to cut down on number of tests for (UInt i = 0; i < 4; i++) ASSERT_EQ(10, idata[i]); } @@ -401,73 +369,53 @@ TEST(LinkTest, DelayedLinkCapnpSerialization) { net2.run(1); - //confirm that in2 is now all 100's - const ArrayBase * ai2 = &(in2->getData()); - Real64* idata = (Real64*)(ai2->getBuffer()); - //only test 4 instead of 64 to cut down on number of tests + // confirm that in2 is now all 100's + const ArrayBase *ai2 = &(in2->getData()); + Real64 *idata = (Real64 *)(ai2->getBuffer()); + // only test 4 instead of 64 to cut down on number of tests for (UInt i = 0; i < 4; i++) ASSERT_EQ(100, idata[i]); } - RegionImplFactory::unregisterCPPRegion("MyTestNode"); } - /** * Base class for region implementations in this test module. See also * L2TestRegion and L4TestRegion. */ -class TestRegionBase: public RegionImpl -{ +class TestRegionBase : public RegionImpl { public: - TestRegionBase(const ValueMap& params, Region *region) : - RegionImpl(region) - { + TestRegionBase(const ValueMap ¶ms, Region *region) : RegionImpl(region) { outputElementCount_ = 1; } - TestRegionBase(BundleIO& bundle, Region* region) : - RegionImpl(region) - { - } + TestRegionBase(BundleIO &bundle, Region *region) : RegionImpl(region) {} - TestRegionBase(capnp::AnyPointer::Reader& proto, Region* region) : - RegionImpl(region) - { - } + TestRegionBase(capnp::AnyPointer::Reader &proto, Region *region) + : RegionImpl(region) {} - virtual ~TestRegionBase() - { - } + virtual ~TestRegionBase() {} // Serialize state. - void serialize(BundleIO& bundle) override - { - } + void serialize(BundleIO &bundle) override {} // De-serialize state. Must be called from deserializing constructor - void deserialize(BundleIO& bundle) override - { - } + void deserialize(BundleIO &bundle) override {} // Serialize state with capnp using RegionImpl::write; - void write(capnp::AnyPointer::Builder& anyProto) const override - { - } + void write(capnp::AnyPointer::Builder &anyProto) const override {} // Deserialize state from capnp. Must be called from deserializing // constructor. using RegionImpl::read; - void read(capnp::AnyPointer::Reader& anyProto) override - { - } + void read(capnp::AnyPointer::Reader &anyProto) override {} // Execute a command - std::string executeCommand(const std::vector& args, Int64 index) override - { + std::string executeCommand(const std::vector &args, + Int64 index) override { return ""; } @@ -475,13 +423,12 @@ class TestRegionBase: public RegionImpl // For per-region outputs, it is the total element count. // This method is called only for outputs whose size is not // specified in the nodespec. - size_t getNodeOutputElementCount(const std::string& outputName) override - { - if (outputName == "out") - { + size_t getNodeOutputElementCount(const std::string &outputName) override { + if (outputName == "out") { return outputElementCount_; } - NTA_THROW << "TestRegionBase::getOutputSize -- unknown output " << outputName; + NTA_THROW << "TestRegionBase::getOutputSize -- unknown output " + << outputName; } /** @@ -494,11 +441,8 @@ class TestRegionBase: public RegionImpl * @param index A node index. (-1) indicates a region-level parameter * */ - void getParameterFromBuffer(const std::string& name, - Int64 index, - IWriteBuffer& value) override - { - } + void getParameterFromBuffer(const std::string &name, Int64 index, + IWriteBuffer &value) override {} /** * Set a parameter from a read buffer. @@ -508,11 +452,8 @@ class TestRegionBase: public RegionImpl * * @param index A node index. (-1) indicates a region-level parameter */ - void setParameterFromBuffer(const std::string& name, - Int64 index, - IReadBuffer& value) override - { - } + void setParameterFromBuffer(const std::string &name, Int64 index, + IReadBuffer &value) override {} private: TestRegionBase(); @@ -521,76 +462,54 @@ class TestRegionBase: public RegionImpl UInt32 outputElementCount_; }; - /* * This region's output is computed as: feedForwardIn + lateralIn */ -class L2TestRegion: public TestRegionBase -{ +class L2TestRegion : public TestRegionBase { public: - L2TestRegion(const ValueMap& params, Region *region) : - TestRegionBase(params, region) - { - } + L2TestRegion(const ValueMap ¶ms, Region *region) + : TestRegionBase(params, region) {} - L2TestRegion(BundleIO& bundle, Region* region) : - TestRegionBase(bundle, region) - { - } + L2TestRegion(BundleIO &bundle, Region *region) + : TestRegionBase(bundle, region) {} - L2TestRegion(capnp::AnyPointer::Reader& proto, Region* region) : - TestRegionBase(proto, region) - { - } + L2TestRegion(capnp::AnyPointer::Reader &proto, Region *region) + : TestRegionBase(proto, region) {} - virtual ~L2TestRegion() - { - } + virtual ~L2TestRegion() {} - std::string getNodeType() - { - return "L2TestRegion"; - } + std::string getNodeType() { return "L2TestRegion"; } // Used by RegionImplFactory to create and cache // a nodespec. Ownership is transferred to the caller. - static Spec* createSpec() - { + static Spec *createSpec() { auto ns = new Spec; /* ----- inputs ------- */ - ns->inputs.add( - "feedForwardIn", - InputSpec( - "Feed-forward input for the node", - NTA_BasicType_UInt64, - 0, // count. omit? - true, // required? - false, // isRegionLevel, - false // isDefaultInput - )); - - ns->inputs.add( - "lateralIn", - InputSpec( - "Lateral input for the node", - NTA_BasicType_UInt64, - 0, // count. omit? - true, // required? - false, // isRegionLevel, - false // isDefaultInput - )); + ns->inputs.add("feedForwardIn", InputSpec("Feed-forward input for the node", + NTA_BasicType_UInt64, + 0, // count. omit? + true, // required? + false, // isRegionLevel, + false // isDefaultInput + )); + + ns->inputs.add("lateralIn", + InputSpec("Lateral input for the node", NTA_BasicType_UInt64, + 0, // count. omit? + true, // required? + false, // isRegionLevel, + false // isDefaultInput + )); /* ----- outputs ------ */ - ns->outputs.add( - "out", - OutputSpec( - "Primary output for the node", - NTA_BasicType_UInt64, - 3, // 1st is output; 2nd is the given feedForwardIn; 3rd is lateralIn - false, // isRegionLevel - true // isDefaultOutput - )); + ns->outputs.add("out", OutputSpec("Primary output for the node", + NTA_BasicType_UInt64, + 3, // 1st is output; 2nd is the given + // feedForwardIn; 3rd is lateralIn + false, // isRegionLevel + true // isDefaultOutput + )); return ns; } @@ -599,8 +518,7 @@ class L2TestRegion: public TestRegionBase * Inputs/Outputs are made available in initialize() * It is always called after the constructor (or load from serialized state) */ - void initialize() override - { + void initialize() override { nodeCount_ = getDimensions().getCount(); out_ = getOutput("out"); feedForwardIn_ = getInput("feedForwardIn"); @@ -608,14 +526,13 @@ class L2TestRegion: public TestRegionBase } // Compute outputs from inputs and internal state - void compute() override - { + void compute() override { NTA_DEBUG << "> Computing: " << getName() << " <"; - const Array & outputArray = out_->getData(); + const Array &outputArray = out_->getData(); NTA_CHECK(outputArray.getCount() == 3); NTA_CHECK(outputArray.getType() == NTA_BasicType_UInt64); - UInt64 *baseOutputBuffer = (UInt64*)outputArray.getBuffer(); + UInt64 *baseOutputBuffer = (UInt64 *)outputArray.getBuffer(); std::vector ffInput; feedForwardIn_->getInputForNode(0, ffInput); @@ -651,81 +568,58 @@ class L2TestRegion: public TestRegionBase const Input *feedForwardIn_; const Input *lateralIn_; const Output *out_; - }; - -class L4TestRegion: public TestRegionBase -{ +class L4TestRegion : public TestRegionBase { public: /* * This region's output is computed as: k + feedbackIn */ - L4TestRegion(const ValueMap& params, Region *region) : - TestRegionBase(params, region), - k_(params.getScalarT("k")) - { - } + L4TestRegion(const ValueMap ¶ms, Region *region) + : TestRegionBase(params, region), k_(params.getScalarT("k")) {} - L4TestRegion(BundleIO& bundle, Region* region) : - TestRegionBase(bundle, region), - k_(0) - { - } + L4TestRegion(BundleIO &bundle, Region *region) + : TestRegionBase(bundle, region), k_(0) {} - L4TestRegion(capnp::AnyPointer::Reader& proto, Region* region) : - TestRegionBase(proto, region), - k_(0) - { - } + L4TestRegion(capnp::AnyPointer::Reader &proto, Region *region) + : TestRegionBase(proto, region), k_(0) {} - virtual ~L4TestRegion() - { - } + virtual ~L4TestRegion() {} - std::string getNodeType() - { - return "L4TestRegion"; - } + std::string getNodeType() { return "L4TestRegion"; } // Used by RegionImplFactory to create and cache // a nodespec. Ownership is transferred to the caller. - static Spec* createSpec() - { + static Spec *createSpec() { auto ns = new Spec; /* ---- parameters ------ */ ns->parameters.add( - "k", - ParameterSpec( - "Constant k value for output computation", // description - NTA_BasicType_UInt64, - 1, // elementCount - "", // constraints - "", // defaultValue - ParameterSpec::ReadWriteAccess)); + "k", + ParameterSpec("Constant k value for output computation", // description + NTA_BasicType_UInt64, + 1, // elementCount + "", // constraints + "", // defaultValue + ParameterSpec::ReadWriteAccess)); /* ----- inputs ------- */ - ns->inputs.add( - "feedbackIn", - InputSpec( - "Feedback input for the node", - NTA_BasicType_UInt64, - 0, // count. omit? - true, // required? - false, // isRegionLevel, - false // isDefaultInput - )); + ns->inputs.add("feedbackIn", InputSpec("Feedback input for the node", + NTA_BasicType_UInt64, + 0, // count. omit? + true, // required? + false, // isRegionLevel, + false // isDefaultInput + )); /* ----- outputs ------ */ ns->outputs.add( - "out", - OutputSpec( - "Primary output for the node", - NTA_BasicType_UInt64, - 2, // 2 elements: 1st is output; 2nd is the given feedbackIn value - false, // isRegionLevel - true // isDefaultOutput - )); + "out", + OutputSpec( + "Primary output for the node", NTA_BasicType_UInt64, + 2, // 2 elements: 1st is output; 2nd is the given feedbackIn value + false, // isRegionLevel + true // isDefaultOutput + )); return ns; } @@ -734,8 +628,7 @@ class L4TestRegion: public TestRegionBase * Inputs/Outputs are made available in initialize() * It is always called after the constructor (or load from serialized state) */ - void initialize() override - { + void initialize() override { nodeCount_ = getDimensions().getCount(); NTA_CHECK(nodeCount_ == 1); out_ = getOutput("out"); @@ -743,14 +636,13 @@ class L4TestRegion: public TestRegionBase } // Compute outputs from inputs and internal state - void compute() override - { + void compute() override { NTA_DEBUG << "> Computing: " << getName() << " <"; - const Array & outputArray = out_->getData(); + const Array &outputArray = out_->getData(); NTA_CHECK(outputArray.getCount() == 2); NTA_CHECK(outputArray.getType() == NTA_BasicType_UInt64); - UInt64 *baseOutputBuffer = (UInt64*)outputArray.getBuffer(); + UInt64 *baseOutputBuffer = (UInt64 *)outputArray.getBuffer(); std::vector nodeInput; feedbackIn_->getInputForNode(0, nodeInput); @@ -779,12 +671,9 @@ class L4TestRegion: public TestRegionBase // Input/output buffers for the whole region const Input *feedbackIn_; const Output *out_; - }; - -TEST(LinkTest, L2L4WithDelayedLinksAndPhases) -{ +TEST(LinkTest, L2L4WithDelayedLinksAndPhases) { // This test simulates a network with L2 and L4, structured as follows: // o R1/R2 ("L4") are in phase 1; R3/R4 ("L2") are in phase 2; // o feed-forward links with delay=0 from R1/R2 to R3/R4, respectively; @@ -793,16 +682,16 @@ TEST(LinkTest, L2L4WithDelayedLinksAndPhases) Network net; - RegionImplFactory::registerCPPRegion("L4TestRegion", - new RegisteredRegionImpl()); - Region * r1 = net.addRegion("R1", "L4TestRegion", "{\"k\": 1}"); - Region * r2 = net.addRegion("R2", "L4TestRegion", "{\"k\": 5}"); + RegionImplFactory::registerCPPRegion( + "L4TestRegion", new RegisteredRegionImpl()); + Region *r1 = net.addRegion("R1", "L4TestRegion", "{\"k\": 1}"); + Region *r2 = net.addRegion("R2", "L4TestRegion", "{\"k\": 5}"); RegionImplFactory::unregisterCPPRegion("L4TestRegion"); - RegionImplFactory::registerCPPRegion("L2TestRegion", - new RegisteredRegionImpl()); - Region * r3 = net.addRegion("R3", "L2TestRegion", ""); - Region * r4 = net.addRegion("R4", "L2TestRegion", ""); + RegionImplFactory::registerCPPRegion( + "L2TestRegion", new RegisteredRegionImpl()); + Region *r3 = net.addRegion("R3", "L2TestRegion", ""); + Region *r4 = net.addRegion("R4", "L2TestRegion", ""); RegionImplFactory::unregisterCPPRegion("L2TestRegion"); // NOTE Dimensions must be multiples of 2 @@ -828,76 +717,70 @@ TEST(LinkTest, L2L4WithDelayedLinksAndPhases) /* Link up the network */ // R1 output - net.link( - "R1", // srcName - "R3", // destName - "UniformLink", // linkType - "", // linkParams - "out", // srcOutput - "feedForwardIn", // destInput - 0 //propagationDelay + net.link("R1", // srcName + "R3", // destName + "UniformLink", // linkType + "", // linkParams + "out", // srcOutput + "feedForwardIn", // destInput + 0 // propagationDelay ); // R2 output - net.link( - "R2", // srcName - "R4", // destName - "UniformLink", // linkType - "", // linkParams - "out", // srcOutput - "feedForwardIn", // destInput - 0 //propagationDelay + net.link("R2", // srcName + "R4", // destName + "UniformLink", // linkType + "", // linkParams + "out", // srcOutput + "feedForwardIn", // destInput + 0 // propagationDelay ); // R3 outputs - net.link( - "R3", // srcName - "R1", // destName - "UniformLink", // linkType - "", // linkParams - "out", // srcOutput - "feedbackIn", // destInput - 1 //propagationDelay + net.link("R3", // srcName + "R1", // destName + "UniformLink", // linkType + "", // linkParams + "out", // srcOutput + "feedbackIn", // destInput + 1 // propagationDelay ); - net.link( - "R3", // srcName - "R4", // destName - "UniformLink", // linkType - "", // linkParams - "out", // srcOutput - "lateralIn", // destInput - 1 //propagationDelay + net.link("R3", // srcName + "R4", // destName + "UniformLink", // linkType + "", // linkParams + "out", // srcOutput + "lateralIn", // destInput + 1 // propagationDelay ); // R4 outputs - net.link( - "R4", // srcName - "R2", // destName - "UniformLink", // linkType - "", // linkParams - "out", // srcOutput - "feedbackIn", // destInput - 1 //propagationDelay + net.link("R4", // srcName + "R2", // destName + "UniformLink", // linkType + "", // linkParams + "out", // srcOutput + "feedbackIn", // destInput + 1 // propagationDelay ); - net.link( - "R4", // srcName - "R3", // destName - "UniformLink", // linkType - "", // linkParams - "out", // srcOutput - "lateralIn", // destInput - 1 //propagationDelay + net.link("R4", // srcName + "R3", // destName + "UniformLink", // linkType + "", // linkParams + "out", // srcOutput + "lateralIn", // destInput + 1 // propagationDelay ); // Initialize the network net.initialize(); - UInt64* r1OutBuf = (UInt64*)(r1->getOutput("out")->getData().getBuffer()); - UInt64* r2OutBuf = (UInt64*)(r2->getOutput("out")->getData().getBuffer()); - UInt64* r3OutBuf = (UInt64*)(r3->getOutput("out")->getData().getBuffer()); - UInt64* r4OutBuf = (UInt64*)(r4->getOutput("out")->getData().getBuffer()); + UInt64 *r1OutBuf = (UInt64 *)(r1->getOutput("out")->getData().getBuffer()); + UInt64 *r2OutBuf = (UInt64 *)(r2->getOutput("out")->getData().getBuffer()); + UInt64 *r3OutBuf = (UInt64 *)(r3->getOutput("out")->getData().getBuffer()); + UInt64 *r4OutBuf = (UInt64 *)(r4->getOutput("out")->getData().getBuffer()); /* ITERATION #1 */ net.run(1); @@ -920,7 +803,6 @@ TEST(LinkTest, L2L4WithDelayedLinksAndPhases) ASSERT_EQ(0u, r4OutBuf[2]); // lateralIn from R3; delay=1 ASSERT_EQ(5u, r4OutBuf[0]); // out (feedForwardIn + lateralIn) - /* ITERATION #2 */ net.run(1); @@ -942,7 +824,6 @@ TEST(LinkTest, L2L4WithDelayedLinksAndPhases) ASSERT_EQ(1u, r4OutBuf[2]); // lateralIn from R3; delay=1 ASSERT_EQ(11u, r4OutBuf[0]); // out (feedForwardIn + lateralIn) - /* ITERATION #3 */ net.run(1); @@ -965,9 +846,7 @@ TEST(LinkTest, L2L4WithDelayedLinksAndPhases) ASSERT_EQ(23u, r4OutBuf[0]); // out (feedForwardIn + lateralIn) } - -TEST(LinkTest, L2L4With1ColDelayedLinksAndPhase1OnOffOn) -{ +TEST(LinkTest, L2L4With1ColDelayedLinksAndPhase1OnOffOn) { // Validates processing of incoming delayed and outgoing non-delayed link in // the context of a region within a suppressed phase. // @@ -985,14 +864,14 @@ TEST(LinkTest, L2L4With1ColDelayedLinksAndPhase1OnOffOn) Network net; - RegionImplFactory::registerCPPRegion("L4TestRegion", - new RegisteredRegionImpl()); - Region * r1 = net.addRegion("R1", "L4TestRegion", "{\"k\": 1}"); + RegionImplFactory::registerCPPRegion( + "L4TestRegion", new RegisteredRegionImpl()); + Region *r1 = net.addRegion("R1", "L4TestRegion", "{\"k\": 1}"); RegionImplFactory::unregisterCPPRegion("L4TestRegion"); - RegionImplFactory::registerCPPRegion("L2TestRegion", - new RegisteredRegionImpl()); - Region * r3 = net.addRegion("R3", "L2TestRegion", ""); + RegionImplFactory::registerCPPRegion( + "L2TestRegion", new RegisteredRegionImpl()); + Region *r3 = net.addRegion("R3", "L2TestRegion", ""); RegionImplFactory::unregisterCPPRegion("L2TestRegion"); // NOTE Dimensions must be multiples of 2 @@ -1014,43 +893,39 @@ TEST(LinkTest, L2L4With1ColDelayedLinksAndPhase1OnOffOn) /* Link up the network */ // R1 output - net.link( - "R1", // srcName - "R3", // destName - "UniformLink", // linkType - "", // linkParams - "out", // srcOutput - "feedForwardIn", // destInput - 0 //propagationDelay + net.link("R1", // srcName + "R3", // destName + "UniformLink", // linkType + "", // linkParams + "out", // srcOutput + "feedForwardIn", // destInput + 0 // propagationDelay ); // R3 outputs - net.link( - "R3", // srcName - "R1", // destName - "UniformLink", // linkType - "", // linkParams - "out", // srcOutput - "feedbackIn", // destInput - 1 //propagationDelay + net.link("R3", // srcName + "R1", // destName + "UniformLink", // linkType + "", // linkParams + "out", // srcOutput + "feedbackIn", // destInput + 1 // propagationDelay ); - net.link( - "R3", // srcName - "R3", // destName - "UniformLink", // linkType - "", // linkParams - "out", // srcOutput - "lateralIn", // destInput - 1 //propagationDelay + net.link("R3", // srcName + "R3", // destName + "UniformLink", // linkType + "", // linkParams + "out", // srcOutput + "lateralIn", // destInput + 1 // propagationDelay ); - // Initialize the network net.initialize(); - UInt64* r1OutBuf = (UInt64*)(r1->getOutput("out")->getData().getBuffer()); - UInt64* r3OutBuf = (UInt64*)(r3->getOutput("out")->getData().getBuffer()); + UInt64 *r1OutBuf = (UInt64 *)(r1->getOutput("out")->getData().getBuffer()); + UInt64 *r3OutBuf = (UInt64 *)(r3->getOutput("out")->getData().getBuffer()); /* ITERATION #1 with all phases enabled */ net.run(1); @@ -1064,7 +939,6 @@ TEST(LinkTest, L2L4With1ColDelayedLinksAndPhase1OnOffOn) ASSERT_EQ(0u, r3OutBuf[2]); // lateralIn loopback from R3; delay=1 ASSERT_EQ(1u, r3OutBuf[0]); // out (feedForwardIn + lateralIn) - /* Disable Phase 1, containing R1 */ net.setMinEnabledPhase(2); @@ -1080,7 +954,6 @@ TEST(LinkTest, L2L4With1ColDelayedLinksAndPhase1OnOffOn) ASSERT_EQ(1u, r3OutBuf[2]); // lateralIn loopback from R3; delay=1 ASSERT_EQ(2u, r3OutBuf[0]); // out (feedForwardIn + lateralIn) - /* ITERATION #3 with Phase 1 disabled */ net.run(1); @@ -1093,7 +966,6 @@ TEST(LinkTest, L2L4With1ColDelayedLinksAndPhase1OnOffOn) ASSERT_EQ(2u, r3OutBuf[2]); // lateralIn loopback from R3; delay=1 ASSERT_EQ(3u, r3OutBuf[0]); // out (feedForwardIn + lateralIn) - /* Enable Phase 1, containing R1 */ net.setMinEnabledPhase(1); @@ -1117,15 +989,12 @@ TEST(LinkTest, L2L4With1ColDelayedLinksAndPhase1OnOffOn) ASSERT_EQ(8u, r1OutBuf[0]); // out (1 + feedbackIn) // Validate R3 - ASSERT_EQ(8u, r3OutBuf[1]); // feedForwardIn from R1; delay=0 - ASSERT_EQ(7u, r3OutBuf[2]); // lateralIn loopback from R3; delay=1 + ASSERT_EQ(8u, r3OutBuf[1]); // feedForwardIn from R1; delay=0 + ASSERT_EQ(7u, r3OutBuf[2]); // lateralIn loopback from R3; delay=1 ASSERT_EQ(15u, r3OutBuf[0]); // out (feedForwardIn + lateralIn) - } - -TEST(LinkTest, SingleL4RegionWithDelayedLoopbackInAndPhaseOnOffOn) -{ +TEST(LinkTest, SingleL4RegionWithDelayedLoopbackInAndPhaseOnOffOn) { // Validates processing of outgoing/incoming delayed link in the context of a // region within a disabled phase. // @@ -1141,9 +1010,9 @@ TEST(LinkTest, SingleL4RegionWithDelayedLoopbackInAndPhaseOnOffOn) Network net; - RegionImplFactory::registerCPPRegion("L4TestRegion", - new RegisteredRegionImpl()); - Region * r1 = net.addRegion("R1", "L4TestRegion", "{\"k\": 1}"); + RegionImplFactory::registerCPPRegion( + "L4TestRegion", new RegisteredRegionImpl()); + Region *r1 = net.addRegion("R1", "L4TestRegion", "{\"k\": 1}"); RegionImplFactory::unregisterCPPRegion("L4TestRegion"); // NOTE Dimensions must be multiples of 2 @@ -1160,21 +1029,19 @@ TEST(LinkTest, SingleL4RegionWithDelayedLoopbackInAndPhaseOnOffOn) /* Link up the network */ // R1 output (loopback) - net.link( - "R1", // srcName - "R1", // destName - "UniformLink", // linkType - "", // linkParams - "out", // srcOutput - "feedbackIn", // destInput - 1 //propagationDelay + net.link("R1", // srcName + "R1", // destName + "UniformLink", // linkType + "", // linkParams + "out", // srcOutput + "feedbackIn", // destInput + 1 // propagationDelay ); - // Initialize the network net.initialize(); - UInt64* r1OutBuf = (UInt64*)(r1->getOutput("out")->getData().getBuffer()); + UInt64 *r1OutBuf = (UInt64 *)(r1->getOutput("out")->getData().getBuffer()); /* ITERATION #1 with phase 1 enabled */ net.run(1); @@ -1183,7 +1050,6 @@ TEST(LinkTest, SingleL4RegionWithDelayedLoopbackInAndPhaseOnOffOn) ASSERT_EQ(0u, r1OutBuf[1]); // feedbackIn from R3; delay=1 ASSERT_EQ(1u, r1OutBuf[0]); // out (1 + feedbackIn) - /* Disable Phase 1, containing R1 */ net.setMaxEnabledPhase(0); @@ -1201,7 +1067,6 @@ TEST(LinkTest, SingleL4RegionWithDelayedLoopbackInAndPhaseOnOffOn) ASSERT_EQ(0u, r1OutBuf[1]); // feedbackIn ASSERT_EQ(1u, r1OutBuf[0]); // out - /* Enable Phase 1, containing R1 */ net.setMaxEnabledPhase(1); diff --git a/src/test/unit/engine/NetworkTest.cpp b/src/test/unit/engine/NetworkTest.cpp index 4bbc3491d8..e7156273ed 100644 --- a/src/test/unit/engine/NetworkTest.cpp +++ b/src/test/unit/engine/NetworkTest.cpp @@ -34,55 +34,57 @@ using namespace nupic; -#define SHOULDFAIL_WITH_MESSAGE(statement, message) \ - { \ - bool caughtException = false; \ - try { \ - statement; \ - } catch(nupic::LoggingException& e) { \ - caughtException = true; \ - EXPECT_STREQ(message, e.getMessage()) << "statement '" #statement "' should fail with message \"" \ - << message << "\", but failed with message \"" << e.getMessage() << "\""; \ - } catch(...) { \ - FAIL() << "statement '" #statement "' did not generate a logging exception"; \ - } \ - EXPECT_EQ(true, caughtException) << "statement '" #statement "' should fail"; \ +#define SHOULDFAIL_WITH_MESSAGE(statement, message) \ + { \ + bool caughtException = false; \ + try { \ + statement; \ + } catch (nupic::LoggingException & e) { \ + caughtException = true; \ + EXPECT_STREQ(message, e.getMessage()) \ + << "statement '" #statement "' should fail with message \"" \ + << message << "\", but failed with message \"" << e.getMessage() \ + << "\""; \ + } catch (...) { \ + FAIL() << "statement '" #statement \ + "' did not generate a logging exception"; \ + } \ + EXPECT_EQ(true, caughtException) \ + << "statement '" #statement "' should fail"; \ } - -TEST(NetworkTest, AutoInitialization) -{ +TEST(NetworkTest, AutoInitialization) { // Uninitialize NuPIC since this test checks auto-initialization - // If shutdown fails, there is probably a problem with another test which - // is not cleaning up its networks. + // If shutdown fails, there is probably a problem with another test which + // is not cleaning up its networks. if (NuPIC::isInitialized()) NuPIC::shutdown(); ASSERT_TRUE(!NuPIC::isInitialized()); - + // creating a network should auto-initialize NuPIC { Network net; ASSERT_TRUE(NuPIC::isInitialized()); Region *l1 = net.addRegion("level1", "TestNode", ""); - + // Use l1 to avoid a compiler warning EXPECT_STREQ("level1", l1->getName().c_str()); - - // Network still exists, so this should fail. + + // Network still exists, so this should fail. EXPECT_THROW(NuPIC::shutdown(), std::exception); } // net destructor has been called so we should be able to shut down NuPIC now NuPIC::shutdown(); } -TEST(NetworkTest, RegionAccess) -{ +TEST(NetworkTest, RegionAccess) { Network net; - EXPECT_THROW( net.addRegion("level1", "nonexistent_nodetype", ""), std::exception ); + EXPECT_THROW(net.addRegion("level1", "nonexistent_nodetype", ""), + std::exception); - // Should be able to add a region + // Should be able to add a region Region *l1 = net.addRegion("level1", "TestNode", ""); ASSERT_TRUE(l1->getNetwork() == &net); @@ -92,23 +94,19 @@ TEST(NetworkTest, RegionAccess) // Make sure partial matches don't work EXPECT_THROW(net.getRegions().getByName("level"), std::exception); - Region* l1a = net.getRegions().getByName("level1"); + Region *l1a = net.getRegions().getByName("level1"); ASSERT_TRUE(l1a == l1); // Should not be able to add a second region with the same name EXPECT_THROW(net.addRegion("level1", "TestNode", ""), std::exception); - } - -TEST(NetworkTest, InitializationBasic) -{ +TEST(NetworkTest, InitializationBasic) { Network net; net.initialize(); } -TEST(NetworkTest, InitializationNoRegions) -{ +TEST(NetworkTest, InitializationNoRegions) { Network net; Region *l1 = net.addRegion("level1", "TestNode", ""); @@ -132,11 +130,9 @@ TEST(NetworkTest, InitializationNoRegions) l2->setDimensions(d); net.run(1); - } -TEST(NetworkTest, Modification) -{ +TEST(NetworkTest, Modification) { NTA_DEBUG << "Running network modification tests"; Network net; @@ -159,17 +155,16 @@ TEST(NetworkTest, Modification) ASSERT_EQ((UInt32)1, phases.size()); ASSERT_TRUE(phases.find(1) != phases.end()); - net.link("level1", "level2", "TestFanIn2", ""); - const Collection& regions = net.getRegions(); + const Collection ®ions = net.getRegions(); ASSERT_EQ((UInt32)2, regions.getCount()); // Should succeed since dimensions are now set net.initialize(); net.run(1); - Region* l2 = regions.getByName("level2"); + Region *l2 = regions.getByName("level2"); Dimensions d2 = l2->getDimensions(); ASSERT_EQ((UInt32)2, d2.size()); ASSERT_EQ((UInt32)2, d2[0]); @@ -209,9 +204,9 @@ TEST(NetworkTest, Modification) ASSERT_EQ((UInt32)2, d2.size()); ASSERT_EQ((UInt32)2, d2[0]); ASSERT_EQ((UInt32)2, d2[1]); - + // add a third region - Region* l3 = net.addRegion("level3", "TestNode", ""); + Region *l3 = net.addRegion("level3", "TestNode", ""); // should have been added at phase 2 phases = net.getPhases("level3"); @@ -247,10 +242,10 @@ TEST(NetworkTest, Modification) net.removeRegion("level1"); ASSERT_EQ((UInt32)0, regions.getCount()); - // build up the network again -- slightly differently with + // build up the network again -- slightly differently with // l1->l2 and l1->l3 l1 = net.addRegion("level1", "TestNode", ""); - l1->setDimensions(d); + l1->setDimensions(d); net.addRegion("level2", "TestNode", ""); net.addRegion("level3", "TestNode", ""); net.link("level1", "level2", "TestFanIn2", ""); @@ -279,11 +274,9 @@ TEST(NetworkTest, Modification) ASSERT_EQ((UInt32)2, d2[1]); // now let the destructor remove everything - } -TEST(NetworkTest, Unlinking) -{ +TEST(NetworkTest, Unlinking) { NTA_DEBUG << "Running unlinking tests"; Network net; net.addRegion("level1", "TestNode", ""); @@ -294,27 +287,35 @@ TEST(NetworkTest, Unlinking) net.getRegions().getByName("level1")->setDimensions(d); net.link("level1", "level2", "TestFanIn2", ""); - ASSERT_TRUE(net.getRegions().getByName("level2")->getDimensions().isUnspecified()); - - EXPECT_THROW(net.removeLink("level1", "level2", "outputdoesnotexist", "bottomUpIn"), std::exception); - EXPECT_THROW(net.removeLink("level1", "level2", "bottomUpOut", "inputdoesnotexist"), std::exception); + ASSERT_TRUE( + net.getRegions().getByName("level2")->getDimensions().isUnspecified()); + + EXPECT_THROW( + net.removeLink("level1", "level2", "outputdoesnotexist", "bottomUpIn"), + std::exception); + EXPECT_THROW( + net.removeLink("level1", "level2", "bottomUpOut", "inputdoesnotexist"), + std::exception); EXPECT_THROW(net.removeLink("level1", "leveldoesnotexist"), std::exception); EXPECT_THROW(net.removeLink("leveldoesnotexist", "level2"), std::exception); // remove the link from the uninitialized network net.removeLink("level1", "level2"); - ASSERT_TRUE(net.getRegions().getByName("level2")->getDimensions().isUnspecified()); + ASSERT_TRUE( + net.getRegions().getByName("level2")->getDimensions().isUnspecified()); EXPECT_THROW(net.removeLink("level1", "level2"), std::exception); // remove, specifying output/input names net.link("level1", "level2", "TestFanIn2", ""); net.removeLink("level1", "level2", "bottomUpOut", "bottomUpIn"); - EXPECT_THROW(net.removeLink("level1", "level2", "bottomUpOut", "bottomUpIn"), std::exception); + EXPECT_THROW(net.removeLink("level1", "level2", "bottomUpOut", "bottomUpIn"), + std::exception); net.link("level1", "level2", "TestFanIn2", ""); net.removeLink("level1", "level2", "bottomUpOut"); - EXPECT_THROW(net.removeLink("level1", "level2", "bottomUpOut"), std::exception); + EXPECT_THROW(net.removeLink("level1", "level2", "bottomUpOut"), + std::exception); // add the link back and initialize (inducing dimensions) net.link("level1", "level2", "TestFanIn2", ""); @@ -325,37 +326,33 @@ TEST(NetworkTest, Unlinking) ASSERT_EQ((UInt32)2, d[0]); ASSERT_EQ((UInt32)1, d[1]); - // remove the link. This will fail because we can't + // remove the link. This will fail because we can't // remove a link to an initialized region - SHOULDFAIL_WITH_MESSAGE(net.removeLink("level1", "level2"), - "Cannot remove link [level1.bottomUpOut (region dims: [4 2]) to level2.bottomUpIn (region dims: [2 1]) type: TestFanIn2] because destination region level2 is initialized. Remove the region first."); - + SHOULDFAIL_WITH_MESSAGE( + net.removeLink("level1", "level2"), + "Cannot remove link [level1.bottomUpOut (region dims: [4 2]) to " + "level2.bottomUpIn (region dims: [2 1]) type: TestFanIn2] because " + "destination region level2 is initialized. Remove the region first."); } typedef std::vector callbackData; callbackData mydata; -void testCallback(Network* net, UInt64 iteration, void* data) -{ - callbackData& thedata = *(static_cast(data)); +void testCallback(Network *net, UInt64 iteration, void *data) { + callbackData &thedata = *(static_cast(data)); // push region names onto callback data - const nupic::Collection& regions = net->getRegions(); - for (size_t i = 0; i < regions.getCount(); i++) - { + const nupic::Collection ®ions = net->getRegions(); + for (size_t i = 0; i < regions.getCount(); i++) { thedata.push_back(regions.getByIndex(i).first); } } - std::vector computeHistory; -static void recordCompute(const std::string& name) -{ +static void recordCompute(const std::string &name) { computeHistory.push_back(name); } - -TEST(NetworkTest, Phases) -{ +TEST(NetworkTest, Phases) { Network net; // should auto-initialize with max phase @@ -367,7 +364,6 @@ TEST(NetworkTest, Phases) ASSERT_EQ((UInt32)1, phaseSet.size()); ASSERT_TRUE(phaseSet.find(0) != phaseSet.end()); - Region *l2 = net.addRegion("level2", "TestNode", ""); EXPECT_STREQ("level2", l2->getName().c_str()); phaseSet = net.getPhases("level2"); @@ -402,8 +398,7 @@ TEST(NetworkTest, Phases) net.setPhases("level1", phaseSet); net.run(2); ASSERT_EQ((UInt32)6, computeHistory.size()); - if (computeHistory.size() == 6) - { + if (computeHistory.size() == 6) { EXPECT_STREQ("level1", computeHistory.at(0).c_str()); EXPECT_STREQ("level2", computeHistory.at(1).c_str()); EXPECT_STREQ("level1", computeHistory.at(2).c_str()); @@ -414,8 +409,7 @@ TEST(NetworkTest, Phases) computeHistory.clear(); } -TEST(NetworkTest, MinMaxPhase) -{ +TEST(NetworkTest, MinMaxPhase) { Network n; UInt32 minPhase = n.getMinPhase(); UInt32 maxPhase = n.getMaxPhase(); @@ -456,7 +450,6 @@ TEST(NetworkTest, MinMaxPhase) EXPECT_STREQ("level2", computeHistory.at(4).c_str()); EXPECT_STREQ("level3", computeHistory.at(5).c_str()); - n.setMinEnabledPhase(0); n.setMaxEnabledPhase(1); computeHistory.clear(); @@ -481,8 +474,7 @@ TEST(NetworkTest, MinMaxPhase) computeHistory.clear(); n.run(2); ASSERT_EQ((UInt32)6, computeHistory.size()); - if (computeHistory.size() == 6) - { + if (computeHistory.size() == 6) { EXPECT_STREQ("level1", computeHistory.at(0).c_str()); EXPECT_STREQ("level2", computeHistory.at(1).c_str()); EXPECT_STREQ("level3", computeHistory.at(2).c_str()); @@ -523,11 +515,9 @@ TEST(NetworkTest, MinMaxPhase) EXPECT_STREQ("level3", computeHistory.at(3).c_str()); EXPECT_STREQ("level2", computeHistory.at(4).c_str()); EXPECT_STREQ("level2", computeHistory.at(5).c_str()); - } -TEST(NetworkTest, Callback) -{ +TEST(NetworkTest, Callback) { Network n; n.addRegion("level1", "TestNode", ""); n.addRegion("level2", "TestNode", ""); @@ -538,9 +528,8 @@ TEST(NetworkTest, Callback) n.getRegions().getByName("level2")->setDimensions(d); n.getRegions().getByName("level3")->setDimensions(d); - - Collection& callbacks = n.getCallbacks(); - Network::callbackItem callback(testCallback, (void*)(&mydata)); + Collection &callbacks = n.getCallbacks(); + Network::callbackItem callback(testCallback, (void *)(&mydata)); callbacks.add("Test Callback", callback); n.run(2); @@ -551,5 +540,4 @@ TEST(NetworkTest, Callback) EXPECT_STREQ("level1", mydata[3].c_str()); EXPECT_STREQ("level2", mydata[4].c_str()); EXPECT_STREQ("level3", mydata[5].c_str()); - } diff --git a/src/test/unit/engine/UniformLinkPolicyTest.cpp b/src/test/unit/engine/UniformLinkPolicyTest.cpp index 0c10f44c1d..fa58ecf310 100644 --- a/src/test/unit/engine/UniformLinkPolicyTest.cpp +++ b/src/test/unit/engine/UniformLinkPolicyTest.cpp @@ -26,33 +26,24 @@ #include "gtest/gtest.h" -#include #include #include +#include #include #include using namespace nupic; -enum LinkSide -{ - srcLinkSide, - destLinkSide -}; +enum LinkSide { srcLinkSide, destLinkSide }; -struct CoordBounds -{ +struct CoordBounds { Coordinate coord; size_t dimension; std::pair bounds; - CoordBounds(Coordinate c, size_t dim, std::pair b) : - coord(std::move(c)), - dimension(dim), - bounds(std::move(b)) - { - } + CoordBounds(Coordinate c, size_t dim, std::pair b) + : coord(std::move(c)), dimension(dim), bounds(std::move(b)) {} }; // --- @@ -61,238 +52,199 @@ struct CoordBounds // class in UniformLinkPolicy. // --- namespace nupic { - class UniformLinkPolicyInspector - { - public: - bool setAndCheckDimensions(LinkSide setLinkSide, - Dimensions setDimensions, - Dimensions checkDimensions, - std::string linkParams, - size_t elementCount = 1) - { - Link dummyLink("UnitTestLink", "", "", ""); - UniformLinkPolicy test(linkParams, &dummyLink); - - // --- - // Since we're a unit test working in isolation, the infrastructure won't - // invoke setNodeOutputElementCount() for us; consequently we'll do that - // directly here. - // --- - test.setNodeOutputElementCount(elementCount); - - setLinkSide == srcLinkSide ? test.setSrcDimensions(setDimensions) : - test.setDestDimensions(setDimensions); - - Dimensions destDims = test.getDestDimensions(); - Dimensions srcDims = test.getSrcDimensions(); - - bool wasExpectedDimensions; - - setLinkSide == srcLinkSide ? (wasExpectedDimensions = - (srcDims == setDimensions && - destDims == checkDimensions)) : - (wasExpectedDimensions = - (srcDims == checkDimensions && - destDims == setDimensions)); - - return(wasExpectedDimensions); - } +class UniformLinkPolicyInspector { +public: + bool setAndCheckDimensions(LinkSide setLinkSide, Dimensions setDimensions, + Dimensions checkDimensions, std::string linkParams, + size_t elementCount = 1) { + Link dummyLink("UnitTestLink", "", "", ""); + UniformLinkPolicy test(linkParams, &dummyLink); + + // --- + // Since we're a unit test working in isolation, the infrastructure won't + // invoke setNodeOutputElementCount() for us; consequently we'll do that + // directly here. + // --- + test.setNodeOutputElementCount(elementCount); + + setLinkSide == srcLinkSide ? test.setSrcDimensions(setDimensions) + : test.setDestDimensions(setDimensions); + + Dimensions destDims = test.getDestDimensions(); + Dimensions srcDims = test.getSrcDimensions(); + + bool wasExpectedDimensions; + + setLinkSide == srcLinkSide + ? (wasExpectedDimensions = + (srcDims == setDimensions && destDims == checkDimensions)) + : (wasExpectedDimensions = + (srcDims == checkDimensions && destDims == setDimensions)); + + return (wasExpectedDimensions); + } + + bool setDimensionsAndCheckBounds(LinkSide setLinkSide, + Dimensions setDimensions, + std::vector checkBoundsVec, + std::string linkParams, + size_t elementCount = 1) { + Link dummyLink("UnitTestLink", "", "", ""); + UniformLinkPolicy test(linkParams, &dummyLink); + + // --- + // Since we're a unit test working in isolation, the infrastructure won't + // invoke setNodeOutputElementCount() for us; consequently we'll do that + // directly here. + // --- + test.setNodeOutputElementCount(elementCount); + + setLinkSide == srcLinkSide ? test.setSrcDimensions(setDimensions) + : test.setDestDimensions(setDimensions); - bool setDimensionsAndCheckBounds( - LinkSide setLinkSide, - Dimensions setDimensions, - std::vector checkBoundsVec, - std::string linkParams, - size_t elementCount = 1) - { - Link dummyLink("UnitTestLink", "", "", ""); - UniformLinkPolicy test(linkParams, &dummyLink); - - // --- - // Since we're a unit test working in isolation, the infrastructure won't - // invoke setNodeOutputElementCount() for us; consequently we'll do that - // directly here. - // --- - test.setNodeOutputElementCount(elementCount); - - setLinkSide == srcLinkSide ? test.setSrcDimensions(setDimensions) : - test.setDestDimensions(setDimensions); - - // --- - // Since we're a unit test working in isolation, the infrastructure won't - // invoke initialize() for us; consequently we'll do that directly here. - // --- - test.initialize(); - - bool allBoundsEqual = true; - - for(auto & elem : checkBoundsVec) - { - std::pair testBounds; - - testBounds = test.getInputBoundsForNode(elem.coord, - elem.dimension); - - if(testBounds != elem.bounds) - { - allBoundsEqual = false; - } - } - - return(allBoundsEqual); + // --- + // Since we're a unit test working in isolation, the infrastructure won't + // invoke initialize() for us; consequently we'll do that directly here. + // --- + test.initialize(); + + bool allBoundsEqual = true; + + for (auto &elem : checkBoundsVec) { + std::pair testBounds; + + testBounds = test.getInputBoundsForNode(elem.coord, elem.dimension); + + if (testBounds != elem.bounds) { + allBoundsEqual = false; } - }; + } + + return (allBoundsEqual); + } +}; } // end namespace nupic UniformLinkPolicyInspector inspector; -Coordinate makeCoordinate(size_t x, size_t y) -{ +Coordinate makeCoordinate(size_t x, size_t y) { Coordinate coord; coord.push_back(x); coord.push_back(y); - return(coord); + return (coord); } -TEST(UniformLinkPolicyTest, StrictMappingOddSource) -{ +TEST(UniformLinkPolicyTest, StrictMappingOddSource) { // --- // Check that a strict mapping with an rfSize of 2 fails on odd source // dimensions // --- - EXPECT_THROW( - inspector.setAndCheckDimensions(srcLinkSide, - Dimensions(9,6), - Dimensions(0,0), - "{mapping: in, " - "rfSize: [2]}"), - std::exception); + EXPECT_THROW(inspector.setAndCheckDimensions(srcLinkSide, Dimensions(9, 6), + Dimensions(0, 0), + "{mapping: in, " + "rfSize: [2]}"), + std::exception); } -TEST(UniformLinkPolicyTest, StrictMappingDimensions) -{ +TEST(UniformLinkPolicyTest, StrictMappingDimensions) { // --- // Check that a strict mapping with an rfSize of 2 calculates proper // dimensions when setting the source // --- - EXPECT_TRUE( - inspector.setAndCheckDimensions(srcLinkSide, - Dimensions(8,6), - Dimensions(4,3), - "{mapping: in, " - "rfSize: [2]}")); + EXPECT_TRUE(inspector.setAndCheckDimensions(srcLinkSide, Dimensions(8, 6), + Dimensions(4, 3), + "{mapping: in, " + "rfSize: [2]}")); } -TEST(UniformLinkPolicyTest, SpanNoImpactSource) -{ +TEST(UniformLinkPolicyTest, SpanNoImpactSource) { // --- // Check that adding in a span with size equal to the source dimensions has // no impact on the calculated destination dimensions when setting the source // --- - EXPECT_TRUE( - inspector.setAndCheckDimensions(srcLinkSide, - Dimensions(8,6), - Dimensions(4,3), - "{mapping: in, " - "rfSize: [2], " - "span: [8,6]}")); + EXPECT_TRUE(inspector.setAndCheckDimensions(srcLinkSide, Dimensions(8, 6), + Dimensions(4, 3), + "{mapping: in, " + "rfSize: [2], " + "span: [8,6]}")); } -TEST(UniformLinkPolicyTest, StrictMappingDestination) -{ +TEST(UniformLinkPolicyTest, StrictMappingDestination) { // --- // Check that a strict mapping with an rfSize of 2 calculates proper // dimensions when setting the destination // --- - EXPECT_TRUE( - inspector.setAndCheckDimensions(destLinkSide, - Dimensions(4,3), - Dimensions(8,6), - "{mapping: in, " - "rfSize: [2]}")); + EXPECT_TRUE(inspector.setAndCheckDimensions(destLinkSide, Dimensions(4, 3), + Dimensions(8, 6), + "{mapping: in, " + "rfSize: [2]}")); } -TEST(UniformLinkPolicyTest, SpanNoImpactDestination) -{ +TEST(UniformLinkPolicyTest, SpanNoImpactDestination) { // --- // Check that adding in a span with size equal to the source dimensions has // no impact on the calculated destination dimensions when setting the // destination // --- - EXPECT_TRUE( - inspector.setAndCheckDimensions(destLinkSide, - Dimensions(4,3), - Dimensions(8,6), - "{mapping: in, " - "rfSize: [2], " - "span: [8,6]}")); + EXPECT_TRUE(inspector.setAndCheckDimensions(destLinkSide, Dimensions(4, 3), + Dimensions(8, 6), + "{mapping: in, " + "rfSize: [2], " + "span: [8,6]}")); } -TEST(UniformLinkPolicyTest, StrictMappingGranularityDestFails) -{ +TEST(UniformLinkPolicyTest, StrictMappingGranularityDestFails) { // --- // Check that using a fractional rfSize with a granularity of elements fails // when the number of elements is inconsistent with a strict mapping // --- - EXPECT_THROW( - inspector.setAndCheckDimensions(destLinkSide, - Dimensions(7), - Dimensions(10), - "{mapping: in, " - "rfSize: [1.42857], " - "rfGranularity: elements}", - 1), - std::exception); + EXPECT_THROW(inspector.setAndCheckDimensions(destLinkSide, Dimensions(7), + Dimensions(10), + "{mapping: in, " + "rfSize: [1.42857], " + "rfGranularity: elements}", + 1), + std::exception); } -TEST(UniformLinkPolicyTest, StrictMappingGranularityDestPasses) -{ +TEST(UniformLinkPolicyTest, StrictMappingGranularityDestPasses) { // --- // Check that when using a compatible number of elements, the above test // passes // --- - EXPECT_TRUE( - inspector.setAndCheckDimensions(destLinkSide, - Dimensions(7), - Dimensions(10), - "{mapping: in, " - "rfSize: [1.42857], " - "rfGranularity: elements}", - 7)); + EXPECT_TRUE(inspector.setAndCheckDimensions(destLinkSide, Dimensions(7), + Dimensions(10), + "{mapping: in, " + "rfSize: [1.42857], " + "rfGranularity: elements}", + 7)); } -TEST(UniformLinkPolicyTest, StrictMappingGranularitySourceFails) -{ +TEST(UniformLinkPolicyTest, StrictMappingGranularitySourceFails) { // --- // Repeat the above two tests setting the source instead of the destination // --- - EXPECT_THROW( - inspector.setAndCheckDimensions(srcLinkSide, - Dimensions(10), - Dimensions(7), - "{mapping: in, " - "rfSize: [1.42857], " - "rfGranularity: elements}", - 1), - std::exception); + EXPECT_THROW(inspector.setAndCheckDimensions(srcLinkSide, Dimensions(10), + Dimensions(7), + "{mapping: in, " + "rfSize: [1.42857], " + "rfGranularity: elements}", + 1), + std::exception); } -TEST(UniformLinkPolicyTest, StrictMappingGranularitySourcePasses) -{ - EXPECT_TRUE( - inspector.setAndCheckDimensions(srcLinkSide, - Dimensions(10), - Dimensions(7), - "{mapping: in, " - "rfSize: [1.42857], " - "rfGranularity: elements}", - 7)); +TEST(UniformLinkPolicyTest, StrictMappingGranularitySourcePasses) { + EXPECT_TRUE(inspector.setAndCheckDimensions(srcLinkSide, Dimensions(10), + Dimensions(7), + "{mapping: in, " + "rfSize: [1.42857], " + "rfGranularity: elements}", + 7)); } -TEST(UniformLinkPolicyTest, NonStrictMappingSourcePasses) -{ +TEST(UniformLinkPolicyTest, NonStrictMappingSourcePasses) { // --- // Check that a non-strict mapping with an rfSize of 2 succeeds on odd source // dimensions and returns the expected values. Specifically, when working in @@ -301,17 +253,14 @@ TEST(UniformLinkPolicyTest, NonStrictMappingSourcePasses) // for source dimensions of [9, 6] and a rfSize of [2] we would expect // dimensions of [4, 3] instead of [5, 3]. // --- - EXPECT_TRUE( - inspector.setAndCheckDimensions(srcLinkSide, - Dimensions(9,6), - Dimensions(4,3), - "{mapping: in, " - "rfSize: [2], " - "strict: false}")); + EXPECT_TRUE(inspector.setAndCheckDimensions(srcLinkSide, Dimensions(9, 6), + Dimensions(4, 3), + "{mapping: in, " + "rfSize: [2], " + "strict: false}")); } -TEST(UniformLinkPolicyTest, NonStrictMappingExpectedDimensions) -{ +TEST(UniformLinkPolicyTest, NonStrictMappingExpectedDimensions) { // --- // Check that a non-strict mapping with overlap and a span has the expected // dimensions. @@ -325,37 +274,31 @@ TEST(UniformLinkPolicyTest, NonStrictMappingExpectedDimensions) // nodes in a given destination node than fewer, be packed into one of the // two spans. Therefore we expect the first dimension to be of size 4. // --- - EXPECT_TRUE( - inspector.setAndCheckDimensions(srcLinkSide, - Dimensions(9,6), - Dimensions(4,2), - "{mapping: in, " - "rfSize: [3], " - "rfOverlap: [2, 0], " - "span: [4, 0], " - "strict: false}")); + EXPECT_TRUE(inspector.setAndCheckDimensions(srcLinkSide, Dimensions(9, 6), + Dimensions(4, 2), + "{mapping: in, " + "rfSize: [3], " + "rfOverlap: [2, 0], " + "span: [4, 0], " + "strict: false}")); } -TEST(UniformLinkPolicyTest, NonStrictMappingExpectedDimensions2) -{ +TEST(UniformLinkPolicyTest, NonStrictMappingExpectedDimensions2) { // --- // Repeat the above test using source dimensions of [10, 6]. In this case // The remaining 9th and 10th node should, be packed into one each of the // two spans. Therefore we expect the first dimension to be of size 4. // --- - EXPECT_TRUE( - inspector.setAndCheckDimensions(srcLinkSide, - Dimensions(10,6), - Dimensions(4,2), - "{mapping: in, " - "rfSize: [3], " - "rfOverlap: [2, 0], " - "span: [4, 0], " - "strict: false}")); + EXPECT_TRUE(inspector.setAndCheckDimensions(srcLinkSide, Dimensions(10, 6), + Dimensions(4, 2), + "{mapping: in, " + "rfSize: [3], " + "rfOverlap: [2, 0], " + "span: [4, 0], " + "strict: false}")); } -TEST(UniformLinkPolicyTest, NonStrictMappingExpectedDimensionsEdge) -{ +TEST(UniformLinkPolicyTest, NonStrictMappingExpectedDimensionsEdge) { // --- // Check the same condition as above, but setting the destination and // inducing the source dimensions. We will test using destination dimensions @@ -382,19 +325,16 @@ TEST(UniformLinkPolicyTest, NonStrictMappingExpectedDimensionsEdge) // being honored due to strict=false, will result in the 5th destination // node receiving no input. // --- - EXPECT_TRUE( - inspector.setAndCheckDimensions(destLinkSide, - Dimensions(5,2), - Dimensions(10,6), - "{mapping: in, " - "rfSize: [3], " - "rfOverlap: [2, 0], " - "span: [4, 0], " - "strict: false}")); + EXPECT_TRUE(inspector.setAndCheckDimensions(destLinkSide, Dimensions(5, 2), + Dimensions(10, 6), + "{mapping: in, " + "rfSize: [3], " + "rfOverlap: [2, 0], " + "span: [4, 0], " + "strict: false}")); } -TEST(UniformLinkPolicyTest, NonStrictMappingSourceDimensions) -{ +TEST(UniformLinkPolicyTest, NonStrictMappingSourceDimensions) { // --- // Test basic non-strict mapping when setting source dimensions. // @@ -403,46 +343,37 @@ TEST(UniformLinkPolicyTest, NonStrictMappingSourceDimensions) // fewer; consequently we expect dimensions of [4, 3] instead of [5, 4] for // the following settings. // --- - EXPECT_TRUE( - inspector.setAndCheckDimensions(srcLinkSide, - Dimensions(8,6), - Dimensions(4,3), - "{mapping: in, " - "rfSize: [1.7], " - "strict: false}")); + EXPECT_TRUE(inspector.setAndCheckDimensions(srcLinkSide, Dimensions(8, 6), + Dimensions(4, 3), + "{mapping: in, " + "rfSize: [1.7], " + "strict: false}")); } -TEST(UniformLinkPolicyTest, NonStrictMappingDestinationDimensions) -{ +TEST(UniformLinkPolicyTest, NonStrictMappingDestinationDimensions) { // --- // Test basic non-strict mapping when setting destination dimensions. // --- - EXPECT_TRUE( - inspector.setAndCheckDimensions(destLinkSide, - Dimensions(4,3), - Dimensions(7,6), - "{mapping: in, " - "rfSize: [1.7], " - "strict: false}")); + EXPECT_TRUE(inspector.setAndCheckDimensions(destLinkSide, Dimensions(4, 3), + Dimensions(7, 6), + "{mapping: in, " + "rfSize: [1.7], " + "strict: false}")); } -TEST(UniformLinkPolicyTest, OverlapOverhangRealisticDimensions) -{ +TEST(UniformLinkPolicyTest, OverlapOverhangRealisticDimensions) { // --- // Test overhang and overlap while using realistic image size dimensions. // --- - EXPECT_TRUE( - inspector.setAndCheckDimensions(srcLinkSide, - Dimensions(320,240), - Dimensions(41,31), - "{mapping: in, " - "rfSize: [16], " - "rfOverlap: [8], " - "overhang: [8]}")); + EXPECT_TRUE(inspector.setAndCheckDimensions(srcLinkSide, Dimensions(320, 240), + Dimensions(41, 31), + "{mapping: in, " + "rfSize: [16], " + "rfOverlap: [8], " + "overhang: [8]}")); } -TEST(UniformLinkPolicyTest, StrictMappingSplitOverReceptiveFields) -{ +TEST(UniformLinkPolicyTest, StrictMappingSplitOverReceptiveFields) { // --- // Test a strict mapping to make sure the elements are split across // receptive fields as expected @@ -450,36 +381,25 @@ TEST(UniformLinkPolicyTest, StrictMappingSplitOverReceptiveFields) std::vector expectedBoundVec; expectedBoundVec.push_back( - CoordBounds(makeCoordinate(0,0), - 0, - std::pair(0, 1))); + CoordBounds(makeCoordinate(0, 0), 0, std::pair(0, 1))); expectedBoundVec.push_back( - CoordBounds(makeCoordinate(1,0), - 0, - std::pair(2, 3))); + CoordBounds(makeCoordinate(1, 0), 0, std::pair(2, 3))); expectedBoundVec.push_back( - CoordBounds(makeCoordinate(2,0), - 0, - std::pair(4, 5))); + CoordBounds(makeCoordinate(2, 0), 0, std::pair(4, 5))); expectedBoundVec.push_back( - CoordBounds(makeCoordinate(3,0), - 0, - std::pair(6, 7))); - - EXPECT_TRUE( - inspector.setDimensionsAndCheckBounds(srcLinkSide, - Dimensions(8,6), - expectedBoundVec, - "{mapping: in, " - "rfSize: [2], " - "strict: false}")); + CoordBounds(makeCoordinate(3, 0), 0, std::pair(6, 7))); + + EXPECT_TRUE(inspector.setDimensionsAndCheckBounds( + srcLinkSide, Dimensions(8, 6), expectedBoundVec, + "{mapping: in, " + "rfSize: [2], " + "strict: false}")); } -TEST(UniformLinkPolicyTest, NonStrictMappingSplitOverReceptiveFields) -{ +TEST(UniformLinkPolicyTest, NonStrictMappingSplitOverReceptiveFields) { // --- // Test a non-strict mapping to make sure the elements are split across // receptive fields as expected @@ -487,30 +407,20 @@ TEST(UniformLinkPolicyTest, NonStrictMappingSplitOverReceptiveFields) std::vector expectedBoundVec; expectedBoundVec.push_back( - CoordBounds(makeCoordinate(0,0), - 0, - std::pair(0, 1))); + CoordBounds(makeCoordinate(0, 0), 0, std::pair(0, 1))); expectedBoundVec.push_back( - CoordBounds(makeCoordinate(1,0), - 0, - std::pair(2, 3))); + CoordBounds(makeCoordinate(1, 0), 0, std::pair(2, 3))); expectedBoundVec.push_back( - CoordBounds(makeCoordinate(2,0), - 0, - std::pair(4, 5))); + CoordBounds(makeCoordinate(2, 0), 0, std::pair(4, 5))); expectedBoundVec.push_back( - CoordBounds(makeCoordinate(3,0), - 0, - std::pair(6, 8))); - - EXPECT_TRUE( - inspector.setDimensionsAndCheckBounds(srcLinkSide, - Dimensions(9,6), - expectedBoundVec, - "{mapping: in, " - "rfSize: [2], " - "strict: false}")); + CoordBounds(makeCoordinate(3, 0), 0, std::pair(6, 8))); + + EXPECT_TRUE(inspector.setDimensionsAndCheckBounds( + srcLinkSide, Dimensions(9, 6), expectedBoundVec, + "{mapping: in, " + "rfSize: [2], " + "strict: false}")); } diff --git a/src/test/unit/engine/YAMLUtilsTest.cpp b/src/test/unit/engine/YAMLUtilsTest.cpp index 2e1caffe90..87420a8e22 100644 --- a/src/test/unit/engine/YAMLUtilsTest.cpp +++ b/src/test/unit/engine/YAMLUtilsTest.cpp @@ -26,17 +26,16 @@ #include "gtest/gtest.h" #include -#include #include +#include using namespace nupic; -TEST(YAMLUtilsTest, toValueTestInt) -{ - const char* s1 = "10"; +TEST(YAMLUtilsTest, toValueTestInt) { + const char *s1 = "10"; Value v = YAMLUtils::toValue(s1, NTA_BasicType_Int32); - EXPECT_TRUE(v.isScalar()) << "assertion v.isScalar() failed at " - << __FILE__ << ":" << __LINE__ ; + EXPECT_TRUE(v.isScalar()) + << "assertion v.isScalar() failed at " << __FILE__ << ":" << __LINE__; ASSERT_EQ(v.getType(), NTA_BasicType_Int32); Int32 i = v.getScalarT(); ASSERT_EQ(10, i); @@ -45,42 +44,39 @@ TEST(YAMLUtilsTest, toValueTestInt) ASSERT_EQ(10, i); } -TEST(YAMLUtilsTest, toValueTestReal32) -{ - const char* s1 = "10.1"; +TEST(YAMLUtilsTest, toValueTestReal32) { + const char *s1 = "10.1"; Value v = YAMLUtils::toValue(s1, NTA_BasicType_Real32); - EXPECT_TRUE(v.isScalar()) << "assertion v.isScalar() failed at " - << __FILE__ << ":" << __LINE__ ; + EXPECT_TRUE(v.isScalar()) + << "assertion v.isScalar() failed at " << __FILE__ << ":" << __LINE__; ASSERT_EQ(v.getType(), NTA_BasicType_Real32); Real32 x = v.getScalarT(); - EXPECT_NEAR(10.1, x, 0.000001) << "assertion 10.1 == " << x - << "\" failed at " << __FILE__ << ":" << __LINE__; + EXPECT_NEAR(10.1, x, 0.000001) << "assertion 10.1 == " << x << "\" failed at " + << __FILE__ << ":" << __LINE__; boost::shared_ptr s = v.getScalar(); x = s->value.real32; - EXPECT_NEAR(10.1, x, 0.000001) << "assertion 10.1 == " << x - << "\" failed at " << __FILE__ << ":" << __LINE__; + EXPECT_NEAR(10.1, x, 0.000001) << "assertion 10.1 == " << x << "\" failed at " + << __FILE__ << ":" << __LINE__; } -TEST(YAMLUtilsTest, toValueTestByte) -{ - const char* s1 = "this is a string"; +TEST(YAMLUtilsTest, toValueTestByte) { + const char *s1 = "this is a string"; Value v = YAMLUtils::toValue(s1, NTA_BasicType_Byte); - EXPECT_TRUE(!v.isScalar()) << "assertion !v.isScalar() failed at " - << __FILE__ << ":" << __LINE__ ; - EXPECT_TRUE(v.isString()) << "assertion v.isScalar() failed at " - << __FILE__ << ":" << __LINE__ ; + EXPECT_TRUE(!v.isScalar()) + << "assertion !v.isScalar() failed at " << __FILE__ << ":" << __LINE__; + EXPECT_TRUE(v.isString()) + << "assertion v.isScalar() failed at " << __FILE__ << ":" << __LINE__; ASSERT_EQ(v.getType(), NTA_BasicType_Byte); std::string s = *v.getString(); EXPECT_STREQ(s1, s.c_str()); } -TEST(YAMLUtilsTest, toValueTestBool) -{ - const char* s1 = "true"; +TEST(YAMLUtilsTest, toValueTestBool) { + const char *s1 = "true"; Value v = YAMLUtils::toValue(s1, NTA_BasicType_Bool); - EXPECT_TRUE(v.isScalar()) << "assertion v.isScalar() failed at " - << __FILE__ << ":" << __LINE__ ; + EXPECT_TRUE(v.isScalar()) + << "assertion v.isScalar() failed at " << __FILE__ << ":" << __LINE__; ASSERT_EQ(v.getType(), NTA_BasicType_Bool); bool b = v.getScalarT(); ASSERT_EQ(true, b); @@ -89,158 +85,115 @@ TEST(YAMLUtilsTest, toValueTestBool) ASSERT_EQ(true, b); } -TEST(YAMLUtilsTest, ParameterSpec) -{ +TEST(YAMLUtilsTest, ParameterSpec) { Collection ps; - ps.add( - "int32Param", - ParameterSpec( - "Int32 scalar parameter", // description - NTA_BasicType_Int32, - 1, // elementCount - "", // constraints - "32", // defaultValue - ParameterSpec::ReadWriteAccess)); + ps.add("int32Param", + ParameterSpec("Int32 scalar parameter", // description + NTA_BasicType_Int32, + 1, // elementCount + "", // constraints + "32", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ps.add("uint32Param", + ParameterSpec("UInt32 scalar parameter", // description + NTA_BasicType_UInt32, + 1, // elementCount + "", // constraints + "33", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ps.add("int64Param", + ParameterSpec("Int64 scalar parameter", // description + NTA_BasicType_Int64, + 1, // elementCount + "", // constraints + "64", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ps.add("uint64Param", + ParameterSpec("UInt64 scalar parameter", // description + NTA_BasicType_UInt64, + 1, // elementCount + "", // constraints + "65", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ps.add("real32Param", + ParameterSpec("Real32 scalar parameter", // description + NTA_BasicType_Real32, + 1, // elementCount + "", // constraints + "32.1", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ps.add("real64Param", + ParameterSpec("Real64 scalar parameter", // description + NTA_BasicType_Real64, + 1, // elementCount + "", // constraints + "64.1", // defaultValue + ParameterSpec::ReadWriteAccess)); + + ps.add("real32ArrayParam", + ParameterSpec("int32 array parameter", NTA_BasicType_Real32, + 0, // array + "", "", ParameterSpec::ReadWriteAccess)); + + ps.add("int64ArrayParam", + ParameterSpec("int64 array parameter", NTA_BasicType_Int64, + 0, // array + "", "", ParameterSpec::ReadWriteAccess)); ps.add( - "uint32Param", - ParameterSpec( - "UInt32 scalar parameter", // description - NTA_BasicType_UInt32, - 1, // elementCount - "", // constraints - "33", // defaultValue - ParameterSpec::ReadWriteAccess)); + "computeCallback", + ParameterSpec("address of a function that is called at every compute()", + NTA_BasicType_Handle, 1, "", + "", // handles must not have a default value + ParameterSpec::ReadWriteAccess)); - ps.add( - "int64Param", - ParameterSpec( - "Int64 scalar parameter", // description - NTA_BasicType_Int64, - 1, // elementCount - "", // constraints - "64", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ps.add( - "uint64Param", - ParameterSpec( - "UInt64 scalar parameter", // description - NTA_BasicType_UInt64, - 1, // elementCount - "", // constraints - "65", // defaultValue - ParameterSpec::ReadWriteAccess)); + ps.add("stringParam", + ParameterSpec("string parameter", NTA_BasicType_Byte, + 0, // length=0 required for strings + "", "default value", ParameterSpec::ReadWriteAccess)); - ps.add( - "real32Param", - ParameterSpec( - "Real32 scalar parameter", // description - NTA_BasicType_Real32, - 1, // elementCount - "", // constraints - "32.1", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ps.add( - "real64Param", - ParameterSpec( - "Real64 scalar parameter", // description - NTA_BasicType_Real64, - 1, // elementCount - "", // constraints - "64.1", // defaultValue - ParameterSpec::ReadWriteAccess)); - - ps.add( - "real32ArrayParam", - ParameterSpec( - "int32 array parameter", - NTA_BasicType_Real32, - 0, // array - "", - "", - ParameterSpec::ReadWriteAccess)); - - ps.add( - "int64ArrayParam", - ParameterSpec( - "int64 array parameter", - NTA_BasicType_Int64, - 0, // array - "", - "", - ParameterSpec::ReadWriteAccess)); - - ps.add( - "computeCallback", - ParameterSpec( - "address of a function that is called at every compute()", - NTA_BasicType_Handle, - 1, - "", - "", // handles must not have a default value - ParameterSpec::ReadWriteAccess)); - - ps.add( - "stringParam", - ParameterSpec( - "string parameter", - NTA_BasicType_Byte, - 0, // length=0 required for strings - "", - "default value", - ParameterSpec::ReadWriteAccess)); - - ps.add( - "boolParam", - ParameterSpec( - "bool parameter", - NTA_BasicType_Bool, - 1, - "", - "false", - ParameterSpec::ReadWriteAccess)); + ps.add("boolParam", ParameterSpec("bool parameter", NTA_BasicType_Bool, 1, "", + "false", ParameterSpec::ReadWriteAccess)); NTA_DEBUG << "ps count: " << ps.getCount(); ValueMap vm = YAMLUtils::toValueMap("", ps); - EXPECT_TRUE(vm.contains("int32Param")) - << "assertion vm.contains(\"int32Param\") failed at " - << __FILE__ << ":" << __LINE__ ; + EXPECT_TRUE(vm.contains("int32Param")) + << "assertion vm.contains(\"int32Param\") failed at " << __FILE__ << ":" + << __LINE__; ASSERT_EQ((Int32)32, vm.getScalarT("int32Param")); EXPECT_TRUE(vm.contains("boolParam")) - << "assertion vm.contains(\"boolParam\") failed at " - << __FILE__ << ":" << __LINE__ ; + << "assertion vm.contains(\"boolParam\") failed at " << __FILE__ << ":" + << __LINE__; ASSERT_EQ(false, vm.getScalarT("boolParam")); // disabled until we fix default string params // TEST(vm.contains("stringParam")); // EXPECT_STREQ("default value", vm.getString("stringParam")->c_str()); - // Test error message in case of invalid parameter with and without nodeType and regionName - try - { + // Test error message in case of invalid parameter with and without nodeType + // and regionName + try { YAMLUtils::toValueMap("{ blah: True }", ps, "nodeType", "regionName"); - } - catch (nupic::Exception & e) - { + } catch (nupic::Exception &e) { std::string s("Unknown parameter 'blah' for region 'regionName'"); EXPECT_TRUE(std::string(e.getMessage()).find(s) == 0) - << "assertion std::string(e.getMessage()).find(s) == 0 failed at " - << __FILE__ << ":" << __LINE__ ; + << "assertion std::string(e.getMessage()).find(s) == 0 failed at " + << __FILE__ << ":" << __LINE__; } - try - { + try { YAMLUtils::toValueMap("{ blah: True }", ps); - } - catch (nupic::Exception & e) - { + } catch (nupic::Exception &e) { std::string s("Unknown parameter 'blah'\nValid"); EXPECT_TRUE(std::string(e.getMessage()).find(s) == 0) - << "assertion std::string(e.getMessage()).find(s) == 0 failed at " - << __FILE__ << ":" << __LINE__ ; + << "assertion std::string(e.getMessage()).find(s) == 0 failed at " + << __FILE__ << ":" << __LINE__; } } diff --git a/src/test/unit/math/DenseTensorUnitTest.cpp b/src/test/unit/math/DenseTensorUnitTest.cpp index faecb4d7b5..84f3a52f0c 100644 --- a/src/test/unit/math/DenseTensorUnitTest.cpp +++ b/src/test/unit/math/DenseTensorUnitTest.cpp @@ -19,11 +19,11 @@ * http://numenta.org/licenses/ * --------------------------------------------------------------------- */ - + /** @file * Implementation of unit testing for class DenseTensor - */ - + */ + //#include // #include // #include "DenseTensorUnitTest.hpp" @@ -42,25 +42,27 @@ // D3::TensorIndex idx(i, j, k); // Test("DenseTensor bounds constructor", d.get(idx), (Real)0); // } -// +// // Test("DenseTensor getNNonZeros 1", d.getNNonZeros(), (UInt)0); -// Test("DenseTensor getBounds 1", d.getBounds(), D3::TensorIndex(4, 4, 4)); +// Test("DenseTensor getBounds 1", d.getBounds(), D3::TensorIndex(4, 4, +// 4)); // } // -// { +// { // ITER_3(3, 3, 3) { // D3::TensorIndex idx(i, j, k); // d.set(idx, Real(i*16+j*4+k)); // Test("DenseTensor bounds get", d.get(idx), i*16+j*4+k); // } -// +// // Test("DenseTensor getNNonZeros 1", d.getNNonZeros(), (UInt)3*3*3-1); -// Test("DenseTensor getBounds 1", d.getBounds(), D3::TensorIndex(4, 4, 4)); +// Test("DenseTensor getBounds 1", d.getBounds(), D3::TensorIndex(4, 4, +// 4)); // } // // D4 d4(5, 4, 3, 2); -// Test("DenseTensor list constructor", d4.getBounds(), Index(5, 4, 3, 2)); -// ITER_4(5, 4, 3, 2) { +// Test("DenseTensor list constructor", d4.getBounds(), Index(5, 4, +// 3, 2)); ITER_4(5, 4, 3, 2) { // Index i4(i, j, k, l); // Test("DenseTensor list constructor", d4.get(i4), (Real)0); // } @@ -84,9 +86,9 @@ // void DenseTensorUnitTest::unitTestGetSet() // { // I3 ub(5, 4, 3); -// +// // D3 d3(ub), d32(ub); -// +// // ITER_3(ub[0], ub[1], ub[2]) { // d3.set(I3(i, j, k), Real(i*ub[1]*ub[2]+j*ub[2]+k)); // d32.set(i, j, k, d3.get(i, j, k)); @@ -101,7 +103,7 @@ // Test("DenseTensor get 3", d3(i, j, k), d3.get(I3(i, j, k))); // } // } -// +// // //-------------------------------------------------------------------------------- // void DenseTensorUnitTest::unitTestIsSymmetric() // { @@ -114,7 +116,7 @@ // // d32.set(0, 0, 1, (Real).5); // Test("DenseTensor isSymmetric 4", d32.isSymmetric(I3(0, 2, 1)), false); -// +// // d32.set(0, 1, 0, (Real).5); // Test("DenseTensor isSymmetric 5", d32.isSymmetric(I3(0, 2, 1)), true); // @@ -131,19 +133,19 @@ // { // { // D2 d2(3, 4), ref(4, 3); -// +// // ITER_2(3, 4) { // d2.set(I2(i, j), Real(i*4+j+1)); // ref.set(I2(j, i), d2.get(i, j)); // } -// +// // d2.permute(I2(1, 0)); // Test("DenseTensor permute 1", d2, ref); // } // // { // D3 d3(3, 4, 5); -// ITER_3(3, 4, 5) +// ITER_3(3, 4, 5) // d3.set(I3(i, j, k), Real(i*20+j*4+k+1)); // // D3 ref3(3, 5, 4); @@ -156,7 +158,7 @@ // // { // D3 d3(3, 4, 5); -// ITER_3(3, 4, 5) +// ITER_3(3, 4, 5) // d3.set(I3(i, j, k), Real(i*20+j*4+k+1)); // // D3 ref3(4, 5, 3); @@ -169,7 +171,7 @@ // // { // D3 d3(3, 4, 5); -// ITER_3(3, 4, 5) +// ITER_3(3, 4, 5) // d3.set(I3(i, j, k), Real(i*20+j*4+k+1)); // // D3 ref3(4, 3, 5); @@ -184,7 +186,7 @@ // //-------------------------------------------------------------------------------- // void DenseTensorUnitTest::unitTestResize() // { -// { +// { // D2 d2(3, 4); // ITER_2(3, 4) d2.set(I2(i, j), Real(i*4+j)); // @@ -223,26 +225,26 @@ // { // { // D2 d2(3, 4), d2r(3, 4), ref(3, 4); -// -// ITER_2(3, 4) +// +// ITER_2(3, 4) // d2.set(I2(i, j), Real(i*4+j)); -// +// // ITER_2(3, 4) // ref.set(I2(i, j), Real(i*4+j)); -// +// // d2.reshape(d2r); // Test("DenseTensor reshape 0", d2r, ref); // } // // { // D2 d2(3, 4), d2r(2, 6), ref(2, 6); -// -// ITER_2(3, 4) +// +// ITER_2(3, 4) // d2.set(I2(i, j), Real(i*4+j)); -// +// // ITER_2(2, 6) // ref.set(I2(i, j), Real(i*6+j)); -// +// // d2.reshape(d2r); // Test("DenseTensor reshape 1", d2r, ref); // } @@ -250,13 +252,13 @@ // { // D2 d2(3, 4); // D3 d3r(2, 2, 3), ref(2, 2, 3); -// +// // ITER_2(3, 4) // d2.set(I2(i, j), Real(i*4+j)); -// +// // ITER_3(2, 2, 3) // ref.set(I3(i, j, k), Real(i*6+j*3+k)); -// +// // d2.reshape(d3r); // Test("DenseTensor reshape 2", d3r, ref); // } @@ -264,13 +266,13 @@ // { // D3 d3(2, 2, 3); // D2 d2r(3, 4), ref(3, 4); -// +// // ITER_3(2, 2, 3) // d3.set(I3(i, j, k), Real(i*6+j*3+k)); // // ITER_2(3, 4) // ref.set(I2(i, j), Real(i*4+j)); -// +// // d3.reshape(d2r); // Test("DenseTensor reshape 3", d2r, ref); // } @@ -278,11 +280,11 @@ // // //-------------------------------------------------------------------------------- // void DenseTensorUnitTest::unitTestSlice() -// { +// { // D3 d(D3::TensorIndex(5, 4, 3)); // // { -// ITER_3(5, 4, 3) +// ITER_3(5, 4, 3) // d.set(D3::TensorIndex(i, j, k), Real(i*16+j*4+k+1)); // } // @@ -298,7 +300,7 @@ // D2 d2ref(D2::TensorIndex(4, 3)); // ITER_2(4, 3) // d2ref.set(D2::TensorIndex(i, j), Real(u*16+i*4+j+1)); -// +// // Domain dom(D3::TensorIndex(u, 0, 0), D3::TensorIndex(u, 4, 3)); // D2 d2(D2::TensorIndex(4, 3)); // d.getSlice(dom, d2); @@ -311,7 +313,7 @@ // D2 d2ref(D2::TensorIndex(5, 3)); // ITER_2(5, 3) // d2ref.set(D2::TensorIndex(i, j), Real(i*16+u*4+j+1)); -// +// // Domain dom(D3::TensorIndex(0, u, 0), D3::TensorIndex(5, u, 3)); // D2 d2(D2::TensorIndex(5, 3)); // d.getSlice(dom, d2); @@ -324,7 +326,7 @@ // D2 d2ref(D2::TensorIndex(5, 4)); // ITER_2(5, 4) // d2ref.set(D2::TensorIndex(i, j), Real(i*16+j*4+u+1)); -// +// // Domain dom(D3::TensorIndex(0, 0, u), D3::TensorIndex(5, 4, u)); // D2 d2(D2::TensorIndex(5, 4)); // d.getSlice(dom, d2); @@ -338,10 +340,9 @@ // D1 d1ref(D1::TensorIndex(5)); // ITER_1(5) // d1ref.set(D1::TensorIndex(i), Real(i*16+u1*4+u2+1)); -// -// Domain dom(D3::TensorIndex(0, u1, u2), D3::TensorIndex(5, u1, u2)); -// D1 d1(D1::TensorIndex(5)); -// d.getSlice(dom, d1); +// +// Domain dom(D3::TensorIndex(0, u1, u2), D3::TensorIndex(5, u1, +// u2)); D1 d1(D1::TensorIndex(5)); d.getSlice(dom, d1); // Test("DenseTensor getSlice 5", (d1 == d1ref), true); // } // } @@ -352,10 +353,9 @@ // D1 d1ref(D1::TensorIndex(4)); // ITER_1(4) // d1ref.set(D1::TensorIndex(i), Real(u1*16+i*4+u2+1)); -// -// Domain dom(D3::TensorIndex(u1, 0, u2), D3::TensorIndex(u1, 4, u2)); -// D1 d1(D1::TensorIndex(4)); -// d.getSlice(dom, d1); +// +// Domain dom(D3::TensorIndex(u1, 0, u2), D3::TensorIndex(u1, 4, +// u2)); D1 d1(D1::TensorIndex(4)); d.getSlice(dom, d1); // Test("DenseTensor getSlice 6", (d1 == d1ref), true); // } // } @@ -366,13 +366,12 @@ // D1 d1ref(D1::TensorIndex(3)); // ITER_1(3) // d1ref.set(D1::TensorIndex(i), Real(u1*16+u2*4+i+1)); -// -// Domain dom(D3::TensorIndex(u1, u2, 0), D3::TensorIndex(u1, u2, 3)); -// D1 d1(D1::TensorIndex(3)); -// d.getSlice(dom, d1); +// +// Domain dom(D3::TensorIndex(u1, u2, 0), D3::TensorIndex(u1, u2, +// 3)); D1 d1(D1::TensorIndex(3)); d.getSlice(dom, d1); // Test("DenseTensor getSlice 7", (d1 == d1ref), true); // } -// } +// } // } // // //-------------------------------------------------------------------------------- @@ -380,7 +379,7 @@ // { // I4 ub4(5, 4, 3, 2); // D4 dA(ub4), dB(ub4), dC(ub4), dref(ub4); -// +// // ITER_4(ub4[0], ub4[1], ub4[2], ub4[3]) { // Index idx(i, j, k, l); // dA.set(idx, Real(i*5*3*7+j*3*7+k*7+l+1)); @@ -401,13 +400,13 @@ // } // // dA.element_apply(dB, dA, std::plus()); -// Test("DenseTensor element_apply 2", (dA == dref), true); +// Test("DenseTensor element_apply 2", (dA == dref), true); // // ITER_4(ub4[0], ub4[1], ub4[2], ub4[3]) { // Index idx(i, j, k, l); // dref.set(idx, dA.get(idx) * dB.get(idx)); -// } -// +// } +// // dA.element_apply(dB, dC, nupic::Multiplies()); // Test("DenseTensor element_apply 3", (dC == dref), true); // @@ -415,9 +414,9 @@ // Index idx(i, j, k, l); // dref.set(idx, dA.get(idx) * dB.get(idx)); // } -// +// // dA.element_apply(dB, dA, nupic::Multiplies()); -// Test("DenseTensor element_apply 4", (dA == dref), true); +// Test("DenseTensor element_apply 4", (dA == dref), true); // // ITER_4(ub4[0], ub4[1], ub4[2], ub4[3]) { // Index idx(i, j, k, l); @@ -428,12 +427,12 @@ // //-------------------------------------------------------------------------------- // void DenseTensorUnitTest::unitTestFactorApply() // { -// { +// { // D2 d2(3, 4), c(3, 4), ref(3, 4); -// +// // ITER_2(3, 4) d2.set(I2(i, j), Real(i*4+j+1)); // -// { +// { // ITER_2(3, 4) ref.set(I2(i, j), Real((i*4+j+1) * (j+2))); // D1 d1(4); ITER_1(4) d1.set(I1((UInt)i), Real(i+2)); // d2.factor_apply(I1(1), d1, c, nupic::Multiplies()); @@ -446,18 +445,18 @@ // d2.factor_apply(I1((UInt)0), d1, c, nupic::Multiplies()); // Test("DenseTensor factor_apply 2", c, ref); // } -// } -// +// } +// // { // I3 ub(4, 2, 3); // D3 d3(ub), c(ub), ref(ub); -// -// ITER_3(ub[0], ub[1], ub[2]) +// +// ITER_3(ub[0], ub[1], ub[2]) // { I3 i3(i, j, k); d3.set(i3, Real(i3.ordinal(ub))); } // // { // D1 d1(ub[0]); ITER_1(ub[0]) d1.set(I1((UInt)i), Real(i+2)); -// ITER_3(ub[0], ub[1], ub[2]) +// ITER_3(ub[0], ub[1], ub[2]) // { I3 i3(i, j, k); ref.set(i3, Real(i3.ordinal(ub)*(i+2))); } // d3.factor_apply(I1((UInt)0), d1, c, nupic::Multiplies()); // Test("DenseTensor factor_apply 3", c, ref); @@ -465,7 +464,7 @@ // // { // D1 d1(ub[1]); ITER_1(ub[1]) d1.set(I1((UInt)i), Real(i+2)); -// ITER_3(ub[0], ub[1], ub[2]) +// ITER_3(ub[0], ub[1], ub[2]) // { I3 i3(i, j, k); ref.set(i3, Real(i3.ordinal(ub)*(j+2))); } // d3.factor_apply(I1((UInt)1), d1, c, nupic::Multiplies()); // Test("DenseTensor factor_apply 4", c, ref); @@ -473,44 +472,44 @@ // // { // D1 d1(ub[2]); ITER_1(ub[2]) d1.set(I1((UInt)i), Real(i+2)); -// ITER_3(ub[0], ub[1], ub[2]) +// ITER_3(ub[0], ub[1], ub[2]) // { I3 i3(i, j, k); ref.set(i3, Real(i3.ordinal(ub)*(k+2))); } // d3.factor_apply(I1((UInt)2), d1, c, nupic::Multiplies()); // Test("DenseTensor factor_apply 5", c, ref); -// } +// } // -// { -// D2 d2(ub[1], ub[2]); +// { +// D2 d2(ub[1], ub[2]); // ITER_2(ub[1], ub[2]) d2.set(I2(i, j), Real(i*ub[2]+j+2)); -// ITER_3(ub[0], ub[1], ub[2]) +// ITER_3(ub[0], ub[1], ub[2]) // { I3 i3(i, j, k); ref.set(i3, Real(i3.ordinal(ub)*(j*ub[2]+k+2))); } // d3.factor_apply(I2(1, 2), d2, c, nupic::Multiplies()); // Test("DenseTensor factor_apply 6", c, ref); // } // // { -// D2 d2(ub[0], ub[2]); +// D2 d2(ub[0], ub[2]); // ITER_2(ub[0], ub[2]) d2.set(I2(i, j), Real(i*ub[2]+j+2)); -// ITER_3(ub[0], ub[1], ub[2]) +// ITER_3(ub[0], ub[1], ub[2]) // { I3 i3(i, j, k); ref.set(i3, Real(i3.ordinal(ub)*(i*ub[2]+k+2))); } // d3.factor_apply(I2(0, 2), d2, c, nupic::Multiplies()); // Test("DenseTensor factor_apply 7", c, ref); // } // // { -// D2 d2(ub[0], ub[1]); +// D2 d2(ub[0], ub[1]); // ITER_2(ub[0], ub[1]) d2.set(I2(i, j), Real(i*ub[1]+j+2)); -// ITER_3(ub[0], ub[1], ub[2]) +// ITER_3(ub[0], ub[1], ub[2]) // { I3 i3(i, j, k); ref.set(i3, Real(i3.ordinal(ub)*(i*ub[1]+j+2))); } // d3.factor_apply(I2(0, 1), d2, c, nupic::Multiplies()); // Test("DenseTensor factor_apply 8", c, ref); -// } -// +// } +// // { // D3 d32(ub[0], ub[1], ub[2]); -// ITER_3(ub[0], ub[1], ub[2]) { -// I3 i3(i, j, k); -// d32.set(i3, Real(i3.ordinal(ub)+1)); +// ITER_3(ub[0], ub[1], ub[2]) { +// I3 i3(i, j, k); +// d32.set(i3, Real(i3.ordinal(ub)+1)); // ref.set(i3, Real(i3.ordinal(ub) * (i3.ordinal(ub)+1))); // } // d3.factor_apply(I3(0, 1, 2), d32, c, nupic::Multiplies()); @@ -526,80 +525,81 @@ // D2 d2(3, 4); ITER_2(3, 4) d2.set(I2(i, j), Real(i*4+j+1)); // // { -// D1 d1(3), ref(3); -// ITER_2(3, 4) ref.set(I1((UInt)i), ref.get(I1((UInt)i)) + d2.get(I2(i, j))); -// d2.accumulate(I1((UInt)1), d1, std::plus()); +// D1 d1(3), ref(3); +// ITER_2(3, 4) ref.set(I1((UInt)i), ref.get(I1((UInt)i)) + d2.get(I2(i, +// j))); d2.accumulate(I1((UInt)1), d1, std::plus()); // Test("DenseTensor accumulate 1", d1, ref); // } // // { -// D1 d1(4), ref(4); -// ITER_2(3, 4) ref.set(I1((UInt)j), ref.get(I1((UInt)j)) + d2.get(I2(i, j))); -// d2.accumulate(I1((UInt)0), d1, std::plus()); +// D1 d1(4), ref(4); +// ITER_2(3, 4) ref.set(I1((UInt)j), ref.get(I1((UInt)j)) + d2.get(I2(i, +// j))); d2.accumulate(I1((UInt)0), d1, std::plus()); // Test("DenseTensor accumulate 2", d1, ref); // } // } // // { -// D3 d3(3, 4, 5); ITER_3(3, 4, 5) d3.set(I3(i, j, k), Real(i*4*5+j*5+k+1)); -// +// D3 d3(3, 4, 5); ITER_3(3, 4, 5) d3.set(I3(i, j, k), +// Real(i*4*5+j*5+k+1)); +// // { // D2 d2(4, 5), ref(4, 5); -// ITER_3(3, 4, 5) ref.set(I2(j, k), ref.get(I2(j, k)) + d3.get(I3(i, j, k))); -// d3.accumulate(I1((UInt)0), d2, std::plus()); +// ITER_3(3, 4, 5) ref.set(I2(j, k), ref.get(I2(j, k)) + d3.get(I3(i, j, +// k))); d3.accumulate(I1((UInt)0), d2, std::plus()); // Test("DenseTensor accumulate 3", d2, ref); // } -// +// // { // D2 d2(3, 5), ref(3, 5); -// ITER_3(3, 4, 5) ref.set(I2(i, k), ref.get(I2(i, k)) + d3.get(I3(i, j, k))); -// d3.accumulate(I1((UInt)1), d2, std::plus()); +// ITER_3(3, 4, 5) ref.set(I2(i, k), ref.get(I2(i, k)) + d3.get(I3(i, j, +// k))); d3.accumulate(I1((UInt)1), d2, std::plus()); // Test("DenseTensor accumulate 4", d2, ref); // } // // { // D2 d2(3, 4), ref(3, 4); -// ITER_3(3, 4, 5) ref.set(I2(i, j), ref.get(I2(i, j)) + d3.get(I3(i, j, k))); -// d3.accumulate(I1((UInt)2), d2, std::plus()); +// ITER_3(3, 4, 5) ref.set(I2(i, j), ref.get(I2(i, j)) + d3.get(I3(i, j, +// k))); d3.accumulate(I1((UInt)2), d2, std::plus()); // Test("DenseTensor accumulate 5", d2, ref); // } // // { // D1 d1(3), ref(3); -// ITER_3(3, 4, 5) ref.set(I1((UInt)i), ref.get(I1((UInt)i)) + d3.get(I3(i, j, k))); -// d3.accumulate(I2(1, 2), d1, std::plus()); +// ITER_3(3, 4, 5) ref.set(I1((UInt)i), ref.get(I1((UInt)i)) + +// d3.get(I3(i, j, k))); d3.accumulate(I2(1, 2), d1, std::plus()); // Test("DenseTensor accumulate 6", d1, ref); // } // // { // D1 d1(4), ref(4); -// ITER_3(3, 4, 5) ref.set(I1((UInt)j), ref.get(I1((UInt)j)) + d3.get(I3(i, j, k))); -// d3.accumulate(I2(0, 2), d1, std::plus()); +// ITER_3(3, 4, 5) ref.set(I1((UInt)j), ref.get(I1((UInt)j)) + +// d3.get(I3(i, j, k))); d3.accumulate(I2(0, 2), d1, std::plus()); // Test("DenseTensor accumulate 7", d1, ref); // } -// +// // { // D1 d1(5), ref(5); -// ITER_3(3, 4, 5) ref.set(I1((UInt)k), ref.get(I1((UInt)k)) + d3.get(I3(i, j, k))); -// d3.accumulate(I2(0, 1), d1, std::plus()); +// ITER_3(3, 4, 5) ref.set(I1((UInt)k), ref.get(I1((UInt)k)) + +// d3.get(I3(i, j, k))); d3.accumulate(I2(0, 1), d1, std::plus()); // Test("DenseTensor accumulate 8", d1, ref); // } // } // // { // Max // D2 d2(3, 4); ITER_2(3, 4) d2.set(I2(i, j), Real(i*4+j+1)); -// -// { -// D1 d1(3), ref(3); +// +// { +// D1 d1(3), ref(3); // ref.set(I1((UInt)0), 4); // ref.set(I1((UInt)1), 8); // ref.set(I1((UInt)2), 12); // d2.accumulate(I1((UInt)1), d1, nupic::Max()); // Test("DenseTensor max 1", d1, ref); // } -// -// { -// D1 d1(4), ref(4); +// +// { +// D1 d1(4), ref(4); // ref.set(I1((UInt)0), 9); // ref.set(I1((UInt)1), 10); // ref.set(I1((UInt)2), 11); @@ -611,14 +611,14 @@ // // { // multiplication // D2 d2(3, 4); ITER_2(3, 4) d2.set(i, j, (Real)(i*4+j+1)); -// +// // D1 d1(3), ref(3); // ref.setAll(1); // ITER_2(3, 4) ref.set(i, d2(i, j) * ref(i)); // d2.accumulate(I1((UInt)1), d1, nupic::Multiplies(), 1); // Test("DenseTensor accumulate 9", d1, ref); // } -// } +// } // // //-------------------------------------------------------------------------------- // void DenseTensorUnitTest::unitTestOuterProduct() @@ -646,7 +646,7 @@ // //-------------------------------------------------------------------------------- // void DenseTensorUnitTest::unitTestContract() // { -// D3 d3(4, 3, 3); +// D3 d3(4, 3, 3); // ITER_3(4, 3, 3) d3.set(i, j, k, (Real)(i*9+j*3+k+1)); // // { @@ -668,18 +668,18 @@ // //-------------------------------------------------------------------------------- // void DenseTensorUnitTest::unitTestInnerProduct() // { -// D2 d2A(3, 4), d2B(4, 3), d2C(3, 3), d2ref(3, 3); -// -// ITER_2(3, 4) +// D2 d2A(3, 4), d2B(4, 3), d2C(3, 3), d2ref(3, 3); +// +// ITER_2(3, 4) // d2A(i, j) = d2B(j, i) = Real(i*4+j+1); // // ITER_3(3, 4, 3) // d2ref(i, k) += d2A(i, j) * d2B(j, k); -// -// d2A.inner_product(1, 0, d2B, d2C, nupic::Multiplies(), std::plus(), 0); -// Test("DenseTensor inner product 1", d2C, d2ref); // -// D4 o(3, 4, 4, 3); +// d2A.inner_product(1, 0, d2B, d2C, nupic::Multiplies(), +// std::plus(), 0); Test("DenseTensor inner product 1", d2C, d2ref); +// +// D4 o(3, 4, 4, 3); // d2A.outer_product(d2B, o, nupic::Multiplies()); // // D2 d2D(3, 3); @@ -687,38 +687,37 @@ // Test("DenseTensor inner product 2", d2C, d2D); // // D3 d3A(3, 4, 5), d3B(3, 3, 5); -// ITER_3(3, 4, 5) d3A.set(i, j, k, (Real)(I3(i, j, k).ordinal(I3(3, 4, 5)) + 1)); -// d2A.inner_product(1, 1, d3A, d3B, nupic::Multiplies(), std::plus()); +// ITER_3(3, 4, 5) d3A.set(i, j, k, (Real)(I3(i, j, k).ordinal(I3(3, 4, 5)) + +// 1)); d2A.inner_product(1, 1, d3A, d3B, nupic::Multiplies(), +// std::plus()); // // D5 o2(3, 4, 3, 4, 5); // d2A.outer_product(d3A, o2, nupic::Multiplies()); -// +// // D3 d3D(3, 3, 5); // o2.contract(1, 3, d3D, std::plus()); // Test("DenseTensor inner product 3", d3B, d3D); // } // - //-------------------------------------------------------------------------------- - // void DenseTensorUnitTest::RunTests() - // { - - //unitTestConstructor(); - //unitTestGetSet(); - //unitTestIsSymmetric(); - //unitTestPermute(); - //unitTestResize(); - //unitTestReshape(); - //unitTestSlice(); - //unitTestElementApply(); - //unitTestFactorApply(); - //unitTestAccumulate(); - //unitTestOuterProduct(); - //unitTestContract(); - //unitTestInnerProduct(); - // } +//-------------------------------------------------------------------------------- +// void DenseTensorUnitTest::RunTests() +// { - //-------------------------------------------------------------------------------- - -// } // namespace nupic +// unitTestConstructor(); +// unitTestGetSet(); +// unitTestIsSymmetric(); +// unitTestPermute(); +// unitTestResize(); +// unitTestReshape(); +// unitTestSlice(); +// unitTestElementApply(); +// unitTestFactorApply(); +// unitTestAccumulate(); +// unitTestOuterProduct(); +// unitTestContract(); +// unitTestInnerProduct(); +// } +//-------------------------------------------------------------------------------- +// } // namespace nupic diff --git a/src/test/unit/math/DomainUnitTest.cpp b/src/test/unit/math/DomainUnitTest.cpp index e2924103b1..31ff4b216f 100644 --- a/src/test/unit/math/DomainUnitTest.cpp +++ b/src/test/unit/math/DomainUnitTest.cpp @@ -19,11 +19,11 @@ * http://numenta.org/licenses/ * --------------------------------------------------------------------- */ - -/** @file + +/** @file * Implementation of unit testing for class Domain - */ - + */ + //#include // #include "DomainUnitTest.hpp" // @@ -56,7 +56,7 @@ // Test("DimRange includes 2", dr.includes(1), true); // Test("DimRange includes 3", dr.includes(2), true); // Test("DimRange includes 4", dr.includes(3), false); -// +// // DimRange dr2(0, 1, 1); // Test("DimRange includes 4", dr2.includes(0), false); // Test("DimRange includes 5", dr2.includes(1), true); @@ -150,7 +150,8 @@ // } // // { -// Domain d(Index(1, 2, 4, 5, 6), Index(1, 5, 7, 5, 9)); +// Domain d(Index(1, 2, 4, 5, 6), Index(1, 5, 7, 5, +// 9)); // // Index openDims; // Test("Domain getOpenDims 1", d.getNOpenDims(), (UInt)3); @@ -164,19 +165,20 @@ // d.getClosedDims(closedDims); // Test("Domain getClosedDims 2", closedDims[0], (UInt) 0); // Test("Domain getClosedDims 3", closedDims[1], (UInt) 3); -// } +// } // // { -// Domain d(Index(1, 2, 4, 5, 6), Index(1, 5, 7, 5, 9)); +// Domain d(Index(1, 2, 4, 5, 6), Index(1, 5, 7, 5, +// 9)); // // Index idx, lb(1, 2, 4, 5, 6), ub(1, 5, 7, 5, 9); -// +// // do { // if (lb <= idx && idx < ub) // Test("Domain includes 1", d.includes(idx), true); // else // Test("Domain includes 2", d.includes(idx), false); -// +// // } while (idx.increment(ub)); // // Domain d1(Index(0,0,0,0), Index(1,0,1,0)); @@ -211,7 +213,7 @@ // } // // { -// DimRange r; +// DimRange r; // r.set(1, 4, 7); // Test("DimRange set 1", r.getDim(), (UInt)1); // Test("DimRange set 2", r.getLB(), (UInt)4); @@ -251,14 +253,12 @@ // Domain d24(I2(1, 4), I2(9, 4)); // Test("Domain size_elts 4", d24.size_elts(), (UInt)0); // Test("Domain empty 2", d24.empty(), true); -// +// // Domain d31(I3(0, 1, 2), I3(10, 9, 8)); // Test("Domain size_elts 5", d31.size_elts(), (UInt)480); // } - // } - - //-------------------------------------------------------------------------------- - -// } // namespace nupic +// } +//-------------------------------------------------------------------------------- +// } // namespace nupic diff --git a/src/test/unit/math/IndexUnitTest.cpp b/src/test/unit/math/IndexUnitTest.cpp index 89feee79d5..6e2913448b 100644 --- a/src/test/unit/math/IndexUnitTest.cpp +++ b/src/test/unit/math/IndexUnitTest.cpp @@ -19,11 +19,11 @@ * http://numenta.org/licenses/ * --------------------------------------------------------------------- */ - -/** @file + +/** @file * Implementation of unit testing for class Index - */ - + */ + //#include // #include // @@ -52,7 +52,7 @@ // for (UInt i = 0; i < 5; ++i) // Test("Index constructor from array", i1[i], i+1); // } -// +// // { // Constructor with ellipsis // const I5 i2(5, 4, 3, 2, 1); // for (int i = 5; i > 0; --i) @@ -64,7 +64,7 @@ // const I5 idx2(idx1); // for (UInt i = 0; i < 5; ++i) // Test("Index copy constructor 1", idx2[i], (UInt)0); -// +// // const I5 i3(5, 4, 3, 2, 1); // I5 i4(i3); // for (UInt i = 0; i < 5; ++i) @@ -93,7 +93,7 @@ // const I5 i4(5, 4, 3, 2, 1); // for (UInt i = 0; i < 5; ++i) // Test("Index operator[] const", i4[i], 5-i); -// } +// } // // { // begin(), end() // I5 idx(5, 4, 3, 2, 1); @@ -132,7 +132,7 @@ // I5 i1(5, 4, 3, 2, 1); // I5 i2(1, 2, 3, 4, 5); // Test("Index operator== 1 ", ! (i1 == i2), true); -// +// // I5 i3(5, 4, 3, 2, 1); // Test("Index operator== 2 ", (i1 == i3), true); // } @@ -141,21 +141,21 @@ // I5 i1(5, 4, 3, 2, 1); // I5 i2(1, 2, 3, 4, 5); // Test("Index operator!= 1 ", (i1 != i2), true); -// +// // I5 i3(5, 4, 3, 2, 1); // Test("Index operator!= 2 ", ! (i1 != i3), true); // } // -// { // ordinal/setFromOrdinal +// { // ordinal/setFromOrdinal // I1 ub((UInt)5); // for (UInt i = 0; i < 5; ++i) { // I1 i1((UInt)i); // Test("Index<1> ordinal 1", ordinal(ub, i1), i); // I1 i2; setFromOrdinal(ub, i, i2); // Test("Index<1> setFromOrdinal 1", i1, i2); -// } +// } // } -// +// // { // increment/ordinal // UInt i = 0; // I1 i1, i2, ub((UInt)5); @@ -170,7 +170,7 @@ // { // increment/ordinal/setFromOrdinal // I5 ub(6, 7, 5, 3, 4), idx; // UInt i = 0; -// do { +// do { // UInt n = ordinal(ub, idx); // I5 idx2(ub, n); // Test("Index increment/ordinal 1", idx, idx2); @@ -181,11 +181,11 @@ // ++i; // } while (increment(ub, idx)); // } -// -// { // incrementing between bounds +// +// { // incrementing between bounds // I5 lb(4, 5, 3, 1, 2), ub(6, 7, 5, 3, 4), idx(lb); // do { -// } while (increment(lb, ub, idx)); +// } while (increment(lb, ub, idx)); // } // // { // incrementing between bounds, over and under => exceptions @@ -201,7 +201,7 @@ // Test("Index setToZero 1", i1[i], (UInt)0); // // I5 i2(5, 4, 3, 2, 1); -// setToZero(i2); +// setToZero(i2); // for (UInt i = 0; i < 5; ++i) // Test("Index setToZero 2", i2[i], (UInt)0); // @@ -235,27 +235,27 @@ // } // // { // complement -// { +// { // I3 i3(0, 2, 4), c; // complement(i3, c); // Test("Index complement 1", c, I3(1, 3, 5)); // } // // { -// I2 i2(0, 2); +// I2 i2(0, 2); // I1 c; // complement(i2, c); // Test("Index complement 2", c, I1(1)); // } -// +// // { -// I1 i1((UInt)0), c; +// I1 i1((UInt)0), c; // complement(i1, c); // Test("Index complement 3", c, I1(1)); // } // // { -// I2 i2(0, 1); +// I2 i2(0, 1); // I1 c; // complement(i2, c); // Test("Index complement 4", c, I1(2)); @@ -269,7 +269,7 @@ // } // } // -// +// // // { // project // const I3 i3(9, 7, 3); @@ -295,8 +295,8 @@ // Test("Index project 6", i4[i], (UInt)3-i); // } // } -// -// { // embed +// +// { // embed // { // I3 i3A(1, 2, 3), dims(0, 1, 2), i3B; // embed(dims, i3A, i3B); @@ -305,8 +305,8 @@ // // { // I6 i6; -// const I3 dims(1, 3, 5), i3(9, 7, 3); -// embed(dims, i3, i6); +// const I3 dims(1, 3, 5), i3(9, 7, 3); +// embed(dims, i3, i6); // Test("Index embed 1A", i6, I6(0, 9, 0, 7, 0, 3)); // // I6 i6B; @@ -316,7 +316,7 @@ // // { // I6 i6; -// const I3 dims(0, 2, 4), i3(9, 7, 3); +// const I3 dims(0, 2, 4), i3(9, 7, 3); // embed(dims, i3, i6); // Test("Index embed 2", i6, I6(9, 0, 7, 0, 3, 0)); // } @@ -334,16 +334,16 @@ // // { // embed is the reciprocal of project // I6 i6; -// +// // { -// I3 dims(1, 3, 5), i3A(9, 7, 3), i3B; -// embed(dims, i3A, i6); +// I3 dims(1, 3, 5), i3A(9, 7, 3), i3B; +// embed(dims, i3A, i6); // project(dims, i6, i3B); // Test("Index embed 5", i3A, i3B); // } // // { -// I3 dims(0, 2, 4), i3A(9, 7, 3), i3B; +// I3 dims(0, 2, 4), i3A(9, 7, 3), i3B; // embed(dims, i3A, i6); // project(dims, i6, i3B); // Test("Index embed 6", i3A, i3B); @@ -384,7 +384,7 @@ // I5 i0(1, 2, 3, 4, 5); // I5 i1(2, 3, 4, 5, 6); // I5 i2(3, 4, 5, 6, 7); -// +// // Test("Index operator<= 1", (i0 <= i0), true); // Test("Index operator<= 2", (i0 <= i1), true); // Test("Index operator<= 3", (i1 <= i2), true); @@ -412,10 +412,11 @@ // I3 ub(1, 3, 4); // Test("Index positiveInBounds 4", positiveInBounds(I3(0,0,0), ub), true); // Test("Index positiveInBounds 1", positiveInBounds(I3(0,2,3), ub), true); -// Test("Index positiveInBounds 1", positiveInBounds(I3(0,3,3), ub), false); -// Test("Index positiveInBounds 2", positiveInBounds(I3(0,2,4), ub), false); -// Test("Index positiveInBounds 3", positiveInBounds(I3(4,5,6), ub), false); -// Test("Index positiveInBounds 3", positiveInBounds(I3(1,2,3), ub), false); +// Test("Index positiveInBounds 1", positiveInBounds(I3(0,3,3), ub), +// false); Test("Index positiveInBounds 2", positiveInBounds(I3(0,2,4), +// ub), false); Test("Index positiveInBounds 3", +// positiveInBounds(I3(4,5,6), ub), false); Test("Index positiveInBounds +// 3", positiveInBounds(I3(1,2,3), ub), false); // } // } // @@ -440,10 +441,10 @@ // void IndexUnitTest::unitTestDynamicIndex() // { // typedef std::vector Idx; -// -// { +// +// { // Idx i3(5); i3[0] = 5; i3[1] = 4; i3[2] = 3; i3[3] = 2; i3[4] = 1; -// setToZero(i3); +// setToZero(i3); // for (UInt i = 0; i < 5; ++i) // Test("Dynamic Index setToZero", i3[i], (UInt)0); // } @@ -458,7 +459,7 @@ // { // Idx i(3), ub(3), j(3), iprev(3); // ub[0] = 3; ub[1] = 2; ub[2] = 5; -// setToZero(i); +// setToZero(i); // UInt n = 0; // do { // Test("Dynamic Index ordinal", ordinal(ub, i), n); @@ -471,7 +472,7 @@ // } while (increment(ub, i)); // } // -// { +// { // Idx i3(3); i3[0] = 9; i3[1] = 7; i3[2] = 3; // Test("Dynamic Index product", product(i3), (UInt)9*7*3); // } @@ -482,7 +483,7 @@ // Test("Dynamic Index complement 1", Compare(c, I3(1, 3, 5)), true); // } // -// { +// { // Idx i3(3); i3[0] = 9; i3[1] = 7; i3[2] = 3; // // Idx i3p(3); @@ -492,8 +493,8 @@ // Test("Dynamic Index project 2", i3p[1], (UInt)7); // Test("Dynamic Index project 3", i3p[2], (UInt)3); // -// Idx i2(2); -// const I2 dims1(0, 2); +// Idx i2(2); +// const I2 dims1(0, 2); // project(dims1, i3, i2); // Test("Dynamic Index project 4", i2[0], (UInt)9); // Test("Dynamic Index project 5", i2[1], (UInt)3); @@ -505,23 +506,24 @@ // project(dims2, i4, i1); // Test("Dynamic Index project 6", i4[i], (UInt)3-i); // } -// } +// } // -// { +// { // Idx i1(3); i1[0] = 5; i1[1] = 4; i1[2] = 3; // Idx i2(3); i2[0] = 2; i2[1] = 1; i2[2] = 0; -// Idx i3(6); +// Idx i3(6); // i3 = concatenate(i1, i2); // for (int i = 5; i >= 0; --i) // Test("Dynamic Index concatenate", i3[5-i], (UInt)i); -// } -// +// } +// // { // permutations // Idx i5(5); i5[0] = 1; i5[1] = 2; i5[2] = 3; i5[3] = 4; i5[4] = 5; // Idx ind(5); ind[0] = 1; ind[1] = 2; ind[2] = 3; ind[3] = 4; ind[4] = 0; -// Idx perm(5); +// Idx perm(5); // nupic::permute(ind, i5, perm); -// Test("Dynamic Index permutation 1", Compare(perm, I5(2, 3, 4, 5, 1)), true); +// Test("Dynamic Index permutation 1", Compare(perm, I5(2, 3, 4, 5, 1)), +// true); // } // // { // Can it be a key in a std::map? @@ -547,22 +549,20 @@ // std::map, UInt>::const_iterator it; // UInt i; // for (i = 0, it = m.begin(); i < 9*7*3; ++i, ++it) { -// Test("Dynamic Index operator< 8", indexEq(I3(ub, i), it->first), true); -// Test("Dynamic Index operator< 9", i, it->second); +// Test("Dynamic Index operator< 8", indexEq(I3(ub, i), it->first), +// true); Test("Dynamic Index operator< 9", i, it->second); // } -// } +// } // } // - //-------------------------------------------------------------------------------- - // void IndexUnitTest::RunTests() - // { - - //unitTestFixedIndex(); - //unitTestDynamicIndex(); - // } +//-------------------------------------------------------------------------------- +// void IndexUnitTest::RunTests() +// { - //-------------------------------------------------------------------------------- - -// } // namespace nupic +// unitTestFixedIndex(); +// unitTestDynamicIndex(); +// } +//-------------------------------------------------------------------------------- +// } // namespace nupic diff --git a/src/test/unit/math/MathsTest.cpp b/src/test/unit/math/MathsTest.cpp index 226d1de5d6..a4ece366fb 100644 --- a/src/test/unit/math/MathsTest.cpp +++ b/src/test/unit/math/MathsTest.cpp @@ -53,43 +53,48 @@ // //-------------------------------------------------------------------------------- // void MathsTest::unitTestNearlyEqual() // { -// Test("nearlyEqual Reals 1", true, nearlyEqual(Real(0.0), Real(0.0000000001))); -// Test("nearlyEqual Reals 2", true, nearlyEqual(0.0, 0.0)); -// Test("nearlyEqual Reals 3", false, nearlyEqual(0.0, 1.0)); +// Test("nearlyEqual Reals 1", true, nearlyEqual(Real(0.0), +// Real(0.0000000001))); Test("nearlyEqual Reals 2", true, nearlyEqual(0.0, +// 0.0)); Test("nearlyEqual Reals 3", false, nearlyEqual(0.0, 1.0)); // Test("nearlyEqual Reals 4", false, nearlyEqual(0.0, 0.01)); // Test("nearlyEqual Reals 5", false, nearlyEqual(0.0, -0.01)); // Test("nearlyEqual Reals 6", false, nearlyEqual(0.0, -1.0)); -// Test("nearlyEqual Reals 7", true, nearlyEqual(Real(2.0), Real(2.000000001))); -// Test("nearlyEqual Reals 8", true, nearlyEqual(Real(2.0), Real(1.999999999))); -// Test("nearlyEqual Reals 9", true, nearlyEqual(Real(-2.0), Real(-2.000000001))); -// Test("nearlyEqual Reals 10", true, nearlyEqual(Real(-2.0), Real(-1.999999999))); +// Test("nearlyEqual Reals 7", true, nearlyEqual(Real(2.0), +// Real(2.000000001))); Test("nearlyEqual Reals 8", true, +// nearlyEqual(Real(2.0), Real(1.999999999))); Test("nearlyEqual Reals 9", +// true, nearlyEqual(Real(-2.0), Real(-2.000000001))); Test("nearlyEqual +// Reals 10", true, nearlyEqual(Real(-2.0), Real(-1.999999999))); // } // // //-------------------------------------------------------------------------------- // void MathsTest::unitTestNearlyEqualVector() -// { +// { // vector v1, v2; -// +// // { -// Test("nearlyEqualVector, empty vectors", true, nearlyEqualVector(v1, v2)); -// +// Test("nearlyEqualVector, empty vectors", true, nearlyEqualVector(v1, +// v2)); +// // v2.push_back(1); -// Test("nearlyEqualVector, different sizes", false, nearlyEqualVector(v1, v2)); -// +// Test("nearlyEqualVector, different sizes", false, nearlyEqualVector(v1, +// v2)); +// // v1.push_back(1); // Test("nearlyEqualVector, 1 element", true, nearlyEqualVector(v1, v2)); -// +// //#if 0 // for(UInt i=0; i<2048; ++i) { // Real v = Real(rng_->get() % 256)/256.0; // v1.push_back(v); // v2.push_back(v); // } -// Test("nearlyEqualVector, 2049 elements 1", true, nearlyEqualVector(v1, v2)); -// +// Test("nearlyEqualVector, 2049 elements 1", true, nearlyEqualVector(v1, +// v2)); +// // v2[512] += 1.0; -// Test("nearlyEqualVector, 2049 elements 2", false, nearlyEqualVector(v1, v2)); -// +// Test("nearlyEqualVector, 2049 elements 2", false, nearlyEqualVector(v1, +// v2)); +// // v1.clear(); v2.clear(); // Test("nearlyEqualVector, after clear", true, nearlyEqualVector(v1, v2)); //#endif @@ -105,61 +110,64 @@ // { // vector empty1; // normalize(empty1.begin(), empty1.end()); -// Test("Normalize vector, empty", true, nearlyEqualVector(empty1, empty1)); +// Test("Normalize vector, empty", true, nearlyEqualVector(empty1, +// empty1)); // } -// +// // { // v1[0] = Real(0.0); v1[1] = Real(0.0); v1[2] = Real(0.0); // answer[0] = Real(0.0); answer[1] = Real(0.0); answer[2] = Real(0.0); // normalize(v1.begin(), v1.end()); // Test("Normalize vector 11", true, nearlyZero(sum(v1))); -// Test("Normalize vector 12", true, nearlyEqualVector(v1, answer)); +// Test("Normalize vector 12", true, nearlyEqualVector(v1, +// answer)); // } // // { // v1[0] = Real(1.0); v1[1] = Real(1.0); v1[2] = Real(1.0); -// answer[0] = Real(1.0/3.0); answer[1] = Real(1.0/3.0); answer[2] = Real(1.0/3.0); -// normalize(v1.begin(), v1.end()); -// Real s = sum(v1); +// answer[0] = Real(1.0/3.0); answer[1] = Real(1.0/3.0); answer[2] = +// Real(1.0/3.0); normalize(v1.begin(), v1.end()); Real s = sum(v1); // Test("Normalize vector 21", true, nearlyEqual(s, Real(1.0))); -// Test("Normalize vector 22", true, nearlyEqualVector(v1, answer)); +// Test("Normalize vector 22", true, nearlyEqualVector(v1, +// answer)); // } // // { // v1[0] = Real(0.5); v1[1] = Real(0.5); v1[2] = Real(0.5); -// answer[0] = Real(0.5/1.5); answer[1] = Real(0.5/1.5); answer[2] = Real(0.5/1.5); -// normalize(v1.begin(), v1.end()); -// Real s = sum(v1); +// answer[0] = Real(0.5/1.5); answer[1] = Real(0.5/1.5); answer[2] = +// Real(0.5/1.5); normalize(v1.begin(), v1.end()); Real s = sum(v1); // Test("Normalize vector 31", true, nearlyEqual(s, Real(1.0))); -// Test("Normalize vector 32", true, nearlyEqualVector(v1, answer)); +// Test("Normalize vector 32", true, nearlyEqualVector(v1, +// answer)); // } // // { // v1[0] = Real(1.0); v1[1] = Real(0.5); v1[2] = Real(1.0); -// answer[0] = Real(1.0/2.5); answer[1] = Real(0.5/2.5); answer[2] = Real(1.0/2.5); -// normalize(v1.begin(), v1.end()); -// Real s = sum(v1); +// answer[0] = Real(1.0/2.5); answer[1] = Real(0.5/2.5); answer[2] = +// Real(1.0/2.5); normalize(v1.begin(), v1.end()); Real s = sum(v1); // Test("Normalize vector 41", true, nearlyEqual(s, Real(1.0))); -// Test("Normalize vector 42", true, nearlyEqualVector(v1, answer)); +// Test("Normalize vector 42", true, nearlyEqualVector(v1, +// answer)); // } // // { // Test normalizing to non-1.0 // v1[0] = Real(1.0); v1[1] = Real(0.5); v1[2] = Real(1.0); -// answer[0] = Real(3.0/2.5); answer[1] = Real(1.5/2.5); answer[2] = Real(3.0/2.5); -// normalize(v1.begin(), v1.end(), 1.0, 3.0); -// Real s = sum(v1); -// Test("Normalize vector 51", true, nearlyEqual(s, Real(3.0))); -// Test("Normalize vector 52", true, nearlyEqualVector(v1, answer)); +// answer[0] = Real(3.0/2.5); answer[1] = Real(1.5/2.5); answer[2] = +// Real(3.0/2.5); normalize(v1.begin(), v1.end(), 1.0, 3.0); Real s = +// sum(v1); Test("Normalize vector 51", true, nearlyEqual(s, +// Real(3.0))); Test("Normalize vector 52", true, +// nearlyEqualVector(v1, answer)); // } // } -// +// // { // normalize // std::vector v1(3), answer(3); // // { // std::vector empty1; // normalize(empty1.begin(), empty1.end()); -// Test("Normalize VectorType, empty", true, nearlyEqualVector(empty1, empty1)); +// Test("Normalize VectorType, empty", true, nearlyEqualVector(empty1, +// empty1)); // } // // { @@ -172,27 +180,24 @@ // // { // v1[0] = Real(1.0); v1[1] = Real(1.0); v1[2] = Real(1.0); -// answer[0] = Real(1.0/3.0); answer[1] = Real(1.0/3.0); answer[2] = Real(1.0/3.0); -// normalize(v1.begin(), v1.end()); -// Real s = sum(v1); +// answer[0] = Real(1.0/3.0); answer[1] = Real(1.0/3.0); answer[2] = +// Real(1.0/3.0); normalize(v1.begin(), v1.end()); Real s = sum(v1); // Test("Normalize VectorType 21", true, nearlyEqual(s, Real(1.0))); // Test("Normalize VectorType 22", true, nearlyEqualVector(v1, answer)); // } // // { // v1[0] = Real(0.5); v1[1] = Real(0.5); v1[2] = Real(0.5); -// answer[0] = Real(0.5/1.5); answer[1] = Real(0.5/1.5); answer[2] = Real(0.5/1.5); -// normalize(v1.begin(), v1.end()); -// Real s = sum(v1); +// answer[0] = Real(0.5/1.5); answer[1] = Real(0.5/1.5); answer[2] = +// Real(0.5/1.5); normalize(v1.begin(), v1.end()); Real s = sum(v1); // Test("Normalize VectorType 31", true, nearlyEqual(s, Real(1.0))); // Test("Normalize VectorType 32", true, nearlyEqualVector(v1, answer)); // } // // { // v1[0] = Real(1.0); v1[1] = Real(0.5); v1[2] = Real(1.0); -// answer[0] = Real(1.0/2.5); answer[1] = Real(0.5/2.5); answer[2] = Real(1.0/2.5); -// normalize(v1.begin(), v1.end()); -// Real s = sum(v1); +// answer[0] = Real(1.0/2.5); answer[1] = Real(0.5/2.5); answer[2] = +// Real(1.0/2.5); normalize(v1.begin(), v1.end()); Real s = sum(v1); // Test("Normalize VectorType 41", true, nearlyEqual(s, Real(1.0))); // Test("Normalize VectorType 42", true, nearlyEqualVector(v1, answer)); // } @@ -246,7 +251,7 @@ // bool comp = s.str() == answer.str(); // Test("vector quiet_NaN to stream", true, comp); // } -// +// // { // vector v1(3); // v1[0] = 1.0; v1[1] = numeric_limits::signaling_NaN(); v1[2] = 3.0; @@ -376,11 +381,11 @@ // nchildren = 21; // w = 125; // nreps = 1000; -// +// // vector boundaries(nchildren, 0); -// +// // boundaries[0] = rng_->getUInt32(w) + 1; -// for (UInt i = 1; i < (nchildren-1); ++i) +// for (UInt i = 1; i < (nchildren-1); ++i) // boundaries[i] = boundaries[i-1] + (rng_->getUInt32(w) + 1); // ncols = nchildren * w; // boundaries[nchildren-1] = ncols; @@ -392,17 +397,17 @@ // x[j] = rng_->getReal64(); // // winnerTakesAll2(boundaries, x.begin(), v.begin()); -// +// // UInt k2 = 0; // for (UInt k1 = 0; k1 < nchildren; ++k1) { // -// vector::iterator it = +// vector::iterator it = // max_element(x.begin() + k2, x.begin() + boundaries[k1]); // Test("Maths winnerTakesAll2 1", v[it - x.begin()], 1); // -// Real s = accumulate(v.begin() + k2, v.begin() + boundaries[k1], (Real)0); -// Test("Maths winnerTakesAll2 2", s, (Real) 1); -// +// Real s = accumulate(v.begin() + k2, v.begin() + boundaries[k1], +// (Real)0); Test("Maths winnerTakesAll2 2", s, (Real) 1); +// // k2 = boundaries[k1]; // } // @@ -425,11 +430,11 @@ // x[0] = 1; // normalize_max(x.begin(), x.end(), 1); // Test("scale 2", nearlyEqual(x[0], (Real)1), true); -// +// // x[0] = 2; // normalize_max(x.begin(), x.end(), 1); // Test("scale 3", nearlyEqual(x[0], (Real)1), true); -// +// // normalize_max(x.begin(), x.end(), .5); // Test("scale 4", nearlyEqual(x[0], (Real).5), true); // @@ -456,7 +461,7 @@ // normalize_max(x.begin(), x.end(), 10); // Test("scale 8a", nearlyEqual(x[0], (Real)7), true); // Test("scale 8b", nearlyEqual(x[1], (Real)10), true); -// } +// } // // { // const UInt N = 256; @@ -473,7 +478,7 @@ // ans[j] /= max; // // normalize_max(x.begin(), x.end(), 1); -// +// // bool identical = true; // for (UInt j = 0; j < N && identical; ++j) // if (!nearlyEqual(x[j], ans[j])) @@ -491,46 +496,55 @@ // UInt n = 1000; // Changing n will change the error! // // { // Constructor with F only -// n = 1000; -// double lb = -1, ub = 1, t_lb = -2, t_ub = 2, t_step = (t_ub - t_lb)/(10*n); -// +// n = 1000; +// double lb = -1, ub = 1, t_lb = -2, t_ub = 2, t_step = (t_ub - +// t_lb)/(10*n); +// // QSI > q_e1(lb, ub, n, Exp()); -// Test("qsi f 1", q_e1.max_error(t_lb, t_ub, t_step).second <= 1e-8, true); +// Test("qsi f 1", q_e1.max_error(t_lb, t_ub, t_step).second <= 1e-8, +// true); // // QSI > q_e2(lb, ub, n, Exp2()); -// Test("qsi f 2", q_e2.max_error(t_lb, t_ub, t_step).second <= 1e-7, true); +// Test("qsi f 2", q_e2.max_error(t_lb, t_ub, t_step).second <= 1e-7, +// true); // } // -// { -// n = 2000; -// Real lb = -1, ub = 1, t_lb = -2, t_ub = 2, t_step = (t_ub - t_lb)/(10*n); +// { +// n = 2000; +// Real lb = -1, ub = 1, t_lb = -2, t_ub = 2, t_step = (t_ub - +// t_lb)/(10*n); // // QSI > q_e1(lb, ub, n, Exp()); -// Test("qsi f 3", q_e1.max_error(t_lb, t_ub, t_step).second <= 1e-3, true); +// Test("qsi f 3", q_e1.max_error(t_lb, t_ub, t_step).second <= 1e-3, +// true); // // QSI > q_e2(lb, ub, n, Exp2()); -// Test("qsi f 4", q_e2.max_error(t_lb, t_ub, t_step).second <= 1e-3, true); +// Test("qsi f 4", q_e2.max_error(t_lb, t_ub, t_step).second <= 1e-3, +// true); // } -// +// // { // Constructor with F and derivative // n = 1000; -// double lb = -1, ub = 1, t_lb = -2, t_ub = 2, t_step = (t_ub - t_lb)/(10*n); +// double lb = -1, ub = 1, t_lb = -2, t_ub = 2, t_step = (t_ub - +// t_lb)/(10*n); // -// QSI > q_e1(lb, ub, n, Exp(1, 2), Exp(2, 2)); -// Test("qsi f ff 1", q_e1.max_error(t_lb, t_ub, t_step).second <= 1e-7, true); +// QSI > q_e1(lb, ub, n, Exp(1, 2), +// Exp(2, 2)); Test("qsi f ff 1", q_e1.max_error(t_lb, t_ub, +// t_step).second <= 1e-7, true); // } -// -// { +// +// { // n = 2000; -// Real lb = -1, ub = 1, t_lb = -2, t_ub = 2, t_step = (t_ub - t_lb)/(10*n); -// QSI > q_e2(lb, ub, n, Exp(1, 2), Exp(2, 2)); -// Test("qsi f ff 2", q_e2.max_error(t_lb, t_ub, t_step).second <= 1e-5, true); +// Real lb = -1, ub = 1, t_lb = -2, t_ub = 2, t_step = (t_ub - +// t_lb)/(10*n); QSI > q_e2(lb, ub, n, Exp(1, 2), +// Exp(2, 2)); Test("qsi f ff 2", q_e2.max_error(t_lb, t_ub, +// t_step).second <= 1e-5, true); // } // // { // Vector to vector // n = 1000; -// double v, lb = -1, ub = 1, t_lb = -2, t_ub = 2, t_step = (t_ub - t_lb)/(n); -// vector x(n), y(n), yref(n); +// double v, lb = -1, ub = 1, t_lb = -2, t_ub = 2, t_step = (t_ub - +// t_lb)/(n); vector x(n), y(n), yref(n); // // v = lb; // @@ -539,46 +553,44 @@ // yref[i] = exp(v); // } // -// QSI > q_e1(lb, ub, n, Exp(1,1), Exp(1,1)); -// q_e1(x.begin(), x.end(), y.begin()); -// Test("qsi vector 1", nearlyEqualRange(y.begin(), y.end(), yref.begin()), true); -// } -// +// QSI > q_e1(lb, ub, n, Exp(1,1), +// Exp(1,1)); q_e1(x.begin(), x.end(), y.begin()); Test("qsi vector +// 1", nearlyEqualRange(y.begin(), y.end(), yref.begin()), true); +// } +// // { // Vector to itself // n = 1000; -// double v, lb = -1, ub = 1, t_lb = -2, t_ub = 2, t_step = (t_ub - t_lb)/(n); -// vector vv(n), yref(n); +// double v, lb = -1, ub = 1, t_lb = -2, t_ub = 2, t_step = (t_ub - +// t_lb)/(n); vector vv(n), yref(n); // // v = lb; // // for (UInt i = 0; i < n; ++i, v += t_step) { // vv[i] = v; -// yref[i] = exp(v); +// yref[i] = exp(v); // } // -// QSI > q_e1(lb, ub, n, Exp(1,1), Exp(1,1)); -// q_e1(vv.begin(), vv.end()); -// Test("qsi vector 1", nearlyEqualRange(vv.begin(), vv.end(), yref.begin()), true); +// QSI > q_e1(lb, ub, n, Exp(1,1), +// Exp(1,1)); q_e1(vv.begin(), vv.end()); Test("qsi vector 1", +// nearlyEqualRange(vv.begin(), vv.end(), yref.begin()), true); // } // */ // } // - //-------------------------------------------------------------------------------- - // void MathsTest::RunTests() - // { - // - //unitTestNearlyZero(); - //unitTestNearlyEqual(); - //unitTestNearlyEqualVector(); - //unitTestNormalize(); - ////unitTestVectorToStream(); - //unitTestElemOps(); - //unitTestWinnerTakesAll(); - //unitTestScale(); - ////unitTestQSI(); // HEAP CORRUPTION on Windows - // } - - //---------------------------------------------------------------------- -// } // end namespace nupic - +//-------------------------------------------------------------------------------- +// void MathsTest::RunTests() +// { +// +// unitTestNearlyZero(); +// unitTestNearlyEqual(); +// unitTestNearlyEqualVector(); +// unitTestNormalize(); +////unitTestVectorToStream(); +// unitTestElemOps(); +// unitTestWinnerTakesAll(); +// unitTestScale(); +////unitTestQSI(); // HEAP CORRUPTION on Windows +// } +//---------------------------------------------------------------------- +// } // end namespace nupic diff --git a/src/test/unit/math/SegmentMatrixAdapterTest.cpp b/src/test/unit/math/SegmentMatrixAdapterTest.cpp index 776aeba754..c4bc040da0 100644 --- a/src/test/unit/math/SegmentMatrixAdapterTest.cpp +++ b/src/test/unit/math/SegmentMatrixAdapterTest.cpp @@ -20,212 +20,184 @@ * ---------------------------------------------------------------------- */ -#include -#include #include "gtest/gtest.h" +#include +#include -using std::vector; using nupic::SegmentMatrixAdapter; using nupic::SparseMatrix; using nupic::UInt32; +using std::vector; namespace { - /** - * The SparseMatrix should contain one row for each added segment. - */ - TEST(SegmentMatrixAdapterTest, addRows) - { - SegmentMatrixAdapter> ssm(2048, 1000); - EXPECT_EQ(0, ssm.matrix.nRows()); - - ssm.createSegment(42); - EXPECT_EQ(1, ssm.matrix.nRows()); - - UInt32 cells[] = {42, 43, 44}; - UInt32 segmentsOut[3]; - ssm.createSegments(cells, cells + 3, segmentsOut); - EXPECT_EQ(4, ssm.matrix.nRows()); - } - - /** - * When you destroy a segment and then create a segment, the number of rows in - * the SparseMatrix should stay constant. - * - * This test doesn't prescribe whether the SegmentMatrixAdapter should - * accomplish this by keeping a list of "destroyed segments" or by simply - * removing rows from the SparseMatrix. - */ - TEST(SegmentMatrixAdapterTest, noRowLeaks) - { - SegmentMatrixAdapter> ssm(2048, 1000); - - // Create 5 segments - UInt32 cells1[] = {42, 43, 44, 45, 46}; - vector created(5); - ssm.createSegments(cells1, cells1 + 5, created.data()); - ASSERT_EQ(5, ssm.matrix.nRows()); - - // Destroy 3 segments, covering both destroy APIs - ssm.destroySegment(created[1]); - - vector toDestroy = {created[2], created[3]}; - ssm.destroySegments(toDestroy.begin(), toDestroy.end()); - - // Create 4 segments, covering both create APIs, and making sure - // createSegments has to reuse destroyed segments *and* add rows in one - // call. - ssm.createSegment(50); - - UInt32 cells2[] = {51, 52, 53}; - UInt32 segmentsOut[3]; - ssm.createSegments(cells2, cells2 + 3, segmentsOut); - - EXPECT_EQ(6, ssm.matrix.nRows()); - } - - /** - * Prepare: - * - Cell that has multiple segments - * - Cell that had multiple segments, then lost some of them - * - Cell that has a single segment - * - Cell that has had segments, then lost them - * - Cell that has never had a segment - * - * Use both create APIs and both destroy APIs. - * - * Verify that getSegmentCounts gets the up-to-date count for each. - */ - TEST(SegmentMatrixAdapterTest, getSegmentCounts) - { - SegmentMatrixAdapter> ssm(2048, 1000); - - ssm.createSegment(42); - - vector cells = {42, 43, 44, 45}; - vector created(cells.size()); - ssm.createSegments(cells.begin(), cells.end(), created.data()); - - ssm.createSegment(43); - - vector destroy = {created[1]}; - ssm.destroySegments(destroy.begin(), destroy.end()); - - ssm.destroySegment(created[3]); - - vector queriedCells = {42, 43, 44, 45, 46}; - vector counts(queriedCells.size()); - ssm.getSegmentCounts(queriedCells.begin(), queriedCells.end(), - counts.begin()); - - vector expected = {2, 1, 1, 0, 0}; - EXPECT_EQ(expected, counts); - } - - TEST(SegmentMatrixAdapterTest, sortSegmentsByCell) - { - SegmentMatrixAdapter> ssm(2048, 1000); - - UInt32 segment1 = ssm.createSegment(42); - UInt32 segment2 = ssm.createSegment(41); - UInt32 segment3 = ssm.createSegment(49); - UInt32 segment4 = ssm.createSegment(45); - UInt32 segment5 = ssm.createSegment(0); - UInt32 segment6 = ssm.createSegment(2047); - const vector sorted = {segment5, - segment2, - segment1, - segment4, - segment3, - segment6}; - - vector mySegments = {segment1, - segment2, - segment3, - segment4, - segment5, - segment6}; - ssm.sortSegmentsByCell(mySegments.begin(), mySegments.end()); - - EXPECT_EQ(sorted, mySegments); - } - - TEST(SegmentMatrixAdapterTest, filterSegmentsByCell) - { - SegmentMatrixAdapter> ssm(2048, 1000); - - // Don't create them in order -- we don't want the segment numbers - // to be ordered in a meaningful way. - - // Shuffled - // {42, 42, 42, 43, 46, 47, 48} - const vector cellsWithSegments = - {47, 42, 46, 43, 42, 48, 42}; - - - vector createdSegments(cellsWithSegments.size()); - ssm.createSegments(cellsWithSegments.begin(), - cellsWithSegments.end(), - createdSegments.begin()); - ssm.sortSegmentsByCell(createdSegments.begin(), createdSegments.end()); - - // Include everything - const vector everything = {42, 42, 42, 43, 46, 47, 48}; - EXPECT_EQ(createdSegments, - ssm.filterSegmentsByCell( - createdSegments.begin(), createdSegments.end(), - everything.begin(), everything.end())); - - // Subset, one cell with multiple segments - const vector subset1 = {42, 43, 48}; - const vector expected1 = {createdSegments[0], - createdSegments[1], - createdSegments[2], - createdSegments[3], - createdSegments[6]}; - EXPECT_EQ(expected1, - ssm.filterSegmentsByCell( - createdSegments.begin(), createdSegments.end(), - subset1.begin(), subset1.end())); - - // Subset, some cells without segments - const vector subset2 = {43, 44, 45, 48}; - const vector expected2 = {createdSegments[3], - createdSegments[6]}; - EXPECT_EQ(expected2, - ssm.filterSegmentsByCell( - createdSegments.begin(), createdSegments.end(), - subset2.begin(), subset2.end())); - } - - TEST(SegmentMatrixAdapterTest, mapSegmentsToCells) - { - SegmentMatrixAdapter> ssm(2048, 1000); - - const vector cellsWithSegments = - {42, 42, 42, 43, 44, 45}; - - vector createdSegments(cellsWithSegments.size()); - ssm.createSegments(cellsWithSegments.begin(), - cellsWithSegments.end(), - createdSegments.begin()); - - // Map everything - vector cells1(createdSegments.size()); - ssm.mapSegmentsToCells(createdSegments.begin(), - createdSegments.end(), - cells1.begin()); - EXPECT_EQ(cellsWithSegments, cells1); - - // Map subset, including duplicates - vector segmentSubset = {createdSegments[3], - createdSegments[3], - createdSegments[0]}; - vector expectedCells2 = {43, 43, 42}; - vector cells2(segmentSubset.size()); - ssm.mapSegmentsToCells(segmentSubset.begin(), - segmentSubset.end(), - cells2.begin()); - EXPECT_EQ(expectedCells2, cells2); - } +/** + * The SparseMatrix should contain one row for each added segment. + */ +TEST(SegmentMatrixAdapterTest, addRows) { + SegmentMatrixAdapter> ssm(2048, 1000); + EXPECT_EQ(0, ssm.matrix.nRows()); + + ssm.createSegment(42); + EXPECT_EQ(1, ssm.matrix.nRows()); + + UInt32 cells[] = {42, 43, 44}; + UInt32 segmentsOut[3]; + ssm.createSegments(cells, cells + 3, segmentsOut); + EXPECT_EQ(4, ssm.matrix.nRows()); +} + +/** + * When you destroy a segment and then create a segment, the number of rows in + * the SparseMatrix should stay constant. + * + * This test doesn't prescribe whether the SegmentMatrixAdapter should + * accomplish this by keeping a list of "destroyed segments" or by simply + * removing rows from the SparseMatrix. + */ +TEST(SegmentMatrixAdapterTest, noRowLeaks) { + SegmentMatrixAdapter> ssm(2048, 1000); + + // Create 5 segments + UInt32 cells1[] = {42, 43, 44, 45, 46}; + vector created(5); + ssm.createSegments(cells1, cells1 + 5, created.data()); + ASSERT_EQ(5, ssm.matrix.nRows()); + + // Destroy 3 segments, covering both destroy APIs + ssm.destroySegment(created[1]); + + vector toDestroy = {created[2], created[3]}; + ssm.destroySegments(toDestroy.begin(), toDestroy.end()); + + // Create 4 segments, covering both create APIs, and making sure + // createSegments has to reuse destroyed segments *and* add rows in one + // call. + ssm.createSegment(50); + + UInt32 cells2[] = {51, 52, 53}; + UInt32 segmentsOut[3]; + ssm.createSegments(cells2, cells2 + 3, segmentsOut); + + EXPECT_EQ(6, ssm.matrix.nRows()); +} + +/** + * Prepare: + * - Cell that has multiple segments + * - Cell that had multiple segments, then lost some of them + * - Cell that has a single segment + * - Cell that has had segments, then lost them + * - Cell that has never had a segment + * + * Use both create APIs and both destroy APIs. + * + * Verify that getSegmentCounts gets the up-to-date count for each. + */ +TEST(SegmentMatrixAdapterTest, getSegmentCounts) { + SegmentMatrixAdapter> ssm(2048, 1000); + + ssm.createSegment(42); + + vector cells = {42, 43, 44, 45}; + vector created(cells.size()); + ssm.createSegments(cells.begin(), cells.end(), created.data()); + + ssm.createSegment(43); + + vector destroy = {created[1]}; + ssm.destroySegments(destroy.begin(), destroy.end()); + + ssm.destroySegment(created[3]); + + vector queriedCells = {42, 43, 44, 45, 46}; + vector counts(queriedCells.size()); + ssm.getSegmentCounts(queriedCells.begin(), queriedCells.end(), + counts.begin()); + + vector expected = {2, 1, 1, 0, 0}; + EXPECT_EQ(expected, counts); +} + +TEST(SegmentMatrixAdapterTest, sortSegmentsByCell) { + SegmentMatrixAdapter> ssm(2048, 1000); + + UInt32 segment1 = ssm.createSegment(42); + UInt32 segment2 = ssm.createSegment(41); + UInt32 segment3 = ssm.createSegment(49); + UInt32 segment4 = ssm.createSegment(45); + UInt32 segment5 = ssm.createSegment(0); + UInt32 segment6 = ssm.createSegment(2047); + const vector sorted = {segment5, segment2, segment1, + segment4, segment3, segment6}; + + vector mySegments = {segment1, segment2, segment3, + segment4, segment5, segment6}; + ssm.sortSegmentsByCell(mySegments.begin(), mySegments.end()); + + EXPECT_EQ(sorted, mySegments); +} + +TEST(SegmentMatrixAdapterTest, filterSegmentsByCell) { + SegmentMatrixAdapter> ssm(2048, 1000); + + // Don't create them in order -- we don't want the segment numbers + // to be ordered in a meaningful way. + + // Shuffled + // {42, 42, 42, 43, 46, 47, 48} + const vector cellsWithSegments = {47, 42, 46, 43, 42, 48, 42}; + + vector createdSegments(cellsWithSegments.size()); + ssm.createSegments(cellsWithSegments.begin(), cellsWithSegments.end(), + createdSegments.begin()); + ssm.sortSegmentsByCell(createdSegments.begin(), createdSegments.end()); + + // Include everything + const vector everything = {42, 42, 42, 43, 46, 47, 48}; + EXPECT_EQ(createdSegments, ssm.filterSegmentsByCell( + createdSegments.begin(), createdSegments.end(), + everything.begin(), everything.end())); + + // Subset, one cell with multiple segments + const vector subset1 = {42, 43, 48}; + const vector expected1 = {createdSegments[0], createdSegments[1], + createdSegments[2], createdSegments[3], + createdSegments[6]}; + EXPECT_EQ(expected1, ssm.filterSegmentsByCell( + createdSegments.begin(), createdSegments.end(), + subset1.begin(), subset1.end())); + + // Subset, some cells without segments + const vector subset2 = {43, 44, 45, 48}; + const vector expected2 = {createdSegments[3], createdSegments[6]}; + EXPECT_EQ(expected2, ssm.filterSegmentsByCell( + createdSegments.begin(), createdSegments.end(), + subset2.begin(), subset2.end())); +} + +TEST(SegmentMatrixAdapterTest, mapSegmentsToCells) { + SegmentMatrixAdapter> ssm(2048, 1000); + + const vector cellsWithSegments = {42, 42, 42, 43, 44, 45}; + + vector createdSegments(cellsWithSegments.size()); + ssm.createSegments(cellsWithSegments.begin(), cellsWithSegments.end(), + createdSegments.begin()); + + // Map everything + vector cells1(createdSegments.size()); + ssm.mapSegmentsToCells(createdSegments.begin(), createdSegments.end(), + cells1.begin()); + EXPECT_EQ(cellsWithSegments, cells1); + + // Map subset, including duplicates + vector segmentSubset = {createdSegments[3], createdSegments[3], + createdSegments[0]}; + vector expectedCells2 = {43, 43, 42}; + vector cells2(segmentSubset.size()); + ssm.mapSegmentsToCells(segmentSubset.begin(), segmentSubset.end(), + cells2.begin()); + EXPECT_EQ(expectedCells2, cells2); } +} // namespace diff --git a/src/test/unit/math/SparseBinaryMatrixTest.cpp b/src/test/unit/math/SparseBinaryMatrixTest.cpp index 0c35d5acfa..53268e6e66 100644 --- a/src/test/unit/math/SparseBinaryMatrixTest.cpp +++ b/src/test/unit/math/SparseBinaryMatrixTest.cpp @@ -26,18 +26,16 @@ #include #include -#include #include +#include #include #include #include - using namespace nupic; -TEST(SparseBinaryMatrixReadWrite, EmptyMatrix) -{ +TEST(SparseBinaryMatrixReadWrite, EmptyMatrix) { SparseBinaryMatrix m1, m2; m1.resize(3, 4); @@ -46,7 +44,8 @@ TEST(SparseBinaryMatrixReadWrite, EmptyMatrix) // write capnp::MallocMessageBuilder message1; - SparseBinaryMatrixProto::Builder protoBuilder = message1.initRoot(); + SparseBinaryMatrixProto::Builder protoBuilder = + message1.initRoot(); m1.write(protoBuilder); kj::std::StdOutputStream out(ss); capnp::writeMessage(out, message1); @@ -54,7 +53,8 @@ TEST(SparseBinaryMatrixReadWrite, EmptyMatrix) // read kj::std::StdInputStream in(ss); capnp::InputStreamMessageReader message2(in); - SparseBinaryMatrixProto::Reader protoReader = message2.getRoot(); + SparseBinaryMatrixProto::Reader protoReader = + message2.getRoot(); m2.read(protoReader); // compare @@ -62,8 +62,7 @@ TEST(SparseBinaryMatrixReadWrite, EmptyMatrix) ASSERT_EQ(m1.nCols(), m2.nCols()) << "Number of columns don't match"; } -TEST(SparseBinaryMatrixReadWrite, Basic) -{ +TEST(SparseBinaryMatrixReadWrite, Basic) { SparseBinaryMatrix m1, m2; m1.resize(3, 4); @@ -73,7 +72,8 @@ TEST(SparseBinaryMatrixReadWrite, Basic) // write capnp::MallocMessageBuilder message1; - SparseBinaryMatrixProto::Builder protoBuilder = message1.initRoot(); + SparseBinaryMatrixProto::Builder protoBuilder = + message1.initRoot(); m1.write(protoBuilder); kj::std::StdOutputStream out(ss); capnp::writeMessage(out, message1); @@ -81,7 +81,8 @@ TEST(SparseBinaryMatrixReadWrite, Basic) // read kj::std::StdInputStream in(ss); capnp::InputStreamMessageReader message2(in); - SparseBinaryMatrixProto::Reader protoReader = message2.getRoot(); + SparseBinaryMatrixProto::Reader protoReader = + message2.getRoot(); m2.read(protoReader); // compare @@ -96,4 +97,3 @@ TEST(SparseBinaryMatrixReadWrite, Basic) ASSERT_EQ(m1r1[0], 1) << "Invalid col index in original matrix"; ASSERT_EQ(m1r1[0], m2r1[0]) << "Invalid col index in copied matrix"; } - diff --git a/src/test/unit/math/SparseMatrix01UnitTest.cpp b/src/test/unit/math/SparseMatrix01UnitTest.cpp index 1440d6477c..55b0e65a8a 100644 --- a/src/test/unit/math/SparseMatrix01UnitTest.cpp +++ b/src/test/unit/math/SparseMatrix01UnitTest.cpp @@ -19,11 +19,11 @@ * http://numenta.org/licenses/ * --------------------------------------------------------------------- */ - + /** @file * Implementation of unit testing for class SparseMatrix01 - */ - + */ + //#include // #include @@ -35,10 +35,9 @@ // // namespace nupic { // -#define TEST_LOOP(M) \ - for (nrows = 0, ncols = M, zr = 15; \ - nrows < M; \ - nrows += M/10, ncols -= M/10, zr = ncols/10) \ +#define TEST_LOOP(M) \ + for (nrows = 0, ncols = M, zr = 15; nrows < M; \ + nrows += M / 10, ncols -= M / 10, zr = ncols / 10) #define M 256 // @@ -61,7 +60,7 @@ // // Tests: // // all constructors, destructors, nNonZeros, nCols, nRows // // toDense, compact, isZero, nNonZerosRow -// UInt ncols, nrows, zr; +// UInt ncols, nrows, zr; // // { // Rectangular shape, no zeros // nrows = 3; ncols = 4; @@ -72,7 +71,7 @@ // sm.compact(); // Compare(dense, sm, "ctor 1 - compact"); // Test("isZero 1 - compact", sm.isZero(), false); -// } +// } // // { // Rectangular shape, zeros // nrows = 3; ncols = 4; @@ -83,9 +82,9 @@ // sm.compact(); // Compare(dense, sm, "ctor 2 - compact"); // Test("isZero 2 - compact", sm.isZero(), false); -// } -// -// { // Rectangular the other way, no zeros +// } +// +// { // Rectangular the other way, no zeros // nrows = 4; ncols = 3; // Dense01 dense(nrows, ncols, 0); // SparseMatrix01 sm(nrows, ncols, dense.begin(), 0); @@ -127,7 +126,7 @@ // sm.compact(); // Compare(dense, sm, "ctor 6 - compact"); // Test("isZero 6 - compact", sm.isZero(), false); -// } +// } // // { // Small values, zeros and empty rows // nrows = 7; ncols = 5; @@ -138,7 +137,7 @@ // sm.compact(); // Compare(dense, sm, "ctor 7 - compact"); // Test("isZero 7 - compact", sm.isZero(), false); -// } +// } // // { // Small values, zeros and empty rows, other constructor // nrows = 10; ncols = 10; @@ -167,26 +166,26 @@ // } // // { // Small values, zeros and empty rows, other constructor -// nrows = 10; ncols = 10; +// nrows = 10; ncols = 10; // Dense01 dense(nrows, ncols, 2, true, true); // SparseMatrix01 sm(ncols, 2); // for (UInt i = 0; i < nrows; ++i) // sm.addRow(dense.begin(i)); // Compare(dense, sm, "ctor 10"); -// Test("isZero 10", sm.isZero(), false); +// Test("isZero 10", sm.isZero(), false); // sm.compact(); // Compare(dense, sm, "ctor 10 - compact"); // Test("isZero 10 - compact", sm.isZero(), false); // } -// +// // { // Empty // Dense01 dense(10, 10, 10); // SparseMatrix01 sm(10, 10, dense.begin(), 0); // Compare(dense, sm, "ctor from empty dense - non compact"); -// Test("isZero 11", sm.isZero(), true); +// Test("isZero 11", sm.isZero(), true); // sm.compact(); // Compare(dense, sm, "ctor from empty dense - compact"); -// Test("isZero 11 - compact", sm.isZero(), true); +// Test("isZero 11 - compact", sm.isZero(), true); // } // // { // Empty, other constructor @@ -195,10 +194,10 @@ // for (UInt i = 0; i < nrows; ++i) // sm.addRow(dense.begin(i)); // Compare(dense, sm, "ctor from empty dense - non compact"); -// Test("isZero 12", sm.isZero(), true); +// Test("isZero 12", sm.isZero(), true); // sm.compact(); // Compare(dense, sm, "ctor from empty dense - compact"); -// Test("isZero 12 - compact", sm.isZero(), true); +// Test("isZero 12 - compact", sm.isZero(), true); // } // // { // Full @@ -213,13 +212,13 @@ // // { // Various rectangular sizes // TEST_LOOP(M) { -// +// // Dense01 dense(nrows, ncols, zr); // SparseMatrix01 sm(ncols, nrows); -// +// // for (UInt i = 0; i < nrows; ++i) // sm.addRow(dense.begin(i)); -// +// // { // stringstream str; // str << "ctor " << nrows << "X" << ncols << "/" << zr @@ -240,46 +239,58 @@ // // try { // SparseMatrix01 sme1(0, 0); -// Test("SparseMatrix01::SparseMatrix01(Int, Int) exception 1", true, false); +// Test("SparseMatrix01::SparseMatrix01(Int, Int) exception 1", true, +// false); // } catch (std::exception&) { -// Test("SparseMatrix01::SparseMatrix01(Int, Int) exception 1", true, true); +// Test("SparseMatrix01::SparseMatrix01(Int, Int) exception 1", true, +// true); // } // // try { // SparseMatrix01 sme1(-1, 0); -// Test("SparseMatrix01::SparseMatrix01(Int, Int) exception 2", true, false); +// Test("SparseMatrix01::SparseMatrix01(Int, Int) exception 2", true, +// false); // } catch (std::exception&) { -// Test("SparseMatrix01::SparseMatrix01(Int, Int) exception 2", true, true); +// Test("SparseMatrix01::SparseMatrix01(Int, Int) exception 2", true, +// true); // } // // try { // SparseMatrix01 sme1(1, -1); -// Test("SparseMatrix01::SparseMatrix01(Int, Int) exception 3", true, false); +// Test("SparseMatrix01::SparseMatrix01(Int, Int) exception 3", true, +// false); // } catch (std::exception&) { -// Test("SparseMatrix01::SparseMatrix01(Int, Int) exception 3", true, true); +// Test("SparseMatrix01::SparseMatrix01(Int, Int) exception 3", true, +// true); // } // // std::vector mat(16, 0); -// +// // try { // SparseMatrix01 sme1(-1, 1, mat.begin(), 0); -// Test("SparseMatrix01::SparseMatrix01(Int, Int, Iter) exception 1", true, false); +// Test("SparseMatrix01::SparseMatrix01(Int, Int, Iter) exception 1", true, +// false); // } catch (std::exception&) { -// Test("SparseMatrix01::SparseMatrix01(Int, Iter) exception 1", true, true); -// } +// Test("SparseMatrix01::SparseMatrix01(Int, Iter) exception 1", true, +// true); +// } // // try { // SparseMatrix01 sme1(1, -1, mat.begin(), 0); -// Test("SparseMatrix01::SparseMatrix01(Int, Int, Iter) exception 2", true, false); +// Test("SparseMatrix01::SparseMatrix01(Int, Int, Iter) exception 2", true, +// false); // } catch (std::exception&) { -// Test("SparseMatrix01::SparseMatrix01(Int, Iter) exception 2", true, true); +// Test("SparseMatrix01::SparseMatrix01(Int, Iter) exception 2", true, +// true); // } -// +// // try { // SparseMatrix01 sme1(1, 0, mat.begin(), 0); -// Test("SparseMatrix01::SparseMatrix01(Int, Int, Iter) exception 3", true, false); +// Test("SparseMatrix01::SparseMatrix01(Int, Int, Iter) exception 3", true, +// false); // } catch (std::exception&) { -// Test("SparseMatrix01::SparseMatrix01(Int, Iter) exception 3", true, true); +// Test("SparseMatrix01::SparseMatrix01(Int, Iter) exception 3", true, +// true); // } // } // @@ -287,28 +298,28 @@ // void SparseMatrix01UnitTest::unit_test_csr() // { // UInt nrows, ncols, zr; -// +// // { // TEST_LOOP(M) { -// +// // Dense01 dense3(nrows, ncols, zr); // SparseMatrix01 sm3(nrows, ncols, dense3.begin(), 0); -// +// // stringstream buf; // sm3.toCSR(buf); // sm3.fromCSR(buf); -// -// { +// +// { // stringstream str; // str << "toCSR/fromCSR A " << nrows << "X" << ncols << "/" << zr; // Compare(dense3, sm3, str.str().c_str()); -// } +// } // // SparseMatrix01 sm4(ncols, nrows); // stringstream buf1; // sm3.toCSR(buf1); // sm4.fromCSR(buf1); -// +// // { // stringstream str; // str << "toCSR/fromCSR B " << nrows << "X" << ncols << "/" << zr; @@ -320,7 +331,7 @@ // sm3.toCSR(buf2); // sm4.fromCSR(buf2); // -// { +// { // stringstream str; // str << "toCSR/fromCSR C " << nrows << "X" << ncols << "/" << zr; // Compare(dense3, sm4, str.str().c_str()); @@ -341,7 +352,7 @@ // { // Is resizing happening correctly? // Dense01 dense(3, 4, 2); // SparseMatrix01 sm(3, 4, dense.begin(), 0); -// +// // { // Smaller size // stringstream buf1, buf2; // buf1 << "csr01 0 3 3 9 0 3 0 1 2 3 0 1 2 3 0 1 2"; @@ -390,7 +401,7 @@ // // Exceptions // SparseMatrix01 sme1(1, 1); // -// { +// { // stringstream s1; // s1 << "ijv"; // try { @@ -422,7 +433,7 @@ // Test("SparseMatrix01::fromCSR() exception 3", true, true); // } // } -// +// // { // stringstream s1; // s1 << "csr01 0 1 0"; @@ -472,7 +483,7 @@ // s1 << "csr01 0 2 3 1 0 1 -1"; // try { // sme1.fromCSR(s1); -// Test("SparseMatrix01::fromCSR() exception 8", true, false); +// Test("SparseMatrix01::fromCSR() exception 8", true, false); // } catch (runtime_error&) { // Test("SparseMatrix01::fromCSR() exception 8", true, true); // } @@ -497,7 +508,7 @@ // ncols = 5; // nrows = 7; // zr = 2; -// +// // Dense01 dense(nrows, ncols, zr); // SparseMatrix01 sm4(ncols, nrows); // sm4.fromDense(nrows, ncols, dense.begin()); @@ -517,14 +528,14 @@ // Compare(dense2, sm5, "fromDense 4"); // // std::vector mat((nrows+1)*(ncols+1), 0); -// +// // sm5.toDense(mat.begin()); // sm5.fromDense(nrows+1, ncols+1, mat.begin()); // Compare(dense2, sm5, "toDense 1"); -// +// // { // TEST_LOOP(M) { -// +// // Dense01 dense3(nrows, ncols, zr); // SparseMatrix01 sm3(nrows, ncols, dense3.begin(), 0); // std::vector mat3(nrows*ncols, 0); @@ -555,37 +566,37 @@ // Dense01 dense(nrows, ncols, zr); // SparseMatrix01 sm(nrows, ncols, dense.begin(), 0); // std::vector mat3(nrows*ncols, 0); -// +// // sm.toDense(mat3.begin()); // sm.fromDense(nrows, ncols, mat3.begin()); -// +// // Compare(dense, sm, "toDense/fromDense from dense"); -// } +// } // // { // What happens if dense matrix is empty? // nrows = ncols = 10; zr = 10; // Dense01 dense(nrows, ncols, zr); // SparseMatrix01 sm(nrows, ncols, dense.begin(), 0); // std::vector mat3(nrows*ncols, 0); -// +// // sm.toDense(mat3.begin()); // sm.fromDense(nrows, ncols, mat3.begin()); -// +// // Compare(dense, sm, "toDense/fromDense from dense"); // } // // { // What happens if there are empty rows? // nrows = ncols = 10; zr = 2; // Dense01 dense(nrows, ncols, zr); -// for (UInt i = 0; i < ncols; ++i) +// for (UInt i = 0; i < ncols; ++i) // dense.at(2,i) = dense.at(4,i) = dense.at(9,i) = 0; -// +// // SparseMatrix01 sm(nrows, ncols, dense.begin(), 0); // std::vector mat3(nrows*ncols, 0); -// +// // sm.toDense(mat3.begin()); // sm.fromDense(nrows, ncols, mat3.begin()); -// +// // Compare(dense, sm, "toDense/fromDense from dense"); // } // @@ -605,10 +616,10 @@ // sm.fromDense(10, 10, dense4.begin()); // Compare(dense4, sm, "fromDense/redim/3"); // } -// +// // // Exceptions // SparseMatrix01 sme1(1, 1); -// +// // try { // sme1.fromDense(-1, 0, dense.begin()); // Test("SparseMatrix01::fromDense() exception 1", true, false); @@ -641,23 +652,23 @@ // // Dense01 dense(nrows, ncols, zr); // SparseMatrix01 sm4(nrows, ncols, dense.begin(), 0); -// +// // sm4.decompact(); // Compare(dense, sm4, "decompact 1"); -// +// // sm4.compact(); // Compare(dense, sm4, "compact 1"); // // sm4.decompact(); // Compare(dense, sm4, "decompact 2"); -// +// // sm4.compact(); // Compare(dense, sm4, "compact 2"); // // sm4.decompact(); // sm4.decompact(); // Compare(dense, sm4, "decompact twice"); -// +// // sm4.compact(); // sm4.compact(); // Compare(dense, sm4, "compact twice"); @@ -671,7 +682,7 @@ // SparseMatrix01 sm3(nrows, ncols, dense3.begin(), 0); // // sm3.decompact(); -// +// // { // stringstream str; // str << "compact/decompact A " << nrows << "X" << ncols << "/" << zr @@ -695,7 +706,7 @@ // Dense01 dense(nrows, ncols, zr); // SparseMatrix01 sm(nrows, ncols, dense.begin(), 0); // std::vector mat3(nrows*ncols, 0); -// +// // sm.decompact(); // Compare(dense, sm, "decompact on dense"); // @@ -708,38 +719,38 @@ // void SparseMatrix01UnitTest::unit_test_getRowSparse() // { // UInt ncols, nrows, zr, i, k; -// +// // { // TEST_LOOP(M) { -// +// // Dense01 dense(nrows, ncols, zr); // SparseMatrix01 sm(nrows, ncols, dense.begin(), 0); // // for (i = 0; i < nrows; ++i) { -// +// // stringstream str; -// str << "getRowSparse A " << nrows << "X" << ncols +// str << "getRowSparse A " << nrows << "X" << ncols // << "/" << zr << " " << i; -// +// // vector ind; ; // sm.getRowSparse(i, back_inserter(ind)); -// +// // std::vector d(ncols, 0); // for (k = 0; k < ind.size(); ++k) // d[ind[k]] = 1.0; -// +// // CompareVectors(ncols, d.begin(), dense.begin(i), str.str().c_str()); // } -// } -// } +// } +// } // } // // //-------------------------------------------------------------------------------- // void SparseMatrix01UnitTest::unit_test_addRow() // { // // addRow, compact -// UInt nrows, ncols, zr; -// +// UInt nrows, ncols, zr; +// // { // TEST_LOOP(M) { // @@ -750,16 +761,16 @@ // sm.addRow(dense.begin(i)); // sm.compact(); // } -// +// // sm.decompact(); -// +// // { // stringstream str; // str << "addRow A " << nrows << "X" << ncols << "/" << zr // << " - non compact"; // Compare(dense, sm, str.str().c_str()); // } -// +// // sm.compact(); // // { @@ -770,8 +781,8 @@ // } // } // } -// -// // These tests compiled conditionally, because they are +// +// // These tests compiled conditionally, because they are // // based on asserts rather than checks // //#ifdef NTA_ASSERTIONS_ON @@ -814,7 +825,8 @@ // // //-------------------------------------------------------------------------------- // /** -// * The vector of all zeros is indistinguishable from the vector where the maxima +// * The vector of all zeros is indistinguishable from the vector where the +// maxima // * are the first elements of each section! // */ // void SparseMatrix01UnitTest::unit_test_addUniqueFilteredRow() @@ -828,42 +840,44 @@ // UInt winner, winner_ref; // // typedef map, pair > Check; -// Check control; Check::iterator it; -// +// Check control; Check::iterator it; +// // for (UInt i = 0; i < nreps; ++i) { -// -// for (UInt j = 0; j < ncols; ++j) +// +// for (UInt j = 0; j < ncols; ++j) // x[j] = rng_->getReal64(); // // winnerTakesAll2(boundaries, x.begin(), v.begin()); // -// it = control.find(v); -// if (it == control.end()) { +// it = control.find(v); +// if (it == control.end()) { // winner_ref = (UInt)control.size(); -// control[v] = make_pair(winner_ref, 1); -// } else { +// control[v] = make_pair(winner_ref, 1); +// } else { // winner_ref = control[v].first; -// control[v].second += 1; -// } +// control[v].second += 1; +// } // // winner = sm01.addUniqueFilteredRow(boundaries.begin(), x.begin()); // Test("SparseMatrix01 addUniqueFilteredRow 1", winner, winner_ref); // } -// -// Test("SparseMatrix01 addUniqueFilteredRow 2", sm01.nRows(), control.size()); +// +// Test("SparseMatrix01 addUniqueFilteredRow 2", sm01.nRows(), +// control.size()); // // SparseMatrix01::RowCounts rc = sm01.getRowCounts(); -// +// // for (UInt i = 0; i < rc.size(); ++i) { // sm01.getRow(rc[i].first, x.begin()); -// Test("SparseMatrix01 addUniqueFilteredRow 3", rc[i].second, control[x].second); +// Test("SparseMatrix01 addUniqueFilteredRow 3", rc[i].second, +// control[x].second); // } // // /* // x[0]=.5;x[1]=.7;x[2]=.3;x[3]=.1;x[4]=.7;x[5]=.1;x[6]=.1;x[7]=0; // sm01.addUniqueFilteredRow(boundaries.begin(), x.begin()); -// cout << sm01 << endl; -// +// cout << sm01 << endl; +// // x[0]=.9;x[1]=.7;x[2]=.3;x[3]=.1;x[4]=.9;x[5]=.1;x[6]=.1;x[7]=0; // sm01.addUniqueFilteredRow(boundaries.begin(), x.begin()); // cout << sm01 << endl; @@ -877,7 +891,7 @@ // cout << sm01 << endl; // */ // } -// +// // //-------------------------------------------------------------------------------- // void SparseMatrix01UnitTest::unit_test_addMinHamming() // { @@ -892,20 +906,20 @@ // UInt winner, winner_ref, hamming, min_hamming, max_distance; // // typedef map, pair > Check; -// Check control; Check::iterator it, arg_it; -// +// Check control; Check::iterator it, arg_it; +// // for (UInt i = 0; i < nreps; ++i) { -// +// // max_distance = rng_->getUInt32(nnzr); // -// for (UInt j = 0; j < ncols; ++j) +// for (UInt j = 0; j < ncols; ++j) // x[j] = rng_->getUInt32(100); // // winnerTakesAll2(boundaries, x.begin(), v.begin()); // // min_hamming = std::numeric_limits::max(); // arg_it = control.begin(); -// +// // for (it = control.begin(); it != control.end(); ++it) { // hamming = 0; // for (UInt k = 0; k < ncols; ++k) { @@ -917,8 +931,8 @@ // min_hamming = hamming; // arg_it = it; // } -// } -// +// } +// // if (min_hamming <= max_distance) { // ++ (arg_it->second.second); // winner_ref = arg_it->second.first; @@ -927,21 +941,21 @@ // control[v] = std::make_pair(winner_ref, 1); // } // -// winner = sm01.addMinHamming(boundaries.begin(), x.begin(), max_distance); -// Test("SparseMatrix01 addMinHamming 1", winner, winner_ref); -// if (winner != winner_ref) { +// winner = sm01.addMinHamming(boundaries.begin(), x.begin(), +// max_distance); Test("SparseMatrix01 addMinHamming 1", winner, +// winner_ref); if (winner != winner_ref) { // cout << winner << " - " << winner_ref << endl << endl; // for (UInt k = 0; k < ncols; ++k) // cout << v[k] << " "; // cout << endl << endl; -// cout << sm01 << endl; +// cout << sm01 << endl; // } // } -// +// // Test("SparseMatrix01 addMinHamming 2", sm01.nRows(), control.size()); // // SparseMatrix01::RowCounts rc = sm01.getRowCounts(); -// +// // for (UInt i = 0; i < rc.size(); ++i) { // sm01.getRow(rc[i].first, x.begin()); // Test("SparseMatrix01 addMinHamming 3", rc[i].second, control[x].second); @@ -950,11 +964,11 @@ // /* // x[0]=.5;x[1]=.7;x[2]=.3;x[3]=.1;x[4]=.7;x[5]=.1;x[6]=.1;x[7]=0; // sm01.addMinHamming(boundaries.begin(), x.begin(), 2); -// cout << sm01 << endl; +// cout << sm01 << endl; // // sm01.addMinHamming(boundaries.begin(), x.begin(), 2); -// cout << sm01 << endl; -// +// cout << sm01 << endl; +// // x[0]=.9;x[1]=.7;x[2]=.3;x[3]=.1;x[4]=.9;x[5]=.1;x[6]=.1;x[7]=0; // sm01.addMinHamming(boundaries.begin(), x.begin(), 2); // cout << sm01 << endl; @@ -971,10 +985,10 @@ // // //-------------------------------------------------------------------------------- // void SparseMatrix01UnitTest::unit_test_deleteRows() -// { +// { // { // Empty matrix // UInt nrows = 3, ncols = 3; -// +// // { // Empty matrix, empty del // SparseMatrix01 sm(ncols, nrows); // vector del; @@ -1002,7 +1016,7 @@ // // TEST_LOOP(M) { // -// Dense01 dense(nrows, ncols, zr); +// Dense01 dense(nrows, ncols, zr); // // { // Empty del // SparseMatrix01 sm(nrows, ncols, dense.begin(), 0); @@ -1018,18 +1032,18 @@ // for (UInt j = 0; j < ncols; ++j) // dense2.at(i,j) = 0; // } -// } +// } // SparseMatrix01 sm(nrows, ncols, dense2.begin(), 0); // vector del; // if (nrows > 2) { -// for (UInt i = 2; i < nrows-2; i += 2) +// for (UInt i = 2; i < nrows-2; i += 2) // del.push_back(i); // sm.deleteRows(del.begin(), del.end()); // dense2.deleteRows(del.begin(), del.end()); // Compare(dense2, sm, "SparseMatrix01::deleteRows() 5A"); // } // } -// +// // { // Rows of all zeros 2 // Dense01 dense2(nrows, ncols, zr); // ITER_1(nrows) { @@ -1041,7 +1055,7 @@ // SparseMatrix01 sm(nrows, ncols, dense2.begin(), 0); // vector del; // if (nrows > 2) { -// for (UInt i = 1; i < nrows-2; i += 2) +// for (UInt i = 1; i < nrows-2; i += 2) // del.push_back(i); // sm.deleteRows(del.begin(), del.end()); // dense2.deleteRows(del.begin(), del.end()); @@ -1054,7 +1068,7 @@ // SparseMatrix01 sm(nrows, ncols, dense.begin(), 0); // Dense01 dense2(nrows, ncols, zr); // vector del; -// for (UInt i = 2; i < nrows-2; ++i) +// for (UInt i = 2; i < nrows-2; ++i) // del.push_back(i); // sm.deleteRows(del.begin(), del.end()); // dense2.deleteRows(del.begin(), del.end()); @@ -1067,7 +1081,7 @@ // SparseMatrix01 sm(nrows, ncols, dense.begin(), 0); // Dense01 dense2(nrows, ncols, zr); // UInt* del = new UInt[nrows-1]; -// for (UInt i = 0; i < nrows-1; ++i) +// for (UInt i = 0; i < nrows-1; ++i) // del[i] = i + 1; // sm.deleteRows(del, del + nrows-2); // dense2.deleteRows(del, del + nrows-2); @@ -1089,7 +1103,7 @@ // // { // All rows // SparseMatrix01 sm(nrows, ncols, dense.begin(), 0); -// vector del; +// vector del; // for (UInt i = 0; i < nrows; ++i) // del.push_back(i); // sm.deleteRows(del.begin(), del.end()); @@ -1099,7 +1113,7 @@ // { // More than all rows => exceptions in assert mode // /* // SparseMatrix01 sm(nrows, ncols, dense.begin(), 0); -// vector del; +// vector del; // for (UInt i = 0; i < 2*nrows; ++i) // del.push_back(i); // sm.deleteRows(del.begin(), del.end()); @@ -1112,12 +1126,13 @@ // for (UInt i = 0; i < nrows; ++i) { // vector del(1); del[0] = 0; // sm.deleteRows(del.begin(), del.end()); -// Test("SparseMatrix01::deleteRows() 10", sm.nRows(), UInt(nrows-i-1)); +// Test("SparseMatrix01::deleteRows() 10", sm.nRows(), +// UInt(nrows-i-1)); // } // } // } -// } -// +// } +// // { // Test with unique rows // UInt nrows = 10; // std::vector boundaries(2); @@ -1128,21 +1143,21 @@ // // for (UInt i = 0; i < nrows; ++i) { // -// for (UInt j = 0; j < ncols; ++j) +// for (UInt j = 0; j < ncols; ++j) // x[j] = rng_->getReal64(); // // sm01.addUniqueFilteredRow(boundaries.begin(), x.begin()); // } // // SparseMatrix01::RowCounts rc = sm01.getRowCounts(); -// vector del; del.push_back(1); del.push_back(3); -// UInt nrows_new = sm01.nRows() - del.size(); +// vector del; del.push_back(1); del.push_back(3); +// UInt nrows_new = sm01.nRows() - del.size(); // sm01.deleteRows(del.begin(), del.end()); // Test("SparseMatrix01::deleteRows 11", sm01.nRows(), nrows_new); // rc = sm01.getRowCounts(); // UInt s = (UInt)rc.size(); // Test("SparseMatrix01::deleteRows 12", s, nrows_new); -// +// // // Remove last row // del.clear(); // del.push_back(sm01.nRows()-1); @@ -1151,7 +1166,7 @@ // Test("SparseMatrix01::deleteRows 13", sm01.nRows(), nrows_new); // rc = sm01.getRowCounts(); // s = (UInt)rc.size(); -// Test("SparseMatrix01::deleteRows 14", s, nrows_new); +// Test("SparseMatrix01::deleteRows 14", s, nrows_new); // } // // { // Delete with threshold @@ -1164,7 +1179,7 @@ // // for (UInt i = 0; i < nrows; ++i) { // -// for (UInt j = 0; j < ncols; ++j) +// for (UInt j = 0; j < ncols; ++j) // x[j] = rng_->getReal64(); // // sm01.addUniqueFilteredRow(boundaries.begin(), x.begin()); @@ -1178,8 +1193,9 @@ // ++keep; // vector > del_rows; // sm01.deleteRows(threshold, back_inserter(del_rows)); -// Test("SparseMatrix01::deleteRows(threshold) 15", del_rows.size(), nrows - keep); -// Test("SparseMatrix01::deleteRows(threshold) 16", sm01.nRows(), keep); +// Test("SparseMatrix01::deleteRows(threshold) 15", del_rows.size(), nrows +// - keep); Test("SparseMatrix01::deleteRows(threshold) 16", sm01.nRows(), +// keep); // } // // { // Delete with threshold - make sure counts are adjusted @@ -1200,7 +1216,7 @@ // // vector > del_rows; // sm01.deleteRows(2, back_inserter(del_rows)); -// +// // Test("SparseMatrix01::deleteRows(threshold) 17", sm01.nRows(), UInt(1)); // SparseMatrix01::RowCounts rc = sm01.getRowCounts(); // UInt idx = rc[0].first, count = rc[0].second; @@ -1225,13 +1241,13 @@ // Test("SparseMatrix01::deleteRows(threshold) 22", count, UInt(3)); // } // } -// +// // //-------------------------------------------------------------------------------- // void SparseMatrix01UnitTest::unit_test_deleteColumns() -// { +// { // { // Empty matrix // UInt nrows = 3, ncols = 3; -// +// // { // Empty matrix, empty del // SparseMatrix01 sm(ncols, nrows); // vector del; @@ -1244,23 +1260,23 @@ // vector del(1); del[0] = 0; // sm.deleteColumns(del.begin(), del.end()); // Test("SparseMatrix01::deleteColumns() 2", sm.nCols(), UInt(2)); -// } -// +// } +// // { // Empty matrix, many dels // SparseMatrix01 sm(ncols, nrows); // vector del(2); del[0] = 0; del[1] = 2; // sm.deleteColumns(del.begin(), del.end()); // Test("SparseMatrix01::deleteColumns() 3", sm.nCols(), UInt(1)); // } -// } // End empty matrix -// +// } // End empty matrix +// // { -// UInt nrows, ncols, zr; -// +// UInt nrows, ncols, zr; +// // TEST_LOOP(M) { // -// Dense01 dense(nrows, ncols, zr); -// +// Dense01 dense(nrows, ncols, zr); +// // { // Empty del // SparseMatrix01 sm(nrows, ncols, dense.begin(), 0); // vector del; @@ -1273,14 +1289,14 @@ // Dense01 dense2(nrows, ncols, zr); // vector del; // if (ncols > 2) { -// for (UInt i = 2; i < ncols-2; ++i) +// for (UInt i = 2; i < ncols-2; ++i) // del.push_back(i); // sm.deleteColumns(del.begin(), del.end()); // dense2.deleteColumns(del.begin(), del.end()); // Compare(dense2, sm, "SparseMatrix01::deleteColumns() 6"); // } // } -// +// // { // Many dels discontiguous // SparseMatrix01 sm(nrows, ncols, dense.begin(), 0); // Dense01 dense2(nrows, ncols, zr); @@ -1294,7 +1310,7 @@ // // { // All rows // SparseMatrix01 sm(nrows, ncols, dense.begin(), 0); -// vector del; +// vector del; // for (UInt i = 0; i < ncols; ++i) // del.push_back(i); // sm.deleteColumns(del.begin(), del.end()); @@ -1304,7 +1320,7 @@ // { // More than all rows => exception in assert mode // /* // SparseMatrix01 sm(nrows, ncols, dense.begin(), 0); -// vector del; +// vector del; // for (UInt i = 0; i < 2*ncols; ++i) // del.push_back(i); // sm.deleteColumns(del.begin(), del.end()); @@ -1317,15 +1333,16 @@ // for (UInt i = 0; i < ncols; ++i) { // vector del(1); del[0] = 0; // sm.deleteColumns(del.begin(), del.end()); -// Test("SparseMatrix01::deleteColumns() 10", sm.nCols(), UInt(ncols-i-1)); +// Test("SparseMatrix01::deleteColumns() 10", sm.nCols(), +// UInt(ncols-i-1)); // } // } // } -// } -// +// } +// // // Test with unique rows // { -// UInt nrows = 10; +// UInt nrows = 10; // std::vector boundaries(2); // boundaries[0] = 4; boundaries[1] = 8; // UInt nnzr = (UInt)boundaries.size(), ncols = boundaries[nnzr-1]; @@ -1334,19 +1351,20 @@ // // for (UInt i = 0; i < nrows; ++i) { // -// for (UInt j = 0; j < ncols; ++j) +// for (UInt j = 0; j < ncols; ++j) // x[j] = rng_->getReal64(); // // sm01.addUniqueFilteredRow(boundaries.begin(), x.begin()); // } -// +// // nrows = sm01.nRows(); // SparseMatrix01::RowCounts rc = sm01.getRowCounts(); // vector del; del.push_back(1); del.push_back(3); del.push_back(5); -// UInt ncols_new = sm01.nCols() - del.size(); +// UInt ncols_new = sm01.nCols() - del.size(); // sm01.deleteColumns(del.begin(), del.end()); // Test("SparseMatrix01::deleteColumns 11", sm01.nCols(), ncols_new); -// Test("SparseMatrix01::deleteColumns 12", sm01.getRowCounts().size(), nrows); +// Test("SparseMatrix01::deleteColumns 12", sm01.getRowCounts().size(), +// nrows); // } // } // @@ -1369,19 +1387,22 @@ // SparseMatrix01 smc(nrows, ncols, dense.begin(), 0); // vector y(nrows, 0); // smc.vecDistSquared(x.begin(), y.begin()); -// CompareVectors(nrows, y.begin(), yref.begin(), "vecDistSquared compact 1"); +// CompareVectors(nrows, y.begin(), yref.begin(), "vecDistSquared compact +// 1"); // // SparseMatrix01 smnc(ncols, nrows); // for (UInt i = 0; i < nrows; ++i) // smnc.addRow(dense.begin(i)); // fill(y.begin(), y.end(), Real(0)); // smnc.vecDistSquared(x.begin(), y.begin()); -// CompareVectors(nrows, y.begin(), yref.begin(), "vecDistSquared non-compact"); +// CompareVectors(nrows, y.begin(), yref.begin(), "vecDistSquared +// non-compact"); // // smnc.compact(); // fill(y.begin(), y.end(), Real(0)); // smnc.vecDistSquared(x.begin(), y.begin()); -// CompareVectors(nrows, y.begin(), yref.begin(), "vecDistSquared compact 2"); +// CompareVectors(nrows, y.begin(), yref.begin(), "vecDistSquared compact +// 2"); // // { // TEST_LOOP(M) { @@ -1481,7 +1502,7 @@ // << " - non compact"; // CompareVectors(nrows, y2.begin(), yref2.begin(), str.str().c_str()); // } -// +// // sm2.compact(); // fill(y2.begin(), y2.end(), Real(0)); // sm2.vecDist(x2.begin(), y2.begin()); @@ -1515,7 +1536,8 @@ // for (UInt i = 0; i < nrows; ++i) // smnc.addRow(dense.begin(i)); // smnc.vecMaxProd(x.begin(), y.begin()); -// CompareVectors(nrows, y.begin(), yref.begin(), "vecMaxProd non compact 1"); +// CompareVectors(nrows, y.begin(), yref.begin(), "vecMaxProd non compact +// 1"); // // smnc.compact(); // fill(y.begin(), y.end(), Real(0)); @@ -1580,7 +1602,8 @@ // for (UInt i = 0; i < nrows; ++i) // smnc.addRow(dense.begin(i)); // smnc.rightVecProd(x.begin(), y.begin()); -// CompareVectors(nrows, y.begin(), yref.begin(), "rightVecProd non compact 1"); +// CompareVectors(nrows, y.begin(), yref.begin(), "rightVecProd non compact +// 1"); // // smnc.compact(); // fill(y.begin(), y.end(), Real(0)); @@ -1728,7 +1751,7 @@ // res.first = 0; res.second = 0; // res = smnc.closestEuclidean(x.begin()); // ComparePair(res, ref, "closestEuclidean compact 2"); -// +// // { // TEST_LOOP(M) { // @@ -1743,7 +1766,7 @@ // ref = dense2.closestEuclidean(x2.begin()); // res.first = 0; res.second = 0; // res = sm2.closestEuclidean(x2.begin()); -// { +// { // stringstream str; // str << "closestEuclidean A " << nrows << "X" << ncols << "/" << zr // << " - non compact"; @@ -1824,13 +1847,13 @@ // ncols = 5; // nrows = 7; // zr = 2; -// +// // Dense01 dense(nrows, ncols, zr); // SparseMatrix01 sm4c(nrows, ncols, dense.begin(), 0); // // { // TEST_LOOP(M) { -// +// // Dense01 dense2(nrows, ncols, zr); // SparseMatrix01 sm2(nrows, ncols, dense2.begin(), 0); // @@ -1841,14 +1864,14 @@ // sm2.decompact(); // dense2.rowMax(x2.begin(), y2.begin()); // sm2.rowMax(x2.begin(), yref2.begin()); -// +// // { // stringstream str; // str << "rowMax A " << nrows << "X" << ncols << "/" << zr // << " - non compact"; // CompareVectors(nrows, y2.begin(), yref2.begin(), str.str().c_str()); // } -// +// // sm2.compact(); // dense2.rowMax(x2.begin(), y2.begin()); // sm2.rowMax(x2.begin(), yref2.begin()); @@ -1867,13 +1890,13 @@ // { // UInt ncols, nrows, zr, i; // ncols = 5; -// nrows = 7; +// nrows = 7; // zr = 2; // // Dense01 dense(nrows, ncols, zr); // SparseMatrix01 sm4c(nrows, ncols, dense.begin(), 0); // -// { +// { // TEST_LOOP(M) { // // Dense01 dense2(nrows, ncols, zr); @@ -1886,14 +1909,14 @@ // sm2.decompact(); // dense2.rowProd(x2.begin(), y2.begin()); // sm2.rowProd(x2.begin(), yref2.begin()); -// -// { +// +// { // stringstream str; // str << "rowProd A " << nrows << "X" << ncols << "/" << zr // << " - non compact"; // CompareVectors(nrows, y2.begin(), yref2.begin(), str.str().c_str()); -// } -// +// } +// // sm2.compact(); // dense2.rowProd(x2.begin(), y2.begin()); // sm2.rowProd(x2.begin(), yref2.begin()); @@ -1906,20 +1929,20 @@ // } // } // -// -// } +// +// } // // //-------------------------------------------------------------------------------- // struct RowCounter : public map, UInt> // { // inline void addRow(vector& v) { -// +// // iterator it = find(v); // if (it == end()) // (*this)[v] = 1; // else // ++ it->second; -// } +// } // // inline void checkRowCounts(Tester& t, SparseMatrix01& sm01, // const char* str) @@ -1938,7 +1961,7 @@ // stringstream buf2; // buf2 << "SparseMatrix01 row counts equal counts " << str; // t.Test(buf2.str(), rc[i].second, it->second); -// } +// } // } // }; // @@ -1948,18 +1971,18 @@ // { // Testing the comparison function // const UInt rowSize = 4; // UInt row1[rowSize], row2[rowSize]; -// +// // RowCompare comp(rowSize); -// -// for (UInt i = 0; i < rowSize; ++i) +// +// for (UInt i = 0; i < rowSize; ++i) // row1[i] = row2[i] = i; -// +// // Test("SparseMatrix01 row counts 1", comp(row1, row2), false); -// +// // row2[1] = 3; row2[2] = 5; row2[3] = 7; // Test("SparseMatrix01 row counts 2", comp(row1, row2), false); // Test("SparseMatrix01 row counts 3", comp(row2, row1), true); -// } +// } // // { // Testing the comparison function with a set // UInt ncols, nrows, nchildren, w, n; @@ -1968,9 +1991,9 @@ // n = 1000; // // vector boundaries(nchildren, 0); -// +// // boundaries[0] = rng_->getUInt32(w) + 1; -// for (UInt i = 1; i < nchildren-1; ++i) +// for (UInt i = 1; i < nchildren-1; ++i) // boundaries[i] = boundaries[i-1] + (rng_->getUInt32(w) + 1); // boundaries[nchildren-1] = ncols; // @@ -1994,9 +2017,10 @@ // comp[k++] = j; // myset.insert(comp); // } -// -// for (UInt i = 0; i < v.size(); ++i) -// Test("SparseMatrix01 row counts 4", (myset.find(v[i]) != myset.end()), true); +// +// for (UInt i = 0; i < v.size(); ++i) +// Test("SparseMatrix01 row counts 4", (myset.find(v[i]) != myset.end()), +// true); // // for (UInt i = 0; i < v.size(); ++i) // delete [] v[i]; @@ -2007,21 +2031,21 @@ // // RowCounter control; // SparseMatrix01 sm(ncols, 1, nnzr); -// +// // for (UInt n = 0; n < nreps; ++n) { -// vector v(ncols, 0); +// vector v(ncols, 0); // GenerateRand01Vector(rng_, nnzr, v); -// sm.addRow(v.begin()); +// sm.addRow(v.begin()); // control.addRow(v); -// } -// +// } +// // control.checkRowCounts(*this, sm, "1"); // // Test("SparseMatrix01 row counts ", sm.getRowCounts().size() > 0, true); // // // checking that counts are intact after decompact and recompact // // that manipulate row pointers -// sm.decompact(); +// sm.decompact(); // // control.checkRowCounts(*this, sm, "2"); // @@ -2030,7 +2054,7 @@ // GenerateRand01Vector(rng_, nnzr, v); // sm.addRow(v.begin()); // control.addRow(v); -// } +// } // // sm.compact(); // @@ -2045,41 +2069,41 @@ // // control.checkRowCounts(*this, sm2, "4"); // } -// } +// } // // //-------------------------------------------------------------------------------- // void SparseMatrix01UnitTest::unit_test_print() // { -// UInt nrows, ncols, zr; +// UInt nrows, ncols, zr; // nrows = 225; ncols = 31; zr = 30; // // TEST_LOOP(M) { // std::stringstream buf1, buf2; -// +// // Dense01 d(nrows, ncols, zr); // SparseMatrix01 sm(nrows, ncols, d.begin(), 0); -// +// // std::vector d2(nrows*ncols); // sm.toDense(d2.begin()); // for (UInt i = 0; i < nrows; ++i) { // for (UInt j = 0; j < ncols; ++j) // buf1 << d2[i*ncols+j] << " "; // buf1 << endl; -// } -// +// } +// // sm.print(buf2); // Test("SparseMatrix01 print 1", buf1.str(), buf2.str()); // } -// } +// } // // //-------------------------------------------------------------------------------- // /* Performance -// UInt nrows = 30000, ncols = 128, zr = 0; -// Dense01 d(nrows, ncols, zr); +// UInt nrows = 30000, ncols = 128, zr = 0; +// Dense01 d(nrows, ncols, zr); // SparseMatrix01 sm01(nrows, ncols, d.begin(), 0); // std::vector x(ncols, 1), y(nrows, 0); -// sm01.compact(); -// +// sm01.compact(); +// // { // 3000 iterations, 30000x128, 69.6% of total time, darwin86 // // 22.6 without hand unrolling by 4 of loop on k // // 15.62s straight loops @@ -2087,16 +2111,16 @@ // // 14.85s with iterator++ instead of indexing into x // // 14.79s with iterator++ and 4-unrolled loop on i // // 14.75s with compact -// // 14.89s with pointer while, 4-unrolled loop on i +// // 14.89s with pointer while, 4-unrolled loop on i // // 14.82s with pointer while, 2-unrolled loop on i // // 14.78s with pointer while, not unrolled loop on i // // 12.39s with pre-fetching x into nzb_ -// boost::timer t; -// ITER_1(3000) +// boost::timer t; +// ITER_1(3000) // sm01.rowProd(x.begin(), y.begin()); -// cout << t.elapsed() << endl; +// cout << t.elapsed() << endl; // } -// */ +// */ // // //-------------------------------------------------------------------------------- // void SparseMatrix01UnitTest::unit_test_numerical_accuracy() @@ -2105,7 +2129,7 @@ // // { // Tests for accuracy/numerical stability/underflows // nrows = 1; zr = 0; -// +// // for (ncols = 1; ncols < 128; ++ncols) { // vector y_d_prod(nrows), y_d_sum(nrows); // vector y_f_prod(nrows), y_f_sum(nrows); @@ -2120,7 +2144,7 @@ // x_d[i] = Real(ncols-i + 1e-5)/Real(ncols); // sm_d.rowProd(x_d.begin(), y_d_prod.begin()); // sm_d.rightVecProd(x_d.begin(), y_d_sum.begin()); -// } +// } // // { // float // Dense01 dense_f(nrows, ncols, zr); @@ -2140,9 +2164,9 @@ // // Real prod_rel_err = fabs((y_d_prod[0] - y_f_prod[0]) / y_d_prod[0]); // Real sum_rel_err = fabs((y_d_sum[0] - y_f_sum[0]) / y_d_sum[0]); -// +// // cout << setprecision(0); -// cout << "ncols = " << ncols +// cout << "ncols = " << ncols // // << " Real = " << float(y_d_prod[0]) // // << " float = " << y_f[0] // // << " error = " << (y_d[0] - y_f[0]) @@ -2151,37 +2175,38 @@ // << endl; // } // } -// } +// } // // //-------------------------------------------------------------------------------- // /** -// * A generator function object, that generates random numbers between 0 and 256. -// * It also has a threshold to control the sparsity of the vectors generated. +// * A generator function object, that generates random numbers between 0 and +// 256. +// * It also has a threshold to control the sparsity of the vectors generated. // */ // template -// struct rand_init +// struct rand_init // { // Random *r_; // int threshold_; -// +// // inline rand_init(Random *r, int threshold =100) // : r_(r), threshold_(threshold) // {} -// -// inline T operator()() -// { +// +// inline T operator()() +// { // return T(r_->getUInt32(100) > threshold_ ? 0 : 1 + r_->getUInt32(255)); // } // }; -// -// +// +// // //-------------------------------------------------------------------------------- // // THERE IS NO addRow WRITTEN YET for Dense01 // // THERE IS NO rowProd WRITTEN YET for Dense01 // void SparseMatrix01UnitTest::unit_test_usage() // { // /* -// SparseMatrix01* sm = +// SparseMatrix01* sm = // new SparseMatrix01(16,0); // // Dense01* smDense = @@ -2192,13 +2217,13 @@ // size_t r = 1000; // while(r==1000 || r==5) // r = rng_->get() % 18; -// +// // if (r == 0) { // // sm->compact(); // // no compact for Dense // -// } else if (r == 1) { +// } else if (r == 1) { // // sm->decompact(); // // no decompact for Dense @@ -2219,7 +2244,7 @@ // } // Compare(*smDense, *sm, "deleteRows"); // -// } else if (r == 3) { +// } else if (r == 3) { // // vector del; // if (rng_->get() % 100 < 90) { @@ -2235,25 +2260,26 @@ // } // Compare(*smDense, *sm, "deleteColumns"); // -// } +// } // else if (r == 4) { -// +// // vector new_row(sm->nCols(), 0); // size_t n = rng_->get() % 16; -// for (size_t z = 0; z < n; ++z) { +// for (size_t z = 0; z < n; ++z) { // if (rng_->get() % 100 < 90) { // for (size_t ii = 0; ii < new_row.size(); ++ii) -// new_row[ii] = (float) (rng_->get() % 100 > 70 ? 0 : rng_->get() % 256); -// } +// new_row[ii] = (float) (rng_->get() % 100 > 70 ? 0 : rng_->get() % +// 256); +// } // sm->addRow(new_row.begin()); // // THERE IS NO addRow WRITTEN YET // //smDense->addRow(ne_row.begin()); // //Compare(*smDense, *sm, "addRow"); // } -// +// // } else if (r == 5) { -// -// size_t nrows = rng_->get() % 32, ncols = rng_->get() % 32+1; +// +// size_t nrows = rng_->get() % 32, ncols = rng_->get() % 32+1; // delete sm; // delete smDense; // sm = new SparseMatrix01(ncols, nrows); @@ -2261,14 +2287,14 @@ // Compare(*smDense, *sm, "constructor(ncols,nrows)"); // // } else if (r == 6) { -// +// // delete sm; // delete smDense; // sm = new SparseMatrix01(16,0); // smDense = new Dense01(16,0); // Compare(*smDense, *sm, "constructor(16,0)"); -// -// } else if (r == 8) { +// +// } else if (r == 8) { // // vector x(sm->nCols()), y(sm->nRows()); // generate(x.begin(), x.end(), rand_init(rng_, 50)); @@ -2277,7 +2303,7 @@ // Compare(*smDense, *sm, "vecDistSquared"); // // } else if (r == 9) { -// +// // if(sm->nCols()==smDense->ncols && sm->nRows()==smDense->nrows){ // vector x(sm->nCols()), y(sm->nRows()); // // vector xDense(smDense->ncols), yDense(smDense->nrows); @@ -2298,9 +2324,9 @@ // smDense->rowDistSquared(randInt, x.begin()); // Compare(*smDense, *sm, "rowDistSquared"); // } -// +// // } else if (r == 11) { -// +// // vector x(sm->nCols()); // generate(x.begin(), x.end(), rand_init(rng_, 50)); // sm->closestEuclidean(x.begin()); @@ -2311,9 +2337,9 @@ // // vector x(sm->nCols()); // generate(x.begin(), x.end(), rand_init(rng_, 50)); -// for (size_t n = 0; n < sm->nCols(); ++n) +// for (size_t n = 0; n < sm->nCols(); ++n) // x.push_back(float(rng_->get() % 256)); -// sm->closestDot(x.begin()); +// sm->closestDot(x.begin()); // smDense->closestDot(x.begin()); // Compare(*smDense, *sm, "closestDot"); // @@ -2328,7 +2354,7 @@ // } // // } else if (r == 14) { -// +// // if(sm->nCols()==smDense->ncols && sm->nRows()==smDense->nrows){ // vector x(sm->nCols()), y(sm->nRows()); // generate(x.begin(), x.end(), rand_init(rng_, 50)); @@ -2336,95 +2362,94 @@ // smDense->vecMaxProd(x.begin(), y.begin()); // Compare(*smDense, *sm, "vecMaxProd"); // } -// +// // } else if (r == 15) { -// +// // if(sm->nCols()==smDense->ncols && sm->nRows()==smDense->nrows){ // vector x(sm->nCols()), y(sm->nRows()); // generate(x.begin(), x.end(), rand_init(rng_, 50)); // sm->rowMax(x.begin(), y.begin()); -// smDense->rowMax(x.begin(), y.begin()); +// smDense->rowMax(x.begin(), y.begin()); // Compare(*smDense, *sm, "rowMax"); // } // // } else if (r == 16) { -// +// // if(sm->nCols()==smDense->ncols && sm->nRows()==smDense->nrows){ // vector x(sm->nCols()), y(sm->nRows()); // generate(x.begin(), x.end(), rand_init(rng_, 50)); -// sm->rowProd(x.begin(), y.begin()); -// smDense->rowProd(x.begin(), y.begin()); +// sm->rowProd(x.begin(), y.begin()); +// smDense->rowProd(x.begin(), y.begin()); // Compare(*smDense, *sm, "rowProd(x.begin(), y.begin())"); // } -// +// // } else if (r == 17) { -// +// // vector x(sm->nCols()), y(sm->nRows()); // generate(x.begin(), x.end(), rand_init(rng_, 50)); // float theRandom = float(rng_->get() % 256); // sm->rowProd(x.begin(), y.begin(), theRandom); // // THERE IS NO rowProd WRITTEN YET // //smDense->rowProd(x.begin(), y.begin(), theRandom); -// //Compare(*smDense, *sm, "rowProd(x.begin(), y.begin(), float(rng_->get() % 256))"); -// } -// -// } -// -// // transpose(SparseMatrix01& tr) -// // vecDistSquared(InIter x, OutIter y) -// // minVecDistSquared(InIter x, -// // vecDist(InIter x, OutIter y) -// // value_type rowDistSquared( size_type& row, InIter x) -// // pair closestEuclidean(InIter x) -// // pair closestDot(InIter x) -// // rightVecProd(InIter x, OutIter y) -// // vecMaxProd(InIter x, OutIter y) -// // vecMaxProd01(InIter x, OutIter y) +// //Compare(*smDense, *sm, "rowProd(x.begin(), y.begin(), +// float(rng_->get() % 256))"); +// } +// +// } +// +// // transpose(SparseMatrix01& tr) +// // vecDistSquared(InIter x, OutIter y) +// // minVecDistSquared(InIter x, +// // vecDist(InIter x, OutIter y) +// // value_type rowDistSquared( size_type& row, InIter x) +// // pair closestEuclidean(InIter x) +// // pair closestDot(InIter x) +// // rightVecProd(InIter x, OutIter y) +// // vecMaxProd(InIter x, OutIter y) +// // vecMaxProd01(InIter x, OutIter y) // // axby_2( size_type& row, value_type a, value_type b, InIter x) // // axby_3(value_type a, value_type b, InIter x) -// // rowMax(OutIter maxima) -// // colMax(OutIter maxima) +// // rowMax(OutIter maxima) +// // colMax(OutIter maxima) // // normalizeRows(bool exact =false) -// // value_type accumulate_nz( size_type& row, binary_functor f, -// // value_type accumulate( size_type& row, binary_functor f, -// // multiply( SparseMatrix01& B, SparseMatrix01& C) +// // value_type accumulate_nz( size_type& row, binary_functor f, +// // value_type accumulate( size_type& row, binary_functor f, +// // multiply( SparseMatrix01& B, SparseMatrix01& C) // // size_type findRow( size_type nnzr, IndIt ind_it, NzIt nz_it) -// // findRows(F f, MatchIt m_it) -// // map( SparseMatrix01& B, SparseMatrix01& C) +// // findRows(F f, MatchIt m_it) +// // map( SparseMatrix01& B, SparseMatrix01& C) // */ -// } -// - //-------------------------------------------------------------------------------- - // void SparseMatrix01UnitTest::RunTests() - // { - - //unit_test_construction(); - //unit_test_fromDense(); - //unit_test_csr(); - //unit_test_compact(); - //unit_test_getRowSparse(); - //unit_test_addRow(); - //unit_test_addUniqueFilteredRow(); - //unit_test_addMinHamming(); - //unit_test_deleteRows(); - //unit_test_deleteColumns(); - //unit_test_rowDistSquared(); - //unit_test_vecDistSquared(); - //unit_test_vecDist(); - //unit_test_closestEuclidean(); - //unit_test_closestDot(); - //unit_test_vecMaxProd(); - //unit_test_vecProd(); - //unit_test_rowMax(); - //unit_test_rowProd(); - //unit_test_row_counts(); - //unit_test_print(); - ////unit_test_usage(); - ////unit_test_numerical_accuracy(); - // } +// } +// +//-------------------------------------------------------------------------------- +// void SparseMatrix01UnitTest::RunTests() +// { - //-------------------------------------------------------------------------------- - -// } // namespace nupic +// unit_test_construction(); +// unit_test_fromDense(); +// unit_test_csr(); +// unit_test_compact(); +// unit_test_getRowSparse(); +// unit_test_addRow(); +// unit_test_addUniqueFilteredRow(); +// unit_test_addMinHamming(); +// unit_test_deleteRows(); +// unit_test_deleteColumns(); +// unit_test_rowDistSquared(); +// unit_test_vecDistSquared(); +// unit_test_vecDist(); +// unit_test_closestEuclidean(); +// unit_test_closestDot(); +// unit_test_vecMaxProd(); +// unit_test_vecProd(); +// unit_test_rowMax(); +// unit_test_rowProd(); +// unit_test_row_counts(); +// unit_test_print(); +////unit_test_usage(); +////unit_test_numerical_accuracy(); +// } +//-------------------------------------------------------------------------------- +// } // namespace nupic diff --git a/src/test/unit/math/SparseMatrix01UnitTest.hpp b/src/test/unit/math/SparseMatrix01UnitTest.hpp index 2dce76d026..0cfb472754 100644 --- a/src/test/unit/math/SparseMatrix01UnitTest.hpp +++ b/src/test/unit/math/SparseMatrix01UnitTest.hpp @@ -51,7 +51,8 @@ // : nrows(nr), ncols(nc), m(nr*nc, 0) // {} // -// Dense01(Int nr, Int nc, Int nzr, bool small =false, bool emptyRows =false) +// Dense01(Int nr, Int nc, Int nzr, bool small =false, bool emptyRows +// =false) // : nrows(nr), ncols(nc), // m(nr*nc, 0) // { @@ -75,7 +76,8 @@ // inline iterator begin() { return m.begin(); } // inline const_iterator begin() const { return m.begin(); } // inline iterator begin(const Int i) { return m.begin() + i*ncols; } -// inline const_iterator begin(const Int i) const { return m.begin() + i*ncols; } +// inline const_iterator begin(const Int i) const { return m.begin() + +// i*ncols; } // // inline Float& at(const Int i, const Int j) // { @@ -385,7 +387,8 @@ // if (sparse.nNonZerosRow(i) != dense.nNonZerosRow(i)) { // std::stringstream str5; // str5 << str << " nNonZerosRow(" << i << ")"; -// Test(str5.str().c_str(), sparse.nNonZerosRow(i), dense.nNonZerosRow(i)); +// Test(str5.str().c_str(), sparse.nNonZerosRow(i), +// dense.nNonZerosRow(i)); // } // // ITER_2(nrows, ncols) diff --git a/src/test/unit/math/SparseMatrixTest.cpp b/src/test/unit/math/SparseMatrixTest.cpp index 5e0adc68db..9ac2c1b0a1 100644 --- a/src/test/unit/math/SparseMatrixTest.cpp +++ b/src/test/unit/math/SparseMatrixTest.cpp @@ -26,18 +26,16 @@ #include #include -#include #include +#include #include #include #include - using namespace nupic; -TEST(SparseMatrixReadWrite, EmptyMatrix) -{ +TEST(SparseMatrixReadWrite, EmptyMatrix) { SparseMatrix m1, m2; m1.resize(3, 4); @@ -46,7 +44,8 @@ TEST(SparseMatrixReadWrite, EmptyMatrix) // write capnp::MallocMessageBuilder message1; - SparseMatrixProto::Builder protoBuilder = message1.initRoot(); + SparseMatrixProto::Builder protoBuilder = + message1.initRoot(); m1.write(protoBuilder); kj::std::StdOutputStream out(ss); capnp::writeMessage(out, message1); @@ -62,8 +61,7 @@ TEST(SparseMatrixReadWrite, EmptyMatrix) ASSERT_EQ(m1.nCols(), m2.nCols()) << "Number of columns don't match"; } -TEST(SparseMatrixReadWrite, Basic) -{ +TEST(SparseMatrixReadWrite, Basic) { SparseMatrix m1, m2; m1.resize(3, 4); @@ -73,7 +71,8 @@ TEST(SparseMatrixReadWrite, Basic) // write capnp::MallocMessageBuilder message1; - SparseMatrixProto::Builder protoBuilder = message1.initRoot(); + SparseMatrixProto::Builder protoBuilder = + message1.initRoot(); m1.write(protoBuilder); kj::std::StdOutputStream out(ss); capnp::writeMessage(out, message1); @@ -88,15 +87,16 @@ TEST(SparseMatrixReadWrite, Basic) ASSERT_EQ(m1.nRows(), m2.nRows()) << "Number of rows don't match"; ASSERT_EQ(m1.nCols(), m2.nCols()) << "Number of columns don't match"; - std::vector > m1r1(m1.nNonZerosOnRow(1)); + std::vector> m1r1(m1.nNonZerosOnRow(1)); m1.getRowToSparse(1, m1r1.begin()); ASSERT_EQ(m1r1.size(), 1) << "Invalid # of elements in original matrix"; - std::vector > m2r1(m2.nNonZerosOnRow(1)); + std::vector> m2r1(m2.nNonZerosOnRow(1)); m2.getRowToSparse(1, m2r1.begin()); ASSERT_EQ(m2r1.size(), 1) << "Invalid # of elements in copied matrix"; ASSERT_EQ(m1r1[0].first, 1) << "Invalid col index in original matrix"; - ASSERT_EQ(m1r1[0].first, m2r1[0].first) << "Invalid col index in copied matrix"; + ASSERT_EQ(m1r1[0].first, m2r1[0].first) + << "Invalid col index in copied matrix"; ASSERT_EQ(m1r1[0].second, 3.0) << "Invalid value in original matrix"; ASSERT_EQ(m1r1[0].second, m2r1[0].second) << "Invalid value in copied matrix"; } diff --git a/src/test/unit/math/SparseMatrixUnitTest.cpp b/src/test/unit/math/SparseMatrixUnitTest.cpp index 2c2dbd7b3d..e742ac490e 100644 --- a/src/test/unit/math/SparseMatrixUnitTest.cpp +++ b/src/test/unit/math/SparseMatrixUnitTest.cpp @@ -27,817 +27,619 @@ #include #include -#include #include "gtest/gtest.h" +#include using std::string; using std::vector; using namespace nupic; -namespace -{ - struct IncrementNonZerosOnOuterTest - { - string name; - UInt32 nrows; - UInt32 ncols; - vector before; - vector outerRows; - vector outerCols; - Real32 delta; - vector expected; - }; +namespace { +struct IncrementNonZerosOnOuterTest { + string name; + UInt32 nrows; + UInt32 ncols; + vector before; + vector outerRows; + vector outerCols; + Real32 delta; + vector expected; +}; - TEST(SparseMatrixTest, incrementNonZerosOnOuter) - { - vector tests = { - { - "Test 1", - // Dimensions - 4, 4, - // Before - {0, 1, 0, 1, - 2, 0, 2, 0, - 0, 1, 0, 1, - 2, 0, 2, 0}, - // Selection - {0, 2, 3}, - {0, 1}, - // Delta - 40, - // Expected - {0, 41, 0, 1, - 2, 0, 2, 0, - 0, 41, 0, 1, - 42, 0, 2, 0} - }, - { - "Test 2", - // Dimensions - 4, 4, - // Before - {1,1,1,1, - 1,1,1,1, - 1,1,1,1, - 1,1,1,1}, - // Selection - {0, 3}, - {0, 3}, - // Delta - 41, - // Expected - {42,1,1,42, - 1,1,1,1, - 1,1,1,1, - 42,1,1,42} - }, - { - "Test 3", - // Dimensions - 4, 4, - // Before - {0,1,1,0, - 1,1,1,1, - 1,1,1,1, - 0,1,1,0}, - // Selection - {0, 3}, - {0, 3}, - // Delta - 41, - // Expected - {0,1,1,0, - 1,1,1,1, - 1,1,1,1, - 0,1,1,0} - } - }; +TEST(SparseMatrixTest, incrementNonZerosOnOuter) { + vector tests = { + {"Test 1", + // Dimensions + 4, + 4, + // Before + {0, 1, 0, 1, 2, 0, 2, 0, 0, 1, 0, 1, 2, 0, 2, 0}, + // Selection + {0, 2, 3}, + {0, 1}, + // Delta + 40, + // Expected + {0, 41, 0, 1, 2, 0, 2, 0, 0, 41, 0, 1, 42, 0, 2, 0}}, + {"Test 2", + // Dimensions + 4, + 4, + // Before + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + // Selection + {0, 3}, + {0, 3}, + // Delta + 41, + // Expected + {42, 1, 1, 42, 1, 1, 1, 1, 1, 1, 1, 1, 42, 1, 1, 42}}, + {"Test 3", + // Dimensions + 4, + 4, + // Before + {0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0}, + // Selection + {0, 3}, + {0, 3}, + // Delta + 41, + // Expected + {0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0}}}; - for (const IncrementNonZerosOnOuterTest& test : tests) - { - NTA_INFO << "Test: " << test.name; - SparseMatrix<> m(test.nrows, test.ncols, test.before.begin()); + for (const IncrementNonZerosOnOuterTest &test : tests) { + NTA_INFO << "Test: " << test.name; + SparseMatrix<> m(test.nrows, test.ncols, test.before.begin()); - m.incrementNonZerosOnOuter(test.outerRows.begin(), test.outerRows.end(), - test.outerCols.begin(), test.outerCols.end(), - test.delta); + m.incrementNonZerosOnOuter(test.outerRows.begin(), test.outerRows.end(), + test.outerCols.begin(), test.outerCols.end(), + test.delta); - vector actual(test.nrows * test.ncols); - m.toDense(actual.begin()); + vector actual(test.nrows * test.ncols); + m.toDense(actual.begin()); - EXPECT_EQ(test.expected, actual); - } + EXPECT_EQ(test.expected, actual); } +} +struct IncrementNonZerosOnRowsExcludingColsTest { + string name; + UInt32 nrows; + UInt32 ncols; + vector before; + vector outerRows; + vector outerCols; + Real32 delta; + vector expected; +}; - struct IncrementNonZerosOnRowsExcludingColsTest - { - string name; - UInt32 nrows; - UInt32 ncols; - vector before; - vector outerRows; - vector outerCols; - Real32 delta; - vector expected; - }; +TEST(SparseMatrixTest, incrementNonZerosOnRowsExcludingCols) { + vector tests = { + {"Test 1", + // Dimensions + 4, + 4, + // Before + {0, 1, 0, 1, 2, 0, 2, 0, 0, 1, 0, 1, 2, 0, 2, 0}, + // Selection + {0, 2, 3}, + {0, 1}, + // Delta + 40, + // Expected + {0, 1, 0, 41, 2, 0, 2, 0, 0, 1, 0, 41, 2, 0, 42, 0}}, + {"Test 2", + // Dimensions + 4, + 4, + // Before + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + // Selection + {0, 3}, + {0, 3}, + // Delta + 41, + // Expected + {1, 42, 42, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 42, 42, 1}}, + {"Test 3", + // Dimensions + 4, + 4, + // Before + {1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1}, + // Selection + {0, 3}, + {0, 3}, + // Delta + 41, + // Expected + {1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1}}}; - TEST(SparseMatrixTest, incrementNonZerosOnRowsExcludingCols) - { - vector tests = { - { - "Test 1", - // Dimensions - 4, 4, - // Before - {0, 1, 0, 1, - 2, 0, 2, 0, - 0, 1, 0, 1, - 2, 0, 2, 0}, - // Selection - {0, 2, 3}, - {0, 1}, - // Delta - 40, - // Expected - {0, 1, 0, 41, - 2, 0, 2, 0, - 0, 1, 0, 41, - 2, 0, 42, 0} - }, - { - "Test 2", - // Dimensions - 4, 4, - // Before - {1,1,1,1, - 1,1,1,1, - 1,1,1,1, - 1,1,1,1}, - // Selection - {0, 3}, - {0, 3}, - // Delta - 41, - // Expected - {1,42,42,1, - 1,1,1,1, - 1,1,1,1, - 1,42,42,1} - }, - { - "Test 3", - // Dimensions - 4, 4, - // Before - {1,0,0,1, - 1,1,1,1, - 1,1,1,1, - 1,0,0,1}, - // Selection - {0, 3}, - {0, 3}, - // Delta - 41, - // Expected - {1,0,0,1, - 1,1,1,1, - 1,1,1,1, - 1,0,0,1} - } - }; - - for (const IncrementNonZerosOnRowsExcludingColsTest& test : tests) - { - NTA_INFO << "Test: " << test.name; - SparseMatrix<> m(test.nrows, test.ncols, test.before.begin()); + for (const IncrementNonZerosOnRowsExcludingColsTest &test : tests) { + NTA_INFO << "Test: " << test.name; + SparseMatrix<> m(test.nrows, test.ncols, test.before.begin()); - m.incrementNonZerosOnRowsExcludingCols( - test.outerRows.begin(), test.outerRows.end(), - test.outerCols.begin(), test.outerCols.end(), - test.delta); + m.incrementNonZerosOnRowsExcludingCols( + test.outerRows.begin(), test.outerRows.end(), test.outerCols.begin(), + test.outerCols.end(), test.delta); - vector actual(test.nrows * test.ncols); - m.toDense(actual.begin()); + vector actual(test.nrows * test.ncols); + m.toDense(actual.begin()); - EXPECT_EQ(test.expected, actual); - } + EXPECT_EQ(test.expected, actual); } +} +struct SetZerosOnOuterTest { + string name; + UInt32 nrows; + UInt32 ncols; + vector before; + vector outerRows; + vector outerCols; + Real32 value; + vector expected; +}; - struct SetZerosOnOuterTest - { - string name; - UInt32 nrows; - UInt32 ncols; - vector before; - vector outerRows; - vector outerCols; - Real32 value; - vector expected; - }; +TEST(SparseMatrixTest, setZerosOnOuter) { + vector tests = { + {"Test 1", + // Dimensions + 4, + 4, + // Before + {0, 1, 0, 1, 2, 0, 2, 0, 0, 1, 0, 1, 2, 0, 2, 0}, + // Selection + {0, 2, 3}, + {0, 1}, + // Value + 42, + // Expected + {42, 1, 0, 1, 2, 0, 2, 0, 42, 1, 0, 1, 2, 42, 2, 0}}, + {"Test 2", + // Dimensions + 4, + 4, + // Before + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + // Selection + {0, 3}, + {0, 3}, + // Value + 42, + // Expected + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}}, + {"Test 3", + // Dimensions + 4, + 4, + // Before + {1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1}, + // Selection + {0, 3}, + {1, 2}, + // Value + 42, + // Expected + {1, 42, 42, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 42, 42, 1}}}; - TEST(SparseMatrixTest, setZerosOnOuter) - { - vector tests = { - { - "Test 1", - // Dimensions - 4, 4, - // Before - {0, 1, 0, 1, - 2, 0, 2, 0, - 0, 1, 0, 1, - 2, 0, 2, 0}, - // Selection - {0, 2, 3}, - {0, 1}, - // Value - 42, - // Expected - {42, 1, 0, 1, - 2, 0, 2, 0, - 42, 1, 0, 1, - 2, 42, 2, 0} - }, - { - "Test 2", - // Dimensions - 4, 4, - // Before - {1,1,1,1, - 1,1,1,1, - 1,1,1,1, - 1,1,1,1}, - // Selection - {0, 3}, - {0, 3}, - // Value - 42, - // Expected - {1,1,1,1, - 1,1,1,1, - 1,1,1,1, - 1,1,1,1} - }, - { - "Test 3", - // Dimensions - 4, 4, - // Before - {1,0,0,1, - 1,1,1,1, - 1,1,1,1, - 1,0,0,1}, - // Selection - {0, 3}, - {1, 2}, - // Value - 42, - // Expected - {1,42,42,1, - 1,1,1,1, - 1,1,1,1, - 1,42,42,1} - } - }; + for (const SetZerosOnOuterTest &test : tests) { + NTA_INFO << "Test: " << test.name; + SparseMatrix<> m(test.nrows, test.ncols, test.before.begin()); - for (const SetZerosOnOuterTest& test : tests) - { - NTA_INFO << "Test: " << test.name; - SparseMatrix<> m(test.nrows, test.ncols, test.before.begin()); + m.setZerosOnOuter(test.outerRows.begin(), test.outerRows.end(), + test.outerCols.begin(), test.outerCols.end(), test.value); - m.setZerosOnOuter( - test.outerRows.begin(), test.outerRows.end(), - test.outerCols.begin(), test.outerCols.end(), - test.value); + vector actual(test.nrows * test.ncols); + m.toDense(actual.begin()); - vector actual(test.nrows * test.ncols); - m.toDense(actual.begin()); - - EXPECT_EQ(test.expected, actual); - } + EXPECT_EQ(test.expected, actual); } +} +struct SetRandomZerosOnOuterTest_single { + string name; + UInt32 nrows; + UInt32 ncols; + vector before; + vector outerRows; + vector outerCols; + Int32 numNewNonZerosPerRow; + Real32 value; +}; - struct SetRandomZerosOnOuterTest_single - { - string name; - UInt32 nrows; - UInt32 ncols; - vector before; - vector outerRows; - vector outerCols; - Int32 numNewNonZerosPerRow; - Real32 value; - }; +TEST(SparseMatrixTest, setRandomZerosOnOuter_single) { + Random rng; - TEST(SparseMatrixTest, setRandomZerosOnOuter_single) - { - Random rng; - - vector tests = { - { - "Test 1", - // Dimensions - 8, 6, - // Before - {1, 1, 0, 0, 1, 1, - 0, 0, 1, 1, 0, 0, - 0, 0, 1, 0, 0, 1, - 1, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, - 0, 0, 1, 1, 0, 1}, - // Selection - {0, 3, 4, 5, 6, 7}, - {0, 3, 4}, - // numNewNonZerosPerRow - 2, - // value - 42 - }, - { - "No selected rows", - // Dimensions - 8, 6, - // Before - {1, 1, 0, 0, 1, 1, - 0, 0, 1, 1, 0, 0, - 0, 0, 1, 0, 0, 1, - 1, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, - 0, 0, 1, 1, 0, 1}, - // Selection - {}, - {0, 3, 4}, - // numNewNonZerosPerRow - 2, - // value - 42 - }, - { - "No selected cols", - // Dimensions - 8, 6, - // Before - {1, 1, 0, 0, 1, 1, - 0, 0, 1, 1, 0, 0, - 0, 0, 1, 0, 0, 1, - 1, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, - 0, 0, 1, 1, 0, 1}, - // Selection - {0, 3, 4, 5, 6, 7}, - {}, - // numNewNonZerosPerRow - 2, - // value - 42 - } - }; + vector tests = { + {"Test 1", + // Dimensions + 8, + 6, + // Before + {1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1}, + // Selection + {0, 3, 4, 5, 6, 7}, + {0, 3, 4}, + // numNewNonZerosPerRow + 2, + // value + 42}, + {"No selected rows", + // Dimensions + 8, + 6, + // Before + {1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1}, + // Selection + {}, + {0, 3, 4}, + // numNewNonZerosPerRow + 2, + // value + 42}, + {"No selected cols", + // Dimensions + 8, + 6, + // Before + {1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1}, + // Selection + {0, 3, 4, 5, 6, 7}, + {}, + // numNewNonZerosPerRow + 2, + // value + 42}}; - for (const SetRandomZerosOnOuterTest_single& test : tests) - { - NTA_INFO << "Test: " << test.name; + for (const SetRandomZerosOnOuterTest_single &test : tests) { + NTA_INFO << "Test: " << test.name; - SparseMatrix<> m(test.nrows, test.ncols, test.before.begin()); + SparseMatrix<> m(test.nrows, test.ncols, test.before.begin()); - m.setRandomZerosOnOuter( - test.outerRows.begin(), test.outerRows.end(), - test.outerCols.begin(), test.outerCols.end(), - test.numNewNonZerosPerRow, test.value, rng); + m.setRandomZerosOnOuter(test.outerRows.begin(), test.outerRows.end(), + test.outerCols.begin(), test.outerCols.end(), + test.numNewNonZerosPerRow, test.value, rng); - vector actual(test.nrows*test.ncols); - m.toDense(actual.begin()); + vector actual(test.nrows * test.ncols); + m.toDense(actual.begin()); - for (UInt32 row = 0; row < test.nrows; row++) - { - Int32 numSelectedZeros = 0; - Int32 numConvertedZeros = 0; + for (UInt32 row = 0; row < test.nrows; row++) { + Int32 numSelectedZeros = 0; + Int32 numConvertedZeros = 0; - for (UInt32 col = 0; col < test.ncols; col++) - { - UInt32 i = row*test.ncols + col; - if (std::binary_search(test.outerRows.begin(), test.outerRows.end(), row) && - std::binary_search(test.outerCols.begin(), test.outerCols.end(), col)) - { - if (test.before[i] == 0) - { - numSelectedZeros++; + for (UInt32 col = 0; col < test.ncols; col++) { + UInt32 i = row * test.ncols + col; + if (std::binary_search(test.outerRows.begin(), test.outerRows.end(), + row) && + std::binary_search(test.outerCols.begin(), test.outerCols.end(), + col)) { + if (test.before[i] == 0) { + numSelectedZeros++; - if (actual[i] != 0) - { - // Should replace zeros with the specified value. - EXPECT_EQ(test.value, actual[i]); - numConvertedZeros++; - } + if (actual[i] != 0) { + // Should replace zeros with the specified value. + EXPECT_EQ(test.value, actual[i]); + numConvertedZeros++; } } - else - { - // Every value not in the selection should be unchanged. - EXPECT_EQ(test.before[i], actual[i]); - } - - if (test.before[i] != 0) - { - // Every value that was nonzero should not have changed. - EXPECT_EQ(test.before[i], actual[i]); - } + } else { + // Every value not in the selection should be unchanged. + EXPECT_EQ(test.before[i], actual[i]); } - // Should replace numNewNonZerosPerRow, or all of them if there aren't - // that many. - EXPECT_EQ(std::min(test.numNewNonZerosPerRow, numSelectedZeros), - numConvertedZeros); + if (test.before[i] != 0) { + // Every value that was nonzero should not have changed. + EXPECT_EQ(test.before[i], actual[i]); + } } + + // Should replace numNewNonZerosPerRow, or all of them if there aren't + // that many. + EXPECT_EQ(std::min(test.numNewNonZerosPerRow, numSelectedZeros), + numConvertedZeros); } } +} +struct SetRandomZerosOnOuterTest_multi { + string name; + UInt32 nrows; + UInt32 ncols; + vector before; + vector outerRows; + vector outerCols; + vector numNewNonZerosPerRow; + Real32 value; +}; - struct SetRandomZerosOnOuterTest_multi - { - string name; - UInt32 nrows; - UInt32 ncols; - vector before; - vector outerRows; - vector outerCols; - vector numNewNonZerosPerRow; - Real32 value; - }; +TEST(SparseMatrixTest, setRandomZerosOnOuter_multi) { + Random rng; - TEST(SparseMatrixTest, setRandomZerosOnOuter_multi) - { - Random rng; - - vector tests = { - { - "Test 1", - // Dimensions - 8, 6, - // Before - {1, 1, 0, 0, 1, 1, - 0, 0, 1, 1, 0, 0, - 0, 0, 1, 0, 0, 1, - 1, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, - 0, 0, 1, 1, 0, 1}, - // Selection - {0, 3, 4, 5, 6, 7}, - {0, 3, 4}, - // numNewNonZerosPerRow - {2, 2, 2, 2, 2, 2}, - // value - 42 - }, - { - "No selected rows", - // Dimensions - 8, 6, - // Before - {1, 1, 0, 0, 1, 1, - 0, 0, 1, 1, 0, 0, - 0, 0, 1, 0, 0, 1, - 1, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, - 0, 0, 1, 1, 0, 1}, - // Selection - {}, - {0, 3, 4}, - // numNewNonZerosPerRow - {}, - // value - 42 - }, - { - "No selected cols", - // Dimensions - 8, 6, - // Before - {1, 1, 0, 0, 1, 1, - 0, 0, 1, 1, 0, 0, - 0, 0, 1, 0, 0, 1, - 1, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, - 0, 0, 1, 1, 0, 1}, - // Selection - {0, 3, 4, 5, 6, 7}, - {}, - // numNewNonZerosPerRow - {2, 2, 2, 2, 2, 2}, - // value - 42 - } - }; + vector tests = { + {"Test 1", + // Dimensions + 8, + 6, + // Before + {1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1}, + // Selection + {0, 3, 4, 5, 6, 7}, + {0, 3, 4}, + // numNewNonZerosPerRow + {2, 2, 2, 2, 2, 2}, + // value + 42}, + {"No selected rows", + // Dimensions + 8, + 6, + // Before + {1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1}, + // Selection + {}, + {0, 3, 4}, + // numNewNonZerosPerRow + {}, + // value + 42}, + {"No selected cols", + // Dimensions + 8, + 6, + // Before + {1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1}, + // Selection + {0, 3, 4, 5, 6, 7}, + {}, + // numNewNonZerosPerRow + {2, 2, 2, 2, 2, 2}, + // value + 42}}; - for (const SetRandomZerosOnOuterTest_multi& test : tests) - { - NTA_INFO << "Test: " << test.name; + for (const SetRandomZerosOnOuterTest_multi &test : tests) { + NTA_INFO << "Test: " << test.name; - SparseMatrix<> m(test.nrows, test.ncols, test.before.begin()); + SparseMatrix<> m(test.nrows, test.ncols, test.before.begin()); - m.setRandomZerosOnOuter( - test.outerRows.begin(), test.outerRows.end(), - test.outerCols.begin(), test.outerCols.end(), - test.numNewNonZerosPerRow.begin(), test.numNewNonZerosPerRow.end(), - test.value, rng); + m.setRandomZerosOnOuter(test.outerRows.begin(), test.outerRows.end(), + test.outerCols.begin(), test.outerCols.end(), + test.numNewNonZerosPerRow.begin(), + test.numNewNonZerosPerRow.end(), test.value, rng); - vector actual(test.nrows*test.ncols); - m.toDense(actual.begin()); + vector actual(test.nrows * test.ncols); + m.toDense(actual.begin()); - for (UInt32 row = 0; row < test.nrows; row++) - { - Int32 numSelectedZeros = 0; - Int32 numConvertedZeros = 0; + for (UInt32 row = 0; row < test.nrows; row++) { + Int32 numSelectedZeros = 0; + Int32 numConvertedZeros = 0; - auto rowInSelection = std::find(test.outerRows.begin(), - test.outerRows.end(), row); - const Int32 requestedNumNewZeros = + auto rowInSelection = + std::find(test.outerRows.begin(), test.outerRows.end(), row); + const Int32 requestedNumNewZeros = rowInSelection != test.outerRows.end() - ? test.numNewNonZerosPerRow[rowInSelection - test.outerRows.begin()] - : 0; + ? test.numNewNonZerosPerRow[rowInSelection - + test.outerRows.begin()] + : 0; - for (UInt32 col = 0; col < test.ncols; col++) - { - UInt32 i = row*test.ncols + col; - if (rowInSelection != test.outerRows.end() && - std::binary_search(test.outerCols.begin(), test.outerCols.end(), col)) - { - if (test.before[i] == 0) - { - numSelectedZeros++; + for (UInt32 col = 0; col < test.ncols; col++) { + UInt32 i = row * test.ncols + col; + if (rowInSelection != test.outerRows.end() && + std::binary_search(test.outerCols.begin(), test.outerCols.end(), + col)) { + if (test.before[i] == 0) { + numSelectedZeros++; - if (actual[i] != 0) - { - // Should replace zeros with the specified value. - EXPECT_EQ(test.value, actual[i]); - numConvertedZeros++; - } + if (actual[i] != 0) { + // Should replace zeros with the specified value. + EXPECT_EQ(test.value, actual[i]); + numConvertedZeros++; } } - else - { - // Every value not in the selection should be unchanged. - EXPECT_EQ(test.before[i], actual[i]); - } - - if (test.before[i] != 0) - { - // Every value that was nonzero should not have changed. - EXPECT_EQ(test.before[i], actual[i]); - } + } else { + // Every value not in the selection should be unchanged. + EXPECT_EQ(test.before[i], actual[i]); } - // Should replace numNewNonZerosPerRow, or all of them if there aren't - // that many. - EXPECT_EQ(std::min(requestedNumNewZeros, numSelectedZeros), - numConvertedZeros); + if (test.before[i] != 0) { + // Every value that was nonzero should not have changed. + EXPECT_EQ(test.before[i], actual[i]); + } } + + // Should replace numNewNonZerosPerRow, or all of them if there aren't + // that many. + EXPECT_EQ(std::min(requestedNumNewZeros, numSelectedZeros), + numConvertedZeros); } } +} +struct IncreaseRowNonZeroCountsOnOuterToTest { + string name; + UInt32 nrows; + UInt32 ncols; + vector before; + vector outerRows; + vector outerCols; + int numDesiredNonzeros; + Real32 value; +}; - struct IncreaseRowNonZeroCountsOnOuterToTest - { - string name; - UInt32 nrows; - UInt32 ncols; - vector before; - vector outerRows; - vector outerCols; - int numDesiredNonzeros; - Real32 value; - }; - - TEST(SparseMatrixTest, increaseRowNonZeroCountsOnOuterTo) - { - Random rng; +TEST(SparseMatrixTest, increaseRowNonZeroCountsOnOuterTo) { + Random rng; - vector tests = { - { - "Test 1", - // Dimensions - 8, 6, - // Before - {1, 1, 0, 0, 1, 1, - 0, 0, 1, 1, 0, 0, - 0, 0, 1, 0, 0, 1, - 1, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, - 0, 0, 1, 1, 0, 1}, - // Selection - {0, 3, 4, 5, 6, 7}, - {0, 3, 4}, - // numDesiredNonzeros - 2, - // value - 42 - }, - { - "No selected rows", - // Dimensions - 8, 6, - // Before - {1, 1, 0, 0, 1, 1, - 0, 0, 1, 1, 0, 0, - 0, 0, 1, 0, 0, 1, - 1, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, - 0, 0, 1, 1, 0, 1}, - // Selection - {}, - {0, 3, 4}, - // numDesiredNonzeros - 2, - // value - 42 - }, - { - "No selected cols", - // Dimensions - 8, 6, - // Before - {1, 1, 0, 0, 1, 1, - 0, 0, 1, 1, 0, 0, - 0, 0, 1, 0, 0, 1, - 1, 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, - 0, 0, 1, 1, 0, 1}, - // Selection - {0, 3, 4, 5, 6, 7}, - {}, - // numDesiredNonzeros - 2, - // value - 42 - }, - { - "Try to catch unsigned integer bugs", - // Dimensions - 2, 4, - // Before - {1, 1, 0, 0, - 1, 1, 1, 0}, - // Selection - {0, 1}, - {0, 1, 2, 3}, - // numDesiredNonzeros - 2, - // value - 42 - } - }; + vector tests = { + {"Test 1", + // Dimensions + 8, + 6, + // Before + {1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1}, + // Selection + {0, 3, 4, 5, 6, 7}, + {0, 3, 4}, + // numDesiredNonzeros + 2, + // value + 42}, + {"No selected rows", + // Dimensions + 8, + 6, + // Before + {1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1}, + // Selection + {}, + {0, 3, 4}, + // numDesiredNonzeros + 2, + // value + 42}, + {"No selected cols", + // Dimensions + 8, + 6, + // Before + {1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1}, + // Selection + {0, 3, 4, 5, 6, 7}, + {}, + // numDesiredNonzeros + 2, + // value + 42}, + {"Try to catch unsigned integer bugs", + // Dimensions + 2, + 4, + // Before + {1, 1, 0, 0, 1, 1, 1, 0}, + // Selection + {0, 1}, + {0, 1, 2, 3}, + // numDesiredNonzeros + 2, + // value + 42}}; - for (const IncreaseRowNonZeroCountsOnOuterToTest& test : tests) - { - NTA_INFO << "Test: " << test.name; + for (const IncreaseRowNonZeroCountsOnOuterToTest &test : tests) { + NTA_INFO << "Test: " << test.name; - SparseMatrix<> m(test.nrows, test.ncols, test.before.begin()); + SparseMatrix<> m(test.nrows, test.ncols, test.before.begin()); - m.increaseRowNonZeroCountsOnOuterTo( - test.outerRows.begin(), test.outerRows.end(), - test.outerCols.begin(), test.outerCols.end(), - test.numDesiredNonzeros, test.value, rng); + m.increaseRowNonZeroCountsOnOuterTo( + test.outerRows.begin(), test.outerRows.end(), test.outerCols.begin(), + test.outerCols.end(), test.numDesiredNonzeros, test.value, rng); - vector actual(test.nrows*test.ncols); - m.toDense(actual.begin()); + vector actual(test.nrows * test.ncols); + m.toDense(actual.begin()); - for (UInt32 row = 0; row < test.nrows; row++) - { - int numSelectedZeros = 0; - int numConvertedZeros = 0; + for (UInt32 row = 0; row < test.nrows; row++) { + int numSelectedZeros = 0; + int numConvertedZeros = 0; - for (UInt32 col = 0; col < test.ncols; col++) - { - UInt32 i = row*test.ncols + col; - if (std::binary_search(test.outerRows.begin(), test.outerRows.end(), row) && - std::binary_search(test.outerCols.begin(), test.outerCols.end(), col)) - { - if (test.before[i] == 0) - { - numSelectedZeros++; + for (UInt32 col = 0; col < test.ncols; col++) { + UInt32 i = row * test.ncols + col; + if (std::binary_search(test.outerRows.begin(), test.outerRows.end(), + row) && + std::binary_search(test.outerCols.begin(), test.outerCols.end(), + col)) { + if (test.before[i] == 0) { + numSelectedZeros++; - if (actual[i] != 0) - { - // Should replace zeros with the specified value. - EXPECT_EQ(test.value, actual[i]); - numConvertedZeros++; - } + if (actual[i] != 0) { + // Should replace zeros with the specified value. + EXPECT_EQ(test.value, actual[i]); + numConvertedZeros++; } } - else - { - // Every value not in the selection should be unchanged. - EXPECT_EQ(test.before[i], actual[i]); - } - - if (test.before[i] != 0) - { - // Every value that was nonzero should not have changed. - EXPECT_EQ(test.before[i], actual[i]); - } + } else { + // Every value not in the selection should be unchanged. + EXPECT_EQ(test.before[i], actual[i]); } - int numSelectedNonZeros = test.outerCols.size() - numSelectedZeros; - int numDesiredToAdd = std::max(0, - test.numDesiredNonzeros - - numSelectedNonZeros); - int numToAdd = std::min(numSelectedZeros, numDesiredToAdd); - EXPECT_EQ(numToAdd, numConvertedZeros); + if (test.before[i] != 0) { + // Every value that was nonzero should not have changed. + EXPECT_EQ(test.before[i], actual[i]); + } } + + int numSelectedNonZeros = test.outerCols.size() - numSelectedZeros; + int numDesiredToAdd = + std::max(0, test.numDesiredNonzeros - numSelectedNonZeros); + int numToAdd = std::min(numSelectedZeros, numDesiredToAdd); + EXPECT_EQ(numToAdd, numConvertedZeros); } } +} +struct ClipRowsBelowAndAboveTest { + string name; + UInt32 nrows; + UInt32 ncols; + vector before; + vector selectedRows; + Real32 lower; + Real32 upper; + vector expected; +}; - struct ClipRowsBelowAndAboveTest - { - string name; - UInt32 nrows; - UInt32 ncols; - vector before; - vector selectedRows; - Real32 lower; - Real32 upper; - vector expected; - }; - - TEST(SparseMatrixTest, clipRowsBelowAndAbove) - { - vector tests = { - { - "Test 1", - // Dimensions - 3, 5, - // Before - {-5, -4, 0.5, 4, 5, - -5, -4, 0.5, 4, 5, - -5, -4, 0.5, 4, 5}, - // Selection - {0, 2}, - // Boundaries - -4, 4, - // Expected - {-4, -4, 0.5, 4, 4, - -5, -4, 0.5, 4, 5, - -4, -4, 0.5, 4, 4} - }, - { - "Test 2", - // Dimensions - 3, 5, - // Before - {-5, -4, 0.5, 4, 5, - -5, -4, 0.5, 4, 5, - -5, -4, 0.5, 4, 5}, - // Selection - {1, 2}, - // Boundaries - 0, 3, - // Expected - {-5, -4, 0.5, 4, 5, - 0, 0, 0.5, 3, 3, - 0, 0, 0.5, 3, 3} - } - }; +TEST(SparseMatrixTest, clipRowsBelowAndAbove) { + vector tests = { + {"Test 1", + // Dimensions + 3, + 5, + // Before + {-5, -4, 0.5, 4, 5, -5, -4, 0.5, 4, 5, -5, -4, 0.5, 4, 5}, + // Selection + {0, 2}, + // Boundaries + -4, + 4, + // Expected + {-4, -4, 0.5, 4, 4, -5, -4, 0.5, 4, 5, -4, -4, 0.5, 4, 4}}, + {"Test 2", + // Dimensions + 3, + 5, + // Before + {-5, -4, 0.5, 4, 5, -5, -4, 0.5, 4, 5, -5, -4, 0.5, 4, 5}, + // Selection + {1, 2}, + // Boundaries + 0, + 3, + // Expected + {-5, -4, 0.5, 4, 5, 0, 0, 0.5, 3, 3, 0, 0, 0.5, 3, 3}}}; - for (const ClipRowsBelowAndAboveTest& test : tests) - { - NTA_INFO << "Test: " << test.name; - SparseMatrix<> m(test.nrows, test.ncols, test.before.begin()); + for (const ClipRowsBelowAndAboveTest &test : tests) { + NTA_INFO << "Test: " << test.name; + SparseMatrix<> m(test.nrows, test.ncols, test.before.begin()); - m.clipRowsBelowAndAbove( - test.selectedRows.begin(), test.selectedRows.end(), - test.lower, test.upper); + m.clipRowsBelowAndAbove(test.selectedRows.begin(), test.selectedRows.end(), + test.lower, test.upper); - vector actual(test.nrows * test.ncols); - m.toDense(actual.begin()); + vector actual(test.nrows * test.ncols); + m.toDense(actual.begin()); - EXPECT_EQ(test.expected, actual); - } + EXPECT_EQ(test.expected, actual); } +} } // end anonymous namespace - //#include // #include // #include "SparseMatrixUnitTest.hpp" @@ -846,3842 +648,3873 @@ namespace // namespace nupic { // -#define TEST_LOOP(M) \ - for (nrows = 0, ncols = M, zr = 15; \ - nrows < M; \ - nrows += M/10, ncols -= M/10, zr = ncols/10) \ -// -// #define M 64 -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_construction() -// { -// UInt ncols, nrows, zr; -// -// { // Deallocate an empty matrix -// SparseMat sm; -// Test("empty matrix 1", sm.isZero(), true); -// } -// -// { // Compact and deallocate an empty matrix -// SparseMat sm; -// Test("empty matrix 2", sm.isZero(), true); -// sm.compact(); -// Test("empty matrix 2 - compact", sm.isZero(), true); -// } -// -// { // De-compact and deallocate an empty matrix -// SparseMat sm; -// Test("empty matrix 3", sm.isZero(), true); -// sm.decompact(); -// Test("empty matrix 3 - decompact", sm.isZero(), true); -// } -// -// { // De-compact/compact and deallocate an empty matrix -// SparseMat sm; -// Test("empty matrix 4", sm.isZero(), true); -// sm.decompact(); -// Test("empty matrix 4 - decompact", sm.isZero(), true); -// sm.compact(); -// Test("empty matrix 4 - compact", sm.isZero(), true); -// } -// -// { // Compact and deallocate an empty matrix -// SparseMat sm(0, 0); -// Test("empty matrix 5", sm.isZero(), true); -// sm.compact(); -// Test("empty matrix 5 - compact", sm.isZero(), true); -// } -// -// { // De-compact and deallocate an empty matrix -// SparseMat sm(0, 0); -// Test("empty matrix 6", sm.isZero(), true); -// sm.decompact(); -// Test("empty matrix 6 - decompact", sm.isZero(), true); -// } -// -// { // De-compact/compact and deallocate an empty matrix -// SparseMat sm(0, 0); -// Test("empty matrix 7", sm.isZero(), true); -// sm.decompact(); -// Test("empty matrix 7 - decompact", sm.isZero(), true); -// sm.compact(); -// Test("empty matrix 7 - compact", sm.isZero(), true); -// } -// -// { // Rectangular shape, no zeros -// nrows = 3; ncols = 4; -// DenseMat dense(nrows, ncols, 0); -// SparseMat sm(nrows, ncols, dense.begin()); -// Compare(dense, sm, "ctor 1"); -// Test("isZero 1", sm.isZero(), false); -// sm.compact(); -// Compare(dense, sm, "ctor 1 - compact"); -// Test("isZero 1 - compact", sm.isZero(), false); -// } -// -// { // Rectangular shape, zeros -// nrows = 3; ncols = 4; -// DenseMat dense(nrows, ncols, 2); -// SparseMat sm(nrows, ncols, dense.begin()); -// Compare(dense, sm, "ctor 2"); -// Test("isZero 2", sm.isZero(), false); -// sm.compact(); -// Compare(dense, sm, "ctor 2 - compact"); -// Test("isZero 2 - compact", sm.isZero(), false); -// } -// -// { // Rectangular the other way, no zeros -// nrows = 4; ncols = 3; -// DenseMat dense(nrows, ncols, 0); -// SparseMat sm(nrows, ncols, dense.begin()); -// Compare(dense, sm, "ctor 3"); -// Test("isZero 3", sm.isZero(), false); -// sm.compact(); -// Compare(dense, sm, "ctor 3 - compact"); -// Test("isZero 3 - compact", sm.isZero(), false); -// } -// -// { // Rectangular the other way, zeros -// nrows = 6; ncols = 5; -// DenseMat dense(nrows, ncols, 2); -// SparseMat sm(nrows, ncols, dense.begin()); -// Compare(dense, sm, "ctor 4"); -// Test("isZero 4", sm.isZero(), false); -// sm.compact(); -// Compare(dense, sm, "ctor 4 - compact"); -// Test("isZero 4 - compact", sm.isZero(), false); -// } -// -// { // Empty rows in the middle and zeros -// nrows = 3; ncols = 4; -// DenseMat dense(nrows, ncols, 2, false, true); -// SparseMat sm(nrows, ncols, dense.begin()); -// Compare(dense, sm, "ctor 5"); -// Test("isZero 5", sm.isZero(), false); -// sm.compact(); -// Compare(dense, sm, "ctor 5 - compact"); -// Test("isZero 5 - compact", sm.isZero(), false); -// } -// -// { // Empty rows in the middle and zeros -// nrows = 7; ncols = 5; -// DenseMat dense(nrows, ncols, 2, false, true); -// SparseMat sm(nrows, ncols, dense.begin()); -// Compare(dense, sm, "ctor 6"); -// Test("isZero 6", sm.isZero(), false); -// sm.compact(); -// Compare(dense, sm, "ctor 6 - compact"); -// Test("isZero 6 - compact", sm.isZero(), false); -// } -// -// { // Small values, zeros and empty rows -// nrows = 7; ncols = 5; -// DenseMat dense(nrows, ncols, 2, true, true, rng_); -// SparseMat sm(nrows, ncols, dense.begin()); -// Compare(dense, sm, "ctor 7"); -// Test("isZero 7", sm.isZero(), false); -// sm.compact(); -// Compare(dense, sm, "ctor 7 - compact"); -// Test("isZero 7 - compact", sm.isZero(), false); -// } -// -// { // Small values, zeros and empty rows, other constructor -// nrows = 10; ncols = 10; -// DenseMat dense(nrows, ncols, 2, true, true, rng_); -// SparseMat sm(0, ncols); -// for (UInt i = 0; i < nrows; ++i) -// sm.addRow(dense.begin(i)); -// Compare(dense, sm, "ctor 8"); -// Test("isZero 8", sm.isZero(), false); -// sm.compact(); -// Compare(dense, sm, "ctor 8 - compact"); -// Test("isZero 8 - compact", sm.isZero(), false); -// } -// -// { // Zero first row -// nrows = 10; ncols = 10; -// DenseMat dense(nrows, ncols, 2, true, true, rng_); -// for (UInt i = 0; i < ncols; ++i) -// dense.at(0, i) = 0; -// SparseMat sm(0, ncols); -// for (UInt i = 0; i < nrows; ++i) -// sm.addRow(dense.begin(i)); -// Compare(dense, sm, "ctor 8B"); -// Test("isZero 8B", sm.isZero(), false); -// sm.compact(); -// Compare(dense, sm, "ctor 8B - compact"); -// Test("isZero 8B - compact", sm.isZero(), false); -// } -// -// { // Small values, zeros and empty rows, other constructor -// nrows = 10; ncols = 10; -// DenseMat dense(nrows, ncols, 2, true, true, rng_); -// SparseMat sm(0, ncols); -// for (UInt i = 0; i < nrows; ++i) -// sm.addRow(dense.begin(i)); -// Compare(dense, sm, "ctor 9"); -// Test("isZero 9", sm.isZero(), false); -// sm.compact(); -// Compare(dense, sm, "ctor 9 - compact"); -// Test("isZero 9 - compact", sm.isZero(), false); -// } -// -// { // Small values, zeros and empty rows, other constructor -// nrows = 10; ncols = 10; -// DenseMat dense(nrows, ncols, 2, true, true, rng_); -// SparseMat sm(0, ncols); -// for (UInt i = 0; i < nrows; ++i) -// sm.addRow(dense.begin(i)); -// Compare(dense, sm, "ctor 10"); -// Test("isZero 10", sm.isZero(), false); -// sm.compact(); -// Compare(dense, sm, "ctor 10 - compact"); -// Test("isZero 10 - compact", sm.isZero(), false); -// } -// -// { // Empty -// DenseMat dense(10, 10, 10); -// SparseMat sm(10, 10, dense.begin()); -// Compare(dense, sm, "ctor from empty dense - non compact"); -// Test("isZero 11", sm.isZero(), true); -// sm.compact(); -// Compare(dense, sm, "ctor from empty dense - compact"); -// Test("isZero 11 - compact", sm.isZero(), true); -// } -// -// { // Empty, other constructor -// DenseMat dense(10, 10, 10); -// SparseMat sm(0, 10); -// for (UInt i = 0; i < nrows; ++i) -// sm.addRow(dense.begin(i)); -// Compare(dense, sm, "ctor from empty dense - non compact"); -// Test("isZero 12", sm.isZero(), true); -// sm.compact(); -// Compare(dense, sm, "ctor from empty dense - compact"); -// Test("isZero 12 - compact", sm.isZero(), true); -// } -// -// { // Full -// DenseMat dense(10, 10, 0); -// SparseMat sm(10, 10, dense.begin()); -// Compare(dense, sm, "ctor from full dense - non compact"); -// Test("isZero 13", sm.isZero(), false); -// sm.compact(); -// Compare(dense, sm, "ctor from full dense - compact"); -// Test("isZero 13 - compact", sm.isZero(), false); -// } -// -// { // Various rectangular sizes -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sm(nrows, ncols, dense.begin()); -// -// sm.decompact(); -// -// { -// std::stringstream str; -// str << "ctor A " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// Compare(dense, sm, str.str().c_str()); -// } -// -// sm.compact(); -// -// { -// std::stringstream str; -// str << "ctor B " << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// Compare(dense, sm, str.str().c_str()); -// } -// } -// } -// -// /* -// try { -// SparseMatrix sme1(-1, 0); -// Test("SparseMatrix::SparseMatrix(Int, Int) exception 2", true, false); -// } catch (std::exception&) { -// Test("SparseMatrix::SparseMatrix(Int, Int) exception 2", true, true); -// } -// -// try { -// SparseMatrix sme1(1, -1); -// Test("SparseMatrix::SparseMatrix(Int, Int) exception 3", true, false); -// } catch (std::exception&) { -// Test("SparseMatrix::SparseMatrix(Int, Int) exception 3", true, true); -// } -// -// try { -// SparseMatrix sme1(1, -1); -// Test("SparseMatrix::SparseMatrix(Int, Int) exception 4", true, false); -// } catch (std::exception&) { -// Test("SparseMatrix::SparseMatrix(Int, Int) exception 4", true, true); -// } -// -// std::vector mat(16, 0); -// -// try { -// SparseMatrix sme1(-1, 1, mat.begin()); -// Test("SparseMatrix::SparseMatrix(Int, Int, Iter) exception 1", true, false); -// } catch (std::exception&) { -// Test("SparseMatrix::SparseMatrix(Int, Iter) exception 1", true, true); -// } -// -// try { -// SparseMatrix sme1(1, -1, mat.begin()); -// Test("SparseMatrix::SparseMatrix(Int, Int, Iter) exception 2", true, false); -// } catch (std::exception&) { -// Test("SparseMatrix::SparseMatrix(Int, Iter) exception 2", true, true); -// } -// */ -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_copy() -// { -// { -// SparseMat sm, sm2; -// DenseMat dense, dense2; -// sm2.copy(sm); -// dense2.copy(dense); -// Compare(dense2, sm2, "SparseMatrix::copy - empty matrix"); -// } -// -// { -// SparseMat sm(0,0), sm2; -// DenseMat dense(0,0), dense2; -// sm2.copy(sm); -// dense2.copy(dense); -// Compare(dense2, sm2, "SparseMatrix::copy - empty matrix 2"); -// } -// -// { -// SparseMat sm(5, 4), sm2; -// DenseMat dense(5, 4), dense2; -// sm2.copy(sm); -// dense2.copy(dense); -// Compare(dense2, sm2, "SparseMatrix::copy - empty matrix 3"); -// } -// -// { -// DenseMat dense(5, 4, 2, false, false), dense2; -// SparseMat sm(5, 4, dense.begin()), sm2; -// sm2.copy(sm); -// dense2.copy(dense); -// Compare(dense2, sm2, "SparseMatrix::copy - 1"); -// } -// -// { -// DenseMat dense(5, 4, 2, false, true), dense2; -// SparseMat sm(5, 4, dense.begin()), sm2; -// sm2.copy(sm); -// dense2.copy(dense); -// Compare(dense2, sm2, "SparseMatrix::copy - 1"); -// } -// -// { -// DenseMat dense(5, 4, 2, true, false, rng_), dense2; -// SparseMat sm(5, 4, dense.begin()), sm2; -// sm2.copy(sm); -// dense2.copy(dense); -// Compare(dense2, sm2, "SparseMatrix::copy - 1"); -// } -// -// { -// DenseMat dense(5, 4, 2, true, true, rng_), dense2; -// SparseMat sm(5, 4, dense.begin()), sm2; -// sm2.copy(sm); -// dense2.copy(dense); -// Compare(dense2, sm2, "SparseMatrix::copy - 1"); -// } -// } -// -// //-------------------------------------------------------------------------------- -// /** -// * TC: Dense::toCSR matches SparseMatrix::toCSR (in stress test) -// * TC: Dense::fromCSR matches SparseMatrix::fromCSR (in stress test) -// * TC: reading in smaller matrix resizes the sparse matrix correctly -// * TC: reading in larger matrix resizes the sparse matrix correctly -// * TC: empty rows are stored correctly in stream -// * TC: empty rows are read correctly from stream -// * TC: empty matrix is written and read correctly -// * TC: values below epsilon are handled correctly in toCSR -// * TC: values below epsilon are handled correctly in fromCSR -// * TC: toCSR exception if bad stream -// * TC: fromCSR exception if bad stream -// * TC: fromCSR exception if bad 'csr' tag -// * TC: fromCSR exception if nrows < 0 -// * TC: fromCSR exception if ncols <= 0 -// * TC: fromCSR exception if nnz < 0 or nnz > nrows * ncols -// * TC: fromCSR exception if nnzr < 0 or nnzr > ncols -// * TC: fromCSR exception if j < 0 or j >= ncols -// * TC: stress test -// * TC: allocate_ exceptions -// * TC: addRow exceptions -// * TC: compact exceptions -// */ -// void SparseMatrixUnitTest::unit_test_csr() -// { -// UInt ncols, nrows, zr; -// -// { // Empty matrix -// // ... is written correctly -// SparseMat sm(3, 4); -// std::stringstream buf; -// sm.toCSR(buf); -// Test("SparseMatrix::toCSR empty", buf.str() == "sm_csr_1.5 12 3 4 0 0 0 0 ", true); -// -// // ... is read correctly -// SparseMat sm2; -// sm2.fromCSR(buf); -// std::stringstream buf2; -// buf2 << "csr 3 4 0 0 0 0"; -// DenseMat dense; -// dense.fromCSR(buf2); -// Compare(dense, sm2, "fromCSR/empty"); -// } -// -// { // Is resizing happening correctly? -// DenseMat dense(3, 4, 2); -// SparseMat sm(3, 4, dense.begin()); -// -// { // When reading in smaller size matrix? -// std::stringstream buf1, buf2; -// buf1 << "csr -1 3 3 9 3 0 1 1 2 2 3 3 0 11 1 12 2 13 3 0 21 1 22 2 23"; -// sm.fromCSR(buf1); -// buf2 << "csr 3 3 9 3 0 1 1 2 2 3 3 0 11 1 12 2 13 3 0 21 1 22 2 23"; -// dense.fromCSR(buf2); -// Compare(dense, sm, "fromCSR/redim/1"); -// } -// -// { // When reading in larger size matrix? -// std::stringstream buf1, buf2; -// buf1 << "csr -1 4 5 20 " -// "5 0 1 1 2 2 3 3 4 4 5 " -// "5 0 11 1 12 2 13 3 14 4 15 " -// "5 0 21 1 22 2 23 3 24 4 25 " -// "5 0 31 1 32 2 33 3 34 4 35"; -// sm.fromCSR(buf1); -// buf2 << "csr 4 5 20 " -// "5 0 1 1 2 2 3 3 4 4 5 " -// "5 0 11 1 12 2 13 3 14 4 15 " -// "5 0 21 1 22 2 23 3 24 4 25 " -// "5 0 31 1 32 2 33 3 34 4 35"; -// dense.fromCSR(buf2); -// Compare(dense, sm, "fromCSR/redim/2"); -// } -// -// { // Empty rows are read in correctly -// std::stringstream buf1, buf2; -// buf1 << "csr -1 4 5 15 " -// "5 0 1 1 2 2 3 3 4 4 5 " -// "0 " -// "5 0 21 1 22 2 23 3 24 4 25 " -// "5 0 31 1 32 2 33 3 34 4 35"; -// sm.fromCSR(buf1); -// buf2 << "csr 4 5 15 " -// "5 0 1 1 2 2 3 3 4 4 5 " -// "0 " -// "5 0 21 1 22 2 23 3 24 4 25 " -// "5 0 31 1 32 2 33 3 34 4 35"; -// dense.fromCSR(buf2); -// Compare(dense, sm, "fromCSR/redim/3"); -// } -// } -// -// { // Initialize fromDenseMat then again fromCSR -// DenseMat dense(3, 4, 2); -// SparseMat sm(3, 4, dense.begin()); -// std::stringstream buf1; -// buf1 << "csr -1 3 3 9 3 0 1 1 2 2 3 3 0 11 1 12 2 13 3 0 21 1 22 2 23"; -// sm.fromCSR(buf1); -// } -// -// { // ... and vice-versa, fromCSR, followed by fromDense -// DenseMat dense(3, 4, 2); -// SparseMat sm(3, 4); -// std::stringstream buf1; -// buf1 << "csr -1 3 3 9 3 0 1 1 2 2 3 3 0 11 1 12 2 13 3 0 21 1 22 2 23"; -// sm.fromCSR(buf1); -// sm.fromDense(3, 4, dense.begin()); -// } -// -// { // Values below epsilon -// -// // ... are written correctly (not written) -// nrows = 128; ncols = 256; -// UInt nnz = ncols/2; -// DenseMat dense(nrows, ncols, nnz, true, true, rng_); -// ITER_2(128, 256) -// dense.at(i,j) /= 1000; -// SparseMat sm(nrows, ncols, dense.begin()); -// std::stringstream buf; -// sm.toCSR(buf); -// std::string tag; buf >> tag; -// buf >> nrows >> ncols >> nnz; -// ITER_1(nrows) { -// buf >> nnz; -// UInt j; Real val; -// ITER_1(nnz) { -// buf >> j >> val; -// if (nupic::nearlyZero(val)) -// Test("SparseMatrix::toCSR/small values", true, false); -// } -// } -// -// // ... are read correctly -// std::stringstream buf1; -// buf1 << "csr -1 3 4 6 " -// << "2 0 " << nupic::Epsilon/2 << " 1 1 " -// << "2 0 " << nupic::Epsilon/2 << " 1 " << nupic::Epsilon/2 << " " -// << "2 0 1 1 1"; -// SparseMat sm2(4, 4); -// sm2.fromCSR(buf1); -// } -// -// { // stress test, matching against Dense::toCSR and Dense::fromCSR -// TEST_LOOP(M) { -// -// DenseMat dense3(nrows, ncols, zr); -// SparseMat sm3(nrows, ncols, dense3.begin()); -// -// std::stringstream buf; -// sm3.toCSR(buf); -// sm3.fromCSR(buf); -// -// { -// std::stringstream str; -// str << "toCSR/fromCSR A " << nrows << "X" << ncols << "/" << zr; -// Compare(dense3, sm3, str.str().c_str()); -// } -// -// SparseMat sm4(nrows, ncols); -// std::stringstream buf1; -// sm3.toCSR(buf1); -// sm4.fromCSR(buf1); -// -// { -// std::stringstream str; -// str << "toCSR/fromCSR B " << nrows << "X" << ncols << "/" << zr; -// Compare(dense3, sm4, str.str().c_str()); -// } -// -// sm4.decompact(); -// std::stringstream buf2; -// sm3.toCSR(buf2); -// sm4.fromCSR(buf2); -// -// { -// std::stringstream str; -// str << "toCSR/fromCSR C " << nrows << "X" << ncols << "/" << zr; -// Compare(dense3, sm4, str.str().c_str()); -// } -// -// std::stringstream buf3; -// sm4.toCSR(buf3); -// sm4.fromCSR(buf3); -// -// { -// std::stringstream str; -// str << "toCSR/fromCSR D " << nrows << "X" << ncols << "/" << zr; -// Compare(dense3, sm4, str.str().c_str()); -// } -// } -// } -// -// /* -// // Exceptions -// SparseMatrix sme1(1, 1); -// -// { -// stringstream s1; -// s1 << "ijv"; -// try { -// sme1.fromCSR(s1); -// Test("SparseMatrix::fromCSR() exception 1", true, false); -// } catch (std::runtime_error&) { -// Test("SparseMatrix::fromCSR() exception 1", true, true); -// } -// } -// -// { -// stringstream s1; -// s1 << "csr -1 -1"; -// try { -// sme1.fromCSR(s1); -// Test("SparseMatrix::fromCSR() exception 2", true, false); -// } catch (std::runtime_error&) { -// Test("SparseMatrix::fromCSR() exception 2", true, true); -// } -// } -// -// { -// stringstream s1; -// s1 << "csr -1 1 -1"; -// try { -// sme1.fromCSR(s1); -// Test("SparseMatrix::fromCSR() exception 3", true, false); -// } catch (std::runtime_error&) { -// Test("SparseMatrix::fromCSR() exception 3", true, true); -// } -// } -// -// { -// stringstream s1; -// s1 << "csr -1 1 0"; -// try { -// sme1.fromCSR(s1); -// Test("SparseMatrix::fromCSR() exception 4", true, false); -// } catch (std::runtime_error&) { -// Test("SparseMatrix::fromCSR() exception 4", true, true); -// } -// } -// -// { -// stringstream s1; -// s1 << "csr -1 4 3 -1"; -// try { -// sme1.fromCSR(s1); -// Test("SparseMatrix::fromCSR() exception 5", true, false); -// } catch (std::runtime_error&) { -// Test("SparseMatrix::fromCSR() exception 5", true, true); -// } -// } -// -// { -// stringstream s1; -// s1 << "csr -1 4 3 15"; -// try { -// sme1.fromCSR(s1); -// Test("SparseMatrix::fromCSR() exception 6", true, false); -// } catch (std::runtime_error&) { -// Test("SparseMatrix::fromCSR() exception 6", true, true); -// } -// } -// -// { -// stringstream s1; -// s1 << "csr -1 2 3 1 5"; -// try { -// sme1.fromCSR(s1); -// Test("SparseMatrix::fromCSR() exception 7", true, false); -// } catch (std::runtime_error&) { -// Test("SparseMatrix::fromCSR() exception 7", true, true); -// } -// } -// -// { -// stringstream s1; -// s1 << "csr -1 2 3 1 0 1 -1"; -// try { -// sme1.fromCSR(s1); -// Test("SparseMatrix::fromCSR() exception 8", true, false); -// } catch (std::runtime_error&) { -// Test("SparseMatrix::fromCSR() exception 8", true, true); -// } -// } -// -// { -// stringstream s1; -// s1 << "csr -1 2 3 1 0 1 4"; -// try { -// sme1.fromCSR(s1); -// Test("SparseMatrix::fromCSR() exception 9", true, false); -// } catch (std::runtime_error&) { -// Test("SparseMatrix::fromCSR() exception 9", true, true); -// } -// } -// */ -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_dense() -// { -// UInt ncols = 5, nrows = 7, zr = 2; -// -// DenseMat dense(nrows, ncols, zr); -// DenseMat dense2(nrows+1, ncols+1, zr+1); -// -// { // fromDense -// SparseMat sparse(nrows, ncols); -// sparse.fromDense(nrows, ncols, dense.begin()); -// Compare(dense, sparse, "fromDenseMat 1"); -// } -// -// { // fromDense -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// sparse.fromDense(nrows+1, ncols+1, dense2.begin()); -// Compare(dense2, sparse, "fromDenseMat 2"); -// -// sparse.decompact(); -// sparse.fromDense(nrows, ncols, dense.begin()); -// Compare(dense, sparse, "fromDenseMat 3"); -// -// sparse.compact(); -// sparse.fromDense(nrows+1, ncols+1, dense2.begin()); -// Compare(dense2, sparse, "fromDenseMat 4"); -// -// std::vector mat((nrows+1)*(ncols+1), 0); -// -// sparse.toDense(mat.begin()); -// sparse.fromDense(nrows+1, ncols+1, mat.begin()); -// Compare(dense2, sparse, "toDenseMat 1"); -// } -// -// { -// TEST_LOOP(M) { -// -// DenseMat dense3(nrows, ncols, zr); -// SparseMat sm3(nrows, ncols, dense3.begin()); -// std::vector mat3(nrows*ncols, 0); -// -// sm3.toDense(mat3.begin()); -// sm3.fromDense(nrows, ncols, mat3.begin()); -// -// { -// std::stringstream str; -// str << "toDense/fromDenseMat A " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// Compare(dense3, sm3, str.str().c_str()); -// } -// -// sm3.compact(); -// -// { -// std::stringstream str; -// str << "toDense/fromDenseMat B " << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// Compare(dense3, sm3, str.str().c_str()); -// } -// } -// } -// -// { // What happens if dense matrix is full? -// nrows = ncols = 10; zr = 0; -// DenseMat dense(nrows, ncols, zr); -// SparseMat sm(nrows, ncols, dense.begin()); -// std::vector mat3(nrows*ncols, 0); -// -// sm.toDense(mat3.begin()); -// sm.fromDense(nrows, ncols, mat3.begin()); -// -// Compare(dense, sm, "toDense/fromDenseMat from dense"); -// } -// -// { // What happens if dense matrix is empty? -// nrows = ncols = 10; zr = 10; -// DenseMat dense(nrows, ncols, zr); -// SparseMat sm(nrows, ncols, dense.begin()); -// std::vector mat3(nrows*ncols, 0); -// -// sm.toDense(mat3.begin()); -// sm.fromDense(nrows, ncols, mat3.begin()); -// -// Compare(dense, sm, "toDense/fromDenseMat from dense"); -// } -// -// { // What happens if there are empty rows? -// nrows = ncols = 10; zr = 2; -// DenseMat dense(nrows, ncols, zr); -// for (UInt i = 0; i < ncols; ++i) -// dense.at(2,i) = dense.at(4,i) = dense.at(9,i) = 0; -// -// SparseMat sm(nrows, ncols, dense.begin()); -// std::vector mat3(nrows*ncols, 0); -// -// sm.toDense(mat3.begin()); -// sm.fromDense(nrows, ncols, mat3.begin()); -// -// Compare(dense, sm, "toDense/fromDenseMat from dense"); -// } -// -// { // Is resizing happening correctly? -// DenseMat dense(3, 4, 2); -// SparseMat sm(3, 4, dense.begin()); -// -// DenseMat dense2(5, 5, 4); -// sm.fromDense(5, 5, dense2.begin()); -// Compare(dense2, sm, "fromDense/redim/1"); -// -// DenseMat dense3(2, 2, 2); -// sm.fromDense(2, 2, dense3.begin()); -// Compare(dense3, sm, "fromDense/redim/2"); -// -// DenseMat dense4(10, 10, 8); -// sm.fromDense(10, 10, dense4.begin()); -// Compare(dense4, sm, "fromDense/redim/3"); -// } -// -// /* -// // Exceptions -// SparseMatrix sme1(1, 1); -// -// try { -// sme1.fromDense(-1, 0, dense.begin()); -// Test("SparseMatrix::fromDense() exception 1", true, false); -// } catch (std::exception&) { -// Test("SparseMatrix::fromDense() exception 1", true, true); -// } -// -// try { -// sme1.fromDense(1, -1, dense.begin()); -// Test("SparseMatrix::fromDense() exception 3", true, false); -// } catch (std::exception&) { -// Test("SparseMatrix::fromDense() exception 3", true, true); -// } -// */ -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_compact() -// { -// UInt ncols, nrows, zr; -// ncols = 5; -// nrows = 7; -// zr = 2; -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sm4(nrows, ncols, dense.begin()); -// -// sm4.decompact(); -// Compare(dense, sm4, "decompact 1"); -// -// sm4.compact(); -// Compare(dense, sm4, "compact 1"); -// -// sm4.decompact(); -// Compare(dense, sm4, "decompact 2"); -// -// sm4.compact(); -// Compare(dense, sm4, "compact 2"); -// -// sm4.decompact(); -// sm4.decompact(); -// Compare(dense, sm4, "decompact twice"); -// -// sm4.compact(); -// sm4.compact(); -// Compare(dense, sm4, "compact twice"); -// -// SparseMat sm5(nrows, ncols, dense.begin()); -// DenseMat dense2(nrows+1, ncols+1, zr+1); -// sm5.fromDense(nrows+1, ncols+1, dense2.begin()); -// sm5.compact(); -// Compare(dense2, sm5, "compact 3"); -// -// { -// TEST_LOOP(M) { -// -// DenseMat dense3(nrows, ncols, zr); -// SparseMat sm3(nrows, ncols, dense3.begin()); -// -// sm3.decompact(); -// -// { -// std::stringstream str; -// str << "compact/decompact A " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// Compare(dense3, sm3, str.str().c_str()); -// } -// -// sm3.compact(); -// -// { -// std::stringstream str; -// str << "compact/decompact B " << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// Compare(dense3, sm3, str.str().c_str()); -// } -// } -// } -// -// { -// nrows = ncols = 10; zr = 0; -// DenseMat dense(nrows, ncols, zr); -// SparseMat sm(nrows, ncols, dense.begin()); -// std::vector mat3(nrows*ncols, 0); -// -// sm.decompact(); -// Compare(dense, sm, "decompact on dense"); -// -// sm.compact(); -// Compare(dense, sm, "compact on dense"); -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_threshold() -// { -// UInt nrows = 7, ncols = 5, zr = 2; -// -// if (0) { // Visual tests -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// cout << "Before thresholding at 50" << endl; -// cout << sparse << endl; -// sparse.threshold(50); -// cout << "After:" << endl; -// cout << sparse << endl; -// } -// -// { -// SparseMat sm; -// DenseMat dense; -// sm.threshold(Real(1.0)); -// dense.threshold(Real(1.0)); -// Compare(dense, sm, "threshold 0A"); -// } -// -// { -// SparseMat sm(0, 0); -// DenseMat dense(0, 0); -// sm.threshold(Real(1.0)); -// dense.threshold(Real(1.0)); -// Compare(dense, sm, "threshold 0B"); -// } -// -// { -// SparseMat sm(nrows, ncols); -// DenseMat dense(nrows, ncols); -// sm.threshold(Real(1.0)); -// dense.threshold(Real(1.0)); -// Compare(dense, sm, "threshold 0C"); -// } -// -// { -// DenseMat dense(nrows, ncols, zr); -// for (UInt i = 0; i < nrows; ++i) -// for (UInt j = 0; j < ncols; ++j) -// dense.at(i,j) = rng_->getReal64(); -// -// SparseMat sm4c(nrows, ncols, dense.begin()); -// -// dense.threshold(Real(.8)); -// sm4c.threshold(Real(.8)); -// Compare(dense, sm4c, "threshold 1"); -// -// sm4c.decompact(); -// sm4c.compact(); -// dense.threshold(Real(.7)); -// sm4c.threshold(Real(.7)); -// Compare(dense, sm4c, "threshold 2"); -// } -// -// { -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sm(nrows, ncols, dense.begin()); -// -// sm.decompact(); -// dense.threshold(Real(.8)); -// sm.threshold(Real(.8)); -// -// { -// std::stringstream str; -// str << "threshold A " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// Compare(dense, sm, str.str().c_str()); -// } -// -// sm.compact(); -// dense.threshold(Real(.7)); -// sm.threshold(Real(.7)); -// -// { -// std::stringstream str; -// str << "threshold B " << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// Compare(dense, sm, str.str().c_str()); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_getRow() -// { -// UInt nrows = 5, ncols = 7, zr = 3, i = 0, k = 0; -// -// if (0) { // Tests for visual inspection -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// cout << sparse << endl; -// for (i = 0; i != nrows; ++i) { -// std::vector dense_row(ncols); -// sparse.getRowToDense(i, dense_row.begin()); -// cout << dense_row << endl; -// } -// } -// -// { -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sm(nrows, ncols, dense.begin()); -// -// for (i = 0; i < nrows; ++i) { -// -// std::stringstream str; -// str << "getRowToSparseMat A " << nrows << "X" << ncols -// << "/" << zr << " " << i; -// -// std::vector ind; std::vector nz; -// sm.getRowToSparse(i, back_inserter(ind), back_inserter(nz)); -// -// std::vector d(ncols, 0); -// for (k = 0; k < ind.size(); ++k) -// d[ind[k]] = nz[k]; -// -// CompareVectors(ncols, d.begin(), dense.begin(i), str.str().c_str()); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_getCol() -// { -// UInt nrows = 5, ncols = 7, zr = 3, i = 0, k = 0; -// -// if (0) { // Tests for visual inspection -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// cout << sparse << endl; -// for (i = 0; i != ncols; ++i) { -// std::vector dense_col(nrows); -// sparse.getColToDense(i, dense_col.begin()); -// cout << dense_col << endl; -// } -// } -// -// { -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sm(nrows, ncols, dense.begin()); -// -// for (i = 0; i < nrows; ++i) { -// -// std::stringstream str; -// str << "getRowToSparseMat A " << nrows << "X" << ncols -// << "/" << zr << " " << i; -// -// std::vector ind; std::vector nz; -// sm.getRowToSparse(i, back_inserter(ind), back_inserter(nz)); -// -// std::vector d(ncols, 0); -// for (k = 0; k < ind.size(); ++k) -// d[ind[k]] = nz[k]; -// -// CompareVectors(ncols, d.begin(), dense.begin(i), str.str().c_str()); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_transpose() -// { -// UInt ncols, nrows, zr; -// -// { -// nrows = 8; ncols = 4; zr = ncols - 2; -// Dense dense(nrows, ncols, zr, false, true); -// Dense dense2(ncols, nrows); -// SparseMatrix sm(nrows, ncols, dense.begin()); -// SparseMatrix sm2(ncols, nrows); -// dense.transpose(dense2); -// sm.transpose(sm2); -// Compare(dense2, sm2, "transpose 1"); -// } -// -// { -// for (nrows = 1, zr = 15; nrows < 256; nrows += 25, zr = ncols/10) { -// -// ncols = nrows; -// -// DenseMat dense(nrows, ncols, zr); -// DenseMat dense2(ncols, nrows, zr); -// SparseMat sm(nrows, ncols, dense.begin()); -// SparseMat sm2(ncols, nrows, dense2.begin()); -// -// { -// std::stringstream str; -// str << "transpose A " << nrows << "X" << ncols << "/" << zr; -// -// dense.transpose(dense2); -// sm.transpose(sm2); -// -// Compare(dense2, sm2, str.str().c_str()); -// } -// -// { -// std::stringstream str; -// str << "transpose B " << nrows << "X" << ncols << "/" << zr; -// -// dense2.transpose(dense); -// sm2.transpose(sm); -// -// Compare(dense, sm, str.str().c_str()); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_addRowCol() -// { -// // addRow, compact -// UInt nrows = 5, ncols = 7, zr = 3; -// -// if (0) { // Visual, keep -// -// { // Add dense row -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// for (UInt i = 0; i != nrows; ++i) { -// std::vector nz; -// dense.getRowToDense(i, back_inserter(nz)); -// sparse.addRow(nz.begin()); -// } -// -// cout << sparse << endl; -// } -// -// { // Add sparse row -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// for (UInt i = 0; i != nrows; ++i) { -// std::vector ind; -// std::vector nz; -// dense.getRowToSparse(i, back_inserter(ind), back_inserter(nz)); -// sparse.addRow(ind.begin(), ind.end(), nz.begin()); -// } -// -// cout << sparse << endl; -// } -// -// { // Add dense col -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// for (UInt i = 0; i != ncols; ++i) { -// std::vector nz; -// dense.getColToDense(i, back_inserter(nz)); -// cout << "Adding: " << nz << endl; -// sparse.addCol(nz.begin()); -// } -// -// cout << "After adding columns:" << endl; -// cout << sparse << endl; -// } -// -// { // Add sparse col -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// for (UInt i = 0; i != ncols; ++i) { -// std::vector ind; -// std::vector nz; -// dense.getColToSparse(i, back_inserter(ind), back_inserter(nz)); -// sparse.addCol(ind.begin(), ind.end(), nz.begin()); -// } -// -// cout << sparse << endl; -// } -// } -// -// /* -// TEST_LOOP(M) { -// -// { // Add dense row -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// for (UInt i = 0; i != nrows; ++i) { -// std::vector nz; -// dense.getRowToDense(i, back_inserter(nz)); -// sparse.addRow(nz.begin()); -// } -// -// { -// std::stringstream str; -// str << "addRow A " << nrows << "X" << ncols << "/" << zr; -// Compare(dense, sparse, str.str().c_str()); -// } -// } -// -// { // Add sparse row -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// for (UInt i = 0; i != nrows; ++i) { -// std::vector ind; -// std::vector nz; -// dense.getRowToSparse(i, back_inserter(ind), back_inserter(nz)); -// sparse.addRow(ind.begin(), ind.end(), nz.begin()); -// } -// -// { -// std::stringstream str; -// str << "addRow B " << nrows << "X" << ncols << "/" << zr; -// Compare(dense, sparse, str.str().c_str()); -// } -// } -// -// { // Add dense col -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// for (UInt i = 0; i != ncols; ++i) { -// std::vector nz; -// dense.getColToDense(i, back_inserter(nz)); -// sparse.addCol(nz.begin()); -// } -// -// { -// std::stringstream str; -// str << "addCol A " << nrows << "X" << ncols << "/" << zr; -// Compare(dense, sparse, str.str().c_str()); -// } -// } -// -// { // Add sparse col -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// for (UInt i = 0; i != ncols; ++i) { -// std::vector ind; -// std::vector nz; -// dense.getColToSparse(i, back_inserter(ind), back_inserter(nz)); -// sparse.addCol(ind.begin(), ind.end(), nz.begin()); -// } -// -// { -// std::stringstream str; -// str << "addCol B " << nrows << "X" << ncols << "/" << zr; -// Compare(dense, sparse, str.str().c_str()); -// } -// } -// } -// */ -// -// { -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(0, ncols); -// -// for (UInt i = 0; i < nrows; ++i) { -// sparse.addRow(dense.begin(i)); -// sparse.compact(); -// } -// -// sparse.decompact(); -// -// { -// std::stringstream str; -// str << "addRow C " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// Compare(dense, sparse, str.str().c_str()); -// } -// -// sparse.compact(); -// -// { -// std::stringstream str; -// str << "addRow D " << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// Compare(dense, sparse, str.str().c_str()); -// } -// } -// } -// -// { // Test that negative numbers are handled correctly -// nrows = 4; ncols = 8; zr = 2; -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(0, ncols); -// for (UInt i = 0; i < nrows; ++i) -// for (UInt j = 0; j < ncols; ++j) -// dense.at(i,j) *= -1; -// -// for (UInt i = 0; i < nrows; ++i) { -// sparse.addRow(dense.begin(i)); -// sparse.compact(); -// } -// -// { -// std::stringstream str; -// str << "addRow w/ negative numbers A " -// << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// Compare(dense, sparse, str.str().c_str()); -// } -// -// sparse.decompact(); -// -// { -// std::stringstream str; -// str << "addRow w/ negative numbers A " -// << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// Compare(dense, sparse, str.str().c_str()); -// } -// } -// -// // These tests compiled conditionally, because they are -// // based on asserts rather than checks -// -//#ifdef NTA_ASSERTIONS_ON -// -// /* -// { // "Dirty" rows tests -// UInt ncols = 4; -// SparseMat sm(0, ncols); -// std::vector > dirty_col(ncols); -// -// // Duplicate zeros (assertion) -// for (UInt i = 0; i < ncols; ++i) -// dirty_col[i] = make_pair(0, 0); -// try { -// sm.addRow(dirty_col.begin(), dirty_col.end()); -// Test("SparseMatrix dirty cols 1", true, false); -// } catch (std::exception&) { -// Test("SparseMatrix dirty cols 1", true, true); -// } -// -// // Out of order indices (assertion) -// dirty_col[0].first = 3; -// try { -// sm.addRow(dirty_col.begin(), dirty_col.end()); -// Test("SparseMatrix dirty cols 2", true, false); -// } catch (std::exception&) { -// Test("SparseMatrix dirty cols 2", true, true); -// } -// -// // Indices out of range (assertion) -// dirty_col[0].first = 9; -// try { -// sm.addRow(dirty_col.begin(), dirty_col.end()); -// Test("SparseMatrix dirty cols 3", true, false); -// } catch (std::exception&) { -// Test("SparseMatrix dirty cols 3", true, true); -// } -// -// // Passed in zero (assertion) -// dirty_col[0].second = 0; -// try { -// sm.addRow(dirty_col.begin(), dirty_col.end()); -// Test("SparseMatrix dirty cols 4", true, false); -// } catch (std::exception&) { -// Test("SparseMatrix dirty cols 4", true, true); -// } -// } -// */ -//#endif -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_resize() -// { -// SparseMat sm; -// DenseMat dense; -// -// sm.resize(3,3); dense.resize(3,3); -// ITER_2(3,3) { -// sm.setNonZero(i,j,Real(i*3+j+1)); -// dense.at(i,j) = Real(i*3+j+1); -// } -// Compare(dense, sm, "SparseMatrix::resize() 1"); -// -// sm.resize(1,1); -// dense.resize(1,1); -// Compare(dense, sm, "SparseMatrix::resize() 2"); -// -// sm.resize(3,3); -// dense.resize(3,3); -// Compare(dense, sm, "SparseMatrix::resize() 3"); -// -// sm.resize(3,4); -// dense.resize(3,4); -// ITER_1(3) { -// sm.setNonZero(i,3,1); -// dense.at(i,3) = 1; -// } -// Compare(dense, sm, "SparseMatrix::resize() 4"); -// -// sm.resize(4,4); -// dense.resize(4,4); -// ITER_1(4) { -// sm.setNonZero(3,i,2); -// dense.at(3,i) = 2; -// } -// Compare(dense, sm, "SparseMatrix::resize() 5"); -// -// sm.resize(5,5); -// dense.resize(5,5); -// ITER_1(5) { -// sm.setNonZero(4,i,3); -// sm.setNonZero(i,4,4); -// dense.at(4,i) = 3; -// dense.at(i,4) = 4; -// } -// Compare(dense, sm, "SparseMatrix::resize() 6"); -// -// sm.resize(7,5); -// dense.resize(7,5); -// ITER_1(5) { -// sm.setNonZero(6,i,5); -// dense.at(6,i) = 5; -// } -// Compare(dense, sm, "SparseMatrix::resize() 7"); -// -// sm.resize(7, 7); -// dense.resize(7,7); -// ITER_1(7) { -// sm.setNonZero(i,6,6); -// dense.at(i,6) = 6; -// } -// Compare(dense, sm, "SparseMatrix::resize() 8"); -// -// // Stress test to see the interaction with deleteRows and deleteCols -// for (UInt i = 0; i < 20; ++i) { -// sm.resize(rng_->getUInt32(256), rng_->getUInt32(256)); -// vector del_r; -// for (UInt ii = 0; ii < sm.nRows()/4; ++ii) -// del_r.push_back(2*ii); -// sm.deleteRows(del_r.begin(), del_r.end()); -// vector del_c; -// for (UInt ii = 0; ii < sm.nCols()/4; ++ii) -// del_c.push_back(2*ii); -// sm.deleteCols(del_c.begin(), del_c.end()); -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_deleteRows() -// { -// { // Empty matrix -// UInt nrows = 3, ncols = 3; -// -// { // Empty matrix, empty del -// SparseMat sm; -// vector del; -// sm.deleteRows(del.begin(), del.end()); -// Test("SparseMatrix::deleteRows() 1", sm.nRows(), UInt(0)); -// } -// -// { // Empty matrix, empty del -// SparseMat sm(0,0); -// vector del; -// sm.deleteRows(del.begin(), del.end()); -// Test("SparseMatrix::deleteRows() 2", sm.nRows(), UInt(0)); -// } -// -// { // Empty matrix, empty del -// SparseMat sm(nrows, ncols); -// vector del; -// sm.deleteRows(del.begin(), del.end()); -// Test("SparseMatrix::deleteRows() 3", sm.nRows(), UInt(nrows)); -// } -// -// { // Empty matrix, 1 del -// SparseMat sm(nrows, ncols); -// vector del(1); del[0] = 0; -// sm.deleteRows(del.begin(), del.end()); -// Test("SparseMatrix::deleteRows() 4", sm.nRows(), UInt(2)); -// } -// -// { // Empty matrix, many dels -// SparseMat sm(nrows, ncols); -// vector del(2); del[0] = 0; del[1] = 2; -// sm.deleteRows(del.begin(), del.end()); -// Test("SparseMatrix::deleteRows() 5", sm.nRows(), UInt(1)); -// } -// } // End empty matrix -// -// { // matrix with only 1 row -// { // 1 row, 1 del -// SparseMat sm(0, 3); -// vector del(1); del[0] = 0; -// std::vector v(3); v[0] = 1.5; v[1] = 2.5; v[2] = 3.5; -// -// sm.addRow(v.begin()); -// sm.deleteRows(del.begin(), del.end()); -// Test("SparseMatrix::deleteRows() 1 row A", sm.nRows(), UInt(0)); -// -// // Test that it is harmless to delete an empty matrix -// sm.deleteRows(del.begin(), del.end()); -// Test("SparseMatrix::deleteRows() 1 row B", sm.nRows(), UInt(0)); -// -// sm.addRow(v.begin()); -// sm.deleteRows(del.begin(), del.end()); -// Test("SparseMatrix::deleteRows() 1 row C", sm.nRows(), UInt(0)); -// -// // Again, test that it is harmless to delete an empty matrix -// sm.deleteRows(del.begin(), del.end()); -// Test("SparseMatrix::deleteRows() 1 row D", sm.nRows(), UInt(0)); -// } -// -// { // PLG-68: was failing when adding again because -// // deleteRows was not updating nrows_max_ properly -// SparseMatrix tam; -// vector x(4), del(1, 0); -// x[0] = .5; x[1] = .75; x[2] = 1.0; x[3] = 1.25; -// -// tam.resize(1, 4); -// tam.elementRowApply(0, std::plus(), x.begin()); -// tam.deleteRows(del.begin(), del.end()); -// -// tam.resize(1, 4); -// tam.elementRowApply(0, std::plus(), x.begin()); -// } -// } -// -// { -// UInt nrows, ncols, zr; -// -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// -// { // Empty del -// SparseMat sm(nrows, ncols, dense.begin()); -// vector del; -// sm.deleteRows(del.begin(), del.end()); -// Test("SparseMatrix::deleteRows() 6A", sm.nRows(), nrows); -// } -// -// { // Rows of all zeros 1 -// if (nrows > 2) { -// DenseMat dense2(nrows, ncols, zr); -// ITER_1(nrows) { -// if (i % 2 == 0) { -// for (UInt j = 0; j < ncols; ++j) -// dense2.at(i,j) = 0; -// } -// } -// SparseMat sm(nrows, ncols, dense2.begin()); -// vector del; -// for (UInt i = 2; i < nrows-2; i += 2) -// del.push_back(i); -// sm.deleteRows(del.begin(), del.end()); -// dense2.deleteRows(del.begin(), del.end()); -// Compare(dense2, sm, "SparseMatrix::deleteRows() 6B"); -// } -// } -// -// { // Rows of all zeros 2 -// if (nrows > 2) { -// DenseMat dense2(nrows, ncols, zr); -// ITER_1(nrows) { -// if (i % 2 == 0) { -// for (UInt j = 0; j < ncols; ++j) -// dense2.at(i,j) = 0; -// } -// } -// SparseMat sm(nrows, ncols, dense2.begin()); -// vector del; -// for (UInt i = 1; i < nrows-2; i += 2) -// del.push_back(i); -// sm.deleteRows(del.begin(), del.end()); -// dense2.deleteRows(del.begin(), del.end()); -// Compare(dense2, sm, "SparseMatrix::deleteRows() 6C"); -// } -// } -// -// { // Many dels contiguous -// if (nrows > 2) { -// SparseMat sm(nrows, ncols, dense.begin()); -// DenseMat dense2(nrows, ncols, zr); -// vector del; -// for (UInt i = 2; i < nrows-2; ++i) -// del.push_back(i); -// sm.deleteRows(del.begin(), del.end()); -// dense2.deleteRows(del.begin(), del.end()); -// Compare(dense2, sm, "SparseMatrix::deleteRows() 6D"); -// } -// } -// -// { // Make sure we stop at the end of the dels! -// if (nrows > 2) { -// SparseMat sm(nrows, ncols, dense.begin()); -// DenseMat dense2(nrows, ncols, zr); -// UInt* del = new UInt[nrows-1]; -// for (UInt i = 0; i < nrows-1; ++i) -// del[i] = i + 1; -// sm.deleteRows(del, del + nrows-2); -// dense2.deleteRows(del, del + nrows-2); -// Compare(dense2, sm, "SparseMatrix::deleteRows() 6E"); -// delete [] del; -// } -// } -// -// { // Many dels discontiguous -// SparseMat sm(nrows, ncols, dense.begin()); -// DenseMat dense2(nrows, ncols, zr); -// vector del; -// for (UInt i = 0; i < nrows; i += 2) -// del.push_back(i); -// sm.deleteRows(del.begin(), del.end()); -// dense2.deleteRows(del.begin(), del.end()); -// Compare(dense2, sm, "SparseMatrix::deleteRows() 7"); -// } -// -// { // All rows -// SparseMat sm(nrows, ncols, dense.begin()); -// vector del; -// for (UInt i = 0; i < nrows; ++i) -// del.push_back(i); -// sm.deleteRows(del.begin(), del.end()); -// Test("SparseMatrix::deleteRows() 8", sm.nRows(), UInt(0)); -// } -// -// /* -// { // More than all rows => exception in assert mode -// SparseMat sm(nrows, ncols, dense.begin()); -// vector del; -// for (UInt i = 0; i < 2*nrows; ++i) -// del.push_back(i); -// sm.deleteRows(del.begin(), del.end()); -// Test("SparseMatrix::deleteRows() 9", sm.nRows(), UInt(0)); -// } -// */ -// -// { // Several dels in a row till empty -// SparseMat sm(nrows, ncols, dense.begin()); -// for (UInt i = 0; i < nrows; ++i) { -// vector del(1); del[0] = 0; -// sm.deleteRows(del.begin(), del.end()); -// Test("SparseMatrix::deleteRows() 10", sm.nRows(), UInt(nrows-i-1)); -// } -// } -// -// { // deleteRows and re-resize it -// SparseMat sm(nrows, ncols, dense.begin()); -// vector del(1); del[0] = nrows-1; -// sm.deleteRows(del.begin(), del.end()); -// sm.resize(nrows, ncols); -// Test("SparseMatrix::deleteRows() 11", sm.nRows(), UInt(nrows)); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_deleteCols() -// { -// { // Empty matrix -// UInt nrows = 3, ncols = 3; -// -// { // Empty matrix, empty del -// SparseMat sm(nrows, ncols); -// vector del; -// sm.deleteCols(del.begin(), del.end()); -// Test("SparseMatrix::deleteCols() 1", sm.nCols(), UInt(3)); -// } -// -// { // Empty matrix, 1 del -// SparseMat sm(nrows, ncols); -// vector del(1); del[0] = 0; -// sm.deleteCols(del.begin(), del.end()); -// Test("SparseMatrix::deleteCols() 2", sm.nCols(), UInt(2)); -// } -// -// { // Empty matrix, many dels -// SparseMat sm(nrows, ncols); -// vector del(2); del[0] = 0; del[1] = 2; -// sm.deleteCols(del.begin(), del.end()); -// Test("SparseMatrix::deleteCols() 3", sm.nCols(), UInt(1)); -// } -// } // End empty matrix -// -// { // For visual inspection -// UInt nrows = 3, ncols = 5; -// DenseMat dense(nrows, ncols, 2); -// SparseMat sm(nrows, ncols, dense.begin()); -// //cout << sm << endl; -// vector del; del.push_back(0); -// sm.deleteCols(del.begin(), del.end()); -// //cout << sm << endl; -// sm.deleteCols(del.begin(), del.end()); -// //cout << sm << endl; -// } -// -// { // deleteCols on matrix of all-zeros -// SparseMat sm(7, 3); -// vector row(3, 0); -// for (UInt i = 0; i < 7; ++i) -// sm.addRow(row.begin()); -// //cout << sm << endl << endl; -// vector del(1, 0); -// sm.deleteCols(del.begin(), del.end()); -// //cout << sm << endl; -// } -// -// { -// UInt nrows, ncols, zr; -// -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// -// { // Empty del -// SparseMat sm(nrows, ncols, dense.begin()); -// vector del; -// sm.deleteCols(del.begin(), del.end()); -// Test("SparseMatrix::deleteCols() 4", sm.nCols(), ncols); -// } -// -// { // Many dels contiguous -// SparseMat sm(nrows, ncols, dense.begin()); -// DenseMat dense2(nrows, ncols, zr); -// vector del; -// if (ncols > 2) { -// for (UInt i = 2; i < ncols-2; ++i) -// del.push_back(i); -// sm.deleteCols(del.begin(), del.end()); -// dense2.deleteCols(del.begin(), del.end()); -// Compare(dense2, sm, "SparseMatrix::deleteCols() 6"); -// } -// } -// -// { // Many dels discontiguous -// SparseMat sm(nrows, ncols, dense.begin()); -// DenseMat dense2(nrows, ncols, zr); -// vector del; -// for (UInt i = 0; i < ncols; i += 2) -// del.push_back(i); -// sm.deleteCols(del.begin(), del.end()); -// dense2.deleteCols(del.begin(), del.end()); -// Compare(dense2, sm, "SparseMatrix::deleteCols() 7"); -// } -// -// { // All rows -// SparseMat sm(nrows, ncols, dense.begin()); -// vector del; -// for (UInt i = 0; i < ncols; ++i) -// del.push_back(i); -// sm.deleteCols(del.begin(), del.end()); -// Test("SparseMatrix::deleteCols() 8", sm.nCols(), UInt(0)); -// } -// -// { // More than all rows => exception in assert mode -// /* -// SparseMat sm(nrows, ncols, dense.begin()); -// vector del; -// for (UInt i = 0; i < 2*ncols; ++i) -// del.push_back(i); -// sm.deleteCols(del.begin(), del.end()); -// Test("SparseMatrix::deleteCols() 9", sm.nCols(), UInt(0)); -// */ -// } -// -// { // Several dels in a row till empty -// SparseMat sm(nrows, ncols, dense.begin()); -// for (UInt i = 0; i < ncols; ++i) { -// vector del(1); del[0] = 0; -// sm.deleteCols(del.begin(), del.end()); -// Test("SparseMatrix::deleteCols() 10", sm.nCols(), UInt(ncols-i-1)); -// } -// } -// -// { // deleteCols and re-resize it -// SparseMat sm(nrows, ncols, dense.begin()); -// vector del(1); del[0] = ncols-1; -// sm.deleteCols(del.begin(), del.end()); -// sm.resize(nrows, ncols); -// Test("SparseMatrix::deleteCols() 11", sm.nCols(), UInt(ncols)); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_set() -// { -// UInt nrows, ncols, nnzr; -// -// if (0) { // Visual tests -// -// // setZero -// nrows = 5; ncols = 7; nnzr = 3; -// DenseMat dense(nrows, ncols, nnzr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// cout << "Initial matrix" << endl; -// cout << sparse << endl; -// -// cout << endl << "Setting all elements to zero one by one" << endl; -// ITER_2(nrows, ncols) -// sparse.setZero(i, j); -// cout << "After:" << endl << sparse << endl; -// -// // setNonZero -// cout << endl << "Setting all elements one by one to:" << endl; -// cout << dense << endl; -// ITER_2(nrows, ncols) { -// sparse.setNonZero(i, j, dense.at(i,j)+1); -// dense.at(i,j) = dense.at(i,j) + 1; -// } -// cout << "After:" << endl << sparse << endl; -// -// // set -// cout << endl << "Setting all elements" << endl; -// ITER_2(nrows, ncols) { -// Real val = (Real) ((i+j) % 5); -// sparse.set(i, j, val); -// dense.at(i,j) = val; -// } -// cout << "After:" << endl << sparse << endl; -// cout << "Should be:" << endl << dense << endl; -// -// } // End visual tests -// -// // Automated tests for set(i,j,val), which exercises both -// // setNonZero and setToZero -// for (nrows = 1; nrows < 64; nrows += 3) -// for (ncols = 1; ncols < 64; ncols += 3) -// { -// SparseMat sm(nrows, ncols); -// DenseMat dense(nrows, ncols); -// -// ITER_2(nrows, ncols) { -// Real val = Real((i*ncols+j+1)%5); -// sm.set(i, j, val); -// dense.at(i, j) = val; -// } -// bool correct = true; -// ITER_2(nrows, ncols) { -// Real val = Real((i*ncols+j+1)%5); -// if (sm.get(i, j) != val) -// correct = false; -// } -// Test("SparseMatrix set/get 1", correct, true); -// -// ITER_1(nrows) { -// dense.at(i, 0) = Real(i+1); -// sm.set(i, 0, Real(i+1)); -// } -// Compare(dense, sm, "SparseMatrix set/get 2"); -// -// ITER_1(ncols) { -// dense.at(0, i) = Real(i+1); -// sm.set(0, i, Real(i+1)); -// } -// Compare(dense, sm, "SparseMatrix set/get 3"); -// -// sm.set(nrows-1, ncols-1, 1); -// dense.at(nrows-1, ncols-1) = 1; -// Compare(dense, sm, "SparseMatrix set/get 4"); -// sm.set(nrows-1, ncols-1, 2); -// dense.at(nrows-1, ncols-1) = 2; -// Compare(dense, sm, "SparseMatrix set/get 5"); -// -// for (UInt k = 0; k != 20; ++k) { -// UInt i = rng_->getUInt32(nrows); -// UInt j = rng_->getUInt32(ncols); -// Real val = Real(1+rng_->getUInt32()); -// sm.set(i, j, Real(val)); -// Test("SparseMatrix set/get 7", sm.get(i, j), val); -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_setRowColToZero() -// { -// UInt nrows, ncols, zr; -// -// if (0) { // Visual tests -// -// // setRowToZero -// nrows = 5; ncols = 7; zr = 3; -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// cout << "Initial matrix" << endl; -// cout << sparse << endl; -// -// cout << endl << "Setting all rows to zero" << endl; -// for (UInt i = 0; i != nrows; ++i) { -// cout << "isRowZero(" << i << ")= " -// << (sparse.isRowZero(i) ? "YES" : "NO") -// << endl; -// sparse.setRowToZero(i); -// cout << "Zeroing row " << i << ":" << endl -// << sparse << endl; -// cout << "isRowZero(" << i << ")= " -// << (sparse.isRowZero(i) ? "YES" : "NO") -// << endl; -// cout << endl; -// } -// -// // setColToZero -// cout << endl << "Setting all columns to zero - 1" << endl; -// ITER_2(nrows, ncols) -// sparse.set(i, j, dense.at(i,j)); -// cout << "Initially: " << endl << sparse << endl; -// for (UInt j = 0; j != ncols; ++j) { -// cout << "isColZero(" << j << ")= " -// << (sparse.isColZero(j) ? "YES" : "NO") -// << endl; -// sparse.setColToZero(j); -// cout << "Zeroing column " << j << ":" << endl -// << sparse << endl; -// cout << "isColZero(" << j << ")= " -// << (sparse.isColZero(j) ? "YES" : "NO") -// << endl; -// cout << endl; -// } -// -// // Again, with a dense matrix, so we can see what happens -// // to the first and last columns -// cout << endl << "Setting all columns to zero - 2" << endl; -// ITER_2(nrows, ncols) -// sparse.set(i,j,(Real)(i+j)); -// cout << "Initially: " << endl << sparse << endl; -// for (UInt j = 0; j != ncols; ++j) { -// cout << "isColZero(" << j << ")= " -// << (sparse.isColZero(j) ? "YES" : "NO") -// << endl; -// sparse.setColToZero(j); -// cout << "Zeroing column " << j << ":" << endl -// << sparse << endl; -// cout << "isColZero(" << j << ")= " -// << (sparse.isColZero(j) ? "YES" : "NO") -// << endl; -// cout << endl; -// } -// } // End visual tests -// -// // Automated tests -// for (nrows = 0; nrows < 16; nrows += 3) -// for (ncols = 0; ncols < 16; ncols += 3) -// for (zr = 0; zr < 16; zr += 3) -// { -// { // compact - remove rows -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// for (UInt i = 0; i != nrows; ++i) { -// sparse.setRowToZero(i); -// dense.setRowToZero(i); -// Compare(dense, sparse, "SparseMatrix setRowToZero 1"); -// } -// } -// -// { // decompact - remove rows -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// sparse.decompact(); -// -// for (UInt i = 0; i != nrows; ++i) { -// sparse.setRowToZero(i); -// dense.setRowToZero(i); -// Compare(dense, sparse, "SparseMatrix setRowToZero 2"); -// } -// } -// -// { // compact - remove columns -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// for (UInt j = 0; j != ncols; ++j) { -// sparse.setColToZero(j); -// dense.setColToZero(j); -// Compare(dense, sparse, "SparseMatrix setColToZero 1"); -// } -// } -// -// { // decompact - remove columns -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// for (UInt j = 0; j != ncols; ++j) { -// sparse.setColToZero(j); -// dense.setColToZero(j); -// Compare(dense, sparse, "SparseMatrix setColToZero 2"); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_vecMaxProd() -// { -// UInt ncols, nrows, zr, i; -// ncols = 5; -// nrows = 7; -// zr = 2; -// -// DenseMat dense(nrows, ncols, zr); -// -// std::vector x(ncols), y(nrows, 0), yref(nrows, 0); -// for (i = 0; i < ncols; ++i) -// x[i] = Real(i); -// -// dense.vecMaxProd(x.begin(), yref.begin()); -// -// SparseMat smnc(nrows, ncols, dense.begin()); -// smnc.decompact(); -// smnc.vecMaxProd(x.begin(), y.begin()); -// CompareVectors(nrows, y.begin(), yref.begin(), "vecMaxProd non compact 1"); -// -// smnc.compact(); -// std::fill(y.begin(), y.end(), Real(0)); -// smnc.vecMaxProd(x.begin(), y.begin()); -// CompareVectors(nrows, y.begin(), yref.begin(), "vecMaxProd compact 1"); -// -// SparseMat smc(nrows, ncols, dense.begin()); -// std::fill(y.begin(), y.end(), Real(0)); -// smc.vecMaxProd(x.begin(), y.begin()); -// CompareVectors(nrows, y.begin(), yref.begin(), "vecMaxProd compact 2"); -// -// { -// TEST_LOOP(M) { -// -// Dense dense2(nrows, ncols, zr); -// SparseMatrix sm2(nrows, ncols, dense2.begin()); -// -// std::vector x2(ncols, 0), yref2(nrows, 0), y2(nrows, 0); -// for (i = 0; i < ncols; ++i) -// x2[i] = Real(i); -// -// sm2.decompact(); -// dense2.vecMaxProd(x2.begin(), yref2.begin()); -// sm2.vecMaxProd(x2.begin(), y2.begin()); -// { -// std::stringstream str; -// str << "vecMaxProd A " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// CompareVectors(nrows, y2.begin(), yref2.begin(), str.str().c_str()); -// } -// -// sm2.compact(); -// std::fill(y2.begin(), y2.end(), Real(0)); -// sm2.vecMaxProd(x2.begin(), y2.begin()); -// { -// std::stringstream str; -// str << "vecMaxProd B " << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// CompareVectors(nrows, y2.begin(), yref2.begin(), str.str().c_str()); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_vecProd() -// { -// UInt ncols = 5, nrows = 7, zr = 2; -// -// DenseMat dense(nrows, ncols, zr); -// -// std::vector x(ncols), y(nrows, 0), yref(nrows, 0); -// for (UInt i = 0; i < ncols; ++i) -// x[i] = Real(i); -// -// dense.rightVecProd(x.begin(), yref.begin()); -// -// SparseMat smnc(nrows, ncols, dense.begin()); -// smnc.decompact(); -// smnc.rightVecProd(x.begin(), y.begin()); -// CompareVectors(nrows, y.begin(), yref.begin(), "rightVecProd non compact 1"); -// -// smnc.compact(); -// std::fill(y.begin(), y.end(), Real(0)); -// smnc.rightVecProd(x.begin(), y.begin()); -// CompareVectors(nrows, y.begin(), yref.begin(), "rightVecProd compact 1"); -// -// SparseMat smc(nrows, ncols, dense.begin()); -// std::fill(y.begin(), y.end(), Real(0)); -// smc.rightVecProd(x.begin(), y.begin()); -// CompareVectors(nrows, y.begin(), yref.begin(), "rightVecProd compact 2"); -// -// { -// TEST_LOOP(M) { -// -// Dense dense2(nrows, ncols, zr); -// SparseMatrix sm2(nrows, ncols, dense2.begin()); -// -// std::vector x2(ncols, 0), yref2(nrows, 0), y2(nrows, 0); -// for (UInt i = 0; i < ncols; ++i) -// x2[i] = Real(i); -// -// sm2.decompact(); -// dense2.rightVecProd(x2.begin(), yref2.begin()); -// sm2.rightVecProd(x2.begin(), y2.begin()); -// { -// std::stringstream str; -// str << "rightVecProd A " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// CompareVectors(nrows, y2.begin(), yref2.begin(), str.str().c_str()); -// } -// -// sm2.compact(); -// std::fill(y2.begin(), y2.end(), Real(0)); -// sm2.rightVecProd(x2.begin(), y2.begin()); -// { -// std::stringstream str; -// str << "rightVecProd B " << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// CompareVectors(nrows, y2.begin(), yref2.begin(), str.str().c_str()); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_axby() -// { -// UInt ncols, nrows, zr, i; -// ncols = 5; -// nrows = 7; -// zr = 2; -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sm4c(nrows, ncols, dense.begin()); -// -// std::vector x(ncols, 0); -// for (i = 0; i < ncols; ++i) -// x[i] = Real(20*i + 1); -// -// { // compact, b = 0 -// dense.axby(3, .5, 0, x.begin()); -// sm4c.axby(3, .5, 0, x.begin()); -// Compare(dense, sm4c, "axby, b = 0"); -// } -// -// { // compact, a = 0, with reallocation -// dense.axby(2, 0, .5, x.begin()); -// sm4c.axby(2, 0, .5, x.begin()); -// Compare(dense, sm4c, "axby, a = 0 /1"); -// } -// -// { // compact, a = 0, without reallocation -// dense.axby(3, 0, .5, x.begin()); -// sm4c.axby(3, 0, .5, x.begin()); -// Compare(dense, sm4c, "axby, a = 0 /2"); -// } -// -// { // compact, a != 0, b != 0, without reallocation -// dense.axby(3, .5, .5, x.begin()); -// sm4c.axby(3, .5, .5, x.begin()); -// Compare(dense, sm4c, "axby, a, b != 0 /1"); -// } -// -// { // compact, a != 0, b != 0, with reallocation -// dense.axby(4, .5, .5, x.begin()); -// sm4c.axby(4, .5, .5, x.begin()); -// Compare(dense, sm4c, "axby, a, b != 0 /2"); -// } -// -// { -// TEST_LOOP(M) { -// -// Dense dense2(nrows, ncols, zr); -// SparseMatrix sm2(nrows, ncols, dense2.begin()); -// -// std::vector x2(ncols, 0), yref2(nrows, 0), y2(nrows, 0); -// for (i = 0; i < ncols; ++i) -// x2[i] = Real(i); -// -// for (i = 0; i < nrows; i += 5) { -// -// dense2.axby(i, (Real).6, (Real).4, x2.begin()); -// sm2.axby(i, (Real).6, (Real).4, x2.begin()); -// { -// std::stringstream str; -// str << "axby " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// Compare(dense2, sm2, str.str().c_str()); -// } -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_axby_3() -// { -// UInt ncols, nrows, zr, i; -// ncols = 5; -// nrows = 7; -// zr = 2; -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sm4c(nrows, ncols, dense.begin()); -// -// std::vector x(ncols, 0); -// for (i = 0; i < ncols; ++i) -// x[i] = i % 2 == 0 ? Real(20*i + 1) : Real(0); -// -// { // compact, b = 0 -// dense.axby(.5, 0, x.begin()); -// sm4c.axby(.5, 0, x.begin()); -// Compare(dense, sm4c, "axby, b = 0"); -// } -// -// { // compact, a = 0, with reallocation -// dense.axby(0, .5, x.begin()); -// sm4c.axby(0, .5, x.begin()); -// Compare(dense, sm4c, "axby, a = 0 /1"); -// } -// -// { // compact, a = 0, without reallocation -// dense.axby(0, .5, x.begin()); -// sm4c.axby(0, .5, x.begin()); -// Compare(dense, sm4c, "axby, a = 0 /2"); -// } -// -// { // compact, a != 0, b != 0, without reallocation -// dense.axby(.5, .5, x.begin()); -// sm4c.axby(.5, .5, x.begin()); -// Compare(dense, sm4c, "axby, a, b != 0 /1"); -// } -// -// { // compact, a != 0, b != 0, with reallocation -// dense.axby(.5, .5, x.begin()); -// sm4c.axby(.5, .5, x.begin()); -// Compare(dense, sm4c, "axby, a, b != 0 /2"); -// } -// -// { -// TEST_LOOP(M) { -// -// Dense dense2(nrows, ncols, zr); -// SparseMatrix sm2(nrows, ncols, dense2.begin()); -// -// std::vector x2(ncols, 0), yref2(nrows, 0), y2(nrows, 0); -// for (i = 0; i < ncols; ++i) -// x2[i] = i % 2 == 0 ? Real(i) : Real(0); -// -// dense2.axby((Real).6, (Real).4, x2.begin()); -// sm2.axby((Real).6, (Real).4, x2.begin()); -// { -// std::stringstream str; -// str << "axby " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// Compare(dense2, sm2, str.str().c_str()); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_rowMax() -// { -// UInt ncols, nrows, zr, i; -// -// { -// TEST_LOOP(M) { -// -// DenseMat dense2(nrows, ncols, zr); -// SparseMat sm2(nrows, ncols, dense2.begin()); -// -// std::vector x2(ncols, 0), yref2(nrows, 0), y2(nrows, 0); -// for (i = 0; i < ncols; ++i) -// x2[i] = Real(i); -// -// sm2.decompact(); -// dense2.threshold(Real(1./nrows)); -// dense2.xMaxAtNonZero(x2.begin(), y2.begin()); -// sm2.threshold(Real(1./nrows)); -// sm2.vecMaxAtNZ(x2.begin(), yref2.begin()); -// -// { -// std::stringstream str; -// str << "xMaxAtNonZero A " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// CompareVectors(nrows, y2.begin(), yref2.begin(), str.str().c_str()); -// } -// -// sm2.compact(); -// dense2.xMaxAtNonZero(x2.begin(), y2.begin()); -// sm2.vecMaxAtNZ(x2.begin(), yref2.begin()); -// { -// std::stringstream str; -// str << "xMaxAtNonZero B " << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// CompareVectors(nrows, y2.begin(), yref2.begin(), str.str().c_str()); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_maxima() -// { -// UInt ncols, nrows, zr; -// -// { -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// std::vector > -// rowMaxDense(nrows), rowMaxSparse(nrows), -// colMaxDense(ncols), colMaxSparse(ncols); -// -// dense.rowMax(rowMaxDense.begin()); -// dense.colMax(colMaxDense.begin()); -// sparse.rowMax(rowMaxSparse.begin()); -// sparse.colMax(colMaxSparse.begin()); -// -// { -// std::stringstream str; -// str << "rowMax " << nrows << "X" << ncols << "/" << zr; -// Compare(rowMaxDense, rowMaxSparse, str.str().c_str()); -// } -// -// { -// std::stringstream str; -// str << "colMax " << nrows << "X" << ncols << "/" << zr; -// Compare(colMaxDense, colMaxSparse, str.str().c_str()); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_normalize() -// { -// UInt nrows = 7, ncols = 5, zr = 2; -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// if (0) { // Visual tests -// -// cout << "Before normalizing rows: " << endl; -// cout << sparse << endl; -// dense.normalizeRows(); -// sparse.normalizeRows(); -// cout << "After normalizing rows: " << endl; -// cout << "Sparse: " << endl << sparse << endl; -// cout << "Dense: " << endl << dense << endl; -// -// cout << "Before normalizing columns: " << endl; -// cout << sparse << endl; -// dense.normalizeCols(); -// sparse.normalizeCols(); -// cout << "After normalizing columns: " << endl; -// cout << "Sparse: " << endl << sparse << endl; -// cout << "Dense: " << endl << dense << endl; -// -// } -// -// if (1) // Automated tests -// { -// TEST_LOOP(M) { -// -// Dense dense2(nrows, ncols, zr); -// SparseMatrix sm2(nrows, ncols, dense2.begin()); -// -// dense2.threshold(double(1./nrows)); -// dense2.normalizeRows(true); -// sm2.decompact(); -// sm2.threshold(double(1./nrows)); -// sm2.normalizeRows(true); -// -// { -// std::stringstream str; -// str << "normalizeRows A " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// Compare(dense2, sm2, str.str().c_str()); -// } -// -// dense2.normalizeRows(true); -// sm2.compact(); -// sm2.normalizeRows(true); -// -// { -// std::stringstream str; -// str << "normalizeRows B " << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// Compare(dense2, sm2, str.str().c_str()); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_rowProd() -// { -// UInt ncols, nrows, zr, i; -// -// { -// TEST_LOOP(M) { -// -// Dense dense2(nrows, ncols, zr); -// SparseMatrix sm2(nrows, ncols, dense2.begin()); -// -// std::vector x2(ncols, 0), yref2(nrows, 0), y2(nrows, 0); -// for (i = 0; i < ncols; ++i) -// x2[i] = double(i)/double(ncols); -// -// sm2.decompact(); -// dense2.threshold(1./double(nrows)); -// dense2.rowProd(x2.begin(), y2.begin()); -// sm2.threshold(1./double(nrows)); -// sm2.rightVecProdAtNZ(x2.begin(), yref2.begin()); -// -// { -// std::stringstream str; -// str << "rowProd A " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// CompareVectors(nrows, y2.begin(), yref2.begin(), str.str().c_str()); -// } -// -// sm2.compact(); -// dense2.rowProd(x2.begin(), y2.begin()); -// sm2.rightVecProdAtNZ(x2.begin(), yref2.begin()); -// { -// std::stringstream str; -// str << "rowProd B " << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// CompareVectors(nrows, y2.begin(), yref2.begin(), str.str().c_str()); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_lerp() -// { -// UInt ncols, nrows, zr; -// nrows = 5; ncols = 7; zr = 4; -// -// { -// DenseMat dense(nrows, ncols, zr); -// DenseMat denseB(nrows, ncols, zr); -// for (UInt i = 0; i < nrows; ++i) -// for (UInt j = 0; j < ncols; ++j) -// denseB.at(i,j) += 2; -// -// SparseMat sm(nrows, ncols, dense.begin()); -// SparseMat smB(nrows, ncols, denseB.begin()); -// -// Real a, b; -// a = b = 1; -// -// dense.lerp(a, b, denseB); -// sm.lerp(a, b, smB); -// -// std::stringstream str; -// str << "lerp " << nrows << "X" << ncols << "/" << zr -// << " " << a << " " << b; -// Compare(dense, sm, str.str().c_str()); -// } -// -// { -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// DenseMat denseB(nrows, ncols, zr); -// for (UInt i = 0; i < nrows; ++i) -// for (UInt j = 0; j < ncols; ++j) -// denseB.at(i,j) += 2; -// -// SparseMat sm(nrows, ncols, dense.begin()); -// SparseMat smB(nrows, ncols, denseB.begin()); -// -// for (Real a = -2; a < 2; a += 1) { -// for (Real b = -2; b < 2; b += 1) { -// dense.lerp(a, b, denseB); -// sm.lerp(a, b, smB); -// std::stringstream str; -// str << "lerp " << nrows << "X" << ncols << "/" << zr -// << " " << a << " " << b; -// Compare(dense, sm, str.str().c_str()); -// } -// } -// } -// } -// -//#ifdef NTA_ASSERTIONS_ON -// nrows = 5; ncols = 7; zr = 4; -// // Exceptions -// { -// DenseMat dense(nrows, ncols, zr); -// DenseMat denseB(nrows+1, ncols, zr); -// SparseMat sm(nrows, ncols, dense.begin()); -// SparseMat smB(nrows+1, ncols, denseB.begin()); -// -// try { -// sm.lerp(1, 1, smB); -// Test("lerp exception 1", 0, 1); -// } catch (std::runtime_error&) { -// Test("lerp exception 1", 1, 1); -// } -// } -// -// { -// DenseMat dense(nrows, ncols, zr); -// DenseMat denseB(nrows, ncols+1, zr); -// SparseMat sm(nrows, ncols, dense.begin()); -// SparseMat smB(nrows, ncols+1, denseB.begin()); -// -// try { -// sm.lerp(1, 1, smB); -// Test("lerp exception 2", 0, 1); -// } catch (std::runtime_error&) { -// Test("lerp exception 2", 1, 1); -// } -// } -//#endif -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_small_values() -// { -// UInt nrows, ncols, zr; -// -// { -// nrows = 200; ncols = 100; zr = ncols - 64; -// DenseMat dense(nrows, ncols, zr, true, true, rng_); -// SparseMat sm(nrows, ncols, dense.begin()); -// DenseMat A(nrows, ncols); -// -// sm.toDense(A.begin()); -// sm.fromDense(nrows, ncols, A.begin()); -// Compare(dense, sm, "to/from Dense, small values"); -// } -// -// { -// nrows = 200; ncols = 100; zr = ncols - 64; -// DenseMat dense(nrows, ncols, zr, true, true, rng_); -// SparseMat sm(nrows, ncols, dense.begin()); -// std::stringstream str1; -// sm.toCSR(str1); -// sm.fromCSR(str1); -// Compare(dense, sm, "to/from CSR, small values"); -// } -// -// { -// nrows = 200; ncols = 100; zr = ncols - 64; -// DenseMat dense(nrows, ncols, zr, true, true, rng_); -// SparseMat sm(nrows, ncols, dense.begin()); -// sm.compact(); -// Compare(dense, sm, "compact, small values"); -// } -// -// { -// nrows = 200; ncols = 100; zr = ncols - 64; -// Dense dense(nrows, ncols, zr, true, true, rng_); -// SparseMatrix sm(nrows, ncols, dense.begin()); -// sm.threshold(4 * nupic::Epsilon); -// dense.threshold(4 * nupic::Epsilon); -// Compare(dense, sm, "threshold, small values 1"); -// sm.threshold(2 * nupic::Epsilon); -// dense.threshold(2 * nupic::Epsilon); -// Compare(dense, sm, "threshold, small values 2"); -// } -// -// { -// nrows = 200; ncols = 100; zr = ncols - 64; -// Dense dense(nrows, ncols, zr, true, true, rng_); -// SparseMatrix sm(nrows, ncols, dense.begin()); -// Compare(dense, sm, "addRow, small values"); -// } -// -// { -// nrows = 8; ncols = 4; zr = ncols - 2; -// Dense dense(nrows, ncols, zr, true, true, rng_); -// Dense dense2(ncols, nrows); -// SparseMatrix sm(nrows, ncols, dense.begin()); -// SparseMatrix sm2(ncols, nrows); -// dense.transpose(dense2); -// sm.transpose(sm2); -// Compare(dense2, sm2, "transpose, small values"); -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_accumulate() -// { -// UInt nrows = 7, ncols = 5, zr = 2; -// -// if (0) { // Visual tests -// -// Dense dense(nrows, ncols, zr); -// SparseMatrix sparse(nrows, ncols, dense.begin()); -// -// std::vector row_sums(nrows), col_sums(ncols); -// -// cout << sparse << endl; -// -// sparse.accumulateAllRowsNZ(row_sums.begin(), std::plus()); -// sparse.accumulateAllColsNZ(col_sums.begin(), std::plus()); -// -// cout << "Row sums = " << row_sums << endl; -// cout << "Col sums = " << col_sums << endl; -// } -// -// /* -// TEST_LOOP(M) { -// -// Dense denseA(nrows, ncols, zr); -// SparseMatrix smA(nrows, ncols, denseA.begin()); -// -// for (UInt r = 0; r < nrows; r += 5) { -// -// { -// double r1 = denseA.accumulate(r, multiplies(), 1); -// double r2 = smA.accumulateRowNZ(r, multiplies(), 1); -// std::stringstream str; -// str << "accumulateRowNZ * " << nrows << "X" << ncols << "/" << zr; -// Test(str.str().c_str(), r1, r2); -// } -// -// { -// double r1 = denseA.accumulate(r, multiplies(), 1); -// double r2 = smA.accumulate(r, multiplies(), 1); -// std::stringstream str; -// str << "accumulate * " << nrows << "X" << ncols << "/" << zr; -// Test(str.str().c_str(), r1, r2); -// } -// -// { -// double r1 = denseA.accumulate(r, plus()); -// double r2 = smA.accumulateRowNZ(r, plus()); -// std::stringstream str; -// str << "accumulateRowNZ + " << nrows << "X" << ncols << "/" << zr; -// Test(str.str().c_str(), r1, r2); -// } -// -// { -// double r1 = denseA.accumulate(r, plus()); -// double r2 = smA.accumulate(r, plus()); -// std::stringstream str; -// str << "accumulate + " << nrows << "X" << ncols << "/" << zr; -// Test(str.str().c_str(), r1, r2); -// } -// -// { -// double r1 = denseA.accumulate(r, nupic::Max, 0); -// double r2 = smA.accumulateRowNZ(r, nupic::Max, 0); -// std::stringstream str; -// str << "accumulateRowNZ max " << nrows << "X" << ncols << "/" << zr; -// Test(str.str().c_str(), r1, r2); -// } -// -// { -// double r1 = denseA.accumulate(r, nupic::Max, 0); -// double r2 = smA.accumulate(r, nupic::Max, 0); -// std::stringstream str; -// str << "accumulate max " << nrows << "X" << ncols << "/" << zr; -// Test(str.str().c_str(), r1, r2); -// } -// } -// } -// */ -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_multiply() -// { -// UInt nrows, ncols, zr, nrows1, ncols1, ncols2, zr1, zr2; -// -// if (0) { // Visual test, keep -// -// DenseMat dense(4, 5, 2); -// SparseMat sparse1(dense.nRows(), dense.nCols(), dense.begin()); -// SparseMat sparse2(sparse1); -// sparse2.transpose(); -// SparseMat sparse3(0,0); -// -// cout << sparse1 << endl << endl << sparse2 << endl << endl; -// sparse1.multiply(sparse2, sparse3); -// cout << sparse3 << endl; -// -// return; -// } -// -// TEST_LOOP(M) { -// -// nrows1 = nrows; ncols1 = ncols; zr1 = zr; -// ncols2 = 2*nrows+1; zr2 = zr1; -// -// Dense denseA(nrows1, ncols1, zr1); -// SparseMatrix smA(nrows1, ncols1, denseA.begin()); -// -// Dense denseB(ncols1, ncols2, zr2); -// SparseMatrix smB(ncols1, ncols2, denseB.begin()); -// -// Dense denseC(nrows1, ncols2, zr2); -// SparseMatrix smC(nrows1, ncols2, denseC.begin()); -// -// { -// denseC.clear(); -// denseA.multiply(denseB, denseC); -// smA.multiply(smB, smC); -// -// std::stringstream str; -// str << "multiply " << nrows << "X" << ncols << "/" << zr; -// Compare(denseC, smC, str.str().c_str()); -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_argMax() -// { -// UInt ncols, nrows, zr; -// UInt m_i_sparse, m_j_sparse, m_i_dense, m_j_dense; -// Real m_val_sparse, m_val_dense; -// -// { -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// dense.max(m_i_dense, m_j_dense, m_val_dense); -// -// sparse.decompact(); -// sparse.max(m_i_sparse, m_j_sparse, m_val_sparse); -// -// { -// std::stringstream str; -// str << "argMax A " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// if (m_i_sparse != m_i_dense -// || m_j_sparse != m_j_dense -// || !nupic::nearlyEqual(m_val_sparse, m_val_dense)) -// Test(str.str().c_str(), 0, 1); -// } -// -// sparse.compact(); -// sparse.max(m_i_sparse, m_j_sparse, m_val_sparse); -// -// { -// std::stringstream str; -// str << "argMax B " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// if (m_i_sparse != m_i_dense -// || m_j_sparse != m_j_dense -// || !nupic::nearlyEqual(m_val_sparse, m_val_dense)) -// Test(str.str().c_str(), 0, 1); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_argMin() -// { -// UInt ncols, nrows, zr; -// UInt m_i_sparse, m_j_sparse, m_i_dense, m_j_dense; -// Real m_val_sparse, m_val_dense; -// -// { -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// dense.min(m_i_dense, m_j_dense, m_val_dense); -// -// sparse.decompact(); -// sparse.min(m_i_sparse, m_j_sparse, m_val_sparse); -// -// { -// std::stringstream str; -// str << "argMin A " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// if (m_i_sparse != m_i_dense -// || m_j_sparse != m_j_dense -// || !nupic::nearlyEqual(m_val_sparse, m_val_dense)) { -// Test(str.str().c_str(), 0, 1); -// } -// } -// -// sparse.compact(); -// sparse.min(m_i_sparse, m_j_sparse, m_val_sparse); -// -// { -// std::stringstream str; -// str << "argMin B " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// if (m_i_sparse != m_i_dense -// || m_j_sparse != m_j_dense -// || !nupic::nearlyEqual(m_val_sparse, m_val_dense)) -// Test(str.str().c_str(), 0, 1); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_rowMax_2() -// { -// UInt ncols, nrows, zr; -// -// { -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// std::vector > optima_sparse(nrows), optima_dense(nrows); -// -// dense.rowMax(optima_dense.begin()); -// -// sparse.decompact(); -// -// for (UInt i = 0; i != nrows; ++i) { -// -// std::pair res_sparse; -// sparse.rowMax(i, res_sparse.first, res_sparse.second); -// -// std::stringstream str; -// str << "rowMax 2 A " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// if (optima_dense[i].first != res_sparse.first -// || !nearlyEqual(optima_dense[i].second, res_sparse.second)) -// Test(str.str().c_str(), 0, 1); -// } -// -// sparse.rowMax(optima_sparse.begin()); -// -// { -// std::stringstream str; -// str << "rowMax 2 B " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// for (UInt i = 0; i != nrows; ++i) { -// if (optima_dense[i].first != optima_sparse[i].first -// || !nupic::nearlyEqual(optima_dense[i].second, optima_sparse[i].second)) -// Test(str.str().c_str(), 0, 1); -// } -// } -// -// sparse.compact(); -// -// for (UInt i = 0; i != nrows; ++i) { -// -// std::pair res_sparse; -// sparse.rowMax(i, res_sparse.first, res_sparse.second); -// -// std::stringstream str; -// str << "rowMax 2 C " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// if (optima_dense[i].first != res_sparse.first -// || !nearlyEqual(optima_dense[i].second, res_sparse.second)) -// Test(str.str().c_str(), 0, 1); -// } -// -// sparse.rowMax(optima_sparse.begin()); -// -// { -// std::stringstream str; -// str << "rowMax 2 D " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// for (UInt i = 0; i != nrows; ++i) { -// if (optima_dense[i].first != optima_sparse[i].first -// || !nupic::nearlyEqual(optima_dense[i].second, optima_sparse[i].second)) -// Test(str.str().c_str(), 0, 1); -// } -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_rowMin() -// { -// UInt ncols, nrows, zr; -// -// { -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// std::vector > optima_sparse(nrows), optima_dense(nrows); -// -// dense.rowMin(optima_dense.begin()); -// -// sparse.decompact(); -// -// for (UInt i = 0; i != nrows; ++i) { -// -// std::pair res_sparse; -// sparse.rowMin(i, res_sparse.first, res_sparse.second); -// -// std::stringstream str; -// str << "rowMin A " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// if (optima_dense[i].first != res_sparse.first -// || !nearlyEqual(optima_dense[i].second, res_sparse.second)) -// Test(str.str().c_str(), 0, 1); -// } -// -// sparse.rowMin(optima_sparse.begin()); -// -// { -// std::stringstream str; -// str << "rowMin B " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// for (UInt i = 0; i != nrows; ++i) { -// if (optima_dense[i].first != optima_sparse[i].first -// || !nupic::nearlyEqual(optima_dense[i].second, optima_sparse[i].second)) -// Test(str.str().c_str(), 0, 1); -// } -// } -// -// sparse.compact(); -// -// for (UInt i = 0; i != nrows; ++i) { -// -// std::pair res_sparse; -// sparse.rowMin(i, res_sparse.first, res_sparse.second); -// -// std::stringstream str; -// str << "rowMin C " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// if (optima_dense[i].first != res_sparse.first -// || !nearlyEqual(optima_dense[i].second, res_sparse.second)) -// Test(str.str().c_str(), 0, 1); -// } -// -// sparse.rowMin(optima_sparse.begin()); -// -// { -// std::stringstream str; -// str << "rowMin D " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// for (UInt i = 0; i != nrows; ++i) { -// if (optima_dense[i].first != optima_sparse[i].first -// || !nupic::nearlyEqual(optima_dense[i].second, optima_sparse[i].second)) -// Test(str.str().c_str(), 0, 1); -// } -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_colMax() -// { -// UInt ncols = 7, nrows = 9, zr = 3; -// -// if (0) { -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// cout << sparse << endl; -// for (UInt j = 0; j != ncols; ++j) { -// UInt col_max_i; -// Real col_max; -// sparse.colMax(j, col_max_i, col_max); -// cout << j << " " << col_max_i << " " << col_max << endl; -// } -// } -// -// { -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// std::vector > optima_sparse(ncols), optima_dense(ncols); -// -// dense.colMax(optima_dense.begin()); -// -// sparse.decompact(); -// -// for (UInt j = 0; j != ncols; ++j) { -// -// std::pair res_sparse; -// sparse.colMax(j, res_sparse.first, res_sparse.second); -// -// std::stringstream str; -// str << "colMax A " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// if (optima_dense[j].first != res_sparse.first -// || !nearlyEqual(optima_dense[j].second, res_sparse.second)) -// Test(str.str().c_str(), 0, 1); -// } -// -// sparse.colMax(optima_sparse.begin()); -// -// { -// std::stringstream str; -// str << "colMax B " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// for (UInt j = 0; j != ncols; ++j) { -// if (optima_dense[j].first != optima_sparse[j].first -// || !nupic::nearlyEqual(optima_dense[j].second, optima_sparse[j].second)) -// Test(str.str().c_str(), 0, 1); -// } -// } -// -// sparse.compact(); -// -// for (UInt i = 0; i != ncols; ++i) { -// -// std::pair res_sparse; -// sparse.colMax(i, res_sparse.first, res_sparse.second); -// -// std::stringstream str; -// str << "colMax C " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// if (optima_dense[i].first != res_sparse.first -// || !nearlyEqual(optima_dense[i].second, res_sparse.second)) -// Test(str.str().c_str(), 0, 1); -// } -// -// sparse.colMax(optima_sparse.begin()); -// -// { -// std::stringstream str; -// str << "colMax D " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// for (UInt i = 0; i != ncols; ++i) { -// if (optima_dense[i].first != optima_sparse[i].first -// || !nupic::nearlyEqual(optima_dense[i].second, optima_sparse[i].second)) -// Test(str.str().c_str(), 0, 1); -// } -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_colMin() -// { -// UInt ncols, nrows, zr; -// -// { -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// std::vector > optima_sparse(ncols), optima_dense(ncols); -// -// dense.colMin(optima_dense.begin()); -// -// sparse.decompact(); -// -// for (UInt i = 0; i != ncols; ++i) { -// -// std::pair res_sparse; -// sparse.colMin(i, res_sparse.first, res_sparse.second); -// -// std::stringstream str; -// str << "rowMax 2 A " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// if (optima_dense[i].first != res_sparse.first -// || !nearlyEqual(optima_dense[i].second, res_sparse.second)) -// Test(str.str().c_str(), 0, 1); -// } -// -// sparse.colMin(optima_sparse.begin()); -// -// { -// std::stringstream str; -// str << "rowMax 2 B " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// for (UInt i = 0; i != ncols; ++i) { -// if (optima_dense[i].first != optima_sparse[i].first -// || !nupic::nearlyEqual(optima_dense[i].second, optima_sparse[i].second)) -// Test(str.str().c_str(), 0, 1); -// } -// } -// -// sparse.compact(); -// -// for (UInt i = 0; i != ncols; ++i) { -// -// std::pair res_sparse; -// sparse.colMin(i, res_sparse.first, res_sparse.second); -// -// std::stringstream str; -// str << "rowMax 2 C " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// if (optima_dense[i].first != res_sparse.first -// || !nearlyEqual(optima_dense[i].second, res_sparse.second)) -// Test(str.str().c_str(), 0, 1); -// } -// -// sparse.colMin(optima_sparse.begin()); -// -// { -// std::stringstream str; -// str << "rowMax 2 D " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// for (UInt i = 0; i != ncols; ++i) { -// if (optima_dense[i].first != optima_sparse[i].first -// || !nupic::nearlyEqual(optima_dense[i].second, optima_sparse[i].second)) -// Test(str.str().c_str(), 0, 1); -// } -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_nNonZeros() -// { -// UInt ncols, nrows, zr; -// -// TEST_LOOP(M) { -// -// DenseMat dense(nrows, ncols, zr); -// SparseMat sparse(nrows, ncols, dense.begin()); -// -// UInt n_s, n_d; -// -// { -// std::vector nrows_s(nrows), nrows_d(nrows); -// std::vector ncols_s(ncols), ncols_d(ncols); -// -// sparse.decompact(); -// -// n_d = dense.nNonZeros(); -// n_s = sparse.nNonZeros(); -// -// { -// std::stringstream str; -// str << "nNonZeros A1 " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// if (n_d != n_s) -// Test(str.str().c_str(), 0, 1); -// } -// -// for (UInt i = 0; i != nrows; ++i) { -// -// n_d = dense.nNonZerosOnRow(i); -// n_s = sparse.nNonZerosOnRow(i); -// -// { -// std::stringstream str; -// str << "nNonZeros B1 " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// if (n_d != n_s) -// Test(str.str().c_str(), 0, 1); -// } -// } -// -// for (UInt i = 0; i != ncols; ++i) { -// -// n_d = dense.nNonZerosOnCol(i); -// n_s = sparse.nNonZerosOnCol(i); -// -// { -// std::stringstream str; -// str << "nNonZeros C1 " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// if (n_d != n_s) -// Test(str.str().c_str(), 0, 1); -// } -// } -// -// dense.nNonZerosPerRow(nrows_d.begin()); -// sparse.nNonZerosPerRow(nrows_s.begin()); -// -// { -// std::stringstream str; -// str << "nNonZeros D1 " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// CompareVectors(nrows, nrows_d.begin(), nrows_s.begin(), str.str().c_str()); -// } -// -// dense.nNonZerosPerCol(ncols_d.begin()); -// sparse.nNonZerosPerCol(ncols_s.begin()); -// -// { -// std::stringstream str; -// str << "nNonZeros E1 " << nrows << "X" << ncols << "/" << zr -// << " - non compact"; -// CompareVectors(ncols, ncols_d.begin(), ncols_s.begin(), str.str().c_str()); -// } -// } -// -// { -// std::vector nrows_s(nrows), nrows_d(nrows); -// std::vector ncols_s(ncols), ncols_d(ncols); -// sparse.compact(); -// -// n_d = dense.nNonZeros(); -// n_s = sparse.nNonZeros(); -// -// { -// std::stringstream str; -// str << "nNonZeros A2 " << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// if (n_d != n_s) -// Test(str.str().c_str(), 0, 1); -// } -// -// for (UInt i = 0; i != nrows; ++i) { -// -// n_d = dense.nNonZerosOnRow(i); -// n_s = sparse.nNonZerosOnRow(i); -// -// { -// std::stringstream str; -// str << "nNonZeros B2 " << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// if (n_d != n_s) -// Test(str.str().c_str(), 0, 1); -// } -// } -// -// for (UInt i = 0; i != ncols; ++i) { -// -// n_d = dense.nNonZerosOnCol(i); -// n_s = sparse.nNonZerosOnCol(i); -// -// { -// std::stringstream str; -// str << "nNonZeros C2 " << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// if (n_d != n_s) -// Test(str.str().c_str(), 0, 1); -// } -// } -// -// dense.nNonZerosPerRow(nrows_d.begin()); -// sparse.nNonZerosPerRow(nrows_s.begin()); -// -// { -// std::stringstream str; -// str << "nNonZeros D2 " << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// CompareVectors(nrows, nrows_d.begin(), nrows_s.begin(), str.str().c_str()); -// } -// -// dense.nNonZerosPerCol(ncols_d.begin()); -// sparse.nNonZerosPerCol(ncols_s.begin()); -// -// { -// std::stringstream str; -// str << "nNonZeros E2 " << nrows << "X" << ncols << "/" << zr -// << " - compact"; -// CompareVectors(ncols, ncols_d.begin(), ncols_s.begin(), str.str().c_str()); -// } -// } -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_extract() -// { -// if (1) { // Visual tests -// -// DenseMat dense(5, 7, 2); -// SparseMat sparse(5, 7, dense.begin()); -// -// /* -// cout << "Sparse:" << endl << sparse << endl; -// -// { // Extract domain -// Domain2D dom(0,4,0,4); -// SparseMatrix extracted(4,4); -// sparse.get(dom, extracted); -// cout << extracted << endl; -// } -// */ -// } -// } -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_deleteRow() -// { -// // This is regression test for an off-by-one memory corruption bug -// // found in deleteRow the symptom of the bug is a seg fault so there -// // is no explicit test here. -// { -// SparseMat* sm = new SparseMat(11, 1); -// sm->deleteRow(3); -// delete sm; -// -// sm = new SparseMat(11, 1); -// sm->deleteRow(3); -// delete sm; -// } -// } -// -// //-------------------------------------------------------------------------------- -// /** -// * A generator function object, that generates random numbers between 0 and 256. -// * It also has a threshold to control the sparsity of the vectors generated. -// */ -// template -// struct rand_init -// { -// Random *r_; -// T threshold_; -// -// inline rand_init(Random *r, T threshold =100) -// : r_(r), threshold_(threshold) -// {} -// -// inline T operator()() -// { -// return T((T)(r_->getUInt32(100)) > threshold_ ? 0 : .001 + (r_->getReal64())); -// } -// }; -// -// //-------------------------------------------------------------------------------- -// void SparseMatrixUnitTest::unit_test_usage() -// { -// using namespace std; -// -// typedef UInt size_type; -// typedef double value_type; -// typedef SparseMatrix SM; -// typedef Dense DM; -// -// size_type maxMatrixSize = 30; -// size_type nrows = 20, ncols = 30, nzr = 20; -// -// DM* dense = new DM(nrows, ncols, nzr, true, true, rng_); -// SM* sparse = new SM(nrows, ncols, dense->begin()); -// -// for (long int a = 0; a < 10000; ++a) { -// -// // Rectify to stop propagation of small errors -// ITER_2(sparse->nRows(), sparse->nCols()) -// if (::fabs(dense->at(i,j) - sparse->get(i,j)) < 1e-6) -// dense->at(i,j) = sparse->get(i,j); -// -// size_type r = rng_->getUInt32(37); -// -// if (r == 0) { -// -// sparse->compact(); -// // no compact for Dense -// -// } else if (r == 1) { -// -// sparse->decompact(); -// // no decompact for Dense -// -// } else if (r == 2) { -// -// if (rng_->getReal64() < 0.90) { -// size_type nrows = sparse->nRows() + rng_->getUInt32(4); -// size_type ncols = sparse->nCols() + rng_->getUInt32(4); -// sparse->resize(nrows, ncols); -// dense->resize(nrows, ncols); -// Compare(*dense, *sparse, "resize, bigger"); -// -// } else { -// if (sparse->nRows() > 2 && sparse->nCols() > 2) { -// size_type nrows = rng_->getUInt32(sparse->nRows()); -// size_type ncols = rng_->getUInt32(sparse->nCols()); -// sparse->resize(nrows, ncols); -// dense->resize(nrows, ncols); -// Compare(*dense, *sparse, "resize, smaller"); -// } -// } -// -// } else if (r == 3) { -// -// vector del; -// -// if (rng_->getReal64() < 0.90) { -// for (size_type ii = 0; ii < sparse->nRows() / 4; ++ii) -// del.push_back(2*ii); -// sparse->deleteRows(del.begin(), del.end()); -// dense->deleteRows(del.begin(), del.end()); -// } else { -// for (size_type ii = 0; ii < sparse->nRows(); ++ii) -// del.push_back(ii); -// sparse->deleteRows(del.begin(), del.end()); -// dense->deleteRows(del.begin(), del.end()); -// } -// -// Compare(*dense, *sparse, "deleteRows"); -// -// } else if (r == 4) { -// -// vector del; -// if (rng_->getReal64() < 0.90) { -// for (size_type ii = 0; ii < sparse->nCols() / 4; ++ii) -// del.push_back(2*ii); -// sparse->deleteCols(del.begin(), del.end()); -// dense->deleteCols(del.begin(), del.end()); -// } else { -// for (size_type ii = 0; ii < sparse->nCols(); ++ii) -// del.push_back(ii); -// sparse->deleteCols(del.begin(), del.end()); -// dense->deleteCols(del.begin(), del.end()); -// } -// Compare(*dense, *sparse, "deleteCols"); -// -// } else if (r == 5) { -// -// SM sparse2(1, 1); -// DM sm2Dense(1, 1); -// Compare(sm2Dense,sparse2, "constructor(1, 1)"); -// -// sparse2.copy(*sparse); -// sparse->copy(sparse2); -// -// sm2Dense.copy(*dense); -// dense->copy(sm2Dense); -// Compare(*dense, *sparse, "copy"); -// -// } else if (r == 6) { -// -// vector row(sparse->nCols()); -// size_type n = rng_->getUInt32(16); -// for (size_type z = 0; z < n; ++z) { -// if (rng_->getReal64() < 0.90) -// generate(row.begin(), row.end(), rand_init(rng_, 70)); -// sparse->addRow(row.begin()); -// dense->addRow(row.begin()); -// Compare(*dense, *sparse, "addRow"); -// } -// -// } else if (r == 7) { -// -// if (sparse->nRows() > 0 && sparse->nCols() > 0) { -// size_type m = sparse->nRows() * sparse->nCols() / 2; -// for (size_type z = 0; z < m; ++z) { -// size_type i = rng_->getUInt32(sparse->nRows()); -// size_type j = rng_->getUInt32(sparse->nCols()); -// value_type v = 1+value_type(rng_->getReal64()); -// sparse->setNonZero(i, j, v); -// dense->setNonZero(i, j, v); -// Compare(*dense, *sparse, "setNonZero"); -// } -// } -// -// } else if (r == 8) { -// -// value_type v = value_type(128 + rng_->getUInt32(128)); -// sparse->threshold(v); -// dense->threshold(v); -// Compare(*dense, *sparse, "threshold"); -// -// } else if (r == 9) { -// -// if (sparse->nCols() > 0 && sparse->nRows() > 0) { -// -// SM B(0, sparse->nCols()); -// DM BDense(0, dense->ncols); -// -// vector row(sparse->nCols()); -// -// for (size_type iii = 0; iii < sparse->nRows(); ++iii) { -// -// if (rng_->getUInt32(100) < 90) -// generate(row.begin(), row.end(), rand_init(rng_, 70)); -// else -// fill(row.begin(), row.end(), value_type(0)); -// -// B.addRow(row.begin()); -// BDense.addRow(row.begin()); -// } -// -// value_type r1=value_type(rng_->getUInt32(5)), r2=value_type(rng_->getUInt32(5)); -// -// sparse->lerp(r1, r2, B); -// dense->lerp(r1, r2, BDense); -// Compare(*dense, *sparse, "lerp", 1e-4); -// } -// -// } else if (r == 10) { -// -// delete sparse; -// delete dense; -// size_type nrows = rng_->getUInt32(maxMatrixSize), ncols = rng_->getUInt32(maxMatrixSize); -// sparse = new SM(ncols, nrows); -// dense = new DM(ncols, nrows); -// Compare(*dense, *sparse, "constructor(rng_->get() % 32, rng_->get() % 32)"); -// -// } else if (r == 11) { -// -// delete sparse; -// delete dense; -// sparse = new SM(); -// dense = new DM(); -// Compare(*dense, *sparse, "constructor()"); -// -// } else if (r == 12) { -// -// delete sparse; -// delete dense; -// sparse = new SM(0,0); -// dense = new DM(0,0); -// Compare(*dense, *sparse, "constructor(0,0)"); -// -// } else if (r == 13) { -// -// SM sm2(sparse->nRows(), sparse->nCols()); -// DM sm2Dense(dense->nrows, dense->ncols); -// Compare(sm2Dense, sm2, "constructor(dense->nRows(), dense->nCols())"); -// -// ITER_2(sm2.nRows(), sm2.nCols()) { -// value_type r = 1+rng_->getUInt32(256); -// sm2.setNonZero(i,j, r); -// sm2Dense.setNonZero(i,j, r); -// } -// sparse->elementApply(sm2, std::plus()); -// dense->add(sm2Dense); -// Compare(*dense, *sparse, "add"); -// -// } else if (r == 14) { -// -// if (sparse->nRows() > 0) { -// vector row(sparse->nCols()); -// generate(row.begin(), row.end(), rand_init(rng_, 70)); -// size_type r = rng_->getUInt32(sparse->nRows()); -// sparse->elementRowApply(r, std::plus(), row.begin()); -// dense->add(r, row.begin()); -// Compare(*dense, *sparse, "add(randomR, row.begin())"); -// } -// -// } else if (r == 15) { -// -// SM B(sparse->nCols(), sparse->nRows()); -// DM BDense(dense->ncols, dense->nrows); -// Compare(BDense, B, "constructor(sm->nCols(), sm->nRows())"); -// sparse->transpose(B); -// dense->transpose(BDense); -// Compare(*dense, *sparse, "transpose"); -// -// } else if (r == 16) { -// -// /* -// vector x(sparse->nCols()), y(sparse->nRows()); -// generate(x.begin(), x.end(), rand_init(rng_, 50)); -// sparse->L2Dist(x.begin(), y.begin()); -// dense->L2Dist(x.begin(), y.begin()); -// Compare(*dense, *sparse, "L2Dist", 1e-4); -// */ -// -// } else if (r == 17) { -// -// /* -// vector x(sparse->nCols()); -// pair closest; -// generate(x.begin(), x.end(), rand_init(rng_, 50)); -// sparse->L2Nearest(x.begin(), closest); -// dense->L2Nearest(x.begin(), closest); -// Compare(*dense, *sparse, "L2Nearest", 1e-4); -// */ -// -// } else if (r == 18) { -// -// /* -// vector x(sparse->nCols()), y(sparse->nRows()); -// generate(x.begin(), x.end(), rand_init(rng_, 50)); -// sparse->vecDist(x.begin(), y.begin()); -// dense->vecDist(x.begin(), y.begin()); -// Compare(*dense, *sparse, "vecDist", 1e-4); -// */ -// -// } else if(r == 19) { -// -// /* -// if(sparse->nRows() > 0) { -// vector x(sparse->nCols()); -// generate(x.begin(), x.end(), rand_init(rng_, 50)); -// size_type randInt=rng_->get() % sparse->nRows(); -// sparse->rowDistSquared(randInt, x.begin()); -// dense->rowDistSquared(randInt, x.begin()); -// Compare(*dense, *sparse, "rowDistSquared", 1e-4); -// } -// */ -// -// } else if (r == 20) { -// -// /* -// vector x(sparse->nCols()); -// generate(x.begin(), x.end(), rand_init(rng_, 50)); -// sparse->closestEuclidean(x.begin()); -// dense->closestEuclidean(x.begin()); -// Compare(*dense, *sparse, "closestEuclidean", 1e-4); -// */ -// -// } else if (r== 21) { -// -// /* -// vector x(sparse->nCols()); -// generate(x.begin(), x.end(), rand_init(rng_, 50)); -// for (size_type n = 0; n < sparse->nCols(); ++n) -// x.push_back(value_type(rng_->get() % 256)); -// sparse->dotNearest(x.begin()); -// dense->dotNearest(x.begin()); -// Compare(*dense, *sparse, "dotNearest", 1e-4); -// */ -// -// } else if (r == 22) { -// -// vector x(sparse->nCols()), y(sparse->nRows()); -// generate(x.begin(), x.end(), rand_init(rng_, 50)); -// sparse->rightVecProd(x.begin(), y.begin()); -// dense->rightVecProd(x.begin(), y.begin()); -// Compare(*dense, *sparse, "rightVecProd", 1e-4); -// -// } else if (r == 23) { -// -// vector x(sparse->nCols()), y(sparse->nRows()); -// generate(x.begin(), x.end(), rand_init(rng_, 50)); -// sparse->vecMaxProd(x.begin(), y.begin()); -// dense->vecMaxProd(x.begin(), y.begin()); -// Compare(*dense, *sparse, "vecMaxProd", 1e-4); -// -// } else if (r == 24) { -// -// vector x(sparse->nCols()), y(sparse->nRows()); -// generate(x.begin(), x.end(), rand_init(rng_, 50)); -// sparse->vecMaxAtNZ(x.begin(), y.begin()); -// dense->vecMaxAtNZ(x.begin(), y.begin()); -// Compare(*dense, *sparse, "vecMaxAtNZ", 1e-4); -// -// } else if (r == 25) { -// -// if (sparse->nRows() > 0) { -// vector x(sparse->nCols()); -// generate(x.begin(), x.end(), rand_init(rng_, 50)); -// size_type row = rng_->getUInt32(sparse->nRows()); -// value_type r1=value_type(rng_->getUInt32(256)), r2=value_type(rng_->getUInt32(256)); -// sparse->axby(row, r1, r2, x.begin()); -// dense->axby(row, r1, r2, x.begin()); -// Compare(*dense, *sparse, "axby", 1e-4); -// } -// -// } else if (r == 26) { -// -// vector x(sparse->nCols()); -// generate(x.begin(), x.end(), rand_init(rng_, 50)); -// value_type r1=value_type(rng_->getUInt32(256)), r2=value_type(rng_->getUInt32(256)); -// sparse->axby(r1, r2, x.begin()); -// dense->axby(r1, r2, x.begin()); -// Compare(*dense, *sparse, "axby 2", 1e-4); -// -// } else if (r == 27) { -// -// /* -// vector x(sparse->nCols()), y(sparse->nRows()); -// generate(x.begin(), x.end(), rand_init(rng_, 50)); -// sparse->rowMax(x.begin(), y.begin()); -// dense->rowMax(x.begin(), y.begin()); -// Compare(*dense, *sparse, "rowMax"); -// */ -// -// } else if (r == 28) { -// -// vector< pair > y(sparse->nRows()); -// sparse->rowMax(y.begin()); -// dense->rowMax(y.begin()); -// Compare(*dense, *sparse, "rowMax 2"); -// -// } else if (r == 29) { -// -// vector< pair > y(sparse->nCols()); -// sparse->colMax(y.begin()); -// dense->colMax(y.begin()); -// Compare(*dense, *sparse, "colMax"); -// -// } else if (r == 30) { -// -// bool exact = true; -// sparse->normalizeRows(exact); -// dense->normalizeRows(exact); -// Compare(*dense, *sparse, "normalizeRows", 1e-4); -// -// } else if (r == 31) { -// -// vector x(sparse->nCols()), y(sparse->nRows()); -// generate(x.begin(), x.end(), rand_init(rng_, 50)); -// sparse->rightVecProdAtNZ(x.begin(), y.begin()); -// dense->rowProd(x.begin(), y.begin()); -// Compare(*dense, *sparse, "rowProd", 1e-4); -// -// } else if (r == 32) { -// -// vector x(sparse->nCols()), y(sparse->nRows()); -// generate(x.begin(), x.end(), rand_init(rng_, 50)); -// value_type theRandom=value_type(rng_->getUInt32(256)); -// sparse->rightVecProdAtNZ(x.begin(), y.begin(), theRandom); -// dense->rowProd(x.begin(), y.begin(), theRandom); -// Compare(*dense, *sparse, "rowProd 2", 1e-4); -// -// } else if (r == 33) { -// -// //size_type row; -// //value_type init; -// -// if (sparse->nRows() != 0) { -// -// /* -// row = rng_->get() % sparse->nRows(); -// init = (rng_->get() % 32768)/32768.0 + .001; -// -// size_type switcher = rng_->get() % 4; -// -// if (switcher == 0) { -// sparse->accumulateRowNZ(row, multiplies(), init); -// dense->accumulateRowNZ(row, multiplies(), init); -// Compare(*dense, *sparse, "accumulateRowNZ with multiplies", 1e-4); -// } else if (switcher == 1) { -// sparse->accumulateRowNZ(row, plus(), init); -// dense->accumulateRowNZ(row, plus(), init); -// Compare(*dense, *sparse, "accumulateRowNZ with plus", 1e-4); -// } else if (switcher == 2) { -// sparse->accumulateRowNZ(row, minus(), init); -// dense->accumulateRowNZ(row, minus(), init); -// Compare(*dense, *sparse, "accumulateRowNZ with minus", 1e-4); -// } else if (switcher == 3) { -// sparse->accumulateRowNZ(row, nupic::Max, init); -// dense->accumulateRowNZ(row, nupic::Max, init); -// Compare(*dense, *sparse, "accumulateRowNZ with Max", 1e-4); -// } -// */ -// } -// -// } else if (r == 34) { -// -// //size_type row; -// //value_type init; -// -// if (sparse->nRows() != 0) { -// /* -// row = rng_->get() % sparse->nRows(); -// init = (rng_->get() % 32768)/32768.0 + .001; -// -// size_type switcher = rng_->get() % 4; -// -// if (switcher == 0) { -// sparse->accumulate(row, multiplies(), init); -// dense->accumulate(row, multiplies(), init); -// Compare(*dense, *sparse, "accumulateRowNZ with multiplies", 1e-4); -// } else if (switcher == 1) { -// sparse->accumulate(row, plus(), init); -// dense->accumulate(row, plus(), init); -// Compare(*dense, *sparse, "accumulateRowNZ with plus", 1e-4); -// } else if (switcher == 2) { -// sparse->accumulate(row, minus(), init); -// dense->accumulate(row, minus(), init); -// Compare(*dense, *sparse, "accumulateRowNZ with minus", 1e-4); -// } else if (switcher == 3) { -// sparse->accumulate(row, nupic::Max, init); -// dense->accumulate(row, nupic::Max, init); -// Compare(*dense, *sparse, "accumulateRowNZ with Max", 1e-4); -// } -// */ -// } -// -// } else if (r == 35) { -// -// if(dense->ncols > 0 && dense->nrows > 0) { -// -// size_type randomTemp = rng_->getUInt32(maxMatrixSize); -// SM B(0, randomTemp); -// SM C(sparse->nRows(), randomTemp); -// DM BDense(0, randomTemp); -// DM CDense(dense->nrows, randomTemp); -// -// vector x(randomTemp); -// -// for (size_type n=0; n < sparse->nCols(); n++) { -// generate(x.begin(), x.end(), rand_init(rng_, 50)); -// B.addRow(x.begin()); -// BDense.addRow(x.begin()); -// } -// -// sparse->multiply( B, C); -// dense->multiply( BDense, CDense); -// Compare(*dense, *sparse, "multiply", 1e-4); -// } -// -// } else if (r == 36) { -// -// if (sparse->nRows() > 0 && sparse->nCols() > 0) { -// -// vector indices, indicesDense; -// vector values, valuesDense; -// -// size_type r = rng_->getUInt32(sparse->nRows()); -// -// sparse->getRowToSparse(r, back_inserter(indices), -// back_inserter(values)); -// -// dense->getRowToSparse(r, back_inserter(indicesDense), -// back_inserter(valuesDense)); -// -// sparse->findRow((size_type)indices.size(), -// indices.begin(), -// values.begin()); -// -// dense->findRow((size_type)indicesDense.size(), -// indicesDense.begin(), -// valuesDense.begin()); -// -// CompareVectors((size_type)indices.size(), indices.begin(), -// indicesDense.begin(), -// "findRow indices"); -// -// CompareVectors((size_type)values.size(), values.begin(), -// valuesDense.begin(), -// "findRow values"); -// } -// } -// } -// -// delete sparse; -// delete dense; -// } -// +#define TEST_LOOP(M) \ + for (nrows = 0, ncols = M, zr = 15; nrows < M; \ + nrows += M / 10, ncols -= M / 10, zr = ncols / 10) \ + // + // #define M 64 + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_construction() + // { + // UInt ncols, nrows, zr; + // + // { // Deallocate an empty matrix + // SparseMat sm; + // Test("empty matrix 1", sm.isZero(), true); + // } + // + // { // Compact and deallocate an empty matrix + // SparseMat sm; + // Test("empty matrix 2", sm.isZero(), true); + // sm.compact(); + // Test("empty matrix 2 - compact", sm.isZero(), true); + // } + // + // { // De-compact and deallocate an empty matrix + // SparseMat sm; + // Test("empty matrix 3", sm.isZero(), true); + // sm.decompact(); + // Test("empty matrix 3 - decompact", sm.isZero(), true); + // } + // + // { // De-compact/compact and deallocate an empty matrix + // SparseMat sm; + // Test("empty matrix 4", sm.isZero(), true); + // sm.decompact(); + // Test("empty matrix 4 - decompact", sm.isZero(), true); + // sm.compact(); + // Test("empty matrix 4 - compact", sm.isZero(), true); + // } + // + // { // Compact and deallocate an empty matrix + // SparseMat sm(0, 0); + // Test("empty matrix 5", sm.isZero(), true); + // sm.compact(); + // Test("empty matrix 5 - compact", sm.isZero(), true); + // } + // + // { // De-compact and deallocate an empty matrix + // SparseMat sm(0, 0); + // Test("empty matrix 6", sm.isZero(), true); + // sm.decompact(); + // Test("empty matrix 6 - decompact", sm.isZero(), true); + // } + // + // { // De-compact/compact and deallocate an empty matrix + // SparseMat sm(0, 0); + // Test("empty matrix 7", sm.isZero(), true); + // sm.decompact(); + // Test("empty matrix 7 - decompact", sm.isZero(), true); + // sm.compact(); + // Test("empty matrix 7 - compact", sm.isZero(), true); + // } + // + // { // Rectangular shape, no zeros + // nrows = 3; ncols = 4; + // DenseMat dense(nrows, ncols, 0); + // SparseMat sm(nrows, ncols, dense.begin()); + // Compare(dense, sm, "ctor 1"); + // Test("isZero 1", sm.isZero(), false); + // sm.compact(); + // Compare(dense, sm, "ctor 1 - compact"); + // Test("isZero 1 - compact", sm.isZero(), false); + // } + // + // { // Rectangular shape, zeros + // nrows = 3; ncols = 4; + // DenseMat dense(nrows, ncols, 2); + // SparseMat sm(nrows, ncols, dense.begin()); + // Compare(dense, sm, "ctor 2"); + // Test("isZero 2", sm.isZero(), false); + // sm.compact(); + // Compare(dense, sm, "ctor 2 - compact"); + // Test("isZero 2 - compact", sm.isZero(), false); + // } + // + // { // Rectangular the other way, no zeros + // nrows = 4; ncols = 3; + // DenseMat dense(nrows, ncols, 0); + // SparseMat sm(nrows, ncols, dense.begin()); + // Compare(dense, sm, "ctor 3"); + // Test("isZero 3", sm.isZero(), false); + // sm.compact(); + // Compare(dense, sm, "ctor 3 - compact"); + // Test("isZero 3 - compact", sm.isZero(), false); + // } + // + // { // Rectangular the other way, zeros + // nrows = 6; ncols = 5; + // DenseMat dense(nrows, ncols, 2); + // SparseMat sm(nrows, ncols, dense.begin()); + // Compare(dense, sm, "ctor 4"); + // Test("isZero 4", sm.isZero(), false); + // sm.compact(); + // Compare(dense, sm, "ctor 4 - compact"); + // Test("isZero 4 - compact", sm.isZero(), false); + // } + // + // { // Empty rows in the middle and zeros + // nrows = 3; ncols = 4; + // DenseMat dense(nrows, ncols, 2, false, true); + // SparseMat sm(nrows, ncols, dense.begin()); + // Compare(dense, sm, "ctor 5"); + // Test("isZero 5", sm.isZero(), false); + // sm.compact(); + // Compare(dense, sm, "ctor 5 - compact"); + // Test("isZero 5 - compact", sm.isZero(), false); + // } + // + // { // Empty rows in the middle and zeros + // nrows = 7; ncols = 5; + // DenseMat dense(nrows, ncols, 2, false, true); + // SparseMat sm(nrows, ncols, dense.begin()); + // Compare(dense, sm, "ctor 6"); + // Test("isZero 6", sm.isZero(), false); + // sm.compact(); + // Compare(dense, sm, "ctor 6 - compact"); + // Test("isZero 6 - compact", sm.isZero(), false); + // } + // + // { // Small values, zeros and empty rows + // nrows = 7; ncols = 5; + // DenseMat dense(nrows, ncols, 2, true, true, rng_); + // SparseMat sm(nrows, ncols, dense.begin()); + // Compare(dense, sm, "ctor 7"); + // Test("isZero 7", sm.isZero(), false); + // sm.compact(); + // Compare(dense, sm, "ctor 7 - compact"); + // Test("isZero 7 - compact", sm.isZero(), false); + // } + // + // { // Small values, zeros and empty rows, other constructor + // nrows = 10; ncols = 10; + // DenseMat dense(nrows, ncols, 2, true, true, rng_); + // SparseMat sm(0, ncols); + // for (UInt i = 0; i < nrows; ++i) + // sm.addRow(dense.begin(i)); + // Compare(dense, sm, "ctor 8"); + // Test("isZero 8", sm.isZero(), false); + // sm.compact(); + // Compare(dense, sm, "ctor 8 - compact"); + // Test("isZero 8 - compact", sm.isZero(), false); + // } + // + // { // Zero first row + // nrows = 10; ncols = 10; + // DenseMat dense(nrows, ncols, 2, true, true, rng_); + // for (UInt i = 0; i < ncols; ++i) + // dense.at(0, i) = 0; + // SparseMat sm(0, ncols); + // for (UInt i = 0; i < nrows; ++i) + // sm.addRow(dense.begin(i)); + // Compare(dense, sm, "ctor 8B"); + // Test("isZero 8B", sm.isZero(), false); + // sm.compact(); + // Compare(dense, sm, "ctor 8B - compact"); + // Test("isZero 8B - compact", sm.isZero(), false); + // } + // + // { // Small values, zeros and empty rows, other constructor + // nrows = 10; ncols = 10; + // DenseMat dense(nrows, ncols, 2, true, true, rng_); + // SparseMat sm(0, ncols); + // for (UInt i = 0; i < nrows; ++i) + // sm.addRow(dense.begin(i)); + // Compare(dense, sm, "ctor 9"); + // Test("isZero 9", sm.isZero(), false); + // sm.compact(); + // Compare(dense, sm, "ctor 9 - compact"); + // Test("isZero 9 - compact", sm.isZero(), false); + // } + // + // { // Small values, zeros and empty rows, other constructor + // nrows = 10; ncols = 10; + // DenseMat dense(nrows, ncols, 2, true, true, rng_); + // SparseMat sm(0, ncols); + // for (UInt i = 0; i < nrows; ++i) + // sm.addRow(dense.begin(i)); + // Compare(dense, sm, "ctor 10"); + // Test("isZero 10", sm.isZero(), false); + // sm.compact(); + // Compare(dense, sm, "ctor 10 - compact"); + // Test("isZero 10 - compact", sm.isZero(), false); + // } + // + // { // Empty + // DenseMat dense(10, 10, 10); + // SparseMat sm(10, 10, dense.begin()); + // Compare(dense, sm, "ctor from empty dense - non compact"); + // Test("isZero 11", sm.isZero(), true); + // sm.compact(); + // Compare(dense, sm, "ctor from empty dense - compact"); + // Test("isZero 11 - compact", sm.isZero(), true); + // } + // + // { // Empty, other constructor + // DenseMat dense(10, 10, 10); + // SparseMat sm(0, 10); + // for (UInt i = 0; i < nrows; ++i) + // sm.addRow(dense.begin(i)); + // Compare(dense, sm, "ctor from empty dense - non compact"); + // Test("isZero 12", sm.isZero(), true); + // sm.compact(); + // Compare(dense, sm, "ctor from empty dense - compact"); + // Test("isZero 12 - compact", sm.isZero(), true); + // } + // + // { // Full + // DenseMat dense(10, 10, 0); + // SparseMat sm(10, 10, dense.begin()); + // Compare(dense, sm, "ctor from full dense - non compact"); + // Test("isZero 13", sm.isZero(), false); + // sm.compact(); + // Compare(dense, sm, "ctor from full dense - compact"); + // Test("isZero 13 - compact", sm.isZero(), false); + // } + // + // { // Various rectangular sizes + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sm(nrows, ncols, dense.begin()); + // + // sm.decompact(); + // + // { + // std::stringstream str; + // str << "ctor A " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // Compare(dense, sm, str.str().c_str()); + // } + // + // sm.compact(); + // + // { + // std::stringstream str; + // str << "ctor B " << nrows << "X" << ncols << "/" << zr + // << " - compact"; + // Compare(dense, sm, str.str().c_str()); + // } + // } + // } + // + // /* + // try { + // SparseMatrix sme1(-1, 0); + // Test("SparseMatrix::SparseMatrix(Int, Int) exception 2", true, false); + // } catch (std::exception&) { + // Test("SparseMatrix::SparseMatrix(Int, Int) exception 2", true, true); + // } + // + // try { + // SparseMatrix sme1(1, -1); + // Test("SparseMatrix::SparseMatrix(Int, Int) exception 3", true, false); + // } catch (std::exception&) { + // Test("SparseMatrix::SparseMatrix(Int, Int) exception 3", true, true); + // } + // + // try { + // SparseMatrix sme1(1, -1); + // Test("SparseMatrix::SparseMatrix(Int, Int) exception 4", true, false); + // } catch (std::exception&) { + // Test("SparseMatrix::SparseMatrix(Int, Int) exception 4", true, true); + // } + // + // std::vector mat(16, 0); + // + // try { + // SparseMatrix sme1(-1, 1, mat.begin()); + // Test("SparseMatrix::SparseMatrix(Int, Int, Iter) exception 1", true, + // false); + // } catch (std::exception&) { + // Test("SparseMatrix::SparseMatrix(Int, Iter) exception 1", true, true); + // } + // + // try { + // SparseMatrix sme1(1, -1, mat.begin()); + // Test("SparseMatrix::SparseMatrix(Int, Int, Iter) exception 2", true, + // false); + // } catch (std::exception&) { + // Test("SparseMatrix::SparseMatrix(Int, Iter) exception 2", true, true); + // } + // */ + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_copy() + // { + // { + // SparseMat sm, sm2; + // DenseMat dense, dense2; + // sm2.copy(sm); + // dense2.copy(dense); + // Compare(dense2, sm2, "SparseMatrix::copy - empty matrix"); + // } + // + // { + // SparseMat sm(0,0), sm2; + // DenseMat dense(0,0), dense2; + // sm2.copy(sm); + // dense2.copy(dense); + // Compare(dense2, sm2, "SparseMatrix::copy - empty matrix 2"); + // } + // + // { + // SparseMat sm(5, 4), sm2; + // DenseMat dense(5, 4), dense2; + // sm2.copy(sm); + // dense2.copy(dense); + // Compare(dense2, sm2, "SparseMatrix::copy - empty matrix 3"); + // } + // + // { + // DenseMat dense(5, 4, 2, false, false), dense2; + // SparseMat sm(5, 4, dense.begin()), sm2; + // sm2.copy(sm); + // dense2.copy(dense); + // Compare(dense2, sm2, "SparseMatrix::copy - 1"); + // } + // + // { + // DenseMat dense(5, 4, 2, false, true), dense2; + // SparseMat sm(5, 4, dense.begin()), sm2; + // sm2.copy(sm); + // dense2.copy(dense); + // Compare(dense2, sm2, "SparseMatrix::copy - 1"); + // } + // + // { + // DenseMat dense(5, 4, 2, true, false, rng_), dense2; + // SparseMat sm(5, 4, dense.begin()), sm2; + // sm2.copy(sm); + // dense2.copy(dense); + // Compare(dense2, sm2, "SparseMatrix::copy - 1"); + // } + // + // { + // DenseMat dense(5, 4, 2, true, true, rng_), dense2; + // SparseMat sm(5, 4, dense.begin()), sm2; + // sm2.copy(sm); + // dense2.copy(dense); + // Compare(dense2, sm2, "SparseMatrix::copy - 1"); + // } + // } + // + // //-------------------------------------------------------------------------------- + // /** + // * TC: Dense::toCSR matches SparseMatrix::toCSR (in stress test) + // * TC: Dense::fromCSR matches SparseMatrix::fromCSR (in stress test) + // * TC: reading in smaller matrix resizes the sparse matrix correctly + // * TC: reading in larger matrix resizes the sparse matrix correctly + // * TC: empty rows are stored correctly in stream + // * TC: empty rows are read correctly from stream + // * TC: empty matrix is written and read correctly + // * TC: values below epsilon are handled correctly in toCSR + // * TC: values below epsilon are handled correctly in fromCSR + // * TC: toCSR exception if bad stream + // * TC: fromCSR exception if bad stream + // * TC: fromCSR exception if bad 'csr' tag + // * TC: fromCSR exception if nrows < 0 + // * TC: fromCSR exception if ncols <= 0 + // * TC: fromCSR exception if nnz < 0 or nnz > nrows * ncols + // * TC: fromCSR exception if nnzr < 0 or nnzr > ncols + // * TC: fromCSR exception if j < 0 or j >= ncols + // * TC: stress test + // * TC: allocate_ exceptions + // * TC: addRow exceptions + // * TC: compact exceptions + // */ + // void SparseMatrixUnitTest::unit_test_csr() + // { + // UInt ncols, nrows, zr; + // + // { // Empty matrix + // // ... is written correctly + // SparseMat sm(3, 4); + // std::stringstream buf; + // sm.toCSR(buf); + // Test("SparseMatrix::toCSR empty", buf.str() == "sm_csr_1.5 12 3 4 0 0 + // 0 0 ", true); + // + // // ... is read correctly + // SparseMat sm2; + // sm2.fromCSR(buf); + // std::stringstream buf2; + // buf2 << "csr 3 4 0 0 0 0"; + // DenseMat dense; + // dense.fromCSR(buf2); + // Compare(dense, sm2, "fromCSR/empty"); + // } + // + // { // Is resizing happening correctly? + // DenseMat dense(3, 4, 2); + // SparseMat sm(3, 4, dense.begin()); + // + // { // When reading in smaller size matrix? + // std::stringstream buf1, buf2; + // buf1 << "csr -1 3 3 9 3 0 1 1 2 2 3 3 0 11 1 12 2 13 3 0 21 1 22 2 + // 23"; sm.fromCSR(buf1); buf2 << "csr 3 3 9 3 0 1 1 2 2 3 3 0 11 1 + // 12 2 13 3 0 21 1 22 2 23"; dense.fromCSR(buf2); Compare(dense, sm, + // "fromCSR/redim/1"); + // } + // + // { // When reading in larger size matrix? + // std::stringstream buf1, buf2; + // buf1 << "csr -1 4 5 20 " + // "5 0 1 1 2 2 3 3 4 4 5 " + // "5 0 11 1 12 2 13 3 14 4 15 " + // "5 0 21 1 22 2 23 3 24 4 25 " + // "5 0 31 1 32 2 33 3 34 4 35"; + // sm.fromCSR(buf1); + // buf2 << "csr 4 5 20 " + // "5 0 1 1 2 2 3 3 4 4 5 " + // "5 0 11 1 12 2 13 3 14 4 15 " + // "5 0 21 1 22 2 23 3 24 4 25 " + // "5 0 31 1 32 2 33 3 34 4 35"; + // dense.fromCSR(buf2); + // Compare(dense, sm, "fromCSR/redim/2"); + // } + // + // { // Empty rows are read in correctly + // std::stringstream buf1, buf2; + // buf1 << "csr -1 4 5 15 " + // "5 0 1 1 2 2 3 3 4 4 5 " + // "0 " + // "5 0 21 1 22 2 23 3 24 4 25 " + // "5 0 31 1 32 2 33 3 34 4 35"; + // sm.fromCSR(buf1); + // buf2 << "csr 4 5 15 " + // "5 0 1 1 2 2 3 3 4 4 5 " + // "0 " + // "5 0 21 1 22 2 23 3 24 4 25 " + // "5 0 31 1 32 2 33 3 34 4 35"; + // dense.fromCSR(buf2); + // Compare(dense, sm, "fromCSR/redim/3"); + // } + // } + // + // { // Initialize fromDenseMat then again fromCSR + // DenseMat dense(3, 4, 2); + // SparseMat sm(3, 4, dense.begin()); + // std::stringstream buf1; + // buf1 << "csr -1 3 3 9 3 0 1 1 2 2 3 3 0 11 1 12 2 13 3 0 21 1 22 2 + // 23"; sm.fromCSR(buf1); + // } + // + // { // ... and vice-versa, fromCSR, followed by fromDense + // DenseMat dense(3, 4, 2); + // SparseMat sm(3, 4); + // std::stringstream buf1; + // buf1 << "csr -1 3 3 9 3 0 1 1 2 2 3 3 0 11 1 12 2 13 3 0 21 1 22 2 + // 23"; sm.fromCSR(buf1); sm.fromDense(3, 4, dense.begin()); + // } + // + // { // Values below epsilon + // + // // ... are written correctly (not written) + // nrows = 128; ncols = 256; + // UInt nnz = ncols/2; + // DenseMat dense(nrows, ncols, nnz, true, true, rng_); + // ITER_2(128, 256) + // dense.at(i,j) /= 1000; + // SparseMat sm(nrows, ncols, dense.begin()); + // std::stringstream buf; + // sm.toCSR(buf); + // std::string tag; buf >> tag; + // buf >> nrows >> ncols >> nnz; + // ITER_1(nrows) { + // buf >> nnz; + // UInt j; Real val; + // ITER_1(nnz) { + // buf >> j >> val; + // if (nupic::nearlyZero(val)) + // Test("SparseMatrix::toCSR/small values", true, false); + // } + // } + // + // // ... are read correctly + // std::stringstream buf1; + // buf1 << "csr -1 3 4 6 " + // << "2 0 " << nupic::Epsilon/2 << " 1 1 " + // << "2 0 " << nupic::Epsilon/2 << " 1 " << nupic::Epsilon/2 << " " + // << "2 0 1 1 1"; + // SparseMat sm2(4, 4); + // sm2.fromCSR(buf1); + // } + // + // { // stress test, matching against Dense::toCSR and Dense::fromCSR + // TEST_LOOP(M) { + // + // DenseMat dense3(nrows, ncols, zr); + // SparseMat sm3(nrows, ncols, dense3.begin()); + // + // std::stringstream buf; + // sm3.toCSR(buf); + // sm3.fromCSR(buf); + // + // { + // std::stringstream str; + // str << "toCSR/fromCSR A " << nrows << "X" << ncols << "/" << zr; + // Compare(dense3, sm3, str.str().c_str()); + // } + // + // SparseMat sm4(nrows, ncols); + // std::stringstream buf1; + // sm3.toCSR(buf1); + // sm4.fromCSR(buf1); + // + // { + // std::stringstream str; + // str << "toCSR/fromCSR B " << nrows << "X" << ncols << "/" << zr; + // Compare(dense3, sm4, str.str().c_str()); + // } + // + // sm4.decompact(); + // std::stringstream buf2; + // sm3.toCSR(buf2); + // sm4.fromCSR(buf2); + // + // { + // std::stringstream str; + // str << "toCSR/fromCSR C " << nrows << "X" << ncols << "/" << zr; + // Compare(dense3, sm4, str.str().c_str()); + // } + // + // std::stringstream buf3; + // sm4.toCSR(buf3); + // sm4.fromCSR(buf3); + // + // { + // std::stringstream str; + // str << "toCSR/fromCSR D " << nrows << "X" << ncols << "/" << zr; + // Compare(dense3, sm4, str.str().c_str()); + // } + // } + // } + // + // /* + // // Exceptions + // SparseMatrix sme1(1, 1); + // + // { + // stringstream s1; + // s1 << "ijv"; + // try { + // sme1.fromCSR(s1); + // Test("SparseMatrix::fromCSR() exception 1", true, false); + // } catch (std::runtime_error&) { + // Test("SparseMatrix::fromCSR() exception 1", true, true); + // } + // } + // + // { + // stringstream s1; + // s1 << "csr -1 -1"; + // try { + // sme1.fromCSR(s1); + // Test("SparseMatrix::fromCSR() exception 2", true, false); + // } catch (std::runtime_error&) { + // Test("SparseMatrix::fromCSR() exception 2", true, true); + // } + // } + // + // { + // stringstream s1; + // s1 << "csr -1 1 -1"; + // try { + // sme1.fromCSR(s1); + // Test("SparseMatrix::fromCSR() exception 3", true, false); + // } catch (std::runtime_error&) { + // Test("SparseMatrix::fromCSR() exception 3", true, true); + // } + // } + // + // { + // stringstream s1; + // s1 << "csr -1 1 0"; + // try { + // sme1.fromCSR(s1); + // Test("SparseMatrix::fromCSR() exception 4", true, false); + // } catch (std::runtime_error&) { + // Test("SparseMatrix::fromCSR() exception 4", true, true); + // } + // } + // + // { + // stringstream s1; + // s1 << "csr -1 4 3 -1"; + // try { + // sme1.fromCSR(s1); + // Test("SparseMatrix::fromCSR() exception 5", true, false); + // } catch (std::runtime_error&) { + // Test("SparseMatrix::fromCSR() exception 5", true, true); + // } + // } + // + // { + // stringstream s1; + // s1 << "csr -1 4 3 15"; + // try { + // sme1.fromCSR(s1); + // Test("SparseMatrix::fromCSR() exception 6", true, false); + // } catch (std::runtime_error&) { + // Test("SparseMatrix::fromCSR() exception 6", true, true); + // } + // } + // + // { + // stringstream s1; + // s1 << "csr -1 2 3 1 5"; + // try { + // sme1.fromCSR(s1); + // Test("SparseMatrix::fromCSR() exception 7", true, false); + // } catch (std::runtime_error&) { + // Test("SparseMatrix::fromCSR() exception 7", true, true); + // } + // } + // + // { + // stringstream s1; + // s1 << "csr -1 2 3 1 0 1 -1"; + // try { + // sme1.fromCSR(s1); + // Test("SparseMatrix::fromCSR() exception 8", true, false); + // } catch (std::runtime_error&) { + // Test("SparseMatrix::fromCSR() exception 8", true, true); + // } + // } + // + // { + // stringstream s1; + // s1 << "csr -1 2 3 1 0 1 4"; + // try { + // sme1.fromCSR(s1); + // Test("SparseMatrix::fromCSR() exception 9", true, false); + // } catch (std::runtime_error&) { + // Test("SparseMatrix::fromCSR() exception 9", true, true); + // } + // } + // */ + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_dense() + // { + // UInt ncols = 5, nrows = 7, zr = 2; + // + // DenseMat dense(nrows, ncols, zr); + // DenseMat dense2(nrows+1, ncols+1, zr+1); + // + // { // fromDense + // SparseMat sparse(nrows, ncols); + // sparse.fromDense(nrows, ncols, dense.begin()); + // Compare(dense, sparse, "fromDenseMat 1"); + // } + // + // { // fromDense + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // sparse.fromDense(nrows+1, ncols+1, dense2.begin()); + // Compare(dense2, sparse, "fromDenseMat 2"); + // + // sparse.decompact(); + // sparse.fromDense(nrows, ncols, dense.begin()); + // Compare(dense, sparse, "fromDenseMat 3"); + // + // sparse.compact(); + // sparse.fromDense(nrows+1, ncols+1, dense2.begin()); + // Compare(dense2, sparse, "fromDenseMat 4"); + // + // std::vector mat((nrows+1)*(ncols+1), 0); + // + // sparse.toDense(mat.begin()); + // sparse.fromDense(nrows+1, ncols+1, mat.begin()); + // Compare(dense2, sparse, "toDenseMat 1"); + // } + // + // { + // TEST_LOOP(M) { + // + // DenseMat dense3(nrows, ncols, zr); + // SparseMat sm3(nrows, ncols, dense3.begin()); + // std::vector mat3(nrows*ncols, 0); + // + // sm3.toDense(mat3.begin()); + // sm3.fromDense(nrows, ncols, mat3.begin()); + // + // { + // std::stringstream str; + // str << "toDense/fromDenseMat A " << nrows << "X" << ncols << "/" + // << zr + // << " - non compact"; + // Compare(dense3, sm3, str.str().c_str()); + // } + // + // sm3.compact(); + // + // { + // std::stringstream str; + // str << "toDense/fromDenseMat B " << nrows << "X" << ncols << "/" + // << zr + // << " - compact"; + // Compare(dense3, sm3, str.str().c_str()); + // } + // } + // } + // + // { // What happens if dense matrix is full? + // nrows = ncols = 10; zr = 0; + // DenseMat dense(nrows, ncols, zr); + // SparseMat sm(nrows, ncols, dense.begin()); + // std::vector mat3(nrows*ncols, 0); + // + // sm.toDense(mat3.begin()); + // sm.fromDense(nrows, ncols, mat3.begin()); + // + // Compare(dense, sm, "toDense/fromDenseMat from dense"); + // } + // + // { // What happens if dense matrix is empty? + // nrows = ncols = 10; zr = 10; + // DenseMat dense(nrows, ncols, zr); + // SparseMat sm(nrows, ncols, dense.begin()); + // std::vector mat3(nrows*ncols, 0); + // + // sm.toDense(mat3.begin()); + // sm.fromDense(nrows, ncols, mat3.begin()); + // + // Compare(dense, sm, "toDense/fromDenseMat from dense"); + // } + // + // { // What happens if there are empty rows? + // nrows = ncols = 10; zr = 2; + // DenseMat dense(nrows, ncols, zr); + // for (UInt i = 0; i < ncols; ++i) + // dense.at(2,i) = dense.at(4,i) = dense.at(9,i) = 0; + // + // SparseMat sm(nrows, ncols, dense.begin()); + // std::vector mat3(nrows*ncols, 0); + // + // sm.toDense(mat3.begin()); + // sm.fromDense(nrows, ncols, mat3.begin()); + // + // Compare(dense, sm, "toDense/fromDenseMat from dense"); + // } + // + // { // Is resizing happening correctly? + // DenseMat dense(3, 4, 2); + // SparseMat sm(3, 4, dense.begin()); + // + // DenseMat dense2(5, 5, 4); + // sm.fromDense(5, 5, dense2.begin()); + // Compare(dense2, sm, "fromDense/redim/1"); + // + // DenseMat dense3(2, 2, 2); + // sm.fromDense(2, 2, dense3.begin()); + // Compare(dense3, sm, "fromDense/redim/2"); + // + // DenseMat dense4(10, 10, 8); + // sm.fromDense(10, 10, dense4.begin()); + // Compare(dense4, sm, "fromDense/redim/3"); + // } + // + // /* + // // Exceptions + // SparseMatrix sme1(1, 1); + // + // try { + // sme1.fromDense(-1, 0, dense.begin()); + // Test("SparseMatrix::fromDense() exception 1", true, false); + // } catch (std::exception&) { + // Test("SparseMatrix::fromDense() exception 1", true, true); + // } + // + // try { + // sme1.fromDense(1, -1, dense.begin()); + // Test("SparseMatrix::fromDense() exception 3", true, false); + // } catch (std::exception&) { + // Test("SparseMatrix::fromDense() exception 3", true, true); + // } + // */ + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_compact() + // { + // UInt ncols, nrows, zr; + // ncols = 5; + // nrows = 7; + // zr = 2; + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sm4(nrows, ncols, dense.begin()); + // + // sm4.decompact(); + // Compare(dense, sm4, "decompact 1"); + // + // sm4.compact(); + // Compare(dense, sm4, "compact 1"); + // + // sm4.decompact(); + // Compare(dense, sm4, "decompact 2"); + // + // sm4.compact(); + // Compare(dense, sm4, "compact 2"); + // + // sm4.decompact(); + // sm4.decompact(); + // Compare(dense, sm4, "decompact twice"); + // + // sm4.compact(); + // sm4.compact(); + // Compare(dense, sm4, "compact twice"); + // + // SparseMat sm5(nrows, ncols, dense.begin()); + // DenseMat dense2(nrows+1, ncols+1, zr+1); + // sm5.fromDense(nrows+1, ncols+1, dense2.begin()); + // sm5.compact(); + // Compare(dense2, sm5, "compact 3"); + // + // { + // TEST_LOOP(M) { + // + // DenseMat dense3(nrows, ncols, zr); + // SparseMat sm3(nrows, ncols, dense3.begin()); + // + // sm3.decompact(); + // + // { + // std::stringstream str; + // str << "compact/decompact A " << nrows << "X" << ncols << "/" << + // zr + // << " - non compact"; + // Compare(dense3, sm3, str.str().c_str()); + // } + // + // sm3.compact(); + // + // { + // std::stringstream str; + // str << "compact/decompact B " << nrows << "X" << ncols << "/" << + // zr + // << " - compact"; + // Compare(dense3, sm3, str.str().c_str()); + // } + // } + // } + // + // { + // nrows = ncols = 10; zr = 0; + // DenseMat dense(nrows, ncols, zr); + // SparseMat sm(nrows, ncols, dense.begin()); + // std::vector mat3(nrows*ncols, 0); + // + // sm.decompact(); + // Compare(dense, sm, "decompact on dense"); + // + // sm.compact(); + // Compare(dense, sm, "compact on dense"); + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_threshold() + // { + // UInt nrows = 7, ncols = 5, zr = 2; + // + // if (0) { // Visual tests + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // cout << "Before thresholding at 50" << endl; + // cout << sparse << endl; + // sparse.threshold(50); + // cout << "After:" << endl; + // cout << sparse << endl; + // } + // + // { + // SparseMat sm; + // DenseMat dense; + // sm.threshold(Real(1.0)); + // dense.threshold(Real(1.0)); + // Compare(dense, sm, "threshold 0A"); + // } + // + // { + // SparseMat sm(0, 0); + // DenseMat dense(0, 0); + // sm.threshold(Real(1.0)); + // dense.threshold(Real(1.0)); + // Compare(dense, sm, "threshold 0B"); + // } + // + // { + // SparseMat sm(nrows, ncols); + // DenseMat dense(nrows, ncols); + // sm.threshold(Real(1.0)); + // dense.threshold(Real(1.0)); + // Compare(dense, sm, "threshold 0C"); + // } + // + // { + // DenseMat dense(nrows, ncols, zr); + // for (UInt i = 0; i < nrows; ++i) + // for (UInt j = 0; j < ncols; ++j) + // dense.at(i,j) = rng_->getReal64(); + // + // SparseMat sm4c(nrows, ncols, dense.begin()); + // + // dense.threshold(Real(.8)); + // sm4c.threshold(Real(.8)); + // Compare(dense, sm4c, "threshold 1"); + // + // sm4c.decompact(); + // sm4c.compact(); + // dense.threshold(Real(.7)); + // sm4c.threshold(Real(.7)); + // Compare(dense, sm4c, "threshold 2"); + // } + // + // { + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sm(nrows, ncols, dense.begin()); + // + // sm.decompact(); + // dense.threshold(Real(.8)); + // sm.threshold(Real(.8)); + // + // { + // std::stringstream str; + // str << "threshold A " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // Compare(dense, sm, str.str().c_str()); + // } + // + // sm.compact(); + // dense.threshold(Real(.7)); + // sm.threshold(Real(.7)); + // + // { + // std::stringstream str; + // str << "threshold B " << nrows << "X" << ncols << "/" << zr + // << " - compact"; + // Compare(dense, sm, str.str().c_str()); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_getRow() + // { + // UInt nrows = 5, ncols = 7, zr = 3, i = 0, k = 0; + // + // if (0) { // Tests for visual inspection + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // cout << sparse << endl; + // for (i = 0; i != nrows; ++i) { + // std::vector dense_row(ncols); + // sparse.getRowToDense(i, dense_row.begin()); + // cout << dense_row << endl; + // } + // } + // + // { + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sm(nrows, ncols, dense.begin()); + // + // for (i = 0; i < nrows; ++i) { + // + // std::stringstream str; + // str << "getRowToSparseMat A " << nrows << "X" << ncols + // << "/" << zr << " " << i; + // + // std::vector ind; std::vector nz; + // sm.getRowToSparse(i, back_inserter(ind), back_inserter(nz)); + // + // std::vector d(ncols, 0); + // for (k = 0; k < ind.size(); ++k) + // d[ind[k]] = nz[k]; + // + // CompareVectors(ncols, d.begin(), dense.begin(i), + // str.str().c_str()); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_getCol() + // { + // UInt nrows = 5, ncols = 7, zr = 3, i = 0, k = 0; + // + // if (0) { // Tests for visual inspection + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // cout << sparse << endl; + // for (i = 0; i != ncols; ++i) { + // std::vector dense_col(nrows); + // sparse.getColToDense(i, dense_col.begin()); + // cout << dense_col << endl; + // } + // } + // + // { + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sm(nrows, ncols, dense.begin()); + // + // for (i = 0; i < nrows; ++i) { + // + // std::stringstream str; + // str << "getRowToSparseMat A " << nrows << "X" << ncols + // << "/" << zr << " " << i; + // + // std::vector ind; std::vector nz; + // sm.getRowToSparse(i, back_inserter(ind), back_inserter(nz)); + // + // std::vector d(ncols, 0); + // for (k = 0; k < ind.size(); ++k) + // d[ind[k]] = nz[k]; + // + // CompareVectors(ncols, d.begin(), dense.begin(i), + // str.str().c_str()); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_transpose() + // { + // UInt ncols, nrows, zr; + // + // { + // nrows = 8; ncols = 4; zr = ncols - 2; + // Dense dense(nrows, ncols, zr, false, true); + // Dense dense2(ncols, nrows); + // SparseMatrix sm(nrows, ncols, dense.begin()); + // SparseMatrix sm2(ncols, nrows); + // dense.transpose(dense2); + // sm.transpose(sm2); + // Compare(dense2, sm2, "transpose 1"); + // } + // + // { + // for (nrows = 1, zr = 15; nrows < 256; nrows += 25, zr = ncols/10) { + // + // ncols = nrows; + // + // DenseMat dense(nrows, ncols, zr); + // DenseMat dense2(ncols, nrows, zr); + // SparseMat sm(nrows, ncols, dense.begin()); + // SparseMat sm2(ncols, nrows, dense2.begin()); + // + // { + // std::stringstream str; + // str << "transpose A " << nrows << "X" << ncols << "/" << zr; + // + // dense.transpose(dense2); + // sm.transpose(sm2); + // + // Compare(dense2, sm2, str.str().c_str()); + // } + // + // { + // std::stringstream str; + // str << "transpose B " << nrows << "X" << ncols << "/" << zr; + // + // dense2.transpose(dense); + // sm2.transpose(sm); + // + // Compare(dense, sm, str.str().c_str()); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_addRowCol() + // { + // // addRow, compact + // UInt nrows = 5, ncols = 7, zr = 3; + // + // if (0) { // Visual, keep + // + // { // Add dense row + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // for (UInt i = 0; i != nrows; ++i) { + // std::vector nz; + // dense.getRowToDense(i, back_inserter(nz)); + // sparse.addRow(nz.begin()); + // } + // + // cout << sparse << endl; + // } + // + // { // Add sparse row + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // for (UInt i = 0; i != nrows; ++i) { + // std::vector ind; + // std::vector nz; + // dense.getRowToSparse(i, back_inserter(ind), back_inserter(nz)); + // sparse.addRow(ind.begin(), ind.end(), nz.begin()); + // } + // + // cout << sparse << endl; + // } + // + // { // Add dense col + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // for (UInt i = 0; i != ncols; ++i) { + // std::vector nz; + // dense.getColToDense(i, back_inserter(nz)); + // cout << "Adding: " << nz << endl; + // sparse.addCol(nz.begin()); + // } + // + // cout << "After adding columns:" << endl; + // cout << sparse << endl; + // } + // + // { // Add sparse col + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // for (UInt i = 0; i != ncols; ++i) { + // std::vector ind; + // std::vector nz; + // dense.getColToSparse(i, back_inserter(ind), back_inserter(nz)); + // sparse.addCol(ind.begin(), ind.end(), nz.begin()); + // } + // + // cout << sparse << endl; + // } + // } + // + // /* + // TEST_LOOP(M) { + // + // { // Add dense row + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // for (UInt i = 0; i != nrows; ++i) { + // std::vector nz; + // dense.getRowToDense(i, back_inserter(nz)); + // sparse.addRow(nz.begin()); + // } + // + // { + // std::stringstream str; + // str << "addRow A " << nrows << "X" << ncols << "/" << zr; + // Compare(dense, sparse, str.str().c_str()); + // } + // } + // + // { // Add sparse row + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // for (UInt i = 0; i != nrows; ++i) { + // std::vector ind; + // std::vector nz; + // dense.getRowToSparse(i, back_inserter(ind), back_inserter(nz)); + // sparse.addRow(ind.begin(), ind.end(), nz.begin()); + // } + // + // { + // std::stringstream str; + // str << "addRow B " << nrows << "X" << ncols << "/" << zr; + // Compare(dense, sparse, str.str().c_str()); + // } + // } + // + // { // Add dense col + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // for (UInt i = 0; i != ncols; ++i) { + // std::vector nz; + // dense.getColToDense(i, back_inserter(nz)); + // sparse.addCol(nz.begin()); + // } + // + // { + // std::stringstream str; + // str << "addCol A " << nrows << "X" << ncols << "/" << zr; + // Compare(dense, sparse, str.str().c_str()); + // } + // } + // + // { // Add sparse col + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // for (UInt i = 0; i != ncols; ++i) { + // std::vector ind; + // std::vector nz; + // dense.getColToSparse(i, back_inserter(ind), back_inserter(nz)); + // sparse.addCol(ind.begin(), ind.end(), nz.begin()); + // } + // + // { + // std::stringstream str; + // str << "addCol B " << nrows << "X" << ncols << "/" << zr; + // Compare(dense, sparse, str.str().c_str()); + // } + // } + // } + // */ + // + // { + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(0, ncols); + // + // for (UInt i = 0; i < nrows; ++i) { + // sparse.addRow(dense.begin(i)); + // sparse.compact(); + // } + // + // sparse.decompact(); + // + // { + // std::stringstream str; + // str << "addRow C " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // Compare(dense, sparse, str.str().c_str()); + // } + // + // sparse.compact(); + // + // { + // std::stringstream str; + // str << "addRow D " << nrows << "X" << ncols << "/" << zr + // << " - compact"; + // Compare(dense, sparse, str.str().c_str()); + // } + // } + // } + // + // { // Test that negative numbers are handled correctly + // nrows = 4; ncols = 8; zr = 2; + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(0, ncols); + // for (UInt i = 0; i < nrows; ++i) + // for (UInt j = 0; j < ncols; ++j) + // dense.at(i,j) *= -1; + // + // for (UInt i = 0; i < nrows; ++i) { + // sparse.addRow(dense.begin(i)); + // sparse.compact(); + // } + // + // { + // std::stringstream str; + // str << "addRow w/ negative numbers A " + // << nrows << "X" << ncols << "/" << zr + // << " - compact"; + // Compare(dense, sparse, str.str().c_str()); + // } + // + // sparse.decompact(); + // + // { + // std::stringstream str; + // str << "addRow w/ negative numbers A " + // << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // Compare(dense, sparse, str.str().c_str()); + // } + // } + // + // // These tests compiled conditionally, because they are + // // based on asserts rather than checks + // + //#ifdef NTA_ASSERTIONS_ON + // + // /* + // { // "Dirty" rows tests + // UInt ncols = 4; + // SparseMat sm(0, ncols); + // std::vector > dirty_col(ncols); + // + // // Duplicate zeros (assertion) + // for (UInt i = 0; i < ncols; ++i) + // dirty_col[i] = make_pair(0, 0); + // try { + // sm.addRow(dirty_col.begin(), dirty_col.end()); + // Test("SparseMatrix dirty cols 1", true, false); + // } catch (std::exception&) { + // Test("SparseMatrix dirty cols 1", true, true); + // } + // + // // Out of order indices (assertion) + // dirty_col[0].first = 3; + // try { + // sm.addRow(dirty_col.begin(), dirty_col.end()); + // Test("SparseMatrix dirty cols 2", true, false); + // } catch (std::exception&) { + // Test("SparseMatrix dirty cols 2", true, true); + // } + // + // // Indices out of range (assertion) + // dirty_col[0].first = 9; + // try { + // sm.addRow(dirty_col.begin(), dirty_col.end()); + // Test("SparseMatrix dirty cols 3", true, false); + // } catch (std::exception&) { + // Test("SparseMatrix dirty cols 3", true, true); + // } + // + // // Passed in zero (assertion) + // dirty_col[0].second = 0; + // try { + // sm.addRow(dirty_col.begin(), dirty_col.end()); + // Test("SparseMatrix dirty cols 4", true, false); + // } catch (std::exception&) { + // Test("SparseMatrix dirty cols 4", true, true); + // } + // } + // */ + //#endif + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_resize() + // { + // SparseMat sm; + // DenseMat dense; + // + // sm.resize(3,3); dense.resize(3,3); + // ITER_2(3,3) { + // sm.setNonZero(i,j,Real(i*3+j+1)); + // dense.at(i,j) = Real(i*3+j+1); + // } + // Compare(dense, sm, "SparseMatrix::resize() 1"); + // + // sm.resize(1,1); + // dense.resize(1,1); + // Compare(dense, sm, "SparseMatrix::resize() 2"); + // + // sm.resize(3,3); + // dense.resize(3,3); + // Compare(dense, sm, "SparseMatrix::resize() 3"); + // + // sm.resize(3,4); + // dense.resize(3,4); + // ITER_1(3) { + // sm.setNonZero(i,3,1); + // dense.at(i,3) = 1; + // } + // Compare(dense, sm, "SparseMatrix::resize() 4"); + // + // sm.resize(4,4); + // dense.resize(4,4); + // ITER_1(4) { + // sm.setNonZero(3,i,2); + // dense.at(3,i) = 2; + // } + // Compare(dense, sm, "SparseMatrix::resize() 5"); + // + // sm.resize(5,5); + // dense.resize(5,5); + // ITER_1(5) { + // sm.setNonZero(4,i,3); + // sm.setNonZero(i,4,4); + // dense.at(4,i) = 3; + // dense.at(i,4) = 4; + // } + // Compare(dense, sm, "SparseMatrix::resize() 6"); + // + // sm.resize(7,5); + // dense.resize(7,5); + // ITER_1(5) { + // sm.setNonZero(6,i,5); + // dense.at(6,i) = 5; + // } + // Compare(dense, sm, "SparseMatrix::resize() 7"); + // + // sm.resize(7, 7); + // dense.resize(7,7); + // ITER_1(7) { + // sm.setNonZero(i,6,6); + // dense.at(i,6) = 6; + // } + // Compare(dense, sm, "SparseMatrix::resize() 8"); + // + // // Stress test to see the interaction with deleteRows and deleteCols + // for (UInt i = 0; i < 20; ++i) { + // sm.resize(rng_->getUInt32(256), rng_->getUInt32(256)); + // vector del_r; + // for (UInt ii = 0; ii < sm.nRows()/4; ++ii) + // del_r.push_back(2*ii); + // sm.deleteRows(del_r.begin(), del_r.end()); + // vector del_c; + // for (UInt ii = 0; ii < sm.nCols()/4; ++ii) + // del_c.push_back(2*ii); + // sm.deleteCols(del_c.begin(), del_c.end()); + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_deleteRows() + // { + // { // Empty matrix + // UInt nrows = 3, ncols = 3; + // + // { // Empty matrix, empty del + // SparseMat sm; + // vector del; + // sm.deleteRows(del.begin(), del.end()); + // Test("SparseMatrix::deleteRows() 1", sm.nRows(), UInt(0)); + // } + // + // { // Empty matrix, empty del + // SparseMat sm(0,0); + // vector del; + // sm.deleteRows(del.begin(), del.end()); + // Test("SparseMatrix::deleteRows() 2", sm.nRows(), UInt(0)); + // } + // + // { // Empty matrix, empty del + // SparseMat sm(nrows, ncols); + // vector del; + // sm.deleteRows(del.begin(), del.end()); + // Test("SparseMatrix::deleteRows() 3", sm.nRows(), UInt(nrows)); + // } + // + // { // Empty matrix, 1 del + // SparseMat sm(nrows, ncols); + // vector del(1); del[0] = 0; + // sm.deleteRows(del.begin(), del.end()); + // Test("SparseMatrix::deleteRows() 4", sm.nRows(), UInt(2)); + // } + // + // { // Empty matrix, many dels + // SparseMat sm(nrows, ncols); + // vector del(2); del[0] = 0; del[1] = 2; + // sm.deleteRows(del.begin(), del.end()); + // Test("SparseMatrix::deleteRows() 5", sm.nRows(), UInt(1)); + // } + // } // End empty matrix + // + // { // matrix with only 1 row + // { // 1 row, 1 del + // SparseMat sm(0, 3); + // vector del(1); del[0] = 0; + // std::vector v(3); v[0] = 1.5; v[1] = 2.5; v[2] = 3.5; + // + // sm.addRow(v.begin()); + // sm.deleteRows(del.begin(), del.end()); + // Test("SparseMatrix::deleteRows() 1 row A", sm.nRows(), UInt(0)); + // + // // Test that it is harmless to delete an empty matrix + // sm.deleteRows(del.begin(), del.end()); + // Test("SparseMatrix::deleteRows() 1 row B", sm.nRows(), UInt(0)); + // + // sm.addRow(v.begin()); + // sm.deleteRows(del.begin(), del.end()); + // Test("SparseMatrix::deleteRows() 1 row C", sm.nRows(), UInt(0)); + // + // // Again, test that it is harmless to delete an empty matrix + // sm.deleteRows(del.begin(), del.end()); + // Test("SparseMatrix::deleteRows() 1 row D", sm.nRows(), UInt(0)); + // } + // + // { // PLG-68: was failing when adding again because + // // deleteRows was not updating nrows_max_ properly + // SparseMatrix tam; + // vector x(4), del(1, 0); + // x[0] = .5; x[1] = .75; x[2] = 1.0; x[3] = 1.25; + // + // tam.resize(1, 4); + // tam.elementRowApply(0, std::plus(), x.begin()); + // tam.deleteRows(del.begin(), del.end()); + // + // tam.resize(1, 4); + // tam.elementRowApply(0, std::plus(), x.begin()); + // } + // } + // + // { + // UInt nrows, ncols, zr; + // + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // + // { // Empty del + // SparseMat sm(nrows, ncols, dense.begin()); + // vector del; + // sm.deleteRows(del.begin(), del.end()); + // Test("SparseMatrix::deleteRows() 6A", sm.nRows(), nrows); + // } + // + // { // Rows of all zeros 1 + // if (nrows > 2) { + // DenseMat dense2(nrows, ncols, zr); + // ITER_1(nrows) { + // if (i % 2 == 0) { + // for (UInt j = 0; j < ncols; ++j) + // dense2.at(i,j) = 0; + // } + // } + // SparseMat sm(nrows, ncols, dense2.begin()); + // vector del; + // for (UInt i = 2; i < nrows-2; i += 2) + // del.push_back(i); + // sm.deleteRows(del.begin(), del.end()); + // dense2.deleteRows(del.begin(), del.end()); + // Compare(dense2, sm, "SparseMatrix::deleteRows() 6B"); + // } + // } + // + // { // Rows of all zeros 2 + // if (nrows > 2) { + // DenseMat dense2(nrows, ncols, zr); + // ITER_1(nrows) { + // if (i % 2 == 0) { + // for (UInt j = 0; j < ncols; ++j) + // dense2.at(i,j) = 0; + // } + // } + // SparseMat sm(nrows, ncols, dense2.begin()); + // vector del; + // for (UInt i = 1; i < nrows-2; i += 2) + // del.push_back(i); + // sm.deleteRows(del.begin(), del.end()); + // dense2.deleteRows(del.begin(), del.end()); + // Compare(dense2, sm, "SparseMatrix::deleteRows() 6C"); + // } + // } + // + // { // Many dels contiguous + // if (nrows > 2) { + // SparseMat sm(nrows, ncols, dense.begin()); + // DenseMat dense2(nrows, ncols, zr); + // vector del; + // for (UInt i = 2; i < nrows-2; ++i) + // del.push_back(i); + // sm.deleteRows(del.begin(), del.end()); + // dense2.deleteRows(del.begin(), del.end()); + // Compare(dense2, sm, "SparseMatrix::deleteRows() 6D"); + // } + // } + // + // { // Make sure we stop at the end of the dels! + // if (nrows > 2) { + // SparseMat sm(nrows, ncols, dense.begin()); + // DenseMat dense2(nrows, ncols, zr); + // UInt* del = new UInt[nrows-1]; + // for (UInt i = 0; i < nrows-1; ++i) + // del[i] = i + 1; + // sm.deleteRows(del, del + nrows-2); + // dense2.deleteRows(del, del + nrows-2); + // Compare(dense2, sm, "SparseMatrix::deleteRows() 6E"); + // delete [] del; + // } + // } + // + // { // Many dels discontiguous + // SparseMat sm(nrows, ncols, dense.begin()); + // DenseMat dense2(nrows, ncols, zr); + // vector del; + // for (UInt i = 0; i < nrows; i += 2) + // del.push_back(i); + // sm.deleteRows(del.begin(), del.end()); + // dense2.deleteRows(del.begin(), del.end()); + // Compare(dense2, sm, "SparseMatrix::deleteRows() 7"); + // } + // + // { // All rows + // SparseMat sm(nrows, ncols, dense.begin()); + // vector del; + // for (UInt i = 0; i < nrows; ++i) + // del.push_back(i); + // sm.deleteRows(del.begin(), del.end()); + // Test("SparseMatrix::deleteRows() 8", sm.nRows(), UInt(0)); + // } + // + // /* + // { // More than all rows => exception in assert mode + // SparseMat sm(nrows, ncols, dense.begin()); + // vector del; + // for (UInt i = 0; i < 2*nrows; ++i) + // del.push_back(i); + // sm.deleteRows(del.begin(), del.end()); + // Test("SparseMatrix::deleteRows() 9", sm.nRows(), UInt(0)); + // } + // */ + // + // { // Several dels in a row till empty + // SparseMat sm(nrows, ncols, dense.begin()); + // for (UInt i = 0; i < nrows; ++i) { + // vector del(1); del[0] = 0; + // sm.deleteRows(del.begin(), del.end()); + // Test("SparseMatrix::deleteRows() 10", sm.nRows(), + // UInt(nrows-i-1)); + // } + // } + // + // { // deleteRows and re-resize it + // SparseMat sm(nrows, ncols, dense.begin()); + // vector del(1); del[0] = nrows-1; + // sm.deleteRows(del.begin(), del.end()); + // sm.resize(nrows, ncols); + // Test("SparseMatrix::deleteRows() 11", sm.nRows(), UInt(nrows)); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_deleteCols() + // { + // { // Empty matrix + // UInt nrows = 3, ncols = 3; + // + // { // Empty matrix, empty del + // SparseMat sm(nrows, ncols); + // vector del; + // sm.deleteCols(del.begin(), del.end()); + // Test("SparseMatrix::deleteCols() 1", sm.nCols(), UInt(3)); + // } + // + // { // Empty matrix, 1 del + // SparseMat sm(nrows, ncols); + // vector del(1); del[0] = 0; + // sm.deleteCols(del.begin(), del.end()); + // Test("SparseMatrix::deleteCols() 2", sm.nCols(), UInt(2)); + // } + // + // { // Empty matrix, many dels + // SparseMat sm(nrows, ncols); + // vector del(2); del[0] = 0; del[1] = 2; + // sm.deleteCols(del.begin(), del.end()); + // Test("SparseMatrix::deleteCols() 3", sm.nCols(), UInt(1)); + // } + // } // End empty matrix + // + // { // For visual inspection + // UInt nrows = 3, ncols = 5; + // DenseMat dense(nrows, ncols, 2); + // SparseMat sm(nrows, ncols, dense.begin()); + // //cout << sm << endl; + // vector del; del.push_back(0); + // sm.deleteCols(del.begin(), del.end()); + // //cout << sm << endl; + // sm.deleteCols(del.begin(), del.end()); + // //cout << sm << endl; + // } + // + // { // deleteCols on matrix of all-zeros + // SparseMat sm(7, 3); + // vector row(3, 0); + // for (UInt i = 0; i < 7; ++i) + // sm.addRow(row.begin()); + // //cout << sm << endl << endl; + // vector del(1, 0); + // sm.deleteCols(del.begin(), del.end()); + // //cout << sm << endl; + // } + // + // { + // UInt nrows, ncols, zr; + // + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // + // { // Empty del + // SparseMat sm(nrows, ncols, dense.begin()); + // vector del; + // sm.deleteCols(del.begin(), del.end()); + // Test("SparseMatrix::deleteCols() 4", sm.nCols(), ncols); + // } + // + // { // Many dels contiguous + // SparseMat sm(nrows, ncols, dense.begin()); + // DenseMat dense2(nrows, ncols, zr); + // vector del; + // if (ncols > 2) { + // for (UInt i = 2; i < ncols-2; ++i) + // del.push_back(i); + // sm.deleteCols(del.begin(), del.end()); + // dense2.deleteCols(del.begin(), del.end()); + // Compare(dense2, sm, "SparseMatrix::deleteCols() 6"); + // } + // } + // + // { // Many dels discontiguous + // SparseMat sm(nrows, ncols, dense.begin()); + // DenseMat dense2(nrows, ncols, zr); + // vector del; + // for (UInt i = 0; i < ncols; i += 2) + // del.push_back(i); + // sm.deleteCols(del.begin(), del.end()); + // dense2.deleteCols(del.begin(), del.end()); + // Compare(dense2, sm, "SparseMatrix::deleteCols() 7"); + // } + // + // { // All rows + // SparseMat sm(nrows, ncols, dense.begin()); + // vector del; + // for (UInt i = 0; i < ncols; ++i) + // del.push_back(i); + // sm.deleteCols(del.begin(), del.end()); + // Test("SparseMatrix::deleteCols() 8", sm.nCols(), UInt(0)); + // } + // + // { // More than all rows => exception in assert mode + // /* + // SparseMat sm(nrows, ncols, dense.begin()); + // vector del; + // for (UInt i = 0; i < 2*ncols; ++i) + // del.push_back(i); + // sm.deleteCols(del.begin(), del.end()); + // Test("SparseMatrix::deleteCols() 9", sm.nCols(), UInt(0)); + // */ + // } + // + // { // Several dels in a row till empty + // SparseMat sm(nrows, ncols, dense.begin()); + // for (UInt i = 0; i < ncols; ++i) { + // vector del(1); del[0] = 0; + // sm.deleteCols(del.begin(), del.end()); + // Test("SparseMatrix::deleteCols() 10", sm.nCols(), + // UInt(ncols-i-1)); + // } + // } + // + // { // deleteCols and re-resize it + // SparseMat sm(nrows, ncols, dense.begin()); + // vector del(1); del[0] = ncols-1; + // sm.deleteCols(del.begin(), del.end()); + // sm.resize(nrows, ncols); + // Test("SparseMatrix::deleteCols() 11", sm.nCols(), UInt(ncols)); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_set() + // { + // UInt nrows, ncols, nnzr; + // + // if (0) { // Visual tests + // + // // setZero + // nrows = 5; ncols = 7; nnzr = 3; + // DenseMat dense(nrows, ncols, nnzr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // cout << "Initial matrix" << endl; + // cout << sparse << endl; + // + // cout << endl << "Setting all elements to zero one by one" << endl; + // ITER_2(nrows, ncols) + // sparse.setZero(i, j); + // cout << "After:" << endl << sparse << endl; + // + // // setNonZero + // cout << endl << "Setting all elements one by one to:" << endl; + // cout << dense << endl; + // ITER_2(nrows, ncols) { + // sparse.setNonZero(i, j, dense.at(i,j)+1); + // dense.at(i,j) = dense.at(i,j) + 1; + // } + // cout << "After:" << endl << sparse << endl; + // + // // set + // cout << endl << "Setting all elements" << endl; + // ITER_2(nrows, ncols) { + // Real val = (Real) ((i+j) % 5); + // sparse.set(i, j, val); + // dense.at(i,j) = val; + // } + // cout << "After:" << endl << sparse << endl; + // cout << "Should be:" << endl << dense << endl; + // + // } // End visual tests + // + // // Automated tests for set(i,j,val), which exercises both + // // setNonZero and setToZero + // for (nrows = 1; nrows < 64; nrows += 3) + // for (ncols = 1; ncols < 64; ncols += 3) + // { + // SparseMat sm(nrows, ncols); + // DenseMat dense(nrows, ncols); + // + // ITER_2(nrows, ncols) { + // Real val = Real((i*ncols+j+1)%5); + // sm.set(i, j, val); + // dense.at(i, j) = val; + // } + // bool correct = true; + // ITER_2(nrows, ncols) { + // Real val = Real((i*ncols+j+1)%5); + // if (sm.get(i, j) != val) + // correct = false; + // } + // Test("SparseMatrix set/get 1", correct, true); + // + // ITER_1(nrows) { + // dense.at(i, 0) = Real(i+1); + // sm.set(i, 0, Real(i+1)); + // } + // Compare(dense, sm, "SparseMatrix set/get 2"); + // + // ITER_1(ncols) { + // dense.at(0, i) = Real(i+1); + // sm.set(0, i, Real(i+1)); + // } + // Compare(dense, sm, "SparseMatrix set/get 3"); + // + // sm.set(nrows-1, ncols-1, 1); + // dense.at(nrows-1, ncols-1) = 1; + // Compare(dense, sm, "SparseMatrix set/get 4"); + // sm.set(nrows-1, ncols-1, 2); + // dense.at(nrows-1, ncols-1) = 2; + // Compare(dense, sm, "SparseMatrix set/get 5"); + // + // for (UInt k = 0; k != 20; ++k) { + // UInt i = rng_->getUInt32(nrows); + // UInt j = rng_->getUInt32(ncols); + // Real val = Real(1+rng_->getUInt32()); + // sm.set(i, j, Real(val)); + // Test("SparseMatrix set/get 7", sm.get(i, j), val); + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_setRowColToZero() + // { + // UInt nrows, ncols, zr; + // + // if (0) { // Visual tests + // + // // setRowToZero + // nrows = 5; ncols = 7; zr = 3; + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // cout << "Initial matrix" << endl; + // cout << sparse << endl; + // + // cout << endl << "Setting all rows to zero" << endl; + // for (UInt i = 0; i != nrows; ++i) { + // cout << "isRowZero(" << i << ")= " + // << (sparse.isRowZero(i) ? "YES" : "NO") + // << endl; + // sparse.setRowToZero(i); + // cout << "Zeroing row " << i << ":" << endl + // << sparse << endl; + // cout << "isRowZero(" << i << ")= " + // << (sparse.isRowZero(i) ? "YES" : "NO") + // << endl; + // cout << endl; + // } + // + // // setColToZero + // cout << endl << "Setting all columns to zero - 1" << endl; + // ITER_2(nrows, ncols) + // sparse.set(i, j, dense.at(i,j)); + // cout << "Initially: " << endl << sparse << endl; + // for (UInt j = 0; j != ncols; ++j) { + // cout << "isColZero(" << j << ")= " + // << (sparse.isColZero(j) ? "YES" : "NO") + // << endl; + // sparse.setColToZero(j); + // cout << "Zeroing column " << j << ":" << endl + // << sparse << endl; + // cout << "isColZero(" << j << ")= " + // << (sparse.isColZero(j) ? "YES" : "NO") + // << endl; + // cout << endl; + // } + // + // // Again, with a dense matrix, so we can see what happens + // // to the first and last columns + // cout << endl << "Setting all columns to zero - 2" << endl; + // ITER_2(nrows, ncols) + // sparse.set(i,j,(Real)(i+j)); + // cout << "Initially: " << endl << sparse << endl; + // for (UInt j = 0; j != ncols; ++j) { + // cout << "isColZero(" << j << ")= " + // << (sparse.isColZero(j) ? "YES" : "NO") + // << endl; + // sparse.setColToZero(j); + // cout << "Zeroing column " << j << ":" << endl + // << sparse << endl; + // cout << "isColZero(" << j << ")= " + // << (sparse.isColZero(j) ? "YES" : "NO") + // << endl; + // cout << endl; + // } + // } // End visual tests + // + // // Automated tests + // for (nrows = 0; nrows < 16; nrows += 3) + // for (ncols = 0; ncols < 16; ncols += 3) + // for (zr = 0; zr < 16; zr += 3) + // { + // { // compact - remove rows + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // for (UInt i = 0; i != nrows; ++i) { + // sparse.setRowToZero(i); + // dense.setRowToZero(i); + // Compare(dense, sparse, "SparseMatrix setRowToZero 1"); + // } + // } + // + // { // decompact - remove rows + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // sparse.decompact(); + // + // for (UInt i = 0; i != nrows; ++i) { + // sparse.setRowToZero(i); + // dense.setRowToZero(i); + // Compare(dense, sparse, "SparseMatrix setRowToZero 2"); + // } + // } + // + // { // compact - remove columns + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // for (UInt j = 0; j != ncols; ++j) { + // sparse.setColToZero(j); + // dense.setColToZero(j); + // Compare(dense, sparse, "SparseMatrix setColToZero 1"); + // } + // } + // + // { // decompact - remove columns + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // for (UInt j = 0; j != ncols; ++j) { + // sparse.setColToZero(j); + // dense.setColToZero(j); + // Compare(dense, sparse, "SparseMatrix setColToZero 2"); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_vecMaxProd() + // { + // UInt ncols, nrows, zr, i; + // ncols = 5; + // nrows = 7; + // zr = 2; + // + // DenseMat dense(nrows, ncols, zr); + // + // std::vector x(ncols), y(nrows, 0), yref(nrows, 0); + // for (i = 0; i < ncols; ++i) + // x[i] = Real(i); + // + // dense.vecMaxProd(x.begin(), yref.begin()); + // + // SparseMat smnc(nrows, ncols, dense.begin()); + // smnc.decompact(); + // smnc.vecMaxProd(x.begin(), y.begin()); + // CompareVectors(nrows, y.begin(), yref.begin(), "vecMaxProd non compact + // 1"); + // + // smnc.compact(); + // std::fill(y.begin(), y.end(), Real(0)); + // smnc.vecMaxProd(x.begin(), y.begin()); + // CompareVectors(nrows, y.begin(), yref.begin(), "vecMaxProd compact 1"); + // + // SparseMat smc(nrows, ncols, dense.begin()); + // std::fill(y.begin(), y.end(), Real(0)); + // smc.vecMaxProd(x.begin(), y.begin()); + // CompareVectors(nrows, y.begin(), yref.begin(), "vecMaxProd compact 2"); + // + // { + // TEST_LOOP(M) { + // + // Dense dense2(nrows, ncols, zr); + // SparseMatrix sm2(nrows, ncols, dense2.begin()); + // + // std::vector x2(ncols, 0), yref2(nrows, 0), y2(nrows, 0); + // for (i = 0; i < ncols; ++i) + // x2[i] = Real(i); + // + // sm2.decompact(); + // dense2.vecMaxProd(x2.begin(), yref2.begin()); + // sm2.vecMaxProd(x2.begin(), y2.begin()); + // { + // std::stringstream str; + // str << "vecMaxProd A " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // CompareVectors(nrows, y2.begin(), yref2.begin(), + // str.str().c_str()); + // } + // + // sm2.compact(); + // std::fill(y2.begin(), y2.end(), Real(0)); + // sm2.vecMaxProd(x2.begin(), y2.begin()); + // { + // std::stringstream str; + // str << "vecMaxProd B " << nrows << "X" << ncols << "/" << zr + // << " - compact"; + // CompareVectors(nrows, y2.begin(), yref2.begin(), + // str.str().c_str()); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_vecProd() + // { + // UInt ncols = 5, nrows = 7, zr = 2; + // + // DenseMat dense(nrows, ncols, zr); + // + // std::vector x(ncols), y(nrows, 0), yref(nrows, 0); + // for (UInt i = 0; i < ncols; ++i) + // x[i] = Real(i); + // + // dense.rightVecProd(x.begin(), yref.begin()); + // + // SparseMat smnc(nrows, ncols, dense.begin()); + // smnc.decompact(); + // smnc.rightVecProd(x.begin(), y.begin()); + // CompareVectors(nrows, y.begin(), yref.begin(), "rightVecProd non compact + // 1"); + // + // smnc.compact(); + // std::fill(y.begin(), y.end(), Real(0)); + // smnc.rightVecProd(x.begin(), y.begin()); + // CompareVectors(nrows, y.begin(), yref.begin(), "rightVecProd compact + // 1"); + // + // SparseMat smc(nrows, ncols, dense.begin()); + // std::fill(y.begin(), y.end(), Real(0)); + // smc.rightVecProd(x.begin(), y.begin()); + // CompareVectors(nrows, y.begin(), yref.begin(), "rightVecProd compact + // 2"); + // + // { + // TEST_LOOP(M) { + // + // Dense dense2(nrows, ncols, zr); + // SparseMatrix sm2(nrows, ncols, dense2.begin()); + // + // std::vector x2(ncols, 0), yref2(nrows, 0), y2(nrows, 0); + // for (UInt i = 0; i < ncols; ++i) + // x2[i] = Real(i); + // + // sm2.decompact(); + // dense2.rightVecProd(x2.begin(), yref2.begin()); + // sm2.rightVecProd(x2.begin(), y2.begin()); + // { + // std::stringstream str; + // str << "rightVecProd A " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // CompareVectors(nrows, y2.begin(), yref2.begin(), + // str.str().c_str()); + // } + // + // sm2.compact(); + // std::fill(y2.begin(), y2.end(), Real(0)); + // sm2.rightVecProd(x2.begin(), y2.begin()); + // { + // std::stringstream str; + // str << "rightVecProd B " << nrows << "X" << ncols << "/" << zr + // << " - compact"; + // CompareVectors(nrows, y2.begin(), yref2.begin(), + // str.str().c_str()); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_axby() + // { + // UInt ncols, nrows, zr, i; + // ncols = 5; + // nrows = 7; + // zr = 2; + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sm4c(nrows, ncols, dense.begin()); + // + // std::vector x(ncols, 0); + // for (i = 0; i < ncols; ++i) + // x[i] = Real(20*i + 1); + // + // { // compact, b = 0 + // dense.axby(3, .5, 0, x.begin()); + // sm4c.axby(3, .5, 0, x.begin()); + // Compare(dense, sm4c, "axby, b = 0"); + // } + // + // { // compact, a = 0, with reallocation + // dense.axby(2, 0, .5, x.begin()); + // sm4c.axby(2, 0, .5, x.begin()); + // Compare(dense, sm4c, "axby, a = 0 /1"); + // } + // + // { // compact, a = 0, without reallocation + // dense.axby(3, 0, .5, x.begin()); + // sm4c.axby(3, 0, .5, x.begin()); + // Compare(dense, sm4c, "axby, a = 0 /2"); + // } + // + // { // compact, a != 0, b != 0, without reallocation + // dense.axby(3, .5, .5, x.begin()); + // sm4c.axby(3, .5, .5, x.begin()); + // Compare(dense, sm4c, "axby, a, b != 0 /1"); + // } + // + // { // compact, a != 0, b != 0, with reallocation + // dense.axby(4, .5, .5, x.begin()); + // sm4c.axby(4, .5, .5, x.begin()); + // Compare(dense, sm4c, "axby, a, b != 0 /2"); + // } + // + // { + // TEST_LOOP(M) { + // + // Dense dense2(nrows, ncols, zr); + // SparseMatrix sm2(nrows, ncols, dense2.begin()); + // + // std::vector x2(ncols, 0), yref2(nrows, 0), y2(nrows, 0); + // for (i = 0; i < ncols; ++i) + // x2[i] = Real(i); + // + // for (i = 0; i < nrows; i += 5) { + // + // dense2.axby(i, (Real).6, (Real).4, x2.begin()); + // sm2.axby(i, (Real).6, (Real).4, x2.begin()); + // { + // std::stringstream str; + // str << "axby " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // Compare(dense2, sm2, str.str().c_str()); + // } + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_axby_3() + // { + // UInt ncols, nrows, zr, i; + // ncols = 5; + // nrows = 7; + // zr = 2; + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sm4c(nrows, ncols, dense.begin()); + // + // std::vector x(ncols, 0); + // for (i = 0; i < ncols; ++i) + // x[i] = i % 2 == 0 ? Real(20*i + 1) : Real(0); + // + // { // compact, b = 0 + // dense.axby(.5, 0, x.begin()); + // sm4c.axby(.5, 0, x.begin()); + // Compare(dense, sm4c, "axby, b = 0"); + // } + // + // { // compact, a = 0, with reallocation + // dense.axby(0, .5, x.begin()); + // sm4c.axby(0, .5, x.begin()); + // Compare(dense, sm4c, "axby, a = 0 /1"); + // } + // + // { // compact, a = 0, without reallocation + // dense.axby(0, .5, x.begin()); + // sm4c.axby(0, .5, x.begin()); + // Compare(dense, sm4c, "axby, a = 0 /2"); + // } + // + // { // compact, a != 0, b != 0, without reallocation + // dense.axby(.5, .5, x.begin()); + // sm4c.axby(.5, .5, x.begin()); + // Compare(dense, sm4c, "axby, a, b != 0 /1"); + // } + // + // { // compact, a != 0, b != 0, with reallocation + // dense.axby(.5, .5, x.begin()); + // sm4c.axby(.5, .5, x.begin()); + // Compare(dense, sm4c, "axby, a, b != 0 /2"); + // } + // + // { + // TEST_LOOP(M) { + // + // Dense dense2(nrows, ncols, zr); + // SparseMatrix sm2(nrows, ncols, dense2.begin()); + // + // std::vector x2(ncols, 0), yref2(nrows, 0), y2(nrows, 0); + // for (i = 0; i < ncols; ++i) + // x2[i] = i % 2 == 0 ? Real(i) : Real(0); + // + // dense2.axby((Real).6, (Real).4, x2.begin()); + // sm2.axby((Real).6, (Real).4, x2.begin()); + // { + // std::stringstream str; + // str << "axby " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // Compare(dense2, sm2, str.str().c_str()); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_rowMax() + // { + // UInt ncols, nrows, zr, i; + // + // { + // TEST_LOOP(M) { + // + // DenseMat dense2(nrows, ncols, zr); + // SparseMat sm2(nrows, ncols, dense2.begin()); + // + // std::vector x2(ncols, 0), yref2(nrows, 0), y2(nrows, 0); + // for (i = 0; i < ncols; ++i) + // x2[i] = Real(i); + // + // sm2.decompact(); + // dense2.threshold(Real(1./nrows)); + // dense2.xMaxAtNonZero(x2.begin(), y2.begin()); + // sm2.threshold(Real(1./nrows)); + // sm2.vecMaxAtNZ(x2.begin(), yref2.begin()); + // + // { + // std::stringstream str; + // str << "xMaxAtNonZero A " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // CompareVectors(nrows, y2.begin(), yref2.begin(), + // str.str().c_str()); + // } + // + // sm2.compact(); + // dense2.xMaxAtNonZero(x2.begin(), y2.begin()); + // sm2.vecMaxAtNZ(x2.begin(), yref2.begin()); + // { + // std::stringstream str; + // str << "xMaxAtNonZero B " << nrows << "X" << ncols << "/" << zr + // << " - compact"; + // CompareVectors(nrows, y2.begin(), yref2.begin(), + // str.str().c_str()); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_maxima() + // { + // UInt ncols, nrows, zr; + // + // { + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // std::vector > + // rowMaxDense(nrows), rowMaxSparse(nrows), + // colMaxDense(ncols), colMaxSparse(ncols); + // + // dense.rowMax(rowMaxDense.begin()); + // dense.colMax(colMaxDense.begin()); + // sparse.rowMax(rowMaxSparse.begin()); + // sparse.colMax(colMaxSparse.begin()); + // + // { + // std::stringstream str; + // str << "rowMax " << nrows << "X" << ncols << "/" << zr; + // Compare(rowMaxDense, rowMaxSparse, str.str().c_str()); + // } + // + // { + // std::stringstream str; + // str << "colMax " << nrows << "X" << ncols << "/" << zr; + // Compare(colMaxDense, colMaxSparse, str.str().c_str()); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_normalize() + // { + // UInt nrows = 7, ncols = 5, zr = 2; + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // if (0) { // Visual tests + // + // cout << "Before normalizing rows: " << endl; + // cout << sparse << endl; + // dense.normalizeRows(); + // sparse.normalizeRows(); + // cout << "After normalizing rows: " << endl; + // cout << "Sparse: " << endl << sparse << endl; + // cout << "Dense: " << endl << dense << endl; + // + // cout << "Before normalizing columns: " << endl; + // cout << sparse << endl; + // dense.normalizeCols(); + // sparse.normalizeCols(); + // cout << "After normalizing columns: " << endl; + // cout << "Sparse: " << endl << sparse << endl; + // cout << "Dense: " << endl << dense << endl; + // + // } + // + // if (1) // Automated tests + // { + // TEST_LOOP(M) { + // + // Dense dense2(nrows, ncols, zr); + // SparseMatrix sm2(nrows, ncols, dense2.begin()); + // + // dense2.threshold(double(1./nrows)); + // dense2.normalizeRows(true); + // sm2.decompact(); + // sm2.threshold(double(1./nrows)); + // sm2.normalizeRows(true); + // + // { + // std::stringstream str; + // str << "normalizeRows A " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // Compare(dense2, sm2, str.str().c_str()); + // } + // + // dense2.normalizeRows(true); + // sm2.compact(); + // sm2.normalizeRows(true); + // + // { + // std::stringstream str; + // str << "normalizeRows B " << nrows << "X" << ncols << "/" << zr + // << " - compact"; + // Compare(dense2, sm2, str.str().c_str()); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_rowProd() + // { + // UInt ncols, nrows, zr, i; + // + // { + // TEST_LOOP(M) { + // + // Dense dense2(nrows, ncols, zr); + // SparseMatrix sm2(nrows, ncols, dense2.begin()); + // + // std::vector x2(ncols, 0), yref2(nrows, 0), y2(nrows, 0); + // for (i = 0; i < ncols; ++i) + // x2[i] = double(i)/double(ncols); + // + // sm2.decompact(); + // dense2.threshold(1./double(nrows)); + // dense2.rowProd(x2.begin(), y2.begin()); + // sm2.threshold(1./double(nrows)); + // sm2.rightVecProdAtNZ(x2.begin(), yref2.begin()); + // + // { + // std::stringstream str; + // str << "rowProd A " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // CompareVectors(nrows, y2.begin(), yref2.begin(), + // str.str().c_str()); + // } + // + // sm2.compact(); + // dense2.rowProd(x2.begin(), y2.begin()); + // sm2.rightVecProdAtNZ(x2.begin(), yref2.begin()); + // { + // std::stringstream str; + // str << "rowProd B " << nrows << "X" << ncols << "/" << zr + // << " - compact"; + // CompareVectors(nrows, y2.begin(), yref2.begin(), + // str.str().c_str()); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_lerp() + // { + // UInt ncols, nrows, zr; + // nrows = 5; ncols = 7; zr = 4; + // + // { + // DenseMat dense(nrows, ncols, zr); + // DenseMat denseB(nrows, ncols, zr); + // for (UInt i = 0; i < nrows; ++i) + // for (UInt j = 0; j < ncols; ++j) + // denseB.at(i,j) += 2; + // + // SparseMat sm(nrows, ncols, dense.begin()); + // SparseMat smB(nrows, ncols, denseB.begin()); + // + // Real a, b; + // a = b = 1; + // + // dense.lerp(a, b, denseB); + // sm.lerp(a, b, smB); + // + // std::stringstream str; + // str << "lerp " << nrows << "X" << ncols << "/" << zr + // << " " << a << " " << b; + // Compare(dense, sm, str.str().c_str()); + // } + // + // { + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // DenseMat denseB(nrows, ncols, zr); + // for (UInt i = 0; i < nrows; ++i) + // for (UInt j = 0; j < ncols; ++j) + // denseB.at(i,j) += 2; + // + // SparseMat sm(nrows, ncols, dense.begin()); + // SparseMat smB(nrows, ncols, denseB.begin()); + // + // for (Real a = -2; a < 2; a += 1) { + // for (Real b = -2; b < 2; b += 1) { + // dense.lerp(a, b, denseB); + // sm.lerp(a, b, smB); + // std::stringstream str; + // str << "lerp " << nrows << "X" << ncols << "/" << zr + // << " " << a << " " << b; + // Compare(dense, sm, str.str().c_str()); + // } + // } + // } + // } + // + //#ifdef NTA_ASSERTIONS_ON + // nrows = 5; ncols = 7; zr = 4; + // // Exceptions + // { + // DenseMat dense(nrows, ncols, zr); + // DenseMat denseB(nrows+1, ncols, zr); + // SparseMat sm(nrows, ncols, dense.begin()); + // SparseMat smB(nrows+1, ncols, denseB.begin()); + // + // try { + // sm.lerp(1, 1, smB); + // Test("lerp exception 1", 0, 1); + // } catch (std::runtime_error&) { + // Test("lerp exception 1", 1, 1); + // } + // } + // + // { + // DenseMat dense(nrows, ncols, zr); + // DenseMat denseB(nrows, ncols+1, zr); + // SparseMat sm(nrows, ncols, dense.begin()); + // SparseMat smB(nrows, ncols+1, denseB.begin()); + // + // try { + // sm.lerp(1, 1, smB); + // Test("lerp exception 2", 0, 1); + // } catch (std::runtime_error&) { + // Test("lerp exception 2", 1, 1); + // } + // } + //#endif + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_small_values() + // { + // UInt nrows, ncols, zr; + // + // { + // nrows = 200; ncols = 100; zr = ncols - 64; + // DenseMat dense(nrows, ncols, zr, true, true, rng_); + // SparseMat sm(nrows, ncols, dense.begin()); + // DenseMat A(nrows, ncols); + // + // sm.toDense(A.begin()); + // sm.fromDense(nrows, ncols, A.begin()); + // Compare(dense, sm, "to/from Dense, small values"); + // } + // + // { + // nrows = 200; ncols = 100; zr = ncols - 64; + // DenseMat dense(nrows, ncols, zr, true, true, rng_); + // SparseMat sm(nrows, ncols, dense.begin()); + // std::stringstream str1; + // sm.toCSR(str1); + // sm.fromCSR(str1); + // Compare(dense, sm, "to/from CSR, small values"); + // } + // + // { + // nrows = 200; ncols = 100; zr = ncols - 64; + // DenseMat dense(nrows, ncols, zr, true, true, rng_); + // SparseMat sm(nrows, ncols, dense.begin()); + // sm.compact(); + // Compare(dense, sm, "compact, small values"); + // } + // + // { + // nrows = 200; ncols = 100; zr = ncols - 64; + // Dense dense(nrows, ncols, zr, true, true, rng_); + // SparseMatrix sm(nrows, ncols, dense.begin()); + // sm.threshold(4 * nupic::Epsilon); + // dense.threshold(4 * nupic::Epsilon); + // Compare(dense, sm, "threshold, small values 1"); + // sm.threshold(2 * nupic::Epsilon); + // dense.threshold(2 * nupic::Epsilon); + // Compare(dense, sm, "threshold, small values 2"); + // } + // + // { + // nrows = 200; ncols = 100; zr = ncols - 64; + // Dense dense(nrows, ncols, zr, true, true, rng_); + // SparseMatrix sm(nrows, ncols, dense.begin()); + // Compare(dense, sm, "addRow, small values"); + // } + // + // { + // nrows = 8; ncols = 4; zr = ncols - 2; + // Dense dense(nrows, ncols, zr, true, true, rng_); + // Dense dense2(ncols, nrows); + // SparseMatrix sm(nrows, ncols, dense.begin()); + // SparseMatrix sm2(ncols, nrows); + // dense.transpose(dense2); + // sm.transpose(sm2); + // Compare(dense2, sm2, "transpose, small values"); + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_accumulate() + // { + // UInt nrows = 7, ncols = 5, zr = 2; + // + // if (0) { // Visual tests + // + // Dense dense(nrows, ncols, zr); + // SparseMatrix sparse(nrows, ncols, dense.begin()); + // + // std::vector row_sums(nrows), col_sums(ncols); + // + // cout << sparse << endl; + // + // sparse.accumulateAllRowsNZ(row_sums.begin(), std::plus()); + // sparse.accumulateAllColsNZ(col_sums.begin(), std::plus()); + // + // cout << "Row sums = " << row_sums << endl; + // cout << "Col sums = " << col_sums << endl; + // } + // + // /* + // TEST_LOOP(M) { + // + // Dense denseA(nrows, ncols, zr); + // SparseMatrix smA(nrows, ncols, denseA.begin()); + // + // for (UInt r = 0; r < nrows; r += 5) { + // + // { + // double r1 = denseA.accumulate(r, multiplies(), 1); + // double r2 = smA.accumulateRowNZ(r, multiplies(), 1); + // std::stringstream str; + // str << "accumulateRowNZ * " << nrows << "X" << ncols << "/" << zr; + // Test(str.str().c_str(), r1, r2); + // } + // + // { + // double r1 = denseA.accumulate(r, multiplies(), 1); + // double r2 = smA.accumulate(r, multiplies(), 1); + // std::stringstream str; + // str << "accumulate * " << nrows << "X" << ncols << "/" << zr; + // Test(str.str().c_str(), r1, r2); + // } + // + // { + // double r1 = denseA.accumulate(r, plus()); + // double r2 = smA.accumulateRowNZ(r, plus()); + // std::stringstream str; + // str << "accumulateRowNZ + " << nrows << "X" << ncols << "/" << zr; + // Test(str.str().c_str(), r1, r2); + // } + // + // { + // double r1 = denseA.accumulate(r, plus()); + // double r2 = smA.accumulate(r, plus()); + // std::stringstream str; + // str << "accumulate + " << nrows << "X" << ncols << "/" << zr; + // Test(str.str().c_str(), r1, r2); + // } + // + // { + // double r1 = denseA.accumulate(r, nupic::Max, 0); + // double r2 = smA.accumulateRowNZ(r, nupic::Max, 0); + // std::stringstream str; + // str << "accumulateRowNZ max " << nrows << "X" << ncols << "/" << + // zr; Test(str.str().c_str(), r1, r2); + // } + // + // { + // double r1 = denseA.accumulate(r, nupic::Max, 0); + // double r2 = smA.accumulate(r, nupic::Max, 0); + // std::stringstream str; + // str << "accumulate max " << nrows << "X" << ncols << "/" << zr; + // Test(str.str().c_str(), r1, r2); + // } + // } + // } + // */ + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_multiply() + // { + // UInt nrows, ncols, zr, nrows1, ncols1, ncols2, zr1, zr2; + // + // if (0) { // Visual test, keep + // + // DenseMat dense(4, 5, 2); + // SparseMat sparse1(dense.nRows(), dense.nCols(), dense.begin()); + // SparseMat sparse2(sparse1); + // sparse2.transpose(); + // SparseMat sparse3(0,0); + // + // cout << sparse1 << endl << endl << sparse2 << endl << endl; + // sparse1.multiply(sparse2, sparse3); + // cout << sparse3 << endl; + // + // return; + // } + // + // TEST_LOOP(M) { + // + // nrows1 = nrows; ncols1 = ncols; zr1 = zr; + // ncols2 = 2*nrows+1; zr2 = zr1; + // + // Dense denseA(nrows1, ncols1, zr1); + // SparseMatrix smA(nrows1, ncols1, denseA.begin()); + // + // Dense denseB(ncols1, ncols2, zr2); + // SparseMatrix smB(ncols1, ncols2, denseB.begin()); + // + // Dense denseC(nrows1, ncols2, zr2); + // SparseMatrix smC(nrows1, ncols2, denseC.begin()); + // + // { + // denseC.clear(); + // denseA.multiply(denseB, denseC); + // smA.multiply(smB, smC); + // + // std::stringstream str; + // str << "multiply " << nrows << "X" << ncols << "/" << zr; + // Compare(denseC, smC, str.str().c_str()); + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_argMax() + // { + // UInt ncols, nrows, zr; + // UInt m_i_sparse, m_j_sparse, m_i_dense, m_j_dense; + // Real m_val_sparse, m_val_dense; + // + // { + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // dense.max(m_i_dense, m_j_dense, m_val_dense); + // + // sparse.decompact(); + // sparse.max(m_i_sparse, m_j_sparse, m_val_sparse); + // + // { + // std::stringstream str; + // str << "argMax A " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // if (m_i_sparse != m_i_dense + // || m_j_sparse != m_j_dense + // || !nupic::nearlyEqual(m_val_sparse, m_val_dense)) + // Test(str.str().c_str(), 0, 1); + // } + // + // sparse.compact(); + // sparse.max(m_i_sparse, m_j_sparse, m_val_sparse); + // + // { + // std::stringstream str; + // str << "argMax B " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // if (m_i_sparse != m_i_dense + // || m_j_sparse != m_j_dense + // || !nupic::nearlyEqual(m_val_sparse, m_val_dense)) + // Test(str.str().c_str(), 0, 1); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_argMin() + // { + // UInt ncols, nrows, zr; + // UInt m_i_sparse, m_j_sparse, m_i_dense, m_j_dense; + // Real m_val_sparse, m_val_dense; + // + // { + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // dense.min(m_i_dense, m_j_dense, m_val_dense); + // + // sparse.decompact(); + // sparse.min(m_i_sparse, m_j_sparse, m_val_sparse); + // + // { + // std::stringstream str; + // str << "argMin A " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // if (m_i_sparse != m_i_dense + // || m_j_sparse != m_j_dense + // || !nupic::nearlyEqual(m_val_sparse, m_val_dense)) { + // Test(str.str().c_str(), 0, 1); + // } + // } + // + // sparse.compact(); + // sparse.min(m_i_sparse, m_j_sparse, m_val_sparse); + // + // { + // std::stringstream str; + // str << "argMin B " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // if (m_i_sparse != m_i_dense + // || m_j_sparse != m_j_dense + // || !nupic::nearlyEqual(m_val_sparse, m_val_dense)) + // Test(str.str().c_str(), 0, 1); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_rowMax_2() + // { + // UInt ncols, nrows, zr; + // + // { + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // std::vector > optima_sparse(nrows), + // optima_dense(nrows); + // + // dense.rowMax(optima_dense.begin()); + // + // sparse.decompact(); + // + // for (UInt i = 0; i != nrows; ++i) { + // + // std::pair res_sparse; + // sparse.rowMax(i, res_sparse.first, res_sparse.second); + // + // std::stringstream str; + // str << "rowMax 2 A " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // if (optima_dense[i].first != res_sparse.first + // || !nearlyEqual(optima_dense[i].second, res_sparse.second)) + // Test(str.str().c_str(), 0, 1); + // } + // + // sparse.rowMax(optima_sparse.begin()); + // + // { + // std::stringstream str; + // str << "rowMax 2 B " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // for (UInt i = 0; i != nrows; ++i) { + // if (optima_dense[i].first != optima_sparse[i].first + // || !nupic::nearlyEqual(optima_dense[i].second, + // optima_sparse[i].second)) Test(str.str().c_str(), 0, 1); + // } + // } + // + // sparse.compact(); + // + // for (UInt i = 0; i != nrows; ++i) { + // + // std::pair res_sparse; + // sparse.rowMax(i, res_sparse.first, res_sparse.second); + // + // std::stringstream str; + // str << "rowMax 2 C " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // if (optima_dense[i].first != res_sparse.first + // || !nearlyEqual(optima_dense[i].second, res_sparse.second)) + // Test(str.str().c_str(), 0, 1); + // } + // + // sparse.rowMax(optima_sparse.begin()); + // + // { + // std::stringstream str; + // str << "rowMax 2 D " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // for (UInt i = 0; i != nrows; ++i) { + // if (optima_dense[i].first != optima_sparse[i].first + // || !nupic::nearlyEqual(optima_dense[i].second, + // optima_sparse[i].second)) Test(str.str().c_str(), 0, 1); + // } + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_rowMin() + // { + // UInt ncols, nrows, zr; + // + // { + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // std::vector > optima_sparse(nrows), + // optima_dense(nrows); + // + // dense.rowMin(optima_dense.begin()); + // + // sparse.decompact(); + // + // for (UInt i = 0; i != nrows; ++i) { + // + // std::pair res_sparse; + // sparse.rowMin(i, res_sparse.first, res_sparse.second); + // + // std::stringstream str; + // str << "rowMin A " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // if (optima_dense[i].first != res_sparse.first + // || !nearlyEqual(optima_dense[i].second, res_sparse.second)) + // Test(str.str().c_str(), 0, 1); + // } + // + // sparse.rowMin(optima_sparse.begin()); + // + // { + // std::stringstream str; + // str << "rowMin B " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // for (UInt i = 0; i != nrows; ++i) { + // if (optima_dense[i].first != optima_sparse[i].first + // || !nupic::nearlyEqual(optima_dense[i].second, + // optima_sparse[i].second)) Test(str.str().c_str(), 0, 1); + // } + // } + // + // sparse.compact(); + // + // for (UInt i = 0; i != nrows; ++i) { + // + // std::pair res_sparse; + // sparse.rowMin(i, res_sparse.first, res_sparse.second); + // + // std::stringstream str; + // str << "rowMin C " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // if (optima_dense[i].first != res_sparse.first + // || !nearlyEqual(optima_dense[i].second, res_sparse.second)) + // Test(str.str().c_str(), 0, 1); + // } + // + // sparse.rowMin(optima_sparse.begin()); + // + // { + // std::stringstream str; + // str << "rowMin D " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // for (UInt i = 0; i != nrows; ++i) { + // if (optima_dense[i].first != optima_sparse[i].first + // || !nupic::nearlyEqual(optima_dense[i].second, + // optima_sparse[i].second)) Test(str.str().c_str(), 0, 1); + // } + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_colMax() + // { + // UInt ncols = 7, nrows = 9, zr = 3; + // + // if (0) { + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // cout << sparse << endl; + // for (UInt j = 0; j != ncols; ++j) { + // UInt col_max_i; + // Real col_max; + // sparse.colMax(j, col_max_i, col_max); + // cout << j << " " << col_max_i << " " << col_max << endl; + // } + // } + // + // { + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // std::vector > optima_sparse(ncols), + // optima_dense(ncols); + // + // dense.colMax(optima_dense.begin()); + // + // sparse.decompact(); + // + // for (UInt j = 0; j != ncols; ++j) { + // + // std::pair res_sparse; + // sparse.colMax(j, res_sparse.first, res_sparse.second); + // + // std::stringstream str; + // str << "colMax A " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // if (optima_dense[j].first != res_sparse.first + // || !nearlyEqual(optima_dense[j].second, res_sparse.second)) + // Test(str.str().c_str(), 0, 1); + // } + // + // sparse.colMax(optima_sparse.begin()); + // + // { + // std::stringstream str; + // str << "colMax B " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // for (UInt j = 0; j != ncols; ++j) { + // if (optima_dense[j].first != optima_sparse[j].first + // || !nupic::nearlyEqual(optima_dense[j].second, + // optima_sparse[j].second)) Test(str.str().c_str(), 0, 1); + // } + // } + // + // sparse.compact(); + // + // for (UInt i = 0; i != ncols; ++i) { + // + // std::pair res_sparse; + // sparse.colMax(i, res_sparse.first, res_sparse.second); + // + // std::stringstream str; + // str << "colMax C " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // if (optima_dense[i].first != res_sparse.first + // || !nearlyEqual(optima_dense[i].second, res_sparse.second)) + // Test(str.str().c_str(), 0, 1); + // } + // + // sparse.colMax(optima_sparse.begin()); + // + // { + // std::stringstream str; + // str << "colMax D " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // for (UInt i = 0; i != ncols; ++i) { + // if (optima_dense[i].first != optima_sparse[i].first + // || !nupic::nearlyEqual(optima_dense[i].second, + // optima_sparse[i].second)) Test(str.str().c_str(), 0, 1); + // } + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_colMin() + // { + // UInt ncols, nrows, zr; + // + // { + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // std::vector > optima_sparse(ncols), + // optima_dense(ncols); + // + // dense.colMin(optima_dense.begin()); + // + // sparse.decompact(); + // + // for (UInt i = 0; i != ncols; ++i) { + // + // std::pair res_sparse; + // sparse.colMin(i, res_sparse.first, res_sparse.second); + // + // std::stringstream str; + // str << "rowMax 2 A " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // if (optima_dense[i].first != res_sparse.first + // || !nearlyEqual(optima_dense[i].second, res_sparse.second)) + // Test(str.str().c_str(), 0, 1); + // } + // + // sparse.colMin(optima_sparse.begin()); + // + // { + // std::stringstream str; + // str << "rowMax 2 B " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // for (UInt i = 0; i != ncols; ++i) { + // if (optima_dense[i].first != optima_sparse[i].first + // || !nupic::nearlyEqual(optima_dense[i].second, + // optima_sparse[i].second)) Test(str.str().c_str(), 0, 1); + // } + // } + // + // sparse.compact(); + // + // for (UInt i = 0; i != ncols; ++i) { + // + // std::pair res_sparse; + // sparse.colMin(i, res_sparse.first, res_sparse.second); + // + // std::stringstream str; + // str << "rowMax 2 C " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // if (optima_dense[i].first != res_sparse.first + // || !nearlyEqual(optima_dense[i].second, res_sparse.second)) + // Test(str.str().c_str(), 0, 1); + // } + // + // sparse.colMin(optima_sparse.begin()); + // + // { + // std::stringstream str; + // str << "rowMax 2 D " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // for (UInt i = 0; i != ncols; ++i) { + // if (optima_dense[i].first != optima_sparse[i].first + // || !nupic::nearlyEqual(optima_dense[i].second, + // optima_sparse[i].second)) Test(str.str().c_str(), 0, 1); + // } + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_nNonZeros() + // { + // UInt ncols, nrows, zr; + // + // TEST_LOOP(M) { + // + // DenseMat dense(nrows, ncols, zr); + // SparseMat sparse(nrows, ncols, dense.begin()); + // + // UInt n_s, n_d; + // + // { + // std::vector nrows_s(nrows), nrows_d(nrows); + // std::vector ncols_s(ncols), ncols_d(ncols); + // + // sparse.decompact(); + // + // n_d = dense.nNonZeros(); + // n_s = sparse.nNonZeros(); + // + // { + // std::stringstream str; + // str << "nNonZeros A1 " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // if (n_d != n_s) + // Test(str.str().c_str(), 0, 1); + // } + // + // for (UInt i = 0; i != nrows; ++i) { + // + // n_d = dense.nNonZerosOnRow(i); + // n_s = sparse.nNonZerosOnRow(i); + // + // { + // std::stringstream str; + // str << "nNonZeros B1 " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // if (n_d != n_s) + // Test(str.str().c_str(), 0, 1); + // } + // } + // + // for (UInt i = 0; i != ncols; ++i) { + // + // n_d = dense.nNonZerosOnCol(i); + // n_s = sparse.nNonZerosOnCol(i); + // + // { + // std::stringstream str; + // str << "nNonZeros C1 " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // if (n_d != n_s) + // Test(str.str().c_str(), 0, 1); + // } + // } + // + // dense.nNonZerosPerRow(nrows_d.begin()); + // sparse.nNonZerosPerRow(nrows_s.begin()); + // + // { + // std::stringstream str; + // str << "nNonZeros D1 " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // CompareVectors(nrows, nrows_d.begin(), nrows_s.begin(), + // str.str().c_str()); + // } + // + // dense.nNonZerosPerCol(ncols_d.begin()); + // sparse.nNonZerosPerCol(ncols_s.begin()); + // + // { + // std::stringstream str; + // str << "nNonZeros E1 " << nrows << "X" << ncols << "/" << zr + // << " - non compact"; + // CompareVectors(ncols, ncols_d.begin(), ncols_s.begin(), + // str.str().c_str()); + // } + // } + // + // { + // std::vector nrows_s(nrows), nrows_d(nrows); + // std::vector ncols_s(ncols), ncols_d(ncols); + // sparse.compact(); + // + // n_d = dense.nNonZeros(); + // n_s = sparse.nNonZeros(); + // + // { + // std::stringstream str; + // str << "nNonZeros A2 " << nrows << "X" << ncols << "/" << zr + // << " - compact"; + // if (n_d != n_s) + // Test(str.str().c_str(), 0, 1); + // } + // + // for (UInt i = 0; i != nrows; ++i) { + // + // n_d = dense.nNonZerosOnRow(i); + // n_s = sparse.nNonZerosOnRow(i); + // + // { + // std::stringstream str; + // str << "nNonZeros B2 " << nrows << "X" << ncols << "/" << zr + // << " - compact"; + // if (n_d != n_s) + // Test(str.str().c_str(), 0, 1); + // } + // } + // + // for (UInt i = 0; i != ncols; ++i) { + // + // n_d = dense.nNonZerosOnCol(i); + // n_s = sparse.nNonZerosOnCol(i); + // + // { + // std::stringstream str; + // str << "nNonZeros C2 " << nrows << "X" << ncols << "/" << zr + // << " - compact"; + // if (n_d != n_s) + // Test(str.str().c_str(), 0, 1); + // } + // } + // + // dense.nNonZerosPerRow(nrows_d.begin()); + // sparse.nNonZerosPerRow(nrows_s.begin()); + // + // { + // std::stringstream str; + // str << "nNonZeros D2 " << nrows << "X" << ncols << "/" << zr + // << " - compact"; + // CompareVectors(nrows, nrows_d.begin(), nrows_s.begin(), + // str.str().c_str()); + // } + // + // dense.nNonZerosPerCol(ncols_d.begin()); + // sparse.nNonZerosPerCol(ncols_s.begin()); + // + // { + // std::stringstream str; + // str << "nNonZeros E2 " << nrows << "X" << ncols << "/" << zr + // << " - compact"; + // CompareVectors(ncols, ncols_d.begin(), ncols_s.begin(), + // str.str().c_str()); + // } + // } + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_extract() + // { + // if (1) { // Visual tests + // + // DenseMat dense(5, 7, 2); + // SparseMat sparse(5, 7, dense.begin()); + // + // /* + // cout << "Sparse:" << endl << sparse << endl; + // + // { // Extract domain + // Domain2D dom(0,4,0,4); + // SparseMatrix extracted(4,4); + // sparse.get(dom, extracted); + // cout << extracted << endl; + // } + // */ + // } + // } + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_deleteRow() + // { + // // This is regression test for an off-by-one memory corruption bug + // // found in deleteRow the symptom of the bug is a seg fault so there + // // is no explicit test here. + // { + // SparseMat* sm = new SparseMat(11, 1); + // sm->deleteRow(3); + // delete sm; + // + // sm = new SparseMat(11, 1); + // sm->deleteRow(3); + // delete sm; + // } + // } + // + // //-------------------------------------------------------------------------------- + // /** + // * A generator function object, that generates random numbers between 0 + // and 256. + // * It also has a threshold to control the sparsity of the vectors + // generated. + // */ + // template + // struct rand_init + // { + // Random *r_; + // T threshold_; + // + // inline rand_init(Random *r, T threshold =100) + // : r_(r), threshold_(threshold) + // {} + // + // inline T operator()() + // { + // return T((T)(r_->getUInt32(100)) > threshold_ ? 0 : .001 + + // (r_->getReal64())); + // } + // }; + // + // //-------------------------------------------------------------------------------- + // void SparseMatrixUnitTest::unit_test_usage() + // { + // using namespace std; + // + // typedef UInt size_type; + // typedef double value_type; + // typedef SparseMatrix SM; + // typedef Dense DM; + // + // size_type maxMatrixSize = 30; + // size_type nrows = 20, ncols = 30, nzr = 20; + // + // DM* dense = new DM(nrows, ncols, nzr, true, true, rng_); + // SM* sparse = new SM(nrows, ncols, dense->begin()); + // + // for (long int a = 0; a < 10000; ++a) { + // + // // Rectify to stop propagation of small errors + // ITER_2(sparse->nRows(), sparse->nCols()) + // if (::fabs(dense->at(i,j) - sparse->get(i,j)) < 1e-6) + // dense->at(i,j) = sparse->get(i,j); + // + // size_type r = rng_->getUInt32(37); + // + // if (r == 0) { + // + // sparse->compact(); + // // no compact for Dense + // + // } else if (r == 1) { + // + // sparse->decompact(); + // // no decompact for Dense + // + // } else if (r == 2) { + // + // if (rng_->getReal64() < 0.90) { + // size_type nrows = sparse->nRows() + rng_->getUInt32(4); + // size_type ncols = sparse->nCols() + rng_->getUInt32(4); + // sparse->resize(nrows, ncols); + // dense->resize(nrows, ncols); + // Compare(*dense, *sparse, "resize, bigger"); + // + // } else { + // if (sparse->nRows() > 2 && sparse->nCols() > 2) { + // size_type nrows = rng_->getUInt32(sparse->nRows()); + // size_type ncols = rng_->getUInt32(sparse->nCols()); + // sparse->resize(nrows, ncols); + // dense->resize(nrows, ncols); + // Compare(*dense, *sparse, "resize, smaller"); + // } + // } + // + // } else if (r == 3) { + // + // vector del; + // + // if (rng_->getReal64() < 0.90) { + // for (size_type ii = 0; ii < sparse->nRows() / 4; ++ii) + // del.push_back(2*ii); + // sparse->deleteRows(del.begin(), del.end()); + // dense->deleteRows(del.begin(), del.end()); + // } else { + // for (size_type ii = 0; ii < sparse->nRows(); ++ii) + // del.push_back(ii); + // sparse->deleteRows(del.begin(), del.end()); + // dense->deleteRows(del.begin(), del.end()); + // } + // + // Compare(*dense, *sparse, "deleteRows"); + // + // } else if (r == 4) { + // + // vector del; + // if (rng_->getReal64() < 0.90) { + // for (size_type ii = 0; ii < sparse->nCols() / 4; ++ii) + // del.push_back(2*ii); + // sparse->deleteCols(del.begin(), del.end()); + // dense->deleteCols(del.begin(), del.end()); + // } else { + // for (size_type ii = 0; ii < sparse->nCols(); ++ii) + // del.push_back(ii); + // sparse->deleteCols(del.begin(), del.end()); + // dense->deleteCols(del.begin(), del.end()); + // } + // Compare(*dense, *sparse, "deleteCols"); + // + // } else if (r == 5) { + // + // SM sparse2(1, 1); + // DM sm2Dense(1, 1); + // Compare(sm2Dense,sparse2, "constructor(1, 1)"); + // + // sparse2.copy(*sparse); + // sparse->copy(sparse2); + // + // sm2Dense.copy(*dense); + // dense->copy(sm2Dense); + // Compare(*dense, *sparse, "copy"); + // + // } else if (r == 6) { + // + // vector row(sparse->nCols()); + // size_type n = rng_->getUInt32(16); + // for (size_type z = 0; z < n; ++z) { + // if (rng_->getReal64() < 0.90) + // generate(row.begin(), row.end(), rand_init(rng_, 70)); + // sparse->addRow(row.begin()); + // dense->addRow(row.begin()); + // Compare(*dense, *sparse, "addRow"); + // } + // + // } else if (r == 7) { + // + // if (sparse->nRows() > 0 && sparse->nCols() > 0) { + // size_type m = sparse->nRows() * sparse->nCols() / 2; + // for (size_type z = 0; z < m; ++z) { + // size_type i = rng_->getUInt32(sparse->nRows()); + // size_type j = rng_->getUInt32(sparse->nCols()); + // value_type v = 1+value_type(rng_->getReal64()); + // sparse->setNonZero(i, j, v); + // dense->setNonZero(i, j, v); + // Compare(*dense, *sparse, "setNonZero"); + // } + // } + // + // } else if (r == 8) { + // + // value_type v = value_type(128 + rng_->getUInt32(128)); + // sparse->threshold(v); + // dense->threshold(v); + // Compare(*dense, *sparse, "threshold"); + // + // } else if (r == 9) { + // + // if (sparse->nCols() > 0 && sparse->nRows() > 0) { + // + // SM B(0, sparse->nCols()); + // DM BDense(0, dense->ncols); + // + // vector row(sparse->nCols()); + // + // for (size_type iii = 0; iii < sparse->nRows(); ++iii) { + // + // if (rng_->getUInt32(100) < 90) + // generate(row.begin(), row.end(), rand_init(rng_, 70)); + // else + // fill(row.begin(), row.end(), value_type(0)); + // + // B.addRow(row.begin()); + // BDense.addRow(row.begin()); + // } + // + // value_type r1=value_type(rng_->getUInt32(5)), + // r2=value_type(rng_->getUInt32(5)); + // + // sparse->lerp(r1, r2, B); + // dense->lerp(r1, r2, BDense); + // Compare(*dense, *sparse, "lerp", 1e-4); + // } + // + // } else if (r == 10) { + // + // delete sparse; + // delete dense; + // size_type nrows = rng_->getUInt32(maxMatrixSize), ncols = + // rng_->getUInt32(maxMatrixSize); sparse = new SM(ncols, nrows); dense = + // new DM(ncols, nrows); Compare(*dense, *sparse, + // "constructor(rng_->get() % 32, rng_->get() % 32)"); + // + // } else if (r == 11) { + // + // delete sparse; + // delete dense; + // sparse = new SM(); + // dense = new DM(); + // Compare(*dense, *sparse, "constructor()"); + // + // } else if (r == 12) { + // + // delete sparse; + // delete dense; + // sparse = new SM(0,0); + // dense = new DM(0,0); + // Compare(*dense, *sparse, "constructor(0,0)"); + // + // } else if (r == 13) { + // + // SM sm2(sparse->nRows(), sparse->nCols()); + // DM sm2Dense(dense->nrows, dense->ncols); + // Compare(sm2Dense, sm2, "constructor(dense->nRows(), dense->nCols())"); + // + // ITER_2(sm2.nRows(), sm2.nCols()) { + // value_type r = 1+rng_->getUInt32(256); + // sm2.setNonZero(i,j, r); + // sm2Dense.setNonZero(i,j, r); + // } + // sparse->elementApply(sm2, std::plus()); + // dense->add(sm2Dense); + // Compare(*dense, *sparse, "add"); + // + // } else if (r == 14) { + // + // if (sparse->nRows() > 0) { + // vector row(sparse->nCols()); + // generate(row.begin(), row.end(), rand_init(rng_, 70)); + // size_type r = rng_->getUInt32(sparse->nRows()); + // sparse->elementRowApply(r, std::plus(), row.begin()); + // dense->add(r, row.begin()); + // Compare(*dense, *sparse, "add(randomR, row.begin())"); + // } + // + // } else if (r == 15) { + // + // SM B(sparse->nCols(), sparse->nRows()); + // DM BDense(dense->ncols, dense->nrows); + // Compare(BDense, B, "constructor(sm->nCols(), sm->nRows())"); + // sparse->transpose(B); + // dense->transpose(BDense); + // Compare(*dense, *sparse, "transpose"); + // + // } else if (r == 16) { + // + // /* + // vector x(sparse->nCols()), y(sparse->nRows()); + // generate(x.begin(), x.end(), rand_init(rng_, 50)); + // sparse->L2Dist(x.begin(), y.begin()); + // dense->L2Dist(x.begin(), y.begin()); + // Compare(*dense, *sparse, "L2Dist", 1e-4); + // */ + // + // } else if (r == 17) { + // + // /* + // vector x(sparse->nCols()); + // pair closest; + // generate(x.begin(), x.end(), rand_init(rng_, 50)); + // sparse->L2Nearest(x.begin(), closest); + // dense->L2Nearest(x.begin(), closest); + // Compare(*dense, *sparse, "L2Nearest", 1e-4); + // */ + // + // } else if (r == 18) { + // + // /* + // vector x(sparse->nCols()), y(sparse->nRows()); + // generate(x.begin(), x.end(), rand_init(rng_, 50)); + // sparse->vecDist(x.begin(), y.begin()); + // dense->vecDist(x.begin(), y.begin()); + // Compare(*dense, *sparse, "vecDist", 1e-4); + // */ + // + // } else if(r == 19) { + // + // /* + // if(sparse->nRows() > 0) { + // vector x(sparse->nCols()); + // generate(x.begin(), x.end(), rand_init(rng_, 50)); + // size_type randInt=rng_->get() % sparse->nRows(); + // sparse->rowDistSquared(randInt, x.begin()); + // dense->rowDistSquared(randInt, x.begin()); + // Compare(*dense, *sparse, "rowDistSquared", 1e-4); + // } + // */ + // + // } else if (r == 20) { + // + // /* + // vector x(sparse->nCols()); + // generate(x.begin(), x.end(), rand_init(rng_, 50)); + // sparse->closestEuclidean(x.begin()); + // dense->closestEuclidean(x.begin()); + // Compare(*dense, *sparse, "closestEuclidean", 1e-4); + // */ + // + // } else if (r== 21) { + // + // /* + // vector x(sparse->nCols()); + // generate(x.begin(), x.end(), rand_init(rng_, 50)); + // for (size_type n = 0; n < sparse->nCols(); ++n) + // x.push_back(value_type(rng_->get() % 256)); + // sparse->dotNearest(x.begin()); + // dense->dotNearest(x.begin()); + // Compare(*dense, *sparse, "dotNearest", 1e-4); + // */ + // + // } else if (r == 22) { + // + // vector x(sparse->nCols()), y(sparse->nRows()); + // generate(x.begin(), x.end(), rand_init(rng_, 50)); + // sparse->rightVecProd(x.begin(), y.begin()); + // dense->rightVecProd(x.begin(), y.begin()); + // Compare(*dense, *sparse, "rightVecProd", 1e-4); + // + // } else if (r == 23) { + // + // vector x(sparse->nCols()), y(sparse->nRows()); + // generate(x.begin(), x.end(), rand_init(rng_, 50)); + // sparse->vecMaxProd(x.begin(), y.begin()); + // dense->vecMaxProd(x.begin(), y.begin()); + // Compare(*dense, *sparse, "vecMaxProd", 1e-4); + // + // } else if (r == 24) { + // + // vector x(sparse->nCols()), y(sparse->nRows()); + // generate(x.begin(), x.end(), rand_init(rng_, 50)); + // sparse->vecMaxAtNZ(x.begin(), y.begin()); + // dense->vecMaxAtNZ(x.begin(), y.begin()); + // Compare(*dense, *sparse, "vecMaxAtNZ", 1e-4); + // + // } else if (r == 25) { + // + // if (sparse->nRows() > 0) { + // vector x(sparse->nCols()); + // generate(x.begin(), x.end(), rand_init(rng_, 50)); + // size_type row = rng_->getUInt32(sparse->nRows()); + // value_type r1=value_type(rng_->getUInt32(256)), + // r2=value_type(rng_->getUInt32(256)); sparse->axby(row, r1, r2, + // x.begin()); dense->axby(row, r1, r2, x.begin()); Compare(*dense, + // *sparse, + //"axby", 1e-4); + // } + // + // } else if (r == 26) { + // + // vector x(sparse->nCols()); + // generate(x.begin(), x.end(), rand_init(rng_, 50)); + // value_type r1=value_type(rng_->getUInt32(256)), + // r2=value_type(rng_->getUInt32(256)); sparse->axby(r1, r2, x.begin()); + // dense->axby(r1, r2, x.begin()); + // Compare(*dense, *sparse, "axby 2", 1e-4); + // + // } else if (r == 27) { + // + // /* + // vector x(sparse->nCols()), y(sparse->nRows()); + // generate(x.begin(), x.end(), rand_init(rng_, 50)); + // sparse->rowMax(x.begin(), y.begin()); + // dense->rowMax(x.begin(), y.begin()); + // Compare(*dense, *sparse, "rowMax"); + // */ + // + // } else if (r == 28) { + // + // vector< pair > y(sparse->nRows()); + // sparse->rowMax(y.begin()); + // dense->rowMax(y.begin()); + // Compare(*dense, *sparse, "rowMax 2"); + // + // } else if (r == 29) { + // + // vector< pair > y(sparse->nCols()); + // sparse->colMax(y.begin()); + // dense->colMax(y.begin()); + // Compare(*dense, *sparse, "colMax"); + // + // } else if (r == 30) { + // + // bool exact = true; + // sparse->normalizeRows(exact); + // dense->normalizeRows(exact); + // Compare(*dense, *sparse, "normalizeRows", 1e-4); + // + // } else if (r == 31) { + // + // vector x(sparse->nCols()), y(sparse->nRows()); + // generate(x.begin(), x.end(), rand_init(rng_, 50)); + // sparse->rightVecProdAtNZ(x.begin(), y.begin()); + // dense->rowProd(x.begin(), y.begin()); + // Compare(*dense, *sparse, "rowProd", 1e-4); + // + // } else if (r == 32) { + // + // vector x(sparse->nCols()), y(sparse->nRows()); + // generate(x.begin(), x.end(), rand_init(rng_, 50)); + // value_type theRandom=value_type(rng_->getUInt32(256)); + // sparse->rightVecProdAtNZ(x.begin(), y.begin(), theRandom); + // dense->rowProd(x.begin(), y.begin(), theRandom); + // Compare(*dense, *sparse, "rowProd 2", 1e-4); + // + // } else if (r == 33) { + // + // //size_type row; + // //value_type init; + // + // if (sparse->nRows() != 0) { + // + // /* + // row = rng_->get() % sparse->nRows(); + // init = (rng_->get() % 32768)/32768.0 + .001; + // + // size_type switcher = rng_->get() % 4; + // + // if (switcher == 0) { + // sparse->accumulateRowNZ(row, multiplies(), init); + // dense->accumulateRowNZ(row, multiplies(), init); + // Compare(*dense, *sparse, "accumulateRowNZ with multiplies", 1e-4); + // } else if (switcher == 1) { + // sparse->accumulateRowNZ(row, plus(), init); + // dense->accumulateRowNZ(row, plus(), init); + // Compare(*dense, *sparse, "accumulateRowNZ with plus", 1e-4); + // } else if (switcher == 2) { + // sparse->accumulateRowNZ(row, minus(), init); + // dense->accumulateRowNZ(row, minus(), init); + // Compare(*dense, *sparse, "accumulateRowNZ with minus", 1e-4); + // } else if (switcher == 3) { + // sparse->accumulateRowNZ(row, nupic::Max, init); + // dense->accumulateRowNZ(row, nupic::Max, init); + // Compare(*dense, *sparse, "accumulateRowNZ with Max", 1e-4); + // } + // */ + // } + // + // } else if (r == 34) { + // + // //size_type row; + // //value_type init; + // + // if (sparse->nRows() != 0) { + // /* + // row = rng_->get() % sparse->nRows(); + // init = (rng_->get() % 32768)/32768.0 + .001; + // + // size_type switcher = rng_->get() % 4; + // + // if (switcher == 0) { + // sparse->accumulate(row, multiplies(), init); + // dense->accumulate(row, multiplies(), init); + // Compare(*dense, *sparse, "accumulateRowNZ with multiplies", 1e-4); + // } else if (switcher == 1) { + // sparse->accumulate(row, plus(), init); + // dense->accumulate(row, plus(), init); + // Compare(*dense, *sparse, "accumulateRowNZ with plus", 1e-4); + // } else if (switcher == 2) { + // sparse->accumulate(row, minus(), init); + // dense->accumulate(row, minus(), init); + // Compare(*dense, *sparse, "accumulateRowNZ with minus", 1e-4); + // } else if (switcher == 3) { + // sparse->accumulate(row, nupic::Max, init); + // dense->accumulate(row, nupic::Max, init); + // Compare(*dense, *sparse, "accumulateRowNZ with Max", 1e-4); + // } + // */ + // } + // + // } else if (r == 35) { + // + // if(dense->ncols > 0 && dense->nrows > 0) { + // + // size_type randomTemp = rng_->getUInt32(maxMatrixSize); + // SM B(0, randomTemp); + // SM C(sparse->nRows(), randomTemp); + // DM BDense(0, randomTemp); + // DM CDense(dense->nrows, randomTemp); + // + // vector x(randomTemp); + // + // for (size_type n=0; n < sparse->nCols(); n++) { + // generate(x.begin(), x.end(), rand_init(rng_, 50)); + // B.addRow(x.begin()); + // BDense.addRow(x.begin()); + // } + // + // sparse->multiply( B, C); + // dense->multiply( BDense, CDense); + // Compare(*dense, *sparse, "multiply", 1e-4); + // } + // + // } else if (r == 36) { + // + // if (sparse->nRows() > 0 && sparse->nCols() > 0) { + // + // vector indices, indicesDense; + // vector values, valuesDense; + // + // size_type r = rng_->getUInt32(sparse->nRows()); + // + // sparse->getRowToSparse(r, back_inserter(indices), + // back_inserter(values)); + // + // dense->getRowToSparse(r, back_inserter(indicesDense), + // back_inserter(valuesDense)); + // + // sparse->findRow((size_type)indices.size(), + // indices.begin(), + // values.begin()); + // + // dense->findRow((size_type)indicesDense.size(), + // indicesDense.begin(), + // valuesDense.begin()); + // + // CompareVectors((size_type)indices.size(), indices.begin(), + // indicesDense.begin(), + // "findRow indices"); + // + // CompareVectors((size_type)values.size(), values.begin(), + // valuesDense.begin(), + // "findRow values"); + // } + // } + // } + // + // delete sparse; + // delete dense; + // } + // //-------------------------------------------------------------------------------- // void SparseMatrixUnitTest::RunTests() // { - //unit_test_construction(); - //unit_test_copy(); - //unit_test_dense(); - //unit_test_csr(); - //unit_test_compact(); - //unit_test_threshold(); - //unit_test_addRowCol(); - //unit_test_resize(); - //unit_test_deleteRows(); - //unit_test_deleteCols(); - //unit_test_set(); - //unit_test_setRowColToZero(); - //unit_test_getRow(); - //unit_test_getCol(); - //unit_test_vecMaxProd(); - //unit_test_vecProd(); - //unit_test_axby(); - //unit_test_axby_3(); - //unit_test_rowMax(); - //unit_test_maxima(); - //unit_test_normalize(); - //unit_test_rowProd(); - //unit_test_lerp(); - //unit_test_accumulate(); - //unit_test_transpose(); - //unit_test_multiply(); - //unit_test_small_values(); - //unit_test_argMax(); - //unit_test_argMin(); - //unit_test_rowMax_2(); - //unit_test_rowMin(); - //unit_test_colMax(); - //unit_test_colMin(); - //unit_test_nNonZeros(); - //unit_test_extract(); - //unit_test_deleteRow(); - ////unit_test_usage(); - // } - - //-------------------------------------------------------------------------------- - -// } // namespace nupic +// unit_test_construction(); +// unit_test_copy(); +// unit_test_dense(); +// unit_test_csr(); +// unit_test_compact(); +// unit_test_threshold(); +// unit_test_addRowCol(); +// unit_test_resize(); +// unit_test_deleteRows(); +// unit_test_deleteCols(); +// unit_test_set(); +// unit_test_setRowColToZero(); +// unit_test_getRow(); +// unit_test_getCol(); +// unit_test_vecMaxProd(); +// unit_test_vecProd(); +// unit_test_axby(); +// unit_test_axby_3(); +// unit_test_rowMax(); +// unit_test_maxima(); +// unit_test_normalize(); +// unit_test_rowProd(); +// unit_test_lerp(); +// unit_test_accumulate(); +// unit_test_transpose(); +// unit_test_multiply(); +// unit_test_small_values(); +// unit_test_argMax(); +// unit_test_argMin(); +// unit_test_rowMax_2(); +// unit_test_rowMin(); +// unit_test_colMax(); +// unit_test_colMin(); +// unit_test_nNonZeros(); +// unit_test_extract(); +// unit_test_deleteRow(); +////unit_test_usage(); +// } +//-------------------------------------------------------------------------------- +// } // namespace nupic diff --git a/src/test/unit/math/SparseMatrixUnitTest.hpp b/src/test/unit/math/SparseMatrixUnitTest.hpp index 28ba0d2cad..357173056b 100644 --- a/src/test/unit/math/SparseMatrixUnitTest.hpp +++ b/src/test/unit/math/SparseMatrixUnitTest.hpp @@ -82,7 +82,8 @@ // // //-------------------------------------------------------------------------------- // template -// inline void CompareVectors(UInt n, InIter1 y1, InIter2 y2, const char* str) +// inline void CompareVectors(UInt n, InIter1 y1, InIter2 y2, const char* +// str) // { // InIter1 y1_begin = y1; // InIter2 y2_begin = y2; diff --git a/src/test/unit/math/SparseTensorUnitTest.cpp b/src/test/unit/math/SparseTensorUnitTest.cpp index e7419fd131..0c1c876b3d 100644 --- a/src/test/unit/math/SparseTensorUnitTest.cpp +++ b/src/test/unit/math/SparseTensorUnitTest.cpp @@ -19,11 +19,11 @@ * http://numenta.org/licenses/ * --------------------------------------------------------------------- */ - + /** @file * Implementation of unit testing for class SparseTensor - */ - + */ + // #include "SparseTensorUnitTest.hpp" // // #include @@ -38,20 +38,21 @@ // // //-------------------------------------------------------------------------------- // template -// inline bool Compare(const SparseTensor& A, const DenseTensor& B) +// inline bool Compare(const SparseTensor& A, const DenseTensor& +// B) // { // bool ok = true; -// +// // if (A.getRank() != B.getRank()) { -// NTA_WARN << "Ranks are different: " << A.getRank() << " and " << B.getRank(); -// ok = false; +// NTA_WARN << "Ranks are different: " << A.getRank() << " and " << +// B.getRank(); ok = false; // } -// +// // if (A.isZero() != B.isZero()) { // NTA_WARN << "isZero problem"; // ok = false; // } -// +// // if (A.isDense() != B.isDense()) { // NTA_WARN << "Density problem"; // ok = false; @@ -66,14 +67,14 @@ // NTA_WARN << "Bounds are different: " // << A.getBounds() // << " and " << B.getBounds(); -// ok = false; +// ok = false; // } -// +// // if (A.getNNonZeros() != B.getNNonZeros()) { // NTA_WARN << "Number of non-zeros are different: " -// << A.getNNonZeros() +// << A.getNNonZeros() // << " and " << B.getNNonZeros(); -// ok = false; +// ok = false; // } // // std::vector > diffs; @@ -88,14 +89,14 @@ // } while (increment(A.getBounds(), idx)); // // if (!diffs.empty()) { -// NTA_WARN << "There are " << diffs.size() << " differences between A and B: "; -// for (UInt i = 0; i < diffs.size(); ++i) -// NTA_WARN << "Index: " << get<0>(diffs[i]) +// NTA_WARN << "There are " << diffs.size() << " differences between A and +// B: "; for (UInt i = 0; i < diffs.size(); ++i) +// NTA_WARN << "Index: " << get<0>(diffs[i]) // << " A (sparse) = " << get<1>(diffs[i]) // << " B (dense) = " << get<2>(diffs[i]); // ok = false; // } -// +// // return ok; // } // @@ -111,27 +112,27 @@ // { // Threshold(const Real& threshold_) : threshold(threshold_) {} // Real threshold; -// inline Real operator()(const Real& x) const +// inline Real operator()(const Real& x) const // { return x > threshold ? 0 : x; } // }; -// +// // //-------------------------------------------------------------------------------- // struct Plus3 // { -// inline Real operator()(const Real& x) const +// inline Real operator()(const Real& x) const // { return x + 3; } // }; // // //-------------------------------------------------------------------------------- // struct BinaryPlus3 // { -// inline Real operator()(const Real& x, const Real& y) const +// inline Real operator()(const Real& x, const Real& y) const // { return x + y + 3; } // }; // // //-------------------------------------------------------------------------------- // template -// inline void GenerateRand01(Random *r, UInt nnz, SparseTensor& s) +// inline void GenerateRand01(Random *r, UInt nnz, SparseTensor& s) // { // const Index& ub = s.getBounds(); // Index idx; @@ -145,7 +146,7 @@ // // //-------------------------------------------------------------------------------- // template -// inline void GenerateRandRand01(Random *r, SparseTensor& s) +// inline void GenerateRandRand01(Random *r, SparseTensor& s) // { // const Index& ub = s.getBounds(); // const UInt nnz = 1 + (r->getUInt32(product(ub))); @@ -176,31 +177,34 @@ // { // I3 ub(5, 4, 3); // -// { +// { // S3 s3(ub); -// -// ITER_3(ub[0], ub[1], ub[2]) -// Test("SparseTensor bounds list constructor", s3.get(i, j, k), (Real)0); -// +// +// ITER_3(ub[0], ub[1], ub[2]) +// Test("SparseTensor bounds list constructor", s3.get(i, j, k), +// (Real)0); +// // Test("SparseTensor getRank 1", s3.getRank(), (UInt)3); // Test("SparseTensor getNNonZeros 2", s3.getNNonZeros(), (UInt)0); // Test("SparseTensor isZero 3", s3.isZero(), true); // Test("SparseTensor isDense 4", s3.isDense(), false); -// Test("SparseTensor getBounds 5", s3.getBounds(), I3(ub[0], ub[1], ub[2])); +// Test("SparseTensor getBounds 5", s3.getBounds(), I3(ub[0], ub[1], +// ub[2])); // // ITER_3(ub[0], ub[1], ub[2]) { // s3.set(I3(i, j, k), Real(i*ub[1]*ub[2]+j*ub[2]+k)); -// Test("SparseTensor bounds set/get 6", s3.get(i, j, k), i*ub[1]*ub[2]+j*ub[2]+k); +// Test("SparseTensor bounds set/get 6", s3.get(i, j, k), +// i*ub[1]*ub[2]+j*ub[2]+k); // } // -// Test("SparseTensor getNNonZeros 7", s3.getNNonZeros(), (UInt)ub[0]*ub[1]*ub[2]-1); -// Test("SparseTensor isZero 8", s3.isZero(), false); -// Test("SparseTensor isDense 9", s3.isDense(), false); +// Test("SparseTensor getNNonZeros 7", s3.getNNonZeros(), +// (UInt)ub[0]*ub[1]*ub[2]-1); Test("SparseTensor isZero 8", s3.isZero(), +// false); Test("SparseTensor isDense 9", s3.isDense(), false); // // s3.set(I3(0, 0, 0), 1); -// Test("SparseTensor getNNonZeros 10", s3.getNNonZeros(), (UInt)ub[0]*ub[1]*ub[2]); -// Test("SparseTensor isZero 11", s3.isZero(), false); -// Test("SparseTensor isDense 12", s3.isDense(), true); +// Test("SparseTensor getNNonZeros 10", s3.getNNonZeros(), +// (UInt)ub[0]*ub[1]*ub[2]); Test("SparseTensor isZero 11", s3.isZero(), +// false); Test("SparseTensor isDense 12", s3.isDense(), true); // // s3.clear(); // Test("SparseTensor getNNonZeros 71", s3.getNNonZeros(), (UInt)0); @@ -208,70 +212,77 @@ // Test("SparseTensor isDense 91", s3.isDense(), false); // } // -// { +// { // I3 i3(ub); // S3 s3(i3); -// -// ITER_3(ub[0], ub[1], ub[2]) -// Test("SparseTensor bounds vector constructor", s3.get(i, j, k), (Real)0); -// +// +// ITER_3(ub[0], ub[1], ub[2]) +// Test("SparseTensor bounds vector constructor", s3.get(i, j, k), +// (Real)0); +// // Test("SparseTensor getRank 25", s3.getRank(), (UInt)3); // Test("SparseTensor getNNonZeros 26", s3.getNNonZeros(), (UInt)0); // Test("SparseTensor isZero 27", s3.isZero(), true); // Test("SparseTensor isDense 28", s3.isDense(), false); -// Test("SparseTensor getBounds 29", s3.getBounds(), I3(ub[0], ub[1], ub[2])); +// Test("SparseTensor getBounds 29", s3.getBounds(), I3(ub[0], ub[1], +// ub[2])); // // ITER_3(ub[0], ub[1], ub[2]) { // s3.set(I3(i, j, k), Real(i*ub[1]*ub[2]+j*ub[2]+k)); -// Test("SparseTensor bounds set/get 30", s3.get(i, j, k), i*ub[1]*ub[2]+j*ub[2]+k); +// Test("SparseTensor bounds set/get 30", s3.get(i, j, k), +// i*ub[1]*ub[2]+j*ub[2]+k); // } // -// Test("SparseTensor getNNonZeros 31", s3.getNNonZeros(), (UInt)ub[0]*ub[1]*ub[2]-1); -// Test("SparseTensor isZero 32", s3.isZero(), false); -// Test("SparseTensor isDense 33", s3.isDense(), false); +// Test("SparseTensor getNNonZeros 31", s3.getNNonZeros(), +// (UInt)ub[0]*ub[1]*ub[2]-1); Test("SparseTensor isZero 32", s3.isZero(), +// false); Test("SparseTensor isDense 33", s3.isDense(), false); // // s3.set(I3(0, 0, 0), 1); -// Test("SparseTensor getNNonZeros 34", s3.getNNonZeros(), (UInt)ub[0]*ub[1]*ub[2]); -// Test("SparseTensor isZero 35", s3.isZero(), false); -// Test("SparseTensor isDense 36", s3.isDense(), true); +// Test("SparseTensor getNNonZeros 34", s3.getNNonZeros(), +// (UInt)ub[0]*ub[1]*ub[2]); Test("SparseTensor isZero 35", s3.isZero(), +// false); Test("SparseTensor isDense 36", s3.isDense(), true); // } // -// { +// { // I3 ub3(ub); // S3 s3(ub3); // -// ITER_3(ub[0], ub[1], ub[2]) +// ITER_3(ub[0], ub[1], ub[2]) // s3.set(I3(i, j, k), Real(i*ub[1]*ub[2]+j*ub[2]+k)); // // S3 s32(s3); // // ITER_3(ub[0], ub[1], ub[2]) { -// Test("SparseTensor bounds set/get 37", s32.get(i, j, k), i*ub[1]*ub[2]+j*ub[2]+k); +// Test("SparseTensor bounds set/get 37", s32.get(i, j, k), +// i*ub[1]*ub[2]+j*ub[2]+k); // } // -// Test("SparseTensor getNNonZeros 38", s32.getNNonZeros(), (UInt)ub[0]*ub[1]*ub[2]-1); -// Test("SparseTensor isZero 39", s32.isZero(), false); -// Test("SparseTensor isDense 40", s32.isDense(), false); -// Test("SparseTensor getBounds 41", s32.getBounds(), I3(ub[0], ub[1], ub[2])); +// Test("SparseTensor getNNonZeros 38", s32.getNNonZeros(), +// (UInt)ub[0]*ub[1]*ub[2]-1); Test("SparseTensor isZero 39", s32.isZero(), +// false); Test("SparseTensor isDense 40", s32.isDense(), false); +// Test("SparseTensor getBounds 41", s32.getBounds(), I3(ub[0], ub[1], +// ub[2])); // } // -// { +// { // I3 ub3(ub); // S3 s3(ub3); // -// ITER_3(ub[0], ub[1], ub[2]) +// ITER_3(ub[0], ub[1], ub[2]) // s3.set(I3(i, j, k), Real(i*ub[1]*ub[2]+j*ub[2]+k)); // // S3 s32 = s3; // // ITER_3(ub[0], ub[1], ub[2]) { -// Test("SparseTensor bounds set/get 42", s32.get(i, j, k), i*ub[1]*ub[2]+j*ub[2]+k); +// Test("SparseTensor bounds set/get 42", s32.get(i, j, k), +// i*ub[1]*ub[2]+j*ub[2]+k); // } // -// Test("SparseTensor getNNonZeros 43", s32.getNNonZeros(), (UInt)ub[0]*ub[1]*ub[2]-1); -// Test("SparseTensor isZero 44", s32.isZero(), false); -// Test("SparseTensor isDense 45", s32.isDense(), false); -// Test("SparseTensor getBounds 46", s32.getBounds(), I3(ub[0], ub[1], ub[2])); +// Test("SparseTensor getNNonZeros 43", s32.getNNonZeros(), +// (UInt)ub[0]*ub[1]*ub[2]-1); Test("SparseTensor isZero 44", s32.isZero(), +// false); Test("SparseTensor isDense 45", s32.isDense(), false); +// Test("SparseTensor getBounds 46", s32.getBounds(), I3(ub[0], ub[1], +// ub[2])); // } // // { @@ -289,13 +300,14 @@ // } // // // Constructing a dimension 6, just to be sure, we had a bug -// // with va_arg(indices, UInt) (ellipsis setters/getters) -// // that showed up in dimension 6 on 64 bits only, because of the specific sizes -// // of the integer types used... +// // with va_arg(indices, UInt) (ellipsis setters/getters) +// // that showed up in dimension 6 on 64 bits only, because of the specific +// sizes +// // of the integer types used... // I6 ub6(5, 4, 3, 2, 3, 4); // D6 d6(ub6); // S6 s6(ub6); -// +// // UInt c = 1; // ITER_6(ub6[0], ub6[1], ub6[2], ub6[3], ub6[4], ub6[5]) { // d6.set(I6(i, j, k, l, m, n), Real(c)); @@ -308,13 +320,13 @@ // // This could catch uninitialized values, that used to be a problem // // on shona, revealed by valgrind only // bool correct = true; -// ITER_6(ub6[0], ub6[1], ub6[2], ub6[3], ub6[4], ub6[5]) +// ITER_6(ub6[0], ub6[1], ub6[2], ub6[3], ub6[4], ub6[5]) // if (s6.get(I6(i, j, k, l, m, n)) > 2000 // || d6.get(I6(i, j, k, l, m, n)) > 2000) { // correct = false; // break; // } -// +// // Test("SparseTensor dim 6 set/get 2", correct, true); // } // @@ -323,9 +335,9 @@ // { // I3 ub(5, 4, 3); // -// // These tests compiled conditionally, because they are +// // These tests compiled conditionally, because they are // // based on asserts rather than checks -// +// //#ifdef NTA_ASSERTIONS_ON // // { // out of bounds @@ -370,7 +382,7 @@ // //#endif // -// { +// { // S3 s3(ub), s32(ub), s33(ub); // // ITER_3(ub[0], ub[1], ub[2]) { @@ -379,21 +391,23 @@ // //s33(i, j, k) = s3(i, j, k); // } // -// Test("SparseTensor set 1", s3.getNNonZeros(), (UInt)ub[0]*ub[1]*ub[2]-1); -// Test("SparseTensor set 2", s3.isZero(), false); -// Test("SparseTensor set 3", s3.isDense(), false); -// Test("SparseTensor set 4", s32.getNNonZeros(), (UInt)ub[0]*ub[1]*ub[2]-1); -// Test("SparseTensor set 5", s32.isZero(), false); -// Test("SparseTensor set 6", s32.isDense(), false); -// //Test("SparseTensor set 7", s33.getNNonZeros(), (UInt)ub[0]*ub[1]*ub[2]-1); +// Test("SparseTensor set 1", s3.getNNonZeros(), +// (UInt)ub[0]*ub[1]*ub[2]-1); Test("SparseTensor set 2", s3.isZero(), +// false); Test("SparseTensor set 3", s3.isDense(), false); +// Test("SparseTensor set 4", s32.getNNonZeros(), +// (UInt)ub[0]*ub[1]*ub[2]-1); Test("SparseTensor set 5", s32.isZero(), +// false); Test("SparseTensor set 6", s32.isDense(), false); +// //Test("SparseTensor set 7", s33.getNNonZeros(), +// (UInt)ub[0]*ub[1]*ub[2]-1); // //Test("SparseTensor set 8", s33.isZero(), false); // //Test("SparseTensor set 9", s33.isDense(), false); // // // ITER_3(ub[0], ub[1], ub[2]) { -// Test("SparseTensor get 1", s3.get(I3(i, j, k)), i*ub[1]*ub[2]+j*ub[2]+k); -// Test("SparseTensor get 2", s3.get(i, j, k), s3.get(I3(i, j, k))); -// Test("SparseTensor get 3", s3(i, j, k), s3.get(I3(i, j, k))); +// Test("SparseTensor get 1", s3.get(I3(i, j, k)), +// i*ub[1]*ub[2]+j*ub[2]+k); Test("SparseTensor get 2", s3.get(i, j, k), +// s3.get(I3(i, j, k))); Test("SparseTensor get 3", s3(i, j, k), +// s3.get(I3(i, j, k))); // } // // // setZero @@ -445,7 +459,7 @@ // Test("SparseTensor setZero(Domain) 3", s3.isZero(), true); // } // } -// +// // { // S2 s2(4, 5); // D2 d2(4, 5); @@ -465,28 +479,28 @@ // Test("SparseTensor update 02", Compare(s2, d2), true); // } // -// { +// { // S3 s3(ub); // D3 d3(ub); -// +// // ITER_3(ub[0], ub[1], ub[2]) { -// s3.update(I3(i, j, k), Real(i*ub[1]*ub[2]+j*ub[2]+k), std::plus()); -// d3.set(i, j, k, (Real)(i*ub[1]*ub[2]+j*ub[2]+k)); -// } +// s3.update(I3(i, j, k), Real(i*ub[1]*ub[2]+j*ub[2]+k), +// std::plus()); d3.set(i, j, k, (Real)(i*ub[1]*ub[2]+j*ub[2]+k)); +// } // // Test("SparseTensor update 1", Compare(s3, d3), true); // // ITER_3(ub[0], ub[1], ub[2]) { // s3.update(I3(i, j, k), s3.get(i, j, k), std::plus()); // d3.set(i, j, k, (Real)(d3.get(i, j, k) + d3.get(i, j, k))); -// } +// } // // Test("SparseTensor update 2", Compare(s3, d3), true); -// -// ITER_3(ub[0], ub[1], ub[2]) { +// +// ITER_3(ub[0], ub[1], ub[2]) { // s3.update(I3(i, j, k), s3.get(i, j, k), nupic::Multiplies()); // d3.set(i, j, k, (Real)(d3.get(i, j, k) * d3.get(i, j, k))); -// } +// } // // Test("SparseTensor update 3", Compare(s3, d3), true); // } @@ -495,26 +509,26 @@ // S3 s3(ub), ref(ub); // // Test("SparseTensor clear 1", s3, ref); -// -// ITER_3(ub[0], ub[1], ub[2]) +// +// ITER_3(ub[0], ub[1], ub[2]) // s3.set(I3(i, j, k), (Real)(i*ub[1]*ub[2]+j*ub[2]+k)); -// +// // s3.clear(); // Test("SparseTensor clear 2", s3, ref); // // s3.clear(); // Test("SparseTensor clear 3", s3, ref); // -// ITER_3(ub[0], ub[1], ub[2]) +// ITER_3(ub[0], ub[1], ub[2]) // s3.set(I3(i, j, k), (Real)(i*ub[1]*ub[2]+j*ub[2]+k)); // } // // { // setAll // I4 ub4(5, 4, 3, 2); // S4 s4(ub4); -// +// // Test("SparseTensor setAll 1", s4.isZero(), true); -// +// // s4.setAll(0); // Test("SparseTensor setAll 2A", s4.isZero(), true); // Test("SparseTensor setAll 2B", s4.getNNonZeros(), (UInt)0); @@ -563,9 +577,9 @@ // // s2A.extract(0, some, s2B); // -// // Extract to some values +// // Extract to some values // some.clear(); -// some.insert(0); some.insert(3); +// some.insert(0); some.insert(3); // // s2A.extract(0, some, s2B); // } @@ -573,7 +587,8 @@ // // //-------------------------------------------------------------------------------- // /** -// * TC1: All zeros tensor should be reduced to all zeros tensor with non empty ind +// * TC1: All zeros tensor should be reduced to all zeros tensor with non +// empty ind // * TC2: All zeros tensor is reduce to null tensor with empty ind // * TC3: Full tensor should be reduced appropriately with non empty ind // * Indices should be updated properly. @@ -592,15 +607,15 @@ // // { // TC1 // ind.insert(UInt(1)); -// s2A.reduce(0, ind); +// s2A.reduce(0, ind); // Test("SparseTensor reduce 1A", s2A.isZero(), true); // // Let's make sure we can call it twice, and stay invariant // Test("SparseTensor reduce 1B", s2A.isZero(), true); // // s2A.reduce(1, ind); // Test("SparseTensor reduce 1C", s2A.isZero(), true); -// } -// +// } +// // { // TC2 // ind.clear(); // s2A.reduce(0, ind); @@ -650,7 +665,7 @@ // s2A.reduce(0, ind); // Test("SparseTensor reduce 11A", s2A.isNull(), true); // } -// +// // } // // //-------------------------------------------------------------------------------- @@ -670,110 +685,110 @@ // Test("SparseTensor isPositive() 1", s2.isPositive(), false); // Test("SparseTensor isNonNegative() 1", s2.isNonNegative(), true); // } -// +// // { // Full (dense) tensor 1 // S2 s2(ub2); s2.setAll(1); -// Test("SparseTensor getNNonZeros() 2", s2.getNNonZeros(), ub2.product()); -// Test("SparseTensor getNZeros() 2", s2.getNZeros(), (UInt)0); -// Test("SparseTensor isZero() 2", s2.isZero(), false); +// Test("SparseTensor getNNonZeros() 2", s2.getNNonZeros(), +// ub2.product()); Test("SparseTensor getNZeros() 2", s2.getNZeros(), +// (UInt)0); Test("SparseTensor isZero() 2", s2.isZero(), false); // Test("SparseTensor isSparse() 2", s2.isSparse(), false); // Test("SparseTensor isDense() 2", s2.isDense(), true); // Test("SparseTensor getFillRate() 2", s2.getFillRate(), (Real)1); // Test("SparseTensor isPositive() 2", s2.isPositive(), true); // Test("SparseTensor isNonNegative() 2", s2.isNonNegative(), true); -// } +// } // -// { // Full (dense) tensor 2 (negative tensor) +// { // Full (dense) tensor 2 (negative tensor) // S2 s2(ub2); s2.setAll(-1); -// Test("SparseTensor getNNonZeros() 3", s2.getNNonZeros(), ub2.product()); -// Test("SparseTensor getNZeros() 3", s2.getNZeros(), (UInt)0); -// Test("SparseTensor isZero() 3", s2.isZero(), false); +// Test("SparseTensor getNNonZeros() 3", s2.getNNonZeros(), +// ub2.product()); Test("SparseTensor getNZeros() 3", s2.getNZeros(), +// (UInt)0); Test("SparseTensor isZero() 3", s2.isZero(), false); // Test("SparseTensor isSparse() 3", s2.isSparse(), false); // Test("SparseTensor isDense() 3", s2.isDense(), true); // Test("SparseTensor getFillRate() 3", s2.getFillRate(), (Real)1); // Test("SparseTensor isPositive() 3", s2.isPositive(), false); // Test("SparseTensor isNonNegative() 3", s2.isNonNegative(), false); -// } +// } // -// { // Full (dense) tensor 2 (mixed positive and negative tensor) -// S2 s2(ub2); +// { // Full (dense) tensor 2 (mixed positive and negative tensor) +// S2 s2(ub2); // ITER_2(ub2[0], ub2[1]) { // I2 i2(i, j); // UInt o = i2.ordinal(ub2); -// if (o % 2 == 0) s2.set(i2, Real(o+1)); else s2.set(i2, Real(-1)); -// } -// Test("SparseTensor getNNonZeros() 4", s2.getNNonZeros(), ub2.product()); -// Test("SparseTensor getNZeros() 4", s2.getNZeros(), (UInt)0); -// Test("SparseTensor isZero() 4", s2.isZero(), false); +// if (o % 2 == 0) s2.set(i2, Real(o+1)); else s2.set(i2, Real(-1)); +// } +// Test("SparseTensor getNNonZeros() 4", s2.getNNonZeros(), +// ub2.product()); Test("SparseTensor getNZeros() 4", s2.getNZeros(), +// (UInt)0); Test("SparseTensor isZero() 4", s2.isZero(), false); // Test("SparseTensor isSparse() 4", s2.isSparse(), false); // Test("SparseTensor isDense() 4", s2.isDense(), true); // Test("SparseTensor getFillRate() 4", s2.getFillRate(), (Real)1); // Test("SparseTensor isPositive() 4", s2.isPositive(), false); // Test("SparseTensor isNonNegative() 4", s2.isNonNegative(), false); -// } +// } // // { // Sparse tensor 1 -// S2 s2(ub2); +// S2 s2(ub2); // UInt nnz = 0; // ITER_2(ub2[0], ub2[1]) { // I2 i2(i, j); // UInt o = i2.ordinal(ub2); // if (o % 2 == 0) { s2.set(i2, Real(o+1)); ++nnz; } -// } +// } // Test("SparseTensor getNNonZeros() 5", s2.getNNonZeros(), nnz); -// Test("SparseTensor getNZeros() 5", s2.getNZeros(), ub2.product() - nnz); -// Test("SparseTensor isZero() 5", s2.isZero(), false); +// Test("SparseTensor getNZeros() 5", s2.getNZeros(), ub2.product() - +// nnz); Test("SparseTensor isZero() 5", s2.isZero(), false); // Test("SparseTensor isSparse() 5", s2.isSparse(), true); // Test("SparseTensor isDense() 5", s2.isDense(), false); // Real fr = (Real)nnz / (Real)ub2.product(); // Test("SparseTensor getFillRate() 5", s2.getFillRate(), fr); // Test("SparseTensor isPositive() 5", s2.isPositive(), false); // Test("SparseTensor isNonNegative() 5", s2.isNonNegative(), true); -// } +// } // // { // Sparse tensor 2 (negative tensor) -// S2 s2(ub2); +// S2 s2(ub2); // UInt nnz = 0; // ITER_2(ub2[0], ub2[1]) { // I2 i2(i, j); // UInt o = i2.ordinal(ub2); // if (o % 2 == 0) { s2.set(i2, -1); ++nnz; } -// } +// } // Test("SparseTensor getNNonZeros() 6", s2.getNNonZeros(), nnz); -// Test("SparseTensor getNZeros() 6", s2.getNZeros(), ub2.product() - nnz); -// Test("SparseTensor isZero() 6", s2.isZero(), false); +// Test("SparseTensor getNZeros() 6", s2.getNZeros(), ub2.product() - +// nnz); Test("SparseTensor isZero() 6", s2.isZero(), false); // Test("SparseTensor isSparse() 6", s2.isSparse(), true); // Test("SparseTensor isDense() 6", s2.isDense(), false); // Real fr = (Real)nnz / (Real)ub2.product(); // Test("SparseTensor getFillRate() 6", s2.getFillRate(), fr); // Test("SparseTensor isPositive() 6", s2.isPositive(), false); // Test("SparseTensor isNonNegative() 6", s2.isNonNegative(), false); -// } +// } // // { // Sparse tensor 3 (mixed positive and negative tensor) -// S2 s2(ub2); +// S2 s2(ub2); // UInt nnz = 0; // ITER_2(ub2[0], ub2[1]) { // I2 i2(i, j); // UInt o = i2.ordinal(ub2); -// if (o % 2 == 0) { +// if (o % 2 == 0) { // if (o % 3 == 0) -// s2.set(i2, -1); +// s2.set(i2, -1); // else // s2.set(i2, Real(o)); -// ++nnz; +// ++nnz; // } -// } +// } // Test("SparseTensor getNNonZeros() 7", s2.getNNonZeros(), nnz); -// Test("SparseTensor getNZeros() 7", s2.getNZeros(), ub2.product() - nnz); -// Test("SparseTensor isZero() 7", s2.isZero(), false); +// Test("SparseTensor getNZeros() 7", s2.getNZeros(), ub2.product() - +// nnz); Test("SparseTensor isZero() 7", s2.isZero(), false); // Test("SparseTensor isSparse() 7", s2.isSparse(), true); // Test("SparseTensor isDense() 7", s2.isDense(), false); // Real fr = (Real)nnz / (Real)ub2.product(); // Test("SparseTensor getFillRate() 7", s2.getFillRate(), fr); // Test("SparseTensor isPositive() 7", s2.isPositive(), false); // Test("SparseTensor isNonNegative() 7", s2.isNonNegative(), false); -// } +// } // } // // { // Domain tests (with full domain) @@ -788,98 +803,98 @@ // Test("SparseTensor isDense(d2) 1", s2.isDense(d2), false); // Test("SparseTensor getFillRate(d2) 1", s2.getFillRate(d2), (Real)0); // } -// +// // { // Full (dense) tensor 1 // S2 s2(ub2); s2.setAll(1); -// Test("SparseTensor getNNonZeros(d2) 2", s2.getNNonZeros(d2), ub2.product()); -// Test("SparseTensor getNZeros(d2) 2", s2.getNZeros(d2), (UInt)0); -// Test("SparseTensor isZero(d2) 2", s2.isZero(d2), false); +// Test("SparseTensor getNNonZeros(d2) 2", s2.getNNonZeros(d2), +// ub2.product()); Test("SparseTensor getNZeros(d2) 2", s2.getNZeros(d2), +// (UInt)0); Test("SparseTensor isZero(d2) 2", s2.isZero(d2), false); // Test("SparseTensor isSparse(d2) 2", s2.isSparse(d2), false); // Test("SparseTensor isDense(d2) 2", s2.isDense(d2), true); // Test("SparseTensor getFillRate(d2) 2", s2.getFillRate(d2), (Real)1); -// } +// } // -// { // Full (dense) tensor 2 (negative tensor) +// { // Full (dense) tensor 2 (negative tensor) // S2 s2(ub2); s2.setAll(-1); -// Test("SparseTensor getNNonZeros(d2) 3", s2.getNNonZeros(d2), ub2.product()); -// Test("SparseTensor getNZeros(d2) 3", s2.getNZeros(d2), (UInt)0); -// Test("SparseTensor isZero(d2) 3", s2.isZero(d2), false); +// Test("SparseTensor getNNonZeros(d2) 3", s2.getNNonZeros(d2), +// ub2.product()); Test("SparseTensor getNZeros(d2) 3", s2.getNZeros(d2), +// (UInt)0); Test("SparseTensor isZero(d2) 3", s2.isZero(d2), false); // Test("SparseTensor isSparse(d2) 3", s2.isSparse(d2), false); // Test("SparseTensor isDense(d2) 3", s2.isDense(d2), true); // Test("SparseTensor getFillRate(d2) 3", s2.getFillRate(d2), (Real)1); -// } +// } // -// { // Full (dense) tensor 2 (mixed positive and negative tensor) -// S2 s2(ub2); +// { // Full (dense) tensor 2 (mixed positive and negative tensor) +// S2 s2(ub2); // ITER_2(ub2[0], ub2[1]) { // I2 i2(i, j); // UInt o = i2.ordinal(ub2); -// if (o % 2 == 0) s2.set(i2, Real(o+1)); else s2.set(i2, Real(-1)); -// } -// Test("SparseTensor getNNonZeros(d2) 4", s2.getNNonZeros(d2), ub2.product()); -// Test("SparseTensor getNZeros(d2) 4", s2.getNZeros(d2), (UInt)0); -// Test("SparseTensor isZero(d2) 4", s2.isZero(d2), false); +// if (o % 2 == 0) s2.set(i2, Real(o+1)); else s2.set(i2, Real(-1)); +// } +// Test("SparseTensor getNNonZeros(d2) 4", s2.getNNonZeros(d2), +// ub2.product()); Test("SparseTensor getNZeros(d2) 4", s2.getNZeros(d2), +// (UInt)0); Test("SparseTensor isZero(d2) 4", s2.isZero(d2), false); // Test("SparseTensor isSparse(d2) 4", s2.isSparse(d2), false); // Test("SparseTensor isDense(d2) 4", s2.isDense(d2), true); // Test("SparseTensor getFillRate(d2) 4", s2.getFillRate(d2), (Real)1); -// } +// } // // { // Sparse tensor 1 -// S2 s2(ub2); +// S2 s2(ub2); // UInt nnz = 0; // ITER_2(ub2[0], ub2[1]) { // I2 i2(i, j); // UInt o = i2.ordinal(ub2); // if (o % 2 == 0) { s2.set(i2, Real(o+1)); ++nnz; } -// } +// } // Test("SparseTensor getNNonZeros(d2) 5", s2.getNNonZeros(d2), nnz); -// Test("SparseTensor getNZeros(d2) 5", s2.getNZeros(d2), ub2.product() - nnz); -// Test("SparseTensor isZero(d2) 5", s2.isZero(d2), false); +// Test("SparseTensor getNZeros(d2) 5", s2.getNZeros(d2), ub2.product() - +// nnz); Test("SparseTensor isZero(d2) 5", s2.isZero(d2), false); // Test("SparseTensor isSparse(d2) 5", s2.isSparse(d2), true); // Test("SparseTensor isDense(d2) 5", s2.isDense(d2), false); // Real fr = (Real)nnz / (Real)d2.size_elts(); // Test("SparseTensor getFillRate(d2) 5", s2.getFillRate(d2), fr); -// } +// } // // { // Sparse tensor 2 (negative tensor) -// S2 s2(ub2); +// S2 s2(ub2); // UInt nnz = 0; // ITER_2(ub2[0], ub2[1]) { // I2 i2(i, j); // UInt o = i2.ordinal(ub2); // if (o % 2 == 0) { s2.set(i2, -1); ++nnz; } -// } +// } // Test("SparseTensor getNNonZeros(d2) 6", s2.getNNonZeros(d2), nnz); -// Test("SparseTensor getNZeros(d2) 6", s2.getNZeros(d2), ub2.product() - nnz); -// Test("SparseTensor isZero(d2) 6", s2.isZero(d2), false); +// Test("SparseTensor getNZeros(d2) 6", s2.getNZeros(d2), ub2.product() - +// nnz); Test("SparseTensor isZero(d2) 6", s2.isZero(d2), false); // Test("SparseTensor isSparse(d2) 6", s2.isSparse(d2), true); // Test("SparseTensor isDense(d2) 6", s2.isDense(d2), false); // Real fr = (Real)nnz / (Real)d2.size_elts(); // Test("SparseTensor getFillRate(d2) 6", s2.getFillRate(d2), fr); -// } +// } // // { // Sparse tensor 3 (mixed positive and negative tensor) -// S2 s2(ub2); +// S2 s2(ub2); // UInt nnz = 0; // ITER_2(ub2[0], ub2[1]) { // I2 i2(i, j); // UInt o = i2.ordinal(ub2); -// if (o % 2 == 0) { +// if (o % 2 == 0) { // if (o % 3 == 0) -// s2.set(i2, -1); +// s2.set(i2, -1); // else // s2.set(i2, Real(o)); -// ++nnz; +// ++nnz; // } -// } +// } // Test("SparseTensor getNNonZeros(d2) 7", s2.getNNonZeros(d2), nnz); -// Test("SparseTensor getNZeros(d2) 7", s2.getNZeros(d2), ub2.product() - nnz); -// Test("SparseTensor isZero(d2) 7", s2.isZero(d2), false); +// Test("SparseTensor getNZeros(d2) 7", s2.getNZeros(d2), ub2.product() - +// nnz); Test("SparseTensor isZero(d2) 7", s2.isZero(d2), false); // Test("SparseTensor isSparse(d2) 7", s2.isSparse(d2), true); // Test("SparseTensor isDense(d2) 7", s2.isDense(d2), false); // Real fr = (Real)nnz / (Real)d2.size_elts(); // Test("SparseTensor getFillRate(d2) 7", s2.getFillRate(d2), fr); -// } +// } // } // // { // Domain tests (with partial domain) @@ -887,106 +902,110 @@ // // { // Empty tensor // S2 s2(ub2); -// Test("SparseTensor getNNonZeros(d2p) 1", s2.getNNonZeros(d2p), (UInt)0); -// Test("SparseTensor getNZeros(d2p) 1", s2.getNZeros(d2p), d2p.size_elts()); -// Test("SparseTensor isZero(d2p) 1", s2.isZero(d2p), true); -// Test("SparseTensor isSparse(d2p) 1", s2.isSparse(d2p), true); +// Test("SparseTensor getNNonZeros(d2p) 1", s2.getNNonZeros(d2p), +// (UInt)0); Test("SparseTensor getNZeros(d2p) 1", s2.getNZeros(d2p), +// d2p.size_elts()); Test("SparseTensor isZero(d2p) 1", s2.isZero(d2p), +// true); Test("SparseTensor isSparse(d2p) 1", s2.isSparse(d2p), true); // Test("SparseTensor isDense(d2p) 1", s2.isDense(d2p), false); // Test("SparseTensor getFillRate(d2p) 1", s2.getFillRate(d2p), (Real)0); // } -// +// // { // Full (dense) tensor 1 // S2 s2(ub2); s2.setAll(1); -// Test("SparseTensor getNNonZeros(d2p) 2", s2.getNNonZeros(d2p), d2p.size_elts()); -// Test("SparseTensor getNZeros(d2p) 2", s2.getNZeros(d2p), (UInt)0); -// Test("SparseTensor isZero(d2p) 2", s2.isZero(d2p), false); -// Test("SparseTensor isSparse(d2p) 2", s2.isSparse(d2p), false); -// Test("SparseTensor isDense(d2p) 2", s2.isDense(d2p), true); -// Test("SparseTensor getFillRate(d2p) 2", s2.getFillRate(d2p), (Real)1); -// } -// -// { // Full (dense) tensor 2 (negative tensor) +// Test("SparseTensor getNNonZeros(d2p) 2", s2.getNNonZeros(d2p), +// d2p.size_elts()); Test("SparseTensor getNZeros(d2p) 2", +// s2.getNZeros(d2p), (UInt)0); Test("SparseTensor isZero(d2p) 2", +// s2.isZero(d2p), false); Test("SparseTensor isSparse(d2p) 2", +// s2.isSparse(d2p), false); Test("SparseTensor isDense(d2p) 2", +// s2.isDense(d2p), true); Test("SparseTensor getFillRate(d2p) 2", +// s2.getFillRate(d2p), (Real)1); +// } +// +// { // Full (dense) tensor 2 (negative tensor) // S2 s2(ub2); s2.setAll(-1); -// Test("SparseTensor getNNonZeros(d2p) 3", s2.getNNonZeros(d2p), d2p.size_elts()); -// Test("SparseTensor getNZeros(d2p) 3", s2.getNZeros(d2p), (UInt)0); -// Test("SparseTensor isZero(d2p) 3", s2.isZero(d2p), false); -// Test("SparseTensor isSparse(d2p) 3", s2.isSparse(d2p), false); -// Test("SparseTensor isDense(d2p) 3", s2.isDense(d2p), true); -// Test("SparseTensor getFillRate(d2p) 3", s2.getFillRate(d2p), (Real)1); -// } -// -// { // Full (dense) tensor 2 (mixed positive and negative tensor) -// S2 s2(ub2); +// Test("SparseTensor getNNonZeros(d2p) 3", s2.getNNonZeros(d2p), +// d2p.size_elts()); Test("SparseTensor getNZeros(d2p) 3", +// s2.getNZeros(d2p), (UInt)0); Test("SparseTensor isZero(d2p) 3", +// s2.isZero(d2p), false); Test("SparseTensor isSparse(d2p) 3", +// s2.isSparse(d2p), false); Test("SparseTensor isDense(d2p) 3", +// s2.isDense(d2p), true); Test("SparseTensor getFillRate(d2p) 3", +// s2.getFillRate(d2p), (Real)1); +// } +// +// { // Full (dense) tensor 2 (mixed positive and negative tensor) +// S2 s2(ub2); // ITER_2(ub2[0], ub2[1]) { // I2 i2(i, j); // UInt o = i2.ordinal(ub2); -// if (o % 2 == 0) s2.set(i2, Real(o+1)); else s2.set(i2, Real(-1)); -// } -// Test("SparseTensor getNNonZeros(d2p) 4", s2.getNNonZeros(d2p), d2p.size_elts()); -// Test("SparseTensor getNZeros(d2p) 4", s2.getNZeros(d2p), (UInt)0); -// Test("SparseTensor isZero(d2p) 4", s2.isZero(d2p), false); -// Test("SparseTensor isSparse(d2p) 4", s2.isSparse(d2p), false); -// Test("SparseTensor isDense(d2p) 4", s2.isDense(d2p), true); -// Test("SparseTensor getFillRate(d2p) 4", s2.getFillRate(d2p), (Real)1); -// } +// if (o % 2 == 0) s2.set(i2, Real(o+1)); else s2.set(i2, Real(-1)); +// } +// Test("SparseTensor getNNonZeros(d2p) 4", s2.getNNonZeros(d2p), +// d2p.size_elts()); Test("SparseTensor getNZeros(d2p) 4", +// s2.getNZeros(d2p), (UInt)0); Test("SparseTensor isZero(d2p) 4", +// s2.isZero(d2p), false); Test("SparseTensor isSparse(d2p) 4", +// s2.isSparse(d2p), false); Test("SparseTensor isDense(d2p) 4", +// s2.isDense(d2p), true); Test("SparseTensor getFillRate(d2p) 4", +// s2.getFillRate(d2p), (Real)1); +// } // // { // Sparse tensor 1 -// S2 s2(ub2); +// S2 s2(ub2); // UInt nnz = 0; // ITER_2(ub2[0], ub2[1]) { // I2 i2(i, j); // UInt o = i2.ordinal(ub2); -// if (o % 2 == 0) { s2.set(i2, Real(o+1)); if (d2p.includes(i2)) ++nnz; } -// } +// if (o % 2 == 0) { s2.set(i2, Real(o+1)); if (d2p.includes(i2)) +// ++nnz; } +// } // Test("SparseTensor getNNonZeros(d2p) 5", s2.getNNonZeros(d2p), nnz); -// Test("SparseTensor getNZeros(d2p) 5", s2.getNZeros(d2p), d2p.size_elts() - nnz); -// Test("SparseTensor isZero(d2p) 5", s2.isZero(d2p), false); -// Test("SparseTensor isSparse(d2p) 5", s2.isSparse(d2p), true); -// Test("SparseTensor isDense(d2p) 5", s2.isDense(d2p), false); -// Real fr = (Real)nnz / (Real)d2p.size_elts(); +// Test("SparseTensor getNZeros(d2p) 5", s2.getNZeros(d2p), +// d2p.size_elts() - nnz); Test("SparseTensor isZero(d2p) 5", +// s2.isZero(d2p), false); Test("SparseTensor isSparse(d2p) 5", +// s2.isSparse(d2p), true); Test("SparseTensor isDense(d2p) 5", +// s2.isDense(d2p), false); Real fr = (Real)nnz / (Real)d2p.size_elts(); // Test("SparseTensor getFillRate(d2p) 5", s2.getFillRate(d2p), fr); -// } +// } // // { // Sparse tensor 2 (negative tensor) -// S2 s2(ub2); +// S2 s2(ub2); // UInt nnz = 0; // ITER_2(ub2[0], ub2[1]) { // I2 i2(i, j); // UInt o = i2.ordinal(ub2); // if (o % 2 == 0) { s2.set(i2, -1); if (d2p.includes(i2)) ++nnz; } -// } +// } // Test("SparseTensor getNNonZeros(d2p) 6", s2.getNNonZeros(d2p), nnz); -// Test("SparseTensor getNZeros(d2p) 6", s2.getNZeros(d2p), d2p.size_elts() - nnz); -// Test("SparseTensor isZero(d2p) 6", s2.isZero(d2p), false); -// Test("SparseTensor isSparse(d2p) 6", s2.isSparse(d2p), true); -// Test("SparseTensor isDense(d2p) 6", s2.isDense(d2p), false); -// Real fr = (Real)nnz / (Real)d2p.size_elts(); +// Test("SparseTensor getNZeros(d2p) 6", s2.getNZeros(d2p), +// d2p.size_elts() - nnz); Test("SparseTensor isZero(d2p) 6", +// s2.isZero(d2p), false); Test("SparseTensor isSparse(d2p) 6", +// s2.isSparse(d2p), true); Test("SparseTensor isDense(d2p) 6", +// s2.isDense(d2p), false); Real fr = (Real)nnz / (Real)d2p.size_elts(); // Test("SparseTensor getFillRate(d2p) 6", s2.getFillRate(d2p), fr); -// } +// } // // { // Sparse tensor 3 (mixed positive and negative tensor) -// S2 s2(ub2); +// S2 s2(ub2); // UInt nnz = 0; // ITER_2(ub2[0], ub2[1]) { // I2 i2(i, j); // UInt o = i2.ordinal(ub2); -// if (o % 2 == 0) { +// if (o % 2 == 0) { // if (o % 3 == 0) -// s2.set(i2, -1); +// s2.set(i2, -1); // else // s2.set(i2, Real(o)); // if (d2p.includes(i2)) -// ++nnz; +// ++nnz; // } -// } +// } // Test("SparseTensor getNNonZeros(d2p) 7", s2.getNNonZeros(d2p), nnz); -// Test("SparseTensor getNZeros(d2p) 7", s2.getNZeros(d2p), d2p.size_elts() - nnz); -// Test("SparseTensor isZero(d2p) 7", s2.isZero(d2p), false); -// Test("SparseTensor isSparse(d2p) 7", s2.isSparse(d2p), true); -// Test("SparseTensor isDense(d2p) 7", s2.isDense(d2p), false); -// Real fr = (Real)nnz / (Real)d2p.size_elts(); +// Test("SparseTensor getNZeros(d2p) 7", s2.getNZeros(d2p), +// d2p.size_elts() - nnz); Test("SparseTensor isZero(d2p) 7", +// s2.isZero(d2p), false); Test("SparseTensor isSparse(d2p) 7", +// s2.isSparse(d2p), true); Test("SparseTensor isDense(d2p) 7", +// s2.isDense(d2p), false); Real fr = (Real)nnz / (Real)d2p.size_elts(); // Test("SparseTensor getFillRate(d2p) 7", s2.getFillRate(d2p), fr); -// } +// } // } // // { // Sparse tensor - domain contains only zeros @@ -1015,7 +1034,7 @@ // if (j > ub2[1]/2) { // s2.set(i2, Real(o+1)); // ++nnz; -// } +// } // } // Domain d2f(I2(0, ub2[1]/2+1), I2(ub2[0], ub2[1])); // Test("SparseTensor getNNonZeros(d2f)", s2.getNNonZeros(d2f), nnz); @@ -1027,7 +1046,7 @@ // } // // { // getNNonZeros per sub-space, sparse tensor -// S2 s2(ub2); +// S2 s2(ub2); // S1 s1ref1nz(ub2[0]), s1ref2nz(ub2[1]); // S1 s1ref1z(ub2[0]), s1ref2z(ub2[1]); // ITER_2(ub2[0], ub2[1]) { @@ -1040,11 +1059,11 @@ // } else { // s1ref1z.update(I1(UInt(i)), Real(1), std::plus()); // s1ref2z.update(I1(UInt(j)), Real(1), std::plus()); -// } +// } // } // // // Non-zeros/zeros per row -// S1 s11nz(ub2[0]), s11z(ub2[0]), fr1(ub2[0]), fr1ref(ub2[0]); +// S1 s11nz(ub2[0]), s11z(ub2[0]), fr1(ub2[0]), fr1ref(ub2[0]); // s2.getNNonZeros(I1(UInt(1)), s11nz); // Test("SparseTensor sub-space getNNonZeros 1", s11nz == s1ref1nz, true); // s2.getNZeros(I1(UInt(1)), s11z); @@ -1055,7 +1074,7 @@ // Test("SparseTensor sub-space getFillRate 1", fr1 == fr1ref, true); // // // Non-zeros/zeros per column -// S1 s12nz(ub2[1]), s12z(ub2[1]), fr2(ub2[1]), fr2ref(ub2[1]); +// S1 s12nz(ub2[1]), s12z(ub2[1]), fr2(ub2[1]), fr2ref(ub2[1]); // s2.getNNonZeros(I1(UInt(0)), s12nz); // Test("SparseTensor sub-space getNZeros 2", s12nz == s1ref2nz, true); // s2.getNZeros(I1(UInt(0)), s12z); @@ -1077,10 +1096,10 @@ // S3 s32(5, 3, 3); // Test("SparseTensor isSymmetric 2", s32.isSymmetric(I3(0, 2, 1)), true); // Test("SparseTensor isSymmetric 3", s32.isSymmetric(I3(1, 2, 0)), false); -// +// // s32.set(I3(0, 0, 1), (Real).5); // Test("SparseTensor isSymmetric 4", s32.isSymmetric(I3(0, 2, 1)), false); -// +// // s32.set(I3(0, 1, 0), (Real).5); // Test("SparseTensor isSymmetric 5", s32.isSymmetric(I3(0, 2, 1)), true); // @@ -1096,17 +1115,21 @@ // // { // isAntiSymmetric // S3 s31(5, 4, 3); -// Test("SparseTensor isAntiSymmetric 1", s31.isAntiSymmetric(I3(0, 2, 1)), false); +// Test("SparseTensor isAntiSymmetric 1", s31.isAntiSymmetric(I3(0, 2, 1)), +// false); // // S3 s32(5, 3, 3); -// Test("SparseTensor isAntiSymmetric 2", s32.isAntiSymmetric(I3(0, 2, 1)), true); -// Test("SparseTensor isAntiSymmetric 3", s32.isAntiSymmetric(I3(1, 2, 0)), false); +// Test("SparseTensor isAntiSymmetric 2", s32.isAntiSymmetric(I3(0, 2, 1)), +// true); Test("SparseTensor isAntiSymmetric 3", s32.isAntiSymmetric(I3(1, +// 2, 0)), false); // // s32.set(I3(0, 0, 1), (Real).5); -// Test("SparseTensor isAntiSymmetric 4", s32.isAntiSymmetric(I3(0, 2, 1)), false); -// +// Test("SparseTensor isAntiSymmetric 4", s32.isAntiSymmetric(I3(0, 2, 1)), +// false); +// // s32.set(I3(0, 1, 0), (Real)-.5); -// Test("SparseTensor isAntiSymmetric 5", s32.isAntiSymmetric(I3(0, 2, 1)), true); +// Test("SparseTensor isAntiSymmetric 5", s32.isAntiSymmetric(I3(0, 2, 1)), +// true); // // S2 s21(5, 5); // for (UInt i = 0; i < 5; ++i) @@ -1115,12 +1138,14 @@ // s21.set(I2(j, i), - s21(i, j)); // } // -// Test("SparseTensor isAntiSymmetric 6", s21.isAntiSymmetric(I2(1, 0)), false); +// Test("SparseTensor isAntiSymmetric 6", s21.isAntiSymmetric(I2(1, 0)), +// false); // // for (UInt i = 0; i < 5; ++i) // s21.set(I2(i, i), (Real) 0); // -// Test("SparseTensor isAntiSymmetric 7", s21.isAntiSymmetric(I2(1, 0)), true); +// Test("SparseTensor isAntiSymmetric 7", s21.isAntiSymmetric(I2(1, 0)), +// true); // } // } // @@ -1128,7 +1153,7 @@ // void SparseTensorUnitTest::unitTestToFromDense() // { // I4 ub(5, 4, 3, 2); -// +// // { // S4 s4(ub); // D4 d4(ub); @@ -1136,8 +1161,8 @@ // ITER_4(ub[0], ub[1], ub[2], ub[3]) { // I4 idx(i, j, k, l); // s4.setNonZero(idx, Real(idx.ordinal(ub)+1)); -// } -// +// } +// // Real array[5*4*3*2]; // s4.toDense(array); // d4.clear(); @@ -1145,15 +1170,15 @@ // Test("SparseTensor toDense 2", Compare(s4, d4), true); // } // -// { +// { // S4 s4(ub); // D4 d4(ub); -// +// // ITER_4(ub[0], ub[1], ub[2], ub[3]) { // I4 idx(i, j, k, l); // d4.set(idx, Real(idx.ordinal(ub)+1)); // } -// +// // Real array[5*4*3*2]; // d4.toDense(array); // s4.clear(); @@ -1167,55 +1192,55 @@ // { // D2 d2(5, 4); // S2 s2(5, 4); -// +// // ITER_2(5, 4) { // d2.set(I2(i, j), Real(i*4+j+1)); // s2.set(I2(i, j), Real(i*4+j+1)); // } -// +// // d2.permute(I2(1, 0)); // s2.permute(I2(1, 0)); -// +// // Test("SparseTensor permute 1", Compare(s2, d2), true); // // D3 d3(5, 4, 3); // S3 s3(5, 4, 3); -// +// // ITER_3(5, 4, 3) { // d3.set(I3(i, j, k), (i*12+j*3+k) % 2 == 0 ? Real(i*12*j*3+k) : Real(0)); // s3.set(I3(i, j, k), d3.get(i, j, k)); // } -// +// // d3.permute(I3(1, 0, 2)); // s3.permute(I3(1, 0, 2)); -// +// // Test("SparseTensor permute 2", Compare(s3, d3), true); // // d3.permute(I3(2, 0, 1)); // s3.permute(I3(2, 0, 1)); -// +// // Test("SparseTensor permute 3", Compare(s3, d3), true); // // d3.permute(I3(1, 2, 0)); // s3.permute(I3(1, 2, 0)); -// +// // Test("SparseTensor permute 4", Compare(s3, d3), true); // // d3.permute(I3(2, 1, 0)); // s3.permute(I3(2, 1, 0)); -// +// // Test("SparseTensor permute 5", Compare(s3, d3), true); // // d3.permute(I3(0, 2, 1)); // s3.permute(I3(0, 2, 1)); -// +// // Test("SparseTensor permute 6", Compare(s3, d3), true); // } // // //-------------------------------------------------------------------------------- // void SparseTensorUnitTest::unitTestResize() // { -// { +// { // D2 d2(3, 4); S2 s2(3, 4); // ITER_2(3, 4) { // d2.set(I2(i, j), Real(i*4+j)); @@ -1254,12 +1279,12 @@ // { // D2 d2(3, 4), d2r(3, 4); // S2 s2(3, 4), s2r(3, 4); -// +// // ITER_2(3, 4) { // d2.set(I2(i, j), Real(i*4+j)); // s2.set(I2(i, j), d2(i, j)); // } -// +// // d2.reshape(d2r); // s2.reshape(s2r); // Test("SparseTensor reshape 0", Compare(s2r, d2r), true); @@ -1268,12 +1293,12 @@ // { // D2 d2(3, 4), d2r(2, 6); // S2 s2(3, 4), s2r(2, 6); -// +// // ITER_2(3, 4) { // d2.set(I2(i, j), Real(i*4+j)); // s2.set(I2(i, j), d2(i, j)); // } -// +// // d2.reshape(d2r); // s2.reshape(s2r); // Test("SparseTensor reshape 1", Compare(s2r, d2r), true); @@ -1282,12 +1307,12 @@ // { // D2 d2(3, 4); S2 s2(3, 4); // D3 d3r(2, 2, 3); S3 s3r(2, 2, 3); -// +// // ITER_2(3, 4) { // d2.set(I2(i, j), Real(i*4+j)); // s2.set(I2(i, j), d2(i, j)); // } -// +// // d2.reshape(d3r); // s2.reshape(s3r); // Test("SparseTensor reshape 2", Compare(s3r, d3r), true); @@ -1296,12 +1321,12 @@ // { // D3 d3(2, 2, 3); S3 s3(2, 2, 3); // D2 d2r(3, 4); S2 s2r(3, 4); -// +// // ITER_3(2, 2, 3) { // d3.set(I3(i, j, k), Real(i*6+j*3+k)); // s3.set(I3(i, j, k), d3(i, j, k)); // } -// +// // d3.reshape(d2r); // s3.reshape(s2r); // Test("SparseTensor reshape 3", Compare(s2r, d2r), true); @@ -1314,18 +1339,18 @@ // { // All possible slicings in 3D // I3 ub(5, 4, 3); // S3 s3(ub); -// +// // ITER_3(ub[0], ub[1], ub[2]) { // I3 i3(i, j, k); // s3.set(i3, (Real)(i3.ordinal(ub))); // } -// +// // for (UInt m = 1; m <= 2; ++m) { -// +// // { // Extract a vector from the S3 // S1 s1(ub[0]/m), ref(ub[0]/m); // ITER_2(ub[1], ub[2]) { -// for (UInt n = 0; n < ub[0]/m; ++n) +// for (UInt n = 0; n < ub[0]/m; ++n) // ref.set(I1((UInt)n), Real(n*ub[2]*ub[1] + i*ub[2]+j)); // Domain d(I3(0, i, j), I3(ub[0]/m, i, j)); // s3.getSlice(d, s1); @@ -1397,15 +1422,15 @@ // s3.getSlice(d, s32); // Test("SparseTensor getSlice 7", s32, ref); // } -// } -// -// { // Make sure the slice is correctly situated +// } +// +// { // Make sure the slice is correctly situated // I2 ub2(4, 5); // S2 s2a(ub2), s2b(I2(2, 2)); GenerateOrdered(s2a); // Domain d2(I2(2, 2), I2(4, 4)); // s2a.getSlice(d2, s2b); // } -// +// // { // Make sure the slice is correctly situated // I2 ub2(4, 5); // S2 s2a(ub2), s2b(I2(2, 2)); GenerateOrdered(s2b); @@ -1421,13 +1446,12 @@ // // ITER_4(ub4[0], ub4[1], ub4[2], ub4[3]) { // I4 i4(i, j, k, l); -// Real val = i4.ordinal(ub4) % 2 == 0 ? Real(0) : Real(i4.ordinal(ub4)+1); -// d4A.set(i4, val); -// s4A.set(i4, val); +// Real val = i4.ordinal(ub4) % 2 == 0 ? Real(0) : +// Real(i4.ordinal(ub4)+1); d4A.set(i4, val); s4A.set(i4, val); // } // // for (UInt i = 0; i < nreps; ++i) { -// +// // I4 lb, ub; // for (UInt j = 0; j < 4; ++j) { // lb[j] = rng_->getUInt32(ub4[j]); @@ -1439,8 +1463,8 @@ // Domain d(lb, ub); // // switch (d.getNOpenDims()) { -// -// case 4: +// +// case 4: // { // S4 s4B(ub4); D4 d4B(ub4); // // put garbage in the slices @@ -1448,8 +1472,8 @@ // Test("SparseTensor getSlice 8-1", Compare(s4B, d4B), true); // break; // } -// -// case 3: +// +// case 3: // { // UInt M = ub.max(); // I3 ub3(M, M, M); @@ -1484,19 +1508,19 @@ // } // // // setSlice -// { -// I3 ub3(3, 3, 2); +// { +// I3 ub3(3, 3, 2); // S3 s3A(ub3); // // // Setting 1D slices of zeros on empty // S1 empty(ub3[2]); -// for (UInt i = 0; i < ub3[0]; ++i) +// for (UInt i = 0; i < ub3[0]; ++i) // s3A.setSlice(Domain(I3(i,i,0), I3(i,i,ub3[2])), empty); // Test("SparseTensor setSlice 1", s3A.isZero(), true); // // // Setting 1D slices of zeros on non-empty -// GenerateRandRand01(rng_, s3A); -// for (UInt i = 0; i < ub3[0]; ++i) +// GenerateRandRand01(rng_, s3A); +// for (UInt i = 0; i < ub3[0]; ++i) // for (UInt j = 0; j < ub3[1]; ++j) // s3A.setSlice(Domain(I3(i,j,0), I3(i,j,ub3[2])), empty); // Test("SparseTensor setSlice 2", s3A.isZero(), true); @@ -1504,13 +1528,13 @@ // // Setting 2D slices of zeros on empty // s3A.setAll(0); // S2 empty2(ub3[1], ub3[2]); -// for (UInt i = 0; i < ub3[0]; ++i) +// for (UInt i = 0; i < ub3[0]; ++i) // s3A.setSlice(Domain(I3(i,0,0), I3(i,ub3[1],ub3[2])), empty2); // Test("SparseTensor setSlice 3", s3A.isZero(), true); // // // Setting 2D slices of zeros on non-empty // GenerateRandRand01(rng_, s3A); -// for (UInt i = 0; i < ub3[0]; ++i) +// for (UInt i = 0; i < ub3[0]; ++i) // s3A.setSlice(Domain(I3(i,0,0), I3(i,ub3[1],ub3[2])), empty2); // Test("SparseTensor setSlice 4", s3A.isZero(), true); // @@ -1524,10 +1548,10 @@ // GenerateRandRand01(rng_, s3A); // s3A.setSlice(Domain(I3(0,0,0), ub3), empty3); // Test("SparseTensor setSlice 6", s3A.isZero(), true); -// +// // // Setting 1D slices of non-zeros // S1 s1(ub3[2]); s1.setAll(1); -// for (UInt i = 0; i < ub3[0]; ++i) +// for (UInt i = 0; i < ub3[0]; ++i) // s3A.setSlice(Domain(I3(i,i,0), I3(i,i,ub3[2])), s1); // I3 idx; // do { @@ -1536,7 +1560,7 @@ // else // Test("SparseTensor setSlice 8", s3A.isZero(idx), true); // } while (increment(ub3, idx)); -// +// // s3A.setAll(0); setToZero(idx); // // // Setting 2D slices of non-zeros @@ -1578,10 +1602,10 @@ // void SparseTensorUnitTest::unitTestElementApply() // { // I3 ub(3, 4, 2); -// +// // S3 s3A(ub), s3B(ub), s3C(ub); // D3 d3A(ub), d3B(ub), d3C(ub); -// +// // ITER_3(ub[0], ub[1], ub[2]) { // Real v = Real(i*ub[1]*ub[2]+j*ub[2]+k); // if (I3(i, j, k).ordinal(ub) % 2 == 0) { @@ -1591,7 +1615,7 @@ // s3B.set(I3(i, j, k), v+1); // d3B.set(I3(i, j, k), v+1); // } -// } +// } // // { // Test with functor that introduces new zeros // // (to exercise deletion in map/set and iterator invalidation) @@ -1614,7 +1638,7 @@ // Test("SparseTensor unary element_apply 4A", Compare(s3A, d3A), true); // // /* -// try { +// try { // s3A.element_apply_fast(Plus3()); // Test("SparseTensor unary element_apply 4B", 0, 1); // } catch (std::runtime_error& e) { @@ -1673,7 +1697,7 @@ // d3A.element_apply(d3B, d3A, nupic::Multiplies()); // s3A.element_apply(s3B, s3A, nupic::Multiplies()); // Test("SparseTensor element_apply 4", Compare(s3A, d3A), true); -// +// // s3A.element_apply(s3B, s3C, BinaryPlus3()); // d3A.element_apply(d3B, d3C, BinaryPlus3()); // Test("SparseTensor element_apply 5", Compare(s3A, d3A), true); @@ -1681,13 +1705,13 @@ // // { // With lots of zeros // I2 ub(3, 4); -// +// // S2 s2(ub), sc2(ub); // D2 d2(ub), dc2(ub); // // s2.set(I2(1, 1), (Real)1); // d2.set(I2(1, 1), (Real)1); -// +// // S2 s2B(ub); s2B.set(I2(2, 2), (Real)1); // D2 d2B(ub); d2B.set(I2(2, 2), (Real)1); // @@ -1696,7 +1720,7 @@ // Test("SparseTensor element_apply 5A1", Compare(sc2, dc2), true); // Test("SparseTensor element_apply 5A2", sc2.isZero(), true); // Test("SparseTensor element_apply 5A3", dc2.isZero(), true); -// +// // s2.element_apply_nz(s2B, sc2, nupic::Multiplies()); // Test("SparseTensor element_apply 6A1", Compare(sc2, dc2), true); // Test("SparseTensor element_apply 6A2", sc2.isZero(), true); @@ -1725,7 +1749,7 @@ // // S2 s2(ub), sc2(ub); // D2 d2(ub), dc2(ub); -// +// // ITER_2(ub[0], ub[1]) { // if (I2(i, j).ordinal(ub) % 2 == 0) { // d2.set(I2(i, j), Real(i*ub[1]+j)); @@ -1734,10 +1758,10 @@ // } // // for (UInt n = 0; n < 2; ++n) { -// -// S1 s1(ub[n]); -// D1 d1(ub[n]); -// +// +// S1 s1(ub[n]); +// D1 d1(ub[n]); +// // ITER_1(ub[n]) { // if (i % 2 == 0) { // s1.set(I1((UInt)i), Real(i+2)); @@ -1756,7 +1780,7 @@ // d2.factor_apply(I1(n), d1, dc2, nupic::Multiplies()); // s2.factor_apply(I1(n), s1, sc2, nupic::Multiplies()); // Test("SparseTensor factor_apply 1C", Compare(sc2, dc2), true); -// +// // d2.factor_apply(I1(n), d1, d2, nupic::Multiplies()); // s2.factor_apply(I1(n), s1, s2, nupic::Multiplies()); // Test("SparseTensor factor_apply 2A", Compare(s2, d2), true); @@ -1777,9 +1801,9 @@ // { // s2.clear(); // d2.clear(); -// s2.set(1, 1, (Real)1); +// s2.set(1, 1, (Real)1); // d2.set(1, 1, (Real)1); -// +// // S1 s1(ub[1]); s1.set(2, (Real)1); // D1 d1(ub[1]); d1.set(2, (Real)1); // @@ -1811,28 +1835,28 @@ // I3 ub(5, 4, 3); // D3 d3(ub), dc3(ub); // S3 s3(ub), sc3(ub); -// -// ITER_3(ub[0], ub[1], ub[2]) { -// I3 i3(i, j, k); +// +// ITER_3(ub[0], ub[1], ub[2]) { +// I3 i3(i, j, k); // if (i3.ordinal(ub) % 2 == 0) { -// d3.set(i3, Real(i3.ordinal(ub))); -// s3.set(i3, Real(i3.ordinal(ub))); +// d3.set(i3, Real(i3.ordinal(ub))); +// s3.set(i3, Real(i3.ordinal(ub))); // } // } // // // 3X1 // for (UInt n = 0; n < 3; ++n) { // -// D1 d1(ub[n]); +// D1 d1(ub[n]); // S1 s1(ub[n]); -// +// // ITER_1(ub[n]) { // if (n % 2 == 0) { // d1.set(I1((UInt)i), Real(i+2)); // s1.set(I1((UInt)i), Real(i+2)); // } // } -// +// // d3.factor_apply(I1((UInt)n), d1, dc3, nupic::Multiplies()); // s3.factor_apply_fast(I1((UInt)n), s1, sc3, nupic::Multiplies()); // Test("SparseTensor factor_apply 10A", Compare(sc3, dc3), true); @@ -1862,18 +1886,18 @@ // Test("SparseTensor factor_apply 13A", Compare(s3, d3), true); // } // -// { // 3X2 +// { // 3X2 // I2 ub2(ub[1], ub[2]); -// D2 d2(ub2); +// D2 d2(ub2); // S2 s2(ub2); -// +// // ITER_2(ub[1], ub[2]) { // if (I2(i, j).ordinal(ub2) % 2 == 0) { // d2.set(I2(i, j), Real(i*ub[2]+j+2)); // s2.set(I2(i, j), Real(i*ub[2]+j+2)); // } // } -// +// // d3.factor_apply(I2(1, 2), d2, dc3, nupic::Multiplies()); // s3.factor_apply_fast(I2(1, 2), s2, sc3, nupic::Multiplies()); // Test("SparseTensor factor_apply 14A", Compare(sc3, dc3), true); @@ -1881,7 +1905,7 @@ // d3.factor_apply(I2(1, 2), d2, dc3, nupic::Multiplies()); // s3.factor_apply_nz(I2(1, 2), s2, sc3, nupic::Multiplies()); // Test("SparseTensor factor_apply 14B", Compare(sc3, dc3), true); -// +// // d3.factor_apply(I2(1, 2), d2, dc3, nupic::Multiplies()); // s3.factor_apply(I2(1, 2), s2, sc3, nupic::Multiplies()); // Test("SparseTensor factor_apply 14C", Compare(sc3, dc3), true); @@ -1903,18 +1927,18 @@ // Test("SparseTensor factor_apply 17A", Compare(s3, d3), true); // } // -// { // 3X2 +// { // 3X2 // I2 ub2(ub[0], ub[2]); -// D2 d2(ub2); +// D2 d2(ub2); // S2 s2(ub2); -// +// // ITER_2(ub[0], ub[2]) { // if (I2(i, j).ordinal(ub2) % 2 == 0) { // d2.set(I2(i, j), Real(i*ub[2]+j+2)); // s2.set(I2(i, j), Real(i*ub[2]+j+2)); // } // } -// +// // d3.factor_apply(I2(0, 2), d2, dc3, nupic::Multiplies()); // s3.factor_apply(I2(0, 2), s2, sc3, nupic::Multiplies()); // Test("SparseTensor factor_apply 18", Compare(sc3, dc3), true); @@ -1932,18 +1956,18 @@ // Test("SparseTensor factor_apply 21", Compare(s3, d3), true); // } // -// { // 3X2 +// { // 3X2 // I2 ub2(ub[0], ub[1]); -// D2 d2(ub2); +// D2 d2(ub2); // S2 s2(ub2); -// +// // ITER_2(ub[0], ub[1]) { // if (I2(i, j).ordinal(ub2) % 2 == 0) { // d2.set(I2(i, j), Real(i*ub[2]+j+2)); // s2.set(I2(i, j), Real(i*ub[2]+j+2)); // } // } -// +// // d3.factor_apply(I2(0, 1), d2, dc3, nupic::Multiplies()); // s3.factor_apply(I2(0, 1), s2, sc3, nupic::Multiplies()); // Test("SparseTensor factor_apply 22", Compare(sc3, dc3), true); @@ -1962,25 +1986,26 @@ // } // // { // 3X3 -// D3 d32(ub); +// D3 d32(ub); // S3 s32(ub); // -// ITER_3(ub[0], ub[1], ub[2]) { -// I3 i3(i, j, k); +// ITER_3(ub[0], ub[1], ub[2]) { +// I3 i3(i, j, k); // if (i3.ordinal(ub) % 2 == 0) { -// d32.set(i3, Real(i3.ordinal(ub)+1)); +// d32.set(i3, Real(i3.ordinal(ub)+1)); // s32.set(i3, Real(i3.ordinal(ub)+1)); // } // } // // d3.factor_apply(I3(0, 1, 2), d32, dc3, nupic::Multiplies()); -// s3.factor_apply_fast(I3(0, 1, 2), s32, sc3, nupic::Multiplies()); -// Test("SparseTensor factor_apply 26A", Compare(sc3, dc3), true); +// s3.factor_apply_fast(I3(0, 1, 2), s32, sc3, +// nupic::Multiplies()); Test("SparseTensor factor_apply 26A", +// Compare(sc3, dc3), true); // // d3.factor_apply(I3(0, 1, 2), d32, dc3, nupic::Multiplies()); // s3.factor_apply_nz(I3(0, 1, 2), s32, sc3, nupic::Multiplies()); // Test("SparseTensor factor_apply 26B", Compare(sc3, dc3), true); -// +// // d3.factor_apply(I3(0, 1, 2), d32, dc3, nupic::Multiplies()); // s3.factor_apply(I3(0, 1, 2), s32, sc3, nupic::Multiplies()); // Test("SparseTensor factor_apply 26C", Compare(sc3, dc3), true); @@ -2001,8 +2026,8 @@ // s3.factor_apply(I3(0, 1, 2), s32, s3, std::plus()); // Test("SparseTensor factor_apply 29A", Compare(s3, d3), true); // } -// } -// +// } +// // { // Random multiplications // for (UInt m = 0; m < 10; ++m) { // @@ -2020,7 +2045,7 @@ // } // // I3 ub3(ub4[1], ub4[2], ub4[3]); -// +// // S3 s3B(ub3); D3 d3B(ub3); // ITER_3(ub3[0], ub3[1], ub3[2]) { // I3 i3(i, j, k); @@ -2051,7 +2076,7 @@ // for (UInt m = 0; m < 10; ++m) { // // I2 ub2; -// for (UInt j = 0; j < 2; ++j) +// for (UInt j = 0; j < 2; ++j) // ub2[j] = 1 + (rng_->getUInt32(5)); // // S2 s2A(ub2); D2 d2A(ub2); @@ -2065,26 +2090,27 @@ // } // // I1 ub1(ub2[1]); -// +// // S1 s1B(ub1); D1 d1B(ub1); // ITER_1(ub1[0]) { // if (i % 2 == 1) { // d1B.set(I1((UInt)i), rng_->getReal64()); // s1B.set(I1((UInt)i), d1B.get(i)); // } -// } +// } // // I1 dim1(1); // d2A.factor_apply(dim1, d1B, nupic::Multiplies()); // s2A.factor_apply_fast(dim1, s1B, nupic::Multiplies()); -// Test("SparseTensor factor_apply (in place) 31A", Compare(s2A, d2A), true); +// Test("SparseTensor factor_apply (in place) 31A", Compare(s2A, d2A), +// true); // } // // // Dims 4 and 2 // for (UInt m = 0; m < 10; ++m) { // // I4 ub4; -// for (UInt j = 0; j < 4; ++j) +// for (UInt j = 0; j < 4; ++j) // ub4[j] = 1 + (rng_->getUInt32(5)); // // S4 s4A(ub4); D4 d4A(ub4); @@ -2095,10 +2121,10 @@ // d4A.set(i4, rng_->getReal64()); // s4A.set(i4, d4A.get(i4)); // } -// } +// } // // I2 ub2(ub4[1], ub4[2]); -// +// // S2 s2B(ub2); D2 d2B(ub2); // ITER_2(ub2[0], ub2[1]) { // I2 i2(i, j); @@ -2112,7 +2138,8 @@ // I2 dim2(1, 2); // d4A.factor_apply(dim2, d2B, nupic::Multiplies()); // s4A.factor_apply_fast(dim2, s2B, nupic::Multiplies()); -// Test("SparseTensor factor_apply (in place) 31A", Compare(s4A, d4A), true); +// Test("SparseTensor factor_apply (in place) 31A", Compare(s4A, d4A), +// true); // } // } // @@ -2125,7 +2152,7 @@ // GenerateRandRand01(rng_, C); // noise in C // Aref = A; // B.setAll(Real(1)); -// +// // I2 dims(0, 1); // A.factor_apply_fast(dims, B, C, nupic::Multiplies()); // Test("SparseTensor factor_apply_fast K11", C == Aref, true); @@ -2141,7 +2168,7 @@ // S3 A(ub3), Aref(ub3), C(ub3); // I2 ub2(ub3[1], ub3[2]); // S2 B(ub2); -// +// // for (UInt i = 0; i < 20; ++i) { // GenerateRandRand01(rng_, A); // Aref = A; @@ -2168,7 +2195,7 @@ // S2 s2(ub); ITER_2(ub[0], ub[1]) s2.set(I2(i, j), Real(i*4+j+1)); // // for (UInt n = 0; n < 2; ++n) { -// D1 d1(ub[n]); S1 s1(ub[n]); +// D1 d1(ub[n]); S1 s1(ub[n]); // d2.accumulate(I1((UInt)(1-n)), d1, std::plus()); // s2.accumulate(I1((UInt)(1-n)), s1, std::plus()); // Test("SparseTensor accumulate 1", Compare(s1, d1), true); @@ -2178,16 +2205,16 @@ // { // I3 ub3(3, 4, 5); // -// D3 d3(ub3); -// S3 s3(ub3); -// +// D3 d3(ub3); +// S3 s3(ub3); +// // ITER_3(ub3[0], ub3[1], ub3[2]) { // if (I3(i, j, k).ordinal(ub3) % 2 == 0) { // s3.set(I3(i, j, k), Real(i*4*5+j*5+k+1)); // d3.set(I3(i, j, k), Real(i*4*5+j*5+k+1)); // } // } -// +// // { // D2 d2(4, 5); S2 s2(4, 5); // d3.accumulate(I1((UInt)0), d2, std::plus()); @@ -2208,7 +2235,7 @@ // s3.accumulate(I1((UInt)2), s2, std::plus()); // Test("SparseTensor accumulate 4", Compare(s2, d2), true); // } -// +// // { // D1 d1(3); S1 s1(3); // d3.accumulate(I2(1, 2), d1, std::plus()); @@ -2240,7 +2267,7 @@ // d2.accumulate(I1((UInt)1), d1, nupic::Max()); // s2.accumulate(I1((UInt)1), s1, nupic::Max()); // Test("SparseTensor max 1", Compare(s1, d1), true); -// } +// } // // { // D1 d1(4); S1 s1(4); @@ -2252,7 +2279,7 @@ // // { // Multiplication // I2 ub2(7, 5); -// D2 d2(ub2); S2 s2(ub2); +// D2 d2(ub2); S2 s2(ub2); // // ITER_2(ub2[0], ub2[1]) { // if (I2(i, j).ordinal(ub2) % 2 == 0) { @@ -2260,7 +2287,7 @@ // s2.set(i, j, (Real)(i*4+j+1)); // } // } -// +// // D1 d1(ub2[0]); S1 s1(ub2[0]); // // d2.accumulate_nz(I1((UInt)1), d1, nupic::Multiplies(), 1); @@ -2280,21 +2307,21 @@ // I2 ub2; // for (UInt i = 0; i < ub2.size(); ++i) // ub2[i] = 1 + (rng_->getUInt32(5)); -// +// // S2 s2A(ub2); D2 d2A(ub2); // ITER_2(ub2[0], ub2[1]) { // I2 i2(i, j); // UInt o = i2.ordinal(ub2); // if (o % 2 == 0) { -// s2A.set(i2, (Real)(o+1)); //Real(rng_->get() % 32768 / 32768.0)); -// d2A.set(i2, s2A.get(i2)); +// s2A.set(i2, (Real)(o+1)); //Real(rng_->get() % 32768 / +// 32768.0)); d2A.set(i2, s2A.get(i2)); // } // } -// +// // I1 dims1((UInt)0), ub1(ub2[dims1[0]]); // I1 compDims; dims1.complement(compDims); // S1 s1C(ub1); D1 d1C(ub1); -// +// // s2A.accumulate_nz(compDims, s1C, nupic::Multiplies(), 1); // d2A.accumulate_nz(compDims, d1C, nupic::Multiplies(), 1); // Test("SparseTensor accumulate 10A", Compare(s1C, d1C), true); @@ -2316,11 +2343,11 @@ // d4A.set(i4, s4A.get(i4)); // } // } -// +// // I2 dims2(0, 1), ub2(ub4[dims2[0]], ub4[dims2[1]]); // I2 compDims; dims2.complement(compDims); // S2 s2C(ub2); D2 d2C(ub2); -// +// // s4A.accumulate_nz(compDims, s2C, std::plus(), 0); // d4A.accumulate(compDims, d2C, std::plus(), 0); // Test("SparseTensor accumulate 11A", Compare(s2C, d2C), true); @@ -2337,7 +2364,7 @@ // s2C.clear(); // s4A.accumulate_nz(compDims, s2C, nupic::Multiplies(), 1); // Test("SparseTensor accumulate 11D", Compare(s2C, d2C), true); -// +// // s2C.clear(); // s4A.accumulate_nz(compDims, s2C, nupic::Max()); // d4A.accumulate(compDims, d2C, nupic::Max()); @@ -2350,8 +2377,8 @@ // } // } // -// { // accumulate and boost::lambda: check that it compiles and -// // returns appropriate result for lerp +// { // accumulate and boost::lambda: check that it compiles and +// // returns appropriate result for lerp // I1 ub1(5); // S1 s1A(ub1), s1B(ub1); GenerateOrdered(s1A); GenerateOrdered(s1B); // s1A.element_apply(1 - _1); @@ -2369,21 +2396,21 @@ // S1 s1A(ub1A), s1B(ub1B); // // D2 d2(ub1A[0], ub1B[0]); S2 s2(d2.getBounds()); -// +// // ITER_1(ub1A[0]) { // if (i % 4 == 0) { // d1A.set(I1(UInt(i)), (Real)(i+1)); // s1A.set(I1(UInt(i)), d1A(i)); // } // } -// +// // ITER_1(ub1B[0]) { // if (i % 2 == 0) { // d1B.set(I1(UInt(i)), (Real)(i+1)); // s1B.set(I1(UInt(i)), d1B(i)); // } // } -// +// // //1X1 // d1A.outer_product(d1B, d2, nupic::Multiplies()); // s1A.outer_product_nz(s1B, s2, nupic::Multiplies()); @@ -2405,7 +2432,7 @@ // Test("SparseTensor outer_product 2B", Compare(s3, d3), true); // // // 2X2 -// D4 d4(ub1A[0], ub1B[0], ub1A[0], ub1B[0]); +// D4 d4(ub1A[0], ub1B[0], ub1A[0], ub1B[0]); // S4 s4(ub1A[0], ub1B[0], ub1A[0], ub1B[0]); // // d2.outer_product(d2, d4, std::plus()); @@ -2418,13 +2445,13 @@ // { // I3 ub3(4, 3, 3); // -// D3 d3(ub3); S3 s3(ub3); -// +// D3 d3(ub3); S3 s3(ub3); +// // ITER_3(ub3[0], ub3[1], ub3[2]) { // if (I3(i, j, k).ordinal(ub3) % 2 == 0) { // d3.set(I3(i, j, k), (Real)(I3(i, j, k).ordinal(ub3)+1)); // s3.set(I3(i, j, k), d3(i, j, k)); -// } +// } // } // // { @@ -2441,36 +2468,38 @@ // // { // D1 d1(ub3[0]); S1 s1(ub3[0]); -// +// // d3.contract(1, 2, d1, std::plus()); // s3.contract(1, 2, s1, std::plus()); // Test("SparseTensor contract 1", Compare(s1, d1), true); -// } +// } // } // // //-------------------------------------------------------------------------------- // void SparseTensorUnitTest::unitTestInnerProduct() // { -// D2 d2A(3, 4), d2B(4, 3), d2C(3, 3); +// D2 d2A(3, 4), d2B(4, 3), d2C(3, 3); // S2 s2A(3, 4), s2B(4, 3), s2C(3, 3); -// +// // ITER_2(3, 4) { // if (I2(i, j).ordinal(I2(3, 4)) % 2 == 0) { // d2A(i, j) = d2B(j, i) = Real(i*4+j+1); // s2A.set(I2(i, j), d2A(i, j)); -// s2B.set(I2(j, i), d2A(i, j)); -// } +// s2B.set(I2(j, i), d2A(i, j)); +// } // } -// -// d2A.inner_product(1, 0, d2B, d2C, nupic::Multiplies(), std::plus(), 0); -// s2A.inner_product_nz(1, 0, s2B, s2C, nupic::Multiplies(), std::plus(), 0); -// Test("SparseTensor inner product 1A", Compare(s2C, d2C), true); // -// d2A.inner_product(1, 0, d2B, d2C, nupic::Multiplies(), std::plus(), 0); -// s2A.inner_product(1, 0, s2B, s2C, nupic::Multiplies(), std::plus(), 0); -// Test("SparseTensor inner product 1B", Compare(s2C, d2C), true); +// d2A.inner_product(1, 0, d2B, d2C, nupic::Multiplies(), +// std::plus(), 0); s2A.inner_product_nz(1, 0, s2B, s2C, +// nupic::Multiplies(), std::plus(), 0); Test("SparseTensor inner +// product 1A", Compare(s2C, d2C), true); +// +// d2A.inner_product(1, 0, d2B, d2C, nupic::Multiplies(), +// std::plus(), 0); s2A.inner_product(1, 0, s2B, s2C, +// nupic::Multiplies(), std::plus(), 0); Test("SparseTensor inner +// product 1B", Compare(s2C, d2C), true); // -// S4 o(3, 4, 4, 3); +// S4 o(3, 4, 4, 3); // s2A.outer_product(s2B, o, nupic::Multiplies()); // // S2 s2D(3, 3); @@ -2478,12 +2507,13 @@ // Test("SparseTensor inner product 2", s2C, s2D); // // S3 s3A(3, 4, 5), s3B(3, 3, 5); -// ITER_3(3, 4, 5) s3A.set(I3(i, j, k), (Real)(I3(i, j, k).ordinal(I3(3, 4, 5)) + 1)); -// s2A.inner_product(1, 1, s3A, s3B, nupic::Multiplies(), std::plus()); +// ITER_3(3, 4, 5) s3A.set(I3(i, j, k), (Real)(I3(i, j, k).ordinal(I3(3, 4, +// 5)) + 1)); s2A.inner_product(1, 1, s3A, s3B, nupic::Multiplies(), +// std::plus()); // // S5 o2(3, 4, 3, 4, 5); // s2A.outer_product(s3A, o2, nupic::Multiplies()); -// +// // S3 s3D(3, 3, 5); // o2.contract(1, 3, s3D, std::plus()); // Test("SparseTensor inner product 3", s3B, s3D); @@ -2529,7 +2559,7 @@ // // Test("SparseTensor nz_intersection 3C", Compare(inter1, inter2), true); // } -// +// // { // 1 zero/1 zero // S3 s3A(ub3), s3B(ub3); // ITER_3(ub3[0], ub3[1], ub3[2]) { @@ -2554,9 +2584,9 @@ // UInt n2 = n % 2 == 0 ? 0 : n; // s3A.set(I3(i, j, k), (Real) n1); // s3B.set(I3(i, j, k), (Real) n2); -// } +// } // std::vector inter; -// s3A.nz_intersection(s3B, inter); +// s3A.nz_intersection(s3B, inter); // Test("SparseTensor nz_intersection 5", inter.empty(), true); // } // @@ -2570,18 +2600,20 @@ // } // std::vector inter1, inter2; // s3A.nz_intersection(s3B, inter1); -// Test("SparseTensor nz_intersection 6A", inter1.size(), (ub3.product()-1)/2); +// Test("SparseTensor nz_intersection 6A", inter1.size(), +// (ub3.product()-1)/2); // // s3B.nz_intersection(s3A, inter2); -// Test("SparseTensor nz_intersection 6B", inter2.size(), (ub3.product()-1)/2); +// Test("SparseTensor nz_intersection 6B", inter2.size(), +// (ub3.product()-1)/2); // // Test("SparseTensor nz_intersection 6C", Compare(inter1, inter2), true); // } // // { // 1 out of 4, matching -// S3 s3A(ub3), s3B(ub3); +// S3 s3A(ub3), s3B(ub3); // ITER_3(ub3[0], ub3[1], ub3[2]) { -// UInt n = I3(i, j, k).ordinal(ub3); +// UInt n = I3(i, j, k).ordinal(ub3); // UInt nA = n % 2 == 0 ? n : 0; // UInt nB = n % 4 == 0 ? n : 0; // s3A.set(I3(i, j, k), (Real) nA); @@ -2590,27 +2622,29 @@ // // std::vector inter1, inter2; // s3A.nz_intersection(s3B, inter1); -// Test("SparseTensor nz_intersection 7A", inter1.size(), (ub3.product()-1)/4); -// +// Test("SparseTensor nz_intersection 7A", inter1.size(), +// (ub3.product()-1)/4); +// // s3B.nz_intersection(s3A, inter2); -// Test("SparseTensor nz_intersection 7B", inter2.size(), (ub3.product()-1)/4); +// Test("SparseTensor nz_intersection 7B", inter2.size(), +// (ub3.product()-1)/4); // // Test("SparseTensor nz_intersection 7C", Compare(inter1, inter2), true); // } // // // projections -// I2 ub2(2, 5); +// I2 ub2(2, 5); // // { // Intersection between a non-empty S2 and a non-empty S1 // S2 s2A(ub2); S1 s1B(ub2[1]); // s1B.set(1, (Real)1); // s1B.set(3, (Real)2); -// ITER_2(ub2[0], ub2[1]) s2A.set(I2(i, j), (Real)(I2(i, j).ordinal(ub2)+1)); -// +// ITER_2(ub2[0], ub2[1]) s2A.set(I2(i, j), (Real)(I2(i, +// j).ordinal(ub2)+1)); +// // S2::NonZeros inter1, ans; -// inter1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to see if we clean up -// I2 i2; I1 i1; -// do { +// inter1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to +// see if we clean up I2 i2; I1 i1; do { // i2.project(I1(1), i1); // if (!nearlyZero(s2A.get(i2)) && !nearlyZero(s1B.get(i1))) // ans.push_back(S2::Elt(i2, s2A.get(i2), i1, s1B.get(i1))); @@ -2620,10 +2654,12 @@ // // Test("SparseTensor nz_intersection 8A", inter1.size(), ans.size()); // for (UInt i = 0; i < ans.size(); ++i) { -// Test("SparseTensor nz_intersection 8B", inter1[i].getIndexA(), ans[i].getIndexA()); -// Test("SparseTensor nz_intersection 8C", inter1[i].getIndexB(), ans[i].getIndexB()); -// Test("SparseTensor nz_intersection 8D", inter1[i].getValA(), ans[i].getValA()); -// Test("SparseTensor nz_intersection 8E", inter1[i].getValB(), ans[i].getValB()); +// Test("SparseTensor nz_intersection 8B", inter1[i].getIndexA(), +// ans[i].getIndexA()); Test("SparseTensor nz_intersection 8C", +// inter1[i].getIndexB(), ans[i].getIndexB()); Test("SparseTensor +// nz_intersection 8D", inter1[i].getValA(), ans[i].getValA()); +// Test("SparseTensor nz_intersection 8E", inter1[i].getValB(), +// ans[i].getValB()); // } // } // @@ -2633,7 +2669,8 @@ // s1B.set(3, (Real)2); // // S2::NonZeros inter1; -// inter1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to see if we clean up +// inter1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to +// see if we clean up // // s2A.nz_intersection(I1(1), s1B, inter1); // @@ -2642,10 +2679,12 @@ // // { // Intersection between an empty S1 and a non-empty S2 // S2 s2A(ub2); S1 s1B(ub2[1]); -// ITER_2(ub2[0], ub2[1]) s2A.set(I2(i, j), (Real)(I2(i, j).ordinal(ub2)+1)); +// ITER_2(ub2[0], ub2[1]) s2A.set(I2(i, j), (Real)(I2(i, +// j).ordinal(ub2)+1)); // // S2::NonZeros inter1; -// inter1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to see if we clean up +// inter1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to +// see if we clean up // // s2A.nz_intersection(I1(1), s1B, inter1); // @@ -2656,7 +2695,8 @@ // S2 s2A(ub2); S1 s1B(ub2[1]); // // S2::NonZeros inter1; -// inter1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to see if we clean up +// inter1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to +// see if we clean up // // s2A.nz_intersection(I1(1), s1B, inter1); // @@ -2666,12 +2706,12 @@ // { // Intersection between a full S2 and a full S1 // S2 s2A(ub2); S1 s1B(ub2[1]); // ITER_1(ub2[1]) s1B.set(I1(i), (Real)i+1); -// ITER_2(ub2[0], ub2[1]) s2A.set(I2(i, j), (Real)(I2(i, j).ordinal(ub2)+1)); -// +// ITER_2(ub2[0], ub2[1]) s2A.set(I2(i, j), (Real)(I2(i, +// j).ordinal(ub2)+1)); +// // S2::NonZeros inter1, ans; -// inter1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to see if we clean up -// I2 i2; I1 i1; -// do { +// inter1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to +// see if we clean up I2 i2; I1 i1; do { // i2.project(I1(1), i1); // if (!nearlyZero(s2A.get(i2)) && !nearlyZero(s1B.get(i1))) // ans.push_back(S2::Elt(i2, s2A.get(i2), i1, s1B.get(i1))); @@ -2681,10 +2721,12 @@ // // Test("SparseTensor nz_intersection 12A", inter1.size(), ans.size()); // for (UInt i = 0; i < ans.size(); ++i) { -// Test("SparseTensor nz_intersection 12B", inter1[i].getIndexA(), ans[i].getIndexA()); -// Test("SparseTensor nz_intersection 12C", inter1[i].getIndexB(), ans[i].getIndexB()); -// Test("SparseTensor nz_intersection 12D", inter1[i].getValA(), ans[i].getValA()); -// Test("SparseTensor nz_intersection 12E", inter1[i].getValB(), ans[i].getValB()); +// Test("SparseTensor nz_intersection 12B", inter1[i].getIndexA(), +// ans[i].getIndexA()); Test("SparseTensor nz_intersection 12C", +// inter1[i].getIndexB(), ans[i].getIndexB()); Test("SparseTensor +// nz_intersection 12D", inter1[i].getValA(), ans[i].getValA()); +// Test("SparseTensor nz_intersection 12E", inter1[i].getValB(), +// ans[i].getValB()); // } // } // } @@ -2705,7 +2747,7 @@ // S3 s3A(ub3), s3B(ub3); // ITER_3(ub3[0], ub3[1], ub3[2]) // s3A.set(I3(i, j, k), (Real) I3(i, j, k).ordinal(ub3)+1); -// +// // std::vector u1, u2; // s3A.nz_union(s3B, u1); // Test("SparseTensor nz_union 2A", u1.size(), ub3.product()); @@ -2722,7 +2764,7 @@ // s3A.set(I3(i, j, k), (Real) I3(i, j, k).ordinal(ub3)+1); // s3B.set(I3(i, j, k), (Real) I3(i, j, k).ordinal(ub3)+1); // } -// +// // std::vector u1, u2; // s3A.nz_union(s3B, u1); // Test("SparseTensor nz_union 3A", u1.size(), ub3.product()); @@ -2732,14 +2774,14 @@ // // Test("SparseTensor nz_union 3C", Compare(u1, u2), true); // } -// +// // { // 1 zero/1 zero // S3 s3A(ub3), s3B(ub3); // ITER_3(ub3[0], ub3[1], ub3[2]) { // s3A.set(I3(i, j, k), (Real) I3(i, j, k).ordinal(ub3)); // s3B.set(I3(i, j, k), (Real) I3(i, j, k).ordinal(ub3)); // } -// +// // std::vector u1, u2; // s3A.nz_union(s3B, u1); // Test("SparseTensor nz_union 4A", u1.size(), ub3.product()-1); @@ -2759,7 +2801,7 @@ // s3A.set(I3(i, j, k), (Real) n1); // s3B.set(I3(i, j, k), (Real) n2); // } -// +// // std::vector u1, u2; // s3A.nz_union(s3B, u1); // Test("SparseTensor nz_union 5A", u1.size(), ub3.product()-1); @@ -2802,26 +2844,26 @@ // std::vector u1, u2; // s3A.nz_union(s3B, u1); // Test("SparseTensor nz_union 7A", u1.size(), (ub3.product()-1)/2); -// +// // s3B.nz_union(s3A, u2); // Test("SparseTensor nz_union 7B", u2.size(), (ub3.product()-1)/2); // // Test("SparseTensor nz_union 7C", Compare(u1, u2), true); -// } +// } // // // projections -// I2 ub2(2, 5); +// I2 ub2(2, 5); // // { // Union between a non-empty S2 and a non-empty S1 // S2 s2A(ub2); S1 s1B(ub2[1]); // s1B.set(1, (Real)1); // s1B.set(3, (Real)2); -// ITER_2(ub2[0], ub2[1]) s2A.set(I2(i, j), (Real)(I2(i, j).ordinal(ub2)+1)); -// +// ITER_2(ub2[0], ub2[1]) s2A.set(I2(i, j), (Real)(I2(i, +// j).ordinal(ub2)+1)); +// // S2::NonZeros u1, ans; -// u1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to see if we clean up -// I2 i2; I1 i1; -// do { +// u1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to see if +// we clean up I2 i2; I1 i1; do { // i2.project(I1(1), i1); // if (!nearlyZero(s2A.get(i2)) || !nearlyZero(s1B.get(i1))) // ans.push_back(S2::Elt(i2, s2A.get(i2), i1, s1B.get(i1))); @@ -2831,22 +2873,22 @@ // // Test("SparseTensor nz_union 8A", u1.size(), ans.size()); // for (UInt i = 0; i < ans.size(); ++i) { -// Test("SparseTensor nz_union 8B", u1[i].getIndexA(), ans[i].getIndexA()); -// Test("SparseTensor nz_union 8C", u1[i].getIndexB(), ans[i].getIndexB()); -// Test("SparseTensor nz_union 8D", u1[i].getValA(), ans[i].getValA()); -// Test("SparseTensor nz_union 8E", u1[i].getValB(), ans[i].getValB()); +// Test("SparseTensor nz_union 8B", u1[i].getIndexA(), +// ans[i].getIndexA()); Test("SparseTensor nz_union 8C", +// u1[i].getIndexB(), ans[i].getIndexB()); Test("SparseTensor nz_union +// 8D", u1[i].getValA(), ans[i].getValA()); Test("SparseTensor nz_union +// 8E", u1[i].getValB(), ans[i].getValB()); // } -// } +// } // // { // Union between an empty S2 and a non-empty S1 // S2 s2A(ub2); S1 s1B(ub2[1]); -// s1B.set(1, (Real)1); +// s1B.set(1, (Real)1); // s1B.set(3, (Real)2); // // S2::NonZeros u1, ans; -// u1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to see if we clean up -// I2 i2; I1 i1; -// do { +// u1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to see if +// we clean up I2 i2; I1 i1; do { // i2.project(I1(1), i1); // if (!nearlyZero(s2A.get(i2)) || !nearlyZero(s1B.get(i1))) // ans.push_back(S2::Elt(i2, s2A.get(i2), i1, s1B.get(i1))); @@ -2854,23 +2896,24 @@ // // s2A.nz_union(I1(1), s1B, u1); // -// Test("SparseTensor nz_union 9A", u1.size(), ub2[0] * s1B.getNNonZeros()); -// for (UInt i = 0; i < ans.size(); ++i) { -// Test("SparseTensor nz_union 9B", u1[i].getIndexA(), ans[i].getIndexA()); -// Test("SparseTensor nz_union 9C", u1[i].getIndexB(), ans[i].getIndexB()); -// Test("SparseTensor nz_union 9D", u1[i].getValA(), ans[i].getValA()); -// Test("SparseTensor nz_union 9E", u1[i].getValB(), ans[i].getValB()); +// Test("SparseTensor nz_union 9A", u1.size(), ub2[0] * +// s1B.getNNonZeros()); for (UInt i = 0; i < ans.size(); ++i) { +// Test("SparseTensor nz_union 9B", u1[i].getIndexA(), +// ans[i].getIndexA()); Test("SparseTensor nz_union 9C", +// u1[i].getIndexB(), ans[i].getIndexB()); Test("SparseTensor nz_union +// 9D", u1[i].getValA(), ans[i].getValA()); Test("SparseTensor nz_union +// 9E", u1[i].getValB(), ans[i].getValB()); // } // } // // { // Union between an empty S1 and a non-empty S2 // S2 s2A(ub2); S1 s1B(ub2[1]); -// ITER_2(ub2[0], ub2[1]) s2A.set(I2(i, j), (Real)(I2(i, j).ordinal(ub2)+1)); +// ITER_2(ub2[0], ub2[1]) s2A.set(I2(i, j), (Real)(I2(i, +// j).ordinal(ub2)+1)); // // S2::NonZeros u1, ans; -// u1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to see if we clean up -// I2 i2; I1 i1; -// do { +// u1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to see if +// we clean up I2 i2; I1 i1; do { // i2.project(I1(1), i1); // if (!nearlyZero(s2A.get(i2)) || !nearlyZero(s1B.get(i1))) // ans.push_back(S2::Elt(i2, s2A.get(i2), i1, s1B.get(i1))); @@ -2880,10 +2923,11 @@ // // Test("SparseTensor nz_union 10A", u1.size(), s2A.getNNonZeros()); // for (UInt i = 0; i < ans.size(); ++i) { -// Test("SparseTensor nz_union 10B", u1[i].getIndexA(), ans[i].getIndexA()); -// Test("SparseTensor nz_union 10C", u1[i].getIndexB(), ans[i].getIndexB()); -// Test("SparseTensor nz_union 10D", u1[i].getValA(), ans[i].getValA()); -// Test("SparseTensor nz_union 10E", u1[i].getValB(), ans[i].getValB()); +// Test("SparseTensor nz_union 10B", u1[i].getIndexA(), +// ans[i].getIndexA()); Test("SparseTensor nz_union 10C", +// u1[i].getIndexB(), ans[i].getIndexB()); Test("SparseTensor nz_union +// 10D", u1[i].getValA(), ans[i].getValA()); Test("SparseTensor nz_union +// 10E", u1[i].getValB(), ans[i].getValB()); // } // } // @@ -2891,7 +2935,8 @@ // S2 s2A(ub2); S1 s1B(ub2[1]); // // S2::NonZeros u1; -// u1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to see if we clean up +// u1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to see if +// we clean up // // s2A.nz_union(I1(1), s1B, u1); // @@ -2901,27 +2946,28 @@ // { // Union between a full S2 and a full S1 // S2 s2A(ub2); S1 s1B(ub2[1]); // ITER_1(ub2[1]) s1B.set(I1(i), (Real)i+1); -// ITER_2(ub2[0], ub2[1]) s2A.set(I2(i, j), (Real)(I2(i, j).ordinal(ub2)+1)); -// +// ITER_2(ub2[0], ub2[1]) s2A.set(I2(i, j), (Real)(I2(i, +// j).ordinal(ub2)+1)); +// // S2::NonZeros u1, ans; -// u1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to see if we clean up -// I2 i2; I1 i1; -// do { +// u1.push_back(S2::Elt(I2(0, 0), 1, I1(1), 2)); // fake, to see if +// we clean up I2 i2; I1 i1; do { // i2.project(I1(1), i1); // if (!nearlyZero(s2A.get(i2)) || !nearlyZero(s1B.get(i1))) // ans.push_back(S2::Elt(i2, s2A.get(i2), i1, s1B.get(i1))); // } while (i2.increment(ub2)); // -// s2A.nz_union(I1(1), s1B, u1); +// s2A.nz_union(I1(1), s1B, u1); // // Test("SparseTensor nz_union 12A", u1.size(), ans.size()); // for (UInt i = 0; i < ans.size(); ++i) { -// Test("SparseTensor nz_union 12B", u1[i].getIndexA(), ans[i].getIndexA()); -// Test("SparseTensor nz_union 12C", u1[i].getIndexB(), ans[i].getIndexB()); -// Test("SparseTensor nz_union 12D", u1[i].getValA(), ans[i].getValA()); -// Test("SparseTensor nz_union 12E", u1[i].getValB(), ans[i].getValB()); +// Test("SparseTensor nz_union 12B", u1[i].getIndexA(), +// ans[i].getIndexA()); Test("SparseTensor nz_union 12C", +// u1[i].getIndexB(), ans[i].getIndexB()); Test("SparseTensor nz_union +// 12D", u1[i].getValA(), ans[i].getValA()); Test("SparseTensor nz_union +// 12E", u1[i].getValB(), ans[i].getValB()); // } -// } +// } // } // // //-------------------------------------------------------------------------------- @@ -2938,7 +2984,8 @@ // Test("SparseTensor dynamic getNNonZeros 2", st.getNNonZeros(), (UInt)0); // Test("SparseTensor dynamic isZero 3", st.isZero(), true); // Test("SparseTensor dynamic isDense 4", st.isDense(), false); -// Test("SparseTensor dynamic getBounds 5", Compare(st.getBounds(), i3), true); +// Test("SparseTensor dynamic getBounds 5", Compare(st.getBounds(), i3), +// true); // // DST st2(st); // Test("SparseTensor dynamic 6", st == st2, true); @@ -2949,7 +2996,7 @@ // st3.clear(); // Test("SparseTensor dynamic 8", st3 == st, true); // -// std::vector perm(3); +// std::vector perm(3); // perm[0] = 1; perm[1] = 2; perm[2] = 0; // Test("SparseTensor dynamic 9", st.isSymmetric(perm), false); // Test("SparseTensor dynamic 10", st.isAntiSymmetric(perm), false); @@ -2969,26 +3016,27 @@ // Test("SparseTensor dynamic 11E", st.get(idx), o); // Test("SparseTensor dynamic 11F", st.isZero(idx), false); // } -// Test("SparseTensor dynamic 1G", +// Test("SparseTensor dynamic 1G", // st.update(idx, (double)-o, std::plus()), (double)0); // Test("SparseTensor dynamic 11H", st.get(idx), (double)0); // Test("SparseTensor dynamic 11I", st.isZero(idx), true); -// Test("SparseTensor dynamic 11J", +// Test("SparseTensor dynamic 11J", // st.update(idx, (double)o, std::plus()), (double)o); // Test("SparseTensor dynamic 11K", st.get(idx), (double)o); // Test("SparseTensor dynamic 11L", st.isZero(idx), o == 0); // } // } while (increment(st.getBounds(), idx)); -// -// Test("SparseTensor dynamic 12", st.getNNonZeros(), product(st.getBounds())/2-1); -// Test("SparseTensor dynamic 13", st.isZero(), false); -// Test("SparseTensor dynamic 14", st.isDense(), false); -// +// +// Test("SparseTensor dynamic 12", st.getNNonZeros(), +// product(st.getBounds())/2-1); Test("SparseTensor dynamic 13", +// st.isZero(), false); Test("SparseTensor dynamic 14", st.isDense(), +// false); +// // st.setAll(1); -// Test("SparseTensor dynamic 15", st.getNNonZeros(), product(st.getBounds())); -// Test("SparseTensor dynamic 16", st.isZero(), false); -// Test("SparseTensor dynamic 17", st.isDense(), true); -// +// Test("SparseTensor dynamic 15", st.getNNonZeros(), +// product(st.getBounds())); Test("SparseTensor dynamic 16", st.isZero(), +// false); Test("SparseTensor dynamic 17", st.isDense(), true); +// // st.clear(); // Test("SparseTensor dynamic 18", st.getNNonZeros(), (UInt)0); // Test("SparseTensor dynamic 19", st.isZero(), true); @@ -3024,7 +3072,7 @@ // { // I2 ub2(4, 4); // S2 s2(ub2), ref(ub2); -// +// // ITER_2(ub2[0], ub2[1]) { // I2 i2(i, j); // UInt o = i2.ordinal(ub2); @@ -3044,17 +3092,17 @@ // // //-------------------------------------------------------------------------------- // void SparseTensorUnitTest::unitTestNormalize() -// { +// { // I2 ub2(4, 3); // S2 s2(ub2); -// +// // { // Matrix of zeros should "normalize" to zeros // s2.normalize(I1(UInt(0))); // Test("SparseTensor normalize zero 1", s2.isZero(), true); -// +// // s2.normalize(I1(UInt(0))); // Test("SparseTensor normalize zero 2", s2.isZero(), true); -// +// // s2.normalize(); // Test("SparseTensor normalize zero 3", s2.isZero(), true); // } @@ -3066,7 +3114,8 @@ // Test("SparseTensor normalize uniform 1A", vals.size(), 1u); // bool t = nearlyEqual(vals.begin()->first, Real(1./ub2[0])); // Test("SparseTensor normalize uniform 1B", t, true); -// Test("SparseTensor normalize uniform 1C", vals.begin()->second, product(ub2)); +// Test("SparseTensor normalize uniform 1C", vals.begin()->second, +// product(ub2)); // // s2.setAll(1.0); // s2.normalize(I1(UInt(1))); // sum along rows @@ -3074,7 +3123,8 @@ // Test("SparseTensor normalize uniform 2A", vals.size(), 1u); // t = nearlyEqual(vals.begin()->first, Real(1./ub2[1])); // Test("SparseTensor normalize uniform 2B", t, true); -// Test("SparseTensor normalize uniform 2C", vals.begin()->second, product(ub2)); +// Test("SparseTensor normalize uniform 2C", vals.begin()->second, +// product(ub2)); // // s2.setAll(1.0); // s2.normalize(); @@ -3082,14 +3132,15 @@ // Test("SparseTensor normalize uniform 3A", vals.size(), 1u); // t = nearlyEqual(vals.begin()->first, Real(1./(ub2[0]*ub2[1]))); // Test("SparseTensor normalize uniform 3B", t, true); -// Test("SparseTensor normalize uniform 3C", vals.begin()->second, product(ub2)); +// Test("SparseTensor normalize uniform 3C", vals.begin()->second, +// product(ub2)); // } -// +// // { // Matrix with empty rows should not crash // s2.setAll(1.0); // for (UInt i = 0; i < ub2[1]; ++i) // s2.set(I2(0, i), 0); -// +// // s2.normalize(I1(UInt(1))); // } // @@ -3097,7 +3148,7 @@ // s2.setAll(1.0); // for (UInt i = 0; i < ub2[0]; ++i) // s2.set(I2(i, 0), 0); -// +// // s2.normalize(I1(UInt(0))); // } // @@ -3137,13 +3188,13 @@ // s1_max_col.update(I1((UInt)j), Real(o), nupic::Max()); // if (o > M) { M = Real(o); idxmax = i2; } // S += o; -// } -// } +// } +// } // -// S1 s10(ub2[1]); +// S1 s10(ub2[1]); // s2.max(I1((UInt)0), s10); // max of each column, in a vector // Test("SparseTensor max 1", s10 == s1_max_col, true); -// +// // S1 s11(ub2[0]); // s2.max(I1((UInt)1), s11); // max of each row, in a vector // Test("SparseTensor max 2", s11 == s1_max_row, true); @@ -3156,13 +3207,13 @@ // Test("SparseTensor sum 1", S, SS); // } // -// { // specific tensor +// { // specific tensor // S2 s2(I2(2, 1)); // s2.set(I2(0,0), 70); // s2.set(I2(1,0), 10); // Test("SparseTensor max 3", s2.max().first, I2(0,0)); // Test("SparseTensor max 4", s2.max().second, Real(70)); -// } +// } // // { // empty tensor // S2 s2(I2(2, 1)); @@ -3174,12 +3225,12 @@ // S2 s2(I2(2, 1)); // Test("SparseTensor sum 2", s2.sum(), 0); // } -// } +// } // // //-------------------------------------------------------------------------------- // void SparseTensorUnitTest::unitTestAxby() // { -// +// // } // // //-------------------------------------------------------------------------------- @@ -3195,7 +3246,7 @@ // // from the map, which invalidates iterator! // for (Real k = -2; k < 2.25; k += .25) { // { -// GenerateRandRand01(rng_, A); +// GenerateRandRand01(rng_, A); // Cref.clear(); // setToZero(idx); // do { @@ -3204,9 +3255,9 @@ // A.multiply(k); // Test("SparseTensor in place * k", A, Cref); // } -// +// // { -// GenerateRandRand01(rng_, A); +// GenerateRandRand01(rng_, A); // GenerateRandRand01(rng_, C); // noise in C // Cref.clear(); // setToZero(idx); @@ -3225,7 +3276,7 @@ // s1.multiply(Real(1e-12)); // Test("SparseTensor multiply micromegas", s1.sum(), Real(1.0)); // } -// } +// } // // //-------------------------------------------------------------------------------- // void SparseTensorUnitTest::unitTestPerformance() @@ -3233,17 +3284,17 @@ // /* // typedef Index IL4; // typedef SparseTensor SL4; -// +// // IL4 ub(100000, 100000, 100000, 100000); -// SL4 a(ub), b(ub), c(ub); -// UInt nnz = 500000; +// SL4 a(ub), b(ub), c(ub); +// UInt nnz = 500000; // GenerateRand01(rng_, nnz, a); // GenerateRand01(rng_, nnz, b); // -// timer t; -// a.add(b, c); +// timer t; +// a.add(b, c); // */ -// } +// } // // //-------------------------------------------------------------------------------- // void SparseTensorUnitTest::unitTestNumericalStability() @@ -3262,41 +3313,39 @@ // } // } // - //-------------------------------------------------------------------------------- - // void SparseTensorUnitTest::RunTests() - // { - // - //unitTestConstruction(); - //unitTestGetSet(); - //unitTestExtract(); - //unitTestReduce(); - //unitTestNonZeros(); - //unitTestIsSymmetric(); - //unitTestToFromDense(); - //unitTestPermute(); - //unitTestResize(); - //unitTestReshape(); - //unitTestSlice(); - //unitTestElementApply(); - //unitTestFactorApply(); - //unitTestAccumulate(); - //unitTestOuterProduct(); - //unitTestContract(); - //unitTestInnerProduct(); - //unitTestIntersection(); - //unitTestUnion(); - //unitTestDynamicIndex(); - //unitTestToFromStream(); - //unitTestNormalize(); - //unitTestMaxSum(); - //unitTestAxby(); - //unitTestMultiply(); - ////unitTestNumericalStability(); - ////unitTestPerformance(); - // } - - //-------------------------------------------------------------------------------- - -// } // namespace nupic +//-------------------------------------------------------------------------------- +// void SparseTensorUnitTest::RunTests() +// { +// +// unitTestConstruction(); +// unitTestGetSet(); +// unitTestExtract(); +// unitTestReduce(); +// unitTestNonZeros(); +// unitTestIsSymmetric(); +// unitTestToFromDense(); +// unitTestPermute(); +// unitTestResize(); +// unitTestReshape(); +// unitTestSlice(); +// unitTestElementApply(); +// unitTestFactorApply(); +// unitTestAccumulate(); +// unitTestOuterProduct(); +// unitTestContract(); +// unitTestInnerProduct(); +// unitTestIntersection(); +// unitTestUnion(); +// unitTestDynamicIndex(); +// unitTestToFromStream(); +// unitTestNormalize(); +// unitTestMaxSum(); +// unitTestAxby(); +// unitTestMultiply(); +////unitTestNumericalStability(); +////unitTestPerformance(); +// } +//-------------------------------------------------------------------------------- +// } // namespace nupic diff --git a/src/test/unit/math/SparseTensorUnitTest.hpp b/src/test/unit/math/SparseTensorUnitTest.hpp index dc45d7339e..0a47a5363b 100644 --- a/src/test/unit/math/SparseTensorUnitTest.hpp +++ b/src/test/unit/math/SparseTensorUnitTest.hpp @@ -38,7 +38,8 @@ // //-------------------------------------------------------------------------------- // /** // * @b Responsibility -// * A dense multi-dimensional array. It stores all its values, as opposed to +// * A dense multi-dimensional array. It stores all its values, as opposed +// to // * a SparseTensor that stores only the non-zero values. // * // * @b Rationale @@ -130,10 +131,11 @@ // // inline UInt getRank() const { return bounds_.size(); } // inline bool isZero() const { return getNNonZeros() == 0; } -// inline bool isDense() const { return getNNonZeros() == product(bounds_); } -// inline bool isSparse() const { return getNNonZeros() != product(bounds_); } -// inline Index getBounds() const { return bounds_; } -// inline void clear() { memset(vals_, 0, product(bounds_) * sizeof(Float)); } +// inline bool isDense() const { return getNNonZeros() == product(bounds_); +// } inline bool isSparse() const { return getNNonZeros() != +// product(bounds_); } inline Index getBounds() const { return bounds_; } +// inline void clear() { memset(vals_, 0, product(bounds_) * sizeof(Float)); +// } // // inline Index getNewIndex() const // { @@ -298,8 +300,8 @@ // } // // auto buf = new Float[product(bounds_)]; -// Index idx = getNewZeroIndex(), perm = getNewIndex(), newBounds = getNewIndex(); -// nupic::permute(ind, bounds_, newBounds); +// Index idx = getNewZeroIndex(), perm = getNewIndex(), newBounds = +// getNewIndex(); nupic::permute(ind, bounds_, newBounds); // // do { // nupic::permute(ind, idx, perm); @@ -401,7 +403,8 @@ // } // // template -// inline void element_apply(const DenseTensor& B, DenseTensor& C, binary_functor f) +// inline void element_apply(const DenseTensor& B, DenseTensor& C, +// binary_functor f) // { // { // NTA_ASSERT(getBounds() == B.getBounds()) @@ -620,7 +623,8 @@ // << " and dim: " << dim2 // << " but they have different size: " << bounds_[dim1] // << " and " << bounds_[dim2] -// << " - Can take inner product only along dimensions that have the same size"; +// << " - Can take inner product only along dimensions that have the +// same size"; // } // // Index idx1 = getNewZeroIndex(); @@ -651,10 +655,12 @@ // // // for debugging // template -// NTA_HIDDEN friend std::ostream& operator<<(std::ostream&, const DenseTensor&); +// NTA_HIDDEN friend std::ostream& operator<<(std::ostream&, const +// DenseTensor&); // // template -// NTA_HIDDEN friend bool operator==(const DenseTensor&, const DenseTensor&); +// NTA_HIDDEN friend bool operator==(const DenseTensor&, const +// DenseTensor&); // // private: // Index bounds_; @@ -665,7 +671,8 @@ // }; // // template -// inline std::ostream& operator<<(std::ostream& outStream, const DenseTensor& dense) +// inline std::ostream& operator<<(std::ostream& outStream, const +// DenseTensor& dense) // { // typedef typename Idx::value_type UI; // @@ -697,7 +704,8 @@ // } // // template -// inline bool operator==(const DenseTensor& A, const DenseTensor& B) +// inline bool operator==(const DenseTensor& A, const DenseTensor& B) // { // typedef typename Idx::value_type UI; // @@ -714,7 +722,8 @@ // } // // template -// inline bool operator!=(const DenseTensor& A, const DenseTensor& B) +// inline bool operator!=(const DenseTensor& A, const DenseTensor& B) // { // return ! (A == B); // } diff --git a/src/test/unit/math/TopologyTest.cpp b/src/test/unit/math/TopologyTest.cpp index 01db734b00..23aeac5d9e 100644 --- a/src/test/unit/math/TopologyTest.cpp +++ b/src/test/unit/math/TopologyTest.cpp @@ -24,357 +24,355 @@ * Unit tests for Topology.hpp */ -#include #include "gtest/gtest.h" +#include using std::vector; using namespace nupic; using namespace nupic::math::topology; namespace { - TEST(TopologyTest, IndexFromCoordinates) - { - EXPECT_EQ(0, indexFromCoordinates({0}, {100})); - EXPECT_EQ(50, indexFromCoordinates({50}, {100})); - EXPECT_EQ(99, indexFromCoordinates({99}, {100})); - - EXPECT_EQ(0, indexFromCoordinates({0, 0}, {100, 80})); - EXPECT_EQ(10, indexFromCoordinates({0, 10}, {100, 80})); - EXPECT_EQ(80, indexFromCoordinates({1, 0}, {100, 80})); - EXPECT_EQ(90, indexFromCoordinates({1, 10}, {100, 80})); - - EXPECT_EQ(0, indexFromCoordinates({0, 0, 0}, {100, 10, 8})); - EXPECT_EQ(7, indexFromCoordinates({0, 0, 7}, {100, 10, 8})); - EXPECT_EQ(8, indexFromCoordinates({0, 1, 0}, {100, 10, 8})); - EXPECT_EQ(80, indexFromCoordinates({1, 0, 0}, {100, 10, 8})); - EXPECT_EQ(88, indexFromCoordinates({1, 1, 0}, {100, 10, 8})); - EXPECT_EQ(89, indexFromCoordinates({1, 1, 1}, {100, 10, 8})); - } +TEST(TopologyTest, IndexFromCoordinates) { + EXPECT_EQ(0, indexFromCoordinates({0}, {100})); + EXPECT_EQ(50, indexFromCoordinates({50}, {100})); + EXPECT_EQ(99, indexFromCoordinates({99}, {100})); + + EXPECT_EQ(0, indexFromCoordinates({0, 0}, {100, 80})); + EXPECT_EQ(10, indexFromCoordinates({0, 10}, {100, 80})); + EXPECT_EQ(80, indexFromCoordinates({1, 0}, {100, 80})); + EXPECT_EQ(90, indexFromCoordinates({1, 10}, {100, 80})); + + EXPECT_EQ(0, indexFromCoordinates({0, 0, 0}, {100, 10, 8})); + EXPECT_EQ(7, indexFromCoordinates({0, 0, 7}, {100, 10, 8})); + EXPECT_EQ(8, indexFromCoordinates({0, 1, 0}, {100, 10, 8})); + EXPECT_EQ(80, indexFromCoordinates({1, 0, 0}, {100, 10, 8})); + EXPECT_EQ(88, indexFromCoordinates({1, 1, 0}, {100, 10, 8})); + EXPECT_EQ(89, indexFromCoordinates({1, 1, 1}, {100, 10, 8})); +} - TEST(TopologyTest, CoordinatesFromIndex) - { - EXPECT_EQ(vector({0}), coordinatesFromIndex(0, {100})); - EXPECT_EQ(vector({50}), coordinatesFromIndex(50, {100})); - EXPECT_EQ(vector({99}), coordinatesFromIndex(99, {100})); - - EXPECT_EQ(vector({0, 0}), coordinatesFromIndex(0, {100, 80})); - EXPECT_EQ(vector({0, 10}), coordinatesFromIndex(10, {100, 80})); - EXPECT_EQ(vector({1, 0}), coordinatesFromIndex(80, {100, 80})); - EXPECT_EQ(vector({1, 10}), coordinatesFromIndex(90, {100, 80})); - - EXPECT_EQ(vector({0, 0, 0}), coordinatesFromIndex(0, {100, 10, 8})); - EXPECT_EQ(vector({0, 0, 7}), coordinatesFromIndex(7, {100, 10, 8})); - EXPECT_EQ(vector({0, 1, 0}), coordinatesFromIndex(8, {100, 10, 8})); - EXPECT_EQ(vector({1, 0, 0}), coordinatesFromIndex(80, {100, 10, 8})); - EXPECT_EQ(vector({1, 1, 0}), coordinatesFromIndex(88, {100, 10, 8})); - EXPECT_EQ(vector({1, 1, 1}), coordinatesFromIndex(89, {100, 10, 8})); - } +TEST(TopologyTest, CoordinatesFromIndex) { + EXPECT_EQ(vector({0}), coordinatesFromIndex(0, {100})); + EXPECT_EQ(vector({50}), coordinatesFromIndex(50, {100})); + EXPECT_EQ(vector({99}), coordinatesFromIndex(99, {100})); + + EXPECT_EQ(vector({0, 0}), coordinatesFromIndex(0, {100, 80})); + EXPECT_EQ(vector({0, 10}), coordinatesFromIndex(10, {100, 80})); + EXPECT_EQ(vector({1, 0}), coordinatesFromIndex(80, {100, 80})); + EXPECT_EQ(vector({1, 10}), coordinatesFromIndex(90, {100, 80})); + + EXPECT_EQ(vector({0, 0, 0}), coordinatesFromIndex(0, {100, 10, 8})); + EXPECT_EQ(vector({0, 0, 7}), coordinatesFromIndex(7, {100, 10, 8})); + EXPECT_EQ(vector({0, 1, 0}), coordinatesFromIndex(8, {100, 10, 8})); + EXPECT_EQ(vector({1, 0, 0}), coordinatesFromIndex(80, {100, 10, 8})); + EXPECT_EQ(vector({1, 1, 0}), coordinatesFromIndex(88, {100, 10, 8})); + EXPECT_EQ(vector({1, 1, 1}), coordinatesFromIndex(89, {100, 10, 8})); +} + +// ========================================================================== +// NEIGHBORHOOD +// ========================================================================== - // ========================================================================== - // NEIGHBORHOOD - // ========================================================================== - - void expectNeighborhoodIndices( - const vector& centerCoords, - const vector& dimensions, - UInt radius, - const vector& expected) - { - const UInt centerIndex = indexFromCoordinates(centerCoords, dimensions); - - int i = 0; - for (UInt index : Neighborhood(centerIndex, radius, dimensions)) - { - EXPECT_EQ(expected[i], index); - i++; - } - - EXPECT_EQ(expected.size(), i); +void expectNeighborhoodIndices(const vector ¢erCoords, + const vector &dimensions, UInt radius, + const vector &expected) { + const UInt centerIndex = indexFromCoordinates(centerCoords, dimensions); + + int i = 0; + for (UInt index : Neighborhood(centerIndex, radius, dimensions)) { + EXPECT_EQ(expected[i], index); + i++; } - void expectNeighborhoodCoords( - const vector& centerCoords, - const vector& dimensions, - UInt radius, - const vector >& expected) - { - const UInt centerIndex = indexFromCoordinates(centerCoords, dimensions); - - int i = 0; - for (UInt index : Neighborhood(centerIndex, radius, dimensions)) - { - EXPECT_EQ(indexFromCoordinates(expected[i], dimensions), index); - i++; - } - - EXPECT_EQ(expected.size(), i); + EXPECT_EQ(expected.size(), i); +} + +void expectNeighborhoodCoords(const vector ¢erCoords, + const vector &dimensions, UInt radius, + const vector> &expected) { + const UInt centerIndex = indexFromCoordinates(centerCoords, dimensions); + + int i = 0; + for (UInt index : Neighborhood(centerIndex, radius, dimensions)) { + EXPECT_EQ(indexFromCoordinates(expected[i], dimensions), index); + i++; } - TEST(TopologyTest, NeighborhoodOfOrigin1D) - { - expectNeighborhoodIndices( + EXPECT_EQ(expected.size(), i); +} + +TEST(TopologyTest, NeighborhoodOfOrigin1D) { + expectNeighborhoodIndices( /*centerCoords*/ {0}, /*dimensions*/ {100}, /*radius*/ 2, /*expected*/ {0, 1, 2}); - } +} - TEST(TopologyTest, NeighborhoodOfOrigin2D) - { - expectNeighborhoodCoords( +TEST(TopologyTest, NeighborhoodOfOrigin2D) { + expectNeighborhoodCoords( /*centerCoords*/ {0, 0}, /*dimensions*/ {100, 80}, /*radius*/ 2, - /*expected*/ {{0, 0}, {0, 1}, {0, 2}, - {1, 0}, {1, 1}, {1, 2}, - {2, 0}, {2, 1}, {2, 2}}); - } + /*expected*/ + {{0, 0}, {0, 1}, {0, 2}, {1, 0}, {1, 1}, {1, 2}, {2, 0}, {2, 1}, {2, 2}}); +} - TEST(TopologyTest, NeighborhoodOfOrigin3D) - { - expectNeighborhoodCoords( +TEST(TopologyTest, NeighborhoodOfOrigin3D) { + expectNeighborhoodCoords( /*centerCoords*/ {0, 0, 0}, /*dimensions*/ {100, 80, 60}, /*radius*/ 1, - /*expected*/ {{0, 0, 0}, {0, 0, 1}, - {0, 1, 0}, {0, 1, 1}, - {1, 0, 0}, {1, 0, 1}, - {1, 1, 0}, {1, 1, 1}}); - } + /*expected*/ + {{0, 0, 0}, + {0, 0, 1}, + {0, 1, 0}, + {0, 1, 1}, + {1, 0, 0}, + {1, 0, 1}, + {1, 1, 0}, + {1, 1, 1}}); +} - TEST(TopologyTest, NeighborhoodOfMiddle1D) - { - expectNeighborhoodIndices( +TEST(TopologyTest, NeighborhoodOfMiddle1D) { + expectNeighborhoodIndices( /*centerCoords*/ {50}, /*dimensions*/ {100}, /*radius*/ 1, /*expected*/ {49, 50, 51}); - } +} - TEST(TopologyTest, NeighborhoodOfMiddle2D) - { - expectNeighborhoodCoords( +TEST(TopologyTest, NeighborhoodOfMiddle2D) { + expectNeighborhoodCoords( /*centerCoords*/ {50, 50}, /*dimensions*/ {100, 80}, /*radius*/ 1, - /*expected*/ {{49, 49}, {49, 50}, {49, 51}, - {50, 49}, {50, 50}, {50, 51}, - {51, 49}, {51, 50}, {51, 51}}); - } + /*expected*/ + {{49, 49}, + {49, 50}, + {49, 51}, + {50, 49}, + {50, 50}, + {50, 51}, + {51, 49}, + {51, 50}, + {51, 51}}); +} - TEST(TopologyTest, NeighborhoodOfEnd2D) - { - expectNeighborhoodCoords( +TEST(TopologyTest, NeighborhoodOfEnd2D) { + expectNeighborhoodCoords( /*centerCoords*/ {99, 79}, /*dimensions*/ {100, 80}, /*radius*/ 2, - /*expected*/ {{97, 77}, {97, 78}, {97, 79}, - {98, 77}, {98, 78}, {98, 79}, - {99, 77}, {99, 78}, {99, 79}}); - } + /*expected*/ + {{97, 77}, + {97, 78}, + {97, 79}, + {98, 77}, + {98, 78}, + {98, 79}, + {99, 77}, + {99, 78}, + {99, 79}}); +} - TEST(TopologyTest, NeighborhoodWiderThanWorld) - { - expectNeighborhoodCoords( +TEST(TopologyTest, NeighborhoodWiderThanWorld) { + expectNeighborhoodCoords( /*centerCoords*/ {0, 0}, /*dimensions*/ {3, 2}, /*radius*/ 3, - /*expected*/ {{0, 0}, {0, 1}, - {1, 0}, {1, 1}, - {2, 0}, {2, 1}}); - } + /*expected*/ {{0, 0}, {0, 1}, {1, 0}, {1, 1}, {2, 0}, {2, 1}}); +} - TEST(TopologyTest, NeighborhoodRadiusZero) - { - expectNeighborhoodIndices( +TEST(TopologyTest, NeighborhoodRadiusZero) { + expectNeighborhoodIndices( /*centerCoords*/ {0}, /*dimensions*/ {100}, /*radius*/ 0, /*expected*/ {0}); - expectNeighborhoodCoords( + expectNeighborhoodCoords( /*centerCoords*/ {0, 0}, /*dimensions*/ {100, 80}, /*radius*/ 0, /*expected*/ {{0, 0}}); - expectNeighborhoodCoords( + expectNeighborhoodCoords( /*centerCoords*/ {0, 0, 0}, /*dimensions*/ {100, 80, 60}, /*radius*/ 0, /*expected*/ {{0, 0, 0}}); - } +} - TEST(TopologyTest, NeighborhoodDimensionOne) - { - expectNeighborhoodCoords( +TEST(TopologyTest, NeighborhoodDimensionOne) { + expectNeighborhoodCoords( /*centerCoords*/ {5, 0}, /*dimensions*/ {10, 1}, /*radius*/ 1, /*expected*/ {{4, 0}, {5, 0}, {6, 0}}); - expectNeighborhoodCoords( + expectNeighborhoodCoords( /*centerCoords*/ {5, 0, 0}, /*dimensions*/ {10, 1, 1}, /*radius*/ 1, /*expected*/ {{4, 0, 0}, {5, 0, 0}, {6, 0, 0}}); - } +} +// ========================================================================== +// WRAPPING NEIGHBORHOOD +// ========================================================================== - // ========================================================================== - // WRAPPING NEIGHBORHOOD - // ========================================================================== +void expectWrappingNeighborhoodIndices(const vector ¢erCoords, + const vector &dimensions, + UInt radius, + const vector &expected) { + const UInt centerIndex = indexFromCoordinates(centerCoords, dimensions); - void expectWrappingNeighborhoodIndices( - const vector& centerCoords, - const vector& dimensions, - UInt radius, - const vector& expected) - { - const UInt centerIndex = indexFromCoordinates(centerCoords, dimensions); + int i = 0; + for (UInt index : WrappingNeighborhood(centerIndex, radius, dimensions)) { + EXPECT_EQ(expected[i], index); + i++; + } - int i = 0; - for (UInt index : WrappingNeighborhood(centerIndex, radius, dimensions)) - { - EXPECT_EQ(expected[i], index); - i++; - } + EXPECT_EQ(expected.size(), i); +} - EXPECT_EQ(expected.size(), i); - } +void expectWrappingNeighborhoodCoords(const vector ¢erCoords, + const vector &dimensions, + UInt radius, + const vector> &expected) { + const UInt centerIndex = indexFromCoordinates(centerCoords, dimensions); - void expectWrappingNeighborhoodCoords( - const vector& centerCoords, - const vector& dimensions, - UInt radius, - const vector >& expected) - { - const UInt centerIndex = indexFromCoordinates(centerCoords, dimensions); - - int i = 0; - for (UInt index : WrappingNeighborhood(centerIndex, radius, dimensions)) - { - EXPECT_EQ(indexFromCoordinates(expected[i], dimensions), index); - i++; - } - - EXPECT_EQ(expected.size(), i); + int i = 0; + for (UInt index : WrappingNeighborhood(centerIndex, radius, dimensions)) { + EXPECT_EQ(indexFromCoordinates(expected[i], dimensions), index); + i++; } - TEST(TopologyTest, WrappingNeighborhoodOfOrigin1D) - { - expectWrappingNeighborhoodIndices( + EXPECT_EQ(expected.size(), i); +} + +TEST(TopologyTest, WrappingNeighborhoodOfOrigin1D) { + expectWrappingNeighborhoodIndices( /*centerCoords*/ {0}, /*dimensions*/ {100}, /*radius*/ 1, /*expected*/ {99, 0, 1}); - } +} - TEST(TopologyTest, WrappingNeighborhoodOfOrigin2D) - { - expectWrappingNeighborhoodCoords( +TEST(TopologyTest, WrappingNeighborhoodOfOrigin2D) { + expectWrappingNeighborhoodCoords( /*centerCoords*/ {0, 0}, /*dimensions*/ {100, 80}, /*radius*/ 1, - /*expected*/ {{99, 79}, {99, 0}, {99, 1}, - {0, 79}, {0, 0}, {0, 1}, - {1, 79}, {1, 0}, {1, 1}}); - } + /*expected*/ + {{99, 79}, + {99, 0}, + {99, 1}, + {0, 79}, + {0, 0}, + {0, 1}, + {1, 79}, + {1, 0}, + {1, 1}}); +} - TEST(TopologyTest, WrappingNeighborhoodOfOrigin3D) - { - expectWrappingNeighborhoodCoords( +TEST(TopologyTest, WrappingNeighborhoodOfOrigin3D) { + expectWrappingNeighborhoodCoords( /*centerCoords*/ {0, 0, 0}, /*dimensions*/ {100, 80, 60}, /*radius*/ 1, - /*expected*/ {{99, 79, 59}, {99, 79, 0}, {99, 79, 1}, - {99, 0, 59}, {99, 0, 0}, {99, 0, 1}, - {99, 1, 59}, {99, 1, 0}, {99, 1, 1}, - {0, 79, 59}, {0, 79, 0}, {0, 79, 1}, - {0, 0, 59}, {0, 0, 0}, {0, 0, 1}, - {0, 1, 59}, {0, 1, 0}, {0, 1, 1}, - {1, 79, 59}, {1, 79, 0}, {1, 79, 1}, - {1, 0, 59}, {1, 0, 0}, {1, 0, 1}, - {1, 1, 59}, {1, 1, 0}, {1, 1, 1}}); - } + /*expected*/ {{99, 79, 59}, {99, 79, 0}, {99, 79, 1}, {99, 0, 59}, + {99, 0, 0}, {99, 0, 1}, {99, 1, 59}, {99, 1, 0}, + {99, 1, 1}, {0, 79, 59}, {0, 79, 0}, {0, 79, 1}, + {0, 0, 59}, {0, 0, 0}, {0, 0, 1}, {0, 1, 59}, + {0, 1, 0}, {0, 1, 1}, {1, 79, 59}, {1, 79, 0}, + {1, 79, 1}, {1, 0, 59}, {1, 0, 0}, {1, 0, 1}, + {1, 1, 59}, {1, 1, 0}, {1, 1, 1}}); +} - TEST(TopologyTest, WrappingNeighborhoodOfMiddle1D) - { - expectWrappingNeighborhoodIndices( +TEST(TopologyTest, WrappingNeighborhoodOfMiddle1D) { + expectWrappingNeighborhoodIndices( /*centerCoords*/ {50}, /*dimensions*/ {100}, /*radius*/ 1, /*expected*/ {49, 50, 51}); - } +} - TEST(TopologyTest, WrappingNeighborhoodOfMiddle2D) - { - expectWrappingNeighborhoodCoords( +TEST(TopologyTest, WrappingNeighborhoodOfMiddle2D) { + expectWrappingNeighborhoodCoords( /*centerCoords*/ {50, 50}, /*dimensions*/ {100, 80}, /*radius*/ 1, - /*expected*/{{49, 49}, {49, 50}, {49, 51}, - {50, 49}, {50, 50}, {50, 51}, - {51, 49}, {51, 50}, {51, 51}}); - } + /*expected*/ + {{49, 49}, + {49, 50}, + {49, 51}, + {50, 49}, + {50, 50}, + {50, 51}, + {51, 49}, + {51, 50}, + {51, 51}}); +} - TEST(TopologyTest, WrappingNeighborhoodOfEnd2D) - { - expectWrappingNeighborhoodCoords( +TEST(TopologyTest, WrappingNeighborhoodOfEnd2D) { + expectWrappingNeighborhoodCoords( /*centerCoords*/ {99, 79}, /*dimensions*/ {100, 80}, /*radius*/ 1, - /*expected*/{{98, 78}, {98, 79}, {98, 0}, - {99, 78}, {99, 79}, {99, 0}, - {0, 78}, {0, 79}, {0, 0}}); - } + /*expected*/ + {{98, 78}, + {98, 79}, + {98, 0}, + {99, 78}, + {99, 79}, + {99, 0}, + {0, 78}, + {0, 79}, + {0, 0}}); +} - TEST(TopologyTest, WrappingNeighborhoodWiderThanWorld) - { - // The order is weird because it starts walking from {-3, -3} and avoids - // walking the same point twice. - expectWrappingNeighborhoodCoords( +TEST(TopologyTest, WrappingNeighborhoodWiderThanWorld) { + // The order is weird because it starts walking from {-3, -3} and avoids + // walking the same point twice. + expectWrappingNeighborhoodCoords( /*centerCoords*/ {0, 0}, /*dimensions*/ {3, 2}, /*radius*/ 3, - /*expected*/{{0, 1}, {0, 0}, - {1, 1}, {1, 0}, - {2, 1}, {2, 0}}); - } + /*expected*/ {{0, 1}, {0, 0}, {1, 1}, {1, 0}, {2, 1}, {2, 0}}); +} - TEST(TopologyTest, WrappingNeighborhoodRadiusZero) - { - expectWrappingNeighborhoodIndices( +TEST(TopologyTest, WrappingNeighborhoodRadiusZero) { + expectWrappingNeighborhoodIndices( /*centerCoords*/ {0}, /*dimensions*/ {100}, /*radius*/ 0, /*expected*/ {0}); - expectWrappingNeighborhoodCoords( + expectWrappingNeighborhoodCoords( /*centerCoords*/ {0, 0}, /*dimensions*/ {100, 80}, /*radius*/ 0, /*expected*/ {{0, 0}}); - expectWrappingNeighborhoodCoords( + expectWrappingNeighborhoodCoords( /*centerCoords*/ {0, 0, 0}, /*dimensions*/ {100, 80, 60}, /*radius*/ 0, /*expected*/ {{0, 0, 0}}); - } +} - TEST(TopologyTest, WrappingNeighborhoodDimensionOne) - { - expectWrappingNeighborhoodCoords( +TEST(TopologyTest, WrappingNeighborhoodDimensionOne) { + expectWrappingNeighborhoodCoords( /*centerCoords*/ {5, 0}, /*dimensions*/ {10, 1}, /*radius*/ 1, /*expected*/ {{4, 0}, {5, 0}, {6, 0}}); - expectWrappingNeighborhoodCoords( + expectWrappingNeighborhoodCoords( /*centerCoords*/ {5, 0, 0}, /*dimensions*/ {10, 1, 1}, /*radius*/ 1, /*expected*/ {{4, 0, 0}, {5, 0, 0}, {6, 0, 0}}); - } } +} // namespace diff --git a/src/test/unit/ntypes/ArrayTest.cpp b/src/test/unit/ntypes/ArrayTest.cpp index dc3b90e193..52f0033aa5 100644 --- a/src/test/unit/ntypes/ArrayTest.cpp +++ b/src/test/unit/ntypes/ArrayTest.cpp @@ -29,8 +29,8 @@ #include #include -#include #include +#include #include @@ -40,149 +40,124 @@ using namespace nupic; -struct ArrayTestParameters -{ +struct ArrayTestParameters { NTA_BasicType dataType; unsigned int dataTypeSize; - int allocationSize; //We intentionally use an int instead of a size_t for - //these tests. This is so that we can check test usage - //by a naive user who might use an int and accidentally - //pass negative values. + int allocationSize; // We intentionally use an int instead of a size_t for + // these tests. This is so that we can check test usage + // by a naive user who might use an int and accidentally + // pass negative values. std::string dataTypeText; bool testUsesInvalidParameters; - - ArrayTestParameters() : - dataType((NTA_BasicType) -1), - dataTypeSize(0), - allocationSize(0), - dataTypeText(""), - testUsesInvalidParameters(true) {} - + + ArrayTestParameters() + : dataType((NTA_BasicType)-1), dataTypeSize(0), allocationSize(0), + dataTypeText(""), testUsesInvalidParameters(true) {} + ArrayTestParameters(NTA_BasicType dataTypeParam, - unsigned int dataTypeSizeParam, - int allocationSizeParam, + unsigned int dataTypeSizeParam, int allocationSizeParam, std::string dataTypeTextParam, - bool testUsesInvalidParametersParam) : - dataType(dataTypeParam), - dataTypeSize(dataTypeSizeParam), - allocationSize(allocationSizeParam), - dataTypeText(std::move(dataTypeTextParam)), - testUsesInvalidParameters(testUsesInvalidParametersParam) { } + bool testUsesInvalidParametersParam) + : dataType(dataTypeParam), dataTypeSize(dataTypeSizeParam), + allocationSize(allocationSizeParam), + dataTypeText(std::move(dataTypeTextParam)), + testUsesInvalidParameters(testUsesInvalidParametersParam) {} }; +struct ArrayTest : public ::testing::Test { + std::map testCases_; -struct ArrayTest : public ::testing::Test -{ - std::map testCases_; + typedef std::map::iterator TestCaseIterator; - typedef std::map::iterator - TestCaseIterator; - void setupArrayTests(); }; - - - #ifdef NTA_INSTRUMENTED_MEMORY_GUARDED -//If we're running an appropriately instrumented build, then we're going -//to be running test code which intentionally commits access violations to -//verify proper functioning of the class; to do so, we're -//going to utilize the POSIX signal library and throw a C++ exception from -//our custom signal handler. +// If we're running an appropriately instrumented build, then we're going +// to be running test code which intentionally commits access violations to +// verify proper functioning of the class; to do so, we're +// going to utilize the POSIX signal library and throw a C++ exception from +// our custom signal handler. // -//This should be tested on Windows to verify POSIX compliance. If it does -//not work, the Microsoft C++ extensions __try and __catch can be used to -//catch an access violation on Windows. +// This should be tested on Windows to verify POSIX compliance. If it does +// not work, the Microsoft C++ extensions __try and __catch can be used to +// catch an access violation on Windows. #include +class AccessViolationError {}; -class AccessViolationError -{ -}; - -void AccessViolationHandler(int signal) -{ - throw AccessViolationError(); -} +void AccessViolationHandler(int signal) { throw AccessViolationError(); } typedef void (*AccessViolationHandlerPointer)(int); -TEST_F(ArrayTest, testMemoryOperations) -{ - //Temporarily swap out the the segv and bus handlers. +TEST_F(ArrayTest, testMemoryOperations) { + // Temporarily swap out the the segv and bus handlers. AccessViolationHandlerPointer existingSigsegvHandler; AccessViolationHandlerPointer existingSigbusHandler; existingSigsegvHandler = signal(SIGSEGV, AccessViolationHandler); existingSigbusHandler = signal(SIGBUS, AccessViolationHandler); - - //Since we're going to be testing the memory behavior of ArrayBase, we create a - //pointer here (which will be set to the ArrayBase's buffer) while putting the - //ArrayBase itself inside an artificial scope. That way, when the ArrayBase goes out - //of scope and is destructed we can test that ArrayBase doesn't leak the buffer - //memory. We prefer the artificial scope method to a pointer with new/delete - //as it prevents our code from leaking memory under an unhandled error case. + + // Since we're going to be testing the memory behavior of ArrayBase, we create + // a pointer here (which will be set to the ArrayBase's buffer) while putting + // the ArrayBase itself inside an artificial scope. That way, when the + // ArrayBase goes out of scope and is destructed we can test that ArrayBase + // doesn't leak the buffer memory. We prefer the artificial scope method to a + // pointer with new/delete as it prevents our code from leaking memory under + // an unhandled error case. // - //NOTE: For these tests to be consistent, the code must be built using + // NOTE: For these tests to be consistent, the code must be built using // instrumentation which intentionally guards memory and handles // allocations/deallocations immediately (such as a debugging malloc // library). This test will NOT be run unless // NTA_INSTRUMENTED_MEMORY_GUARDED is defined. - - void * ownedBufferLocation; + + void *ownedBufferLocation; { ArrayBase a(NTA_BasicType_Byte); - + a.allocateBuffer(10); ownedBufferLocation = a.getBuffer(); - - //Verify that we can write into the buffer + + // Verify that we can write into the buffer bool wasAbleToWriteToBuffer = true; - try - { - for(unsigned int i = 0; i < 10; i++) - { - ((char *) ownedBufferLocation)[i] = 'A' + i; + try { + for (unsigned int i = 0; i < 10; i++) { + ((char *)ownedBufferLocation)[i] = 'A' + i; } - } - catch(AccessViolationError exception) - { + } catch (AccessViolationError exception) { wasAbleToWriteToBuffer = false; } TEST2("Write to full length of allocated buffer should succeed", wasAbleToWriteToBuffer); - //Verify that we can read from the buffer + // Verify that we can read from the buffer char testRead = '\0'; - testRead = ((char *) ownedBufferLocation)[4]; - ASSERT_TRUE(!wasAbleToReadFromFreedBuffer) << "Read from freed buffer should fail"; + testRead = ((char *)ownedBufferLocation)[4]; + ASSERT_TRUE(!wasAbleToReadFromFreedBuffer) + << "Read from freed buffer should fail"; } bool wasAbleToReadFromFreedBuffer = true; - try - { + try { char testRead = '\0'; - testRead = ((char *) ownedBufferLocation)[4]; - } - catch(AccessViolationError exception) - { + testRead = ((char *)ownedBufferLocation)[4]; + } catch (AccessViolationError exception) { wasAbleToReadFromFreedBuffer = false; } - ASSERT_TRUE(!wasAbleToReadFromFreedBuffer) << "Read from freed buffer should fail"; + ASSERT_TRUE(!wasAbleToReadFromFreedBuffer) + << "Read from freed buffer should fail"; bool wasAbleToWriteToFreedBuffer = true; - try - { - ((char *) ownedBufferLocation)[4] = 'A'; - } - catch(AccessViolationError exception) - { + try { + ((char *)ownedBufferLocation)[4] = 'A'; + } catch (AccessViolationError exception) { wasAbleToWriteToFreedBuffer = false; } - ASSERT_TRUE(!wasAbleToWriteToFreedBuffer) << "Write to freed buffer should fail"; + ASSERT_TRUE(!wasAbleToWriteToFreedBuffer) + << "Write to freed buffer should fail"; signal(SIGSEGV, existingSigsegvHandler); signal(SIGBUS, existingSigbusHandler); @@ -190,275 +165,240 @@ TEST_F(ArrayTest, testMemoryOperations) #endif -TEST_F(ArrayTest, testArrayCreation) -{ +TEST_F(ArrayTest, testArrayCreation) { setupArrayTests(); - + boost::scoped_ptr arrayP; TestCaseIterator testCase; - - for(testCase = testCases_.begin(); testCase != testCases_.end(); testCase++) - { - char *buf = (char *) -1; - - if(testCase->second.testUsesInvalidParameters) - { + + for (testCase = testCases_.begin(); testCase != testCases_.end(); + testCase++) { + char *buf = (char *)-1; + + if (testCase->second.testUsesInvalidParameters) { bool caughtException = false; - try - { + try { arrayP.reset(new ArrayBase(testCase->second.dataType)); - } - catch(nupic::Exception) - { + } catch (nupic::Exception) { caughtException = true; } ASSERT_TRUE(caughtException) - << "Test case: " + - testCase->first + - " - Should throw an exception on trying to create an invalid " - "ArrayBase"; - } - else - { + << "Test case: " + testCase->first + + " - Should throw an exception on trying to create an invalid " + "ArrayBase"; + } else { arrayP.reset(new ArrayBase(testCase->second.dataType)); - buf = (char *) arrayP->getBuffer(); + buf = (char *)arrayP->getBuffer(); ASSERT_EQ(buf, nullptr) - << "Test case: " + - testCase->first + - " - When not passed a size, a newly created ArrayBase should " - "have a NULL buffer"; - ASSERT_EQ((size_t) 0, arrayP->getCount()) - << "Test case: " + - testCase->first + - " - When not passed a size, a newly created ArrayBase should " - "have a count equal to zero"; + << "Test case: " + testCase->first + + " - When not passed a size, a newly created ArrayBase should " + "have a NULL buffer"; + ASSERT_EQ((size_t)0, arrayP->getCount()) + << "Test case: " + testCase->first + + " - When not passed a size, a newly created ArrayBase should " + "have a count equal to zero"; boost::scoped_array buf2(new char[testCase->second.dataTypeSize * testCase->second.allocationSize]); - - arrayP.reset(new ArrayBase(testCase->second.dataType, - buf2.get(), - testCase->second.allocationSize)); - - buf = (char *) arrayP->getBuffer(); - ASSERT_EQ(buf, buf2.get()) - << "Test case: " + - testCase->first + - " - Preallocating a buffer for a newly created ArrayBase should " - "use the provided buffer"; - ASSERT_EQ((size_t) testCase->second.allocationSize, arrayP->getCount()) - << "Test case: " + - testCase->first + - " - Preallocating a buffer should have a count equal to our " - "allocation size"; - } + + arrayP.reset(new ArrayBase(testCase->second.dataType, buf2.get(), + testCase->second.allocationSize)); + + buf = (char *)arrayP->getBuffer(); + ASSERT_EQ(buf, buf2.get()) << "Test case: " + testCase->first + + " - Preallocating a buffer for a newly " + "created ArrayBase should " + "use the provided buffer"; + ASSERT_EQ((size_t)testCase->second.allocationSize, arrayP->getCount()) + << "Test case: " + testCase->first + + " - Preallocating a buffer should have a count equal to our " + "allocation size"; + } } } -TEST_F(ArrayTest, testBufferAllocation) -{ +TEST_F(ArrayTest, testBufferAllocation) { testCases_.clear(); testCases_["NTA_BasicType_Int32, size 0"] = - ArrayTestParameters(NTA_BasicType_Int32, 4, 0, "Int32", false); + ArrayTestParameters(NTA_BasicType_Int32, 4, 0, "Int32", false); testCases_["NTA_BasicType_Int32, size UINT_MAX"] = - ArrayTestParameters(NTA_BasicType_Int32, 4, UINT_MAX, "Int32", true); + ArrayTestParameters(NTA_BasicType_Int32, 4, UINT_MAX, "Int32", true); testCases_["NTA_BasicType_Int32, size -10"] = - ArrayTestParameters(NTA_BasicType_Int32, 4, -10, "Int32", true); + ArrayTestParameters(NTA_BasicType_Int32, 4, -10, "Int32", true); testCases_["NTA_BasicType_Int32, size 10"] = - ArrayTestParameters(NTA_BasicType_Int32, 4, 10, "Int32", false); - + ArrayTestParameters(NTA_BasicType_Int32, 4, 10, "Int32", false); + bool caughtException; TestCaseIterator testCase; - - for(testCase = testCases_.begin(); testCase != testCases_.end(); testCase++) - { + + for (testCase = testCases_.begin(); testCase != testCases_.end(); + testCase++) { caughtException = false; ArrayBase a(testCase->second.dataType); - try - { - a.allocateBuffer((size_t) (testCase->second.allocationSize)); - } - catch(std::exception& ) - { + try { + a.allocateBuffer((size_t)(testCase->second.allocationSize)); + } catch (std::exception &) { caughtException = true; } - - if(testCase->second.testUsesInvalidParameters) - { - ASSERT_TRUE(caughtException) - << "Test case: " + - testCase->first + - " - allocation of an ArrayBase of invalid size should raise an " - "exception"; - } - else - { + + if (testCase->second.testUsesInvalidParameters) { + ASSERT_TRUE(caughtException) << "Test case: " + testCase->first + + " - allocation of an ArrayBase of " + "invalid size should raise an " + "exception"; + } else { ASSERT_FALSE(caughtException) - << "Test case: " + testCase->first + - " - Allocation of an ArrayBase of valid size should return a " - "valid pointer"; - + << "Test case: " + testCase->first + + " - Allocation of an ArrayBase of valid size should return a " + "valid pointer"; + caughtException = false; - - try - { + + try { a.allocateBuffer(10); - } - catch(nupic::Exception) - { + } catch (nupic::Exception) { caughtException = true; } - + ASSERT_TRUE(caughtException) - << "Test case: " + testCase->first + - " - allocating a buffer when one is already allocated should " - "raise an exception"; - - ASSERT_EQ((size_t) testCase->second.allocationSize, a.getCount()) - << "Test case: " + testCase->first + - " - Size of allocated ArrayBase should match requested size"; + << "Test case: " + testCase->first + + " - allocating a buffer when one is already allocated should " + "raise an exception"; + + ASSERT_EQ((size_t)testCase->second.allocationSize, a.getCount()) + << "Test case: " + testCase->first + + " - Size of allocated ArrayBase should match requested size"; } } } -TEST_F(ArrayTest, testBufferAssignment) -{ +TEST_F(ArrayTest, testBufferAssignment) { testCases_.clear(); testCases_["NTA_BasicType_Int32, buffer assignment"] = - ArrayTestParameters(NTA_BasicType_Int32, 4, 10, "Int32", false); - + ArrayTestParameters(NTA_BasicType_Int32, 4, 10, "Int32", false); + TestCaseIterator testCase; - - for(testCase = testCases_.begin(); testCase != testCases_.end(); testCase++) - { + + for (testCase = testCases_.begin(); testCase != testCases_.end(); + testCase++) { boost::scoped_array buf(new char[testCase->second.dataTypeSize * testCase->second.allocationSize]); - + ArrayBase a(testCase->second.dataType); a.setBuffer(buf.get(), testCase->second.allocationSize); - + ASSERT_EQ(buf.get(), a.getBuffer()) - << "Test case: " + - testCase->first + - " - setBuffer() should used the assigned buffer"; + << "Test case: " + testCase->first + + " - setBuffer() should used the assigned buffer"; boost::scoped_array buf2(new char[testCase->second.dataTypeSize * testCase->second.allocationSize]); bool caughtException = false; - - try - { + + try { a.setBuffer(buf2.get(), testCase->second.allocationSize); - } - catch(nupic::Exception) - { + } catch (nupic::Exception) { caughtException = true; } - + ASSERT_TRUE(caughtException) - << "Test case: " + - testCase->first + - " - setting a buffer when one is already set should raise an " - "exception"; - } + << "Test case: " + testCase->first + + " - setting a buffer when one is already set should raise an " + "exception"; + } } -TEST_F(ArrayTest, testBufferRelease) -{ +TEST_F(ArrayTest, testBufferRelease) { testCases_.clear(); testCases_["NTA_BasicType_Int32, buffer release"] = - ArrayTestParameters(NTA_BasicType_Int32, 4, 10, "Int32", false); - + ArrayTestParameters(NTA_BasicType_Int32, 4, 10, "Int32", false); + TestCaseIterator testCase; - - for(testCase = testCases_.begin(); testCase != testCases_.end(); testCase++) - { + + for (testCase = testCases_.begin(); testCase != testCases_.end(); + testCase++) { boost::scoped_array buf(new char[testCase->second.dataTypeSize * testCase->second.allocationSize]); - + ArrayBase a(testCase->second.dataType); a.setBuffer(buf.get(), testCase->second.allocationSize); a.releaseBuffer(); - + ASSERT_EQ(nullptr, a.getBuffer()) - << "Test case: " + - testCase->first + - " - ArrayBase should no longer hold a reference to a locally allocated " - "buffer after calling releaseBuffer"; - } + << "Test case: " + testCase->first + + " - ArrayBase should no longer hold a reference to a locally " + "allocated " + "buffer after calling releaseBuffer"; + } } -TEST_F(ArrayTest, testArrayTyping) -{ +TEST_F(ArrayTest, testArrayTyping) { setupArrayTests(); - + TestCaseIterator testCase; - for(testCase = testCases_.begin(); testCase != testCases_.end(); testCase++) - { - //testArrayCreation() already validates that ArrayBase objects can't be created - //using invalid NTA_BasicType parameters, so we skip those test cases here - if(testCase->second.testUsesInvalidParameters) - { + for (testCase = testCases_.begin(); testCase != testCases_.end(); + testCase++) { + // testArrayCreation() already validates that ArrayBase objects can't be + // created using invalid NTA_BasicType parameters, so we skip those test + // cases here + if (testCase->second.testUsesInvalidParameters) { continue; } - + ArrayBase a(testCase->second.dataType); ASSERT_EQ(testCase->second.dataType, a.getType()) - << "Test case: " + - testCase->first + - " - the type of a created ArrayBase should match the requested " - "type"; + << "Test case: " + testCase->first + + " - the type of a created ArrayBase should match the requested " + "type"; std::string name(BasicType::getName(a.getType())); ASSERT_EQ(testCase->second.dataTypeText, name) - << "Test case: " + - testCase->first + - " - the string representation of a type contained in a " - "created ArrayBase should match the expected string"; - } + << "Test case: " + testCase->first + + " - the string representation of a type contained in a " + "created ArrayBase should match the expected string"; + } } -void ArrayTest::setupArrayTests() -{ - //we're going to test using all types that can be stored in the ArrayBase... - //the NTA_BasicType enum overrides the default incrementing values for - //some enumerated types, so we must reference them manually +void ArrayTest::setupArrayTests() { + // we're going to test using all types that can be stored in the ArrayBase... + // the NTA_BasicType enum overrides the default incrementing values for + // some enumerated types, so we must reference them manually testCases_.clear(); testCases_["NTA_BasicType_Byte"] = - ArrayTestParameters(NTA_BasicType_Byte, 1, 10, "Byte", false); + ArrayTestParameters(NTA_BasicType_Byte, 1, 10, "Byte", false); testCases_["NTA_BasicType_Int16"] = - ArrayTestParameters(NTA_BasicType_Int16, 2, 10, "Int16", false); + ArrayTestParameters(NTA_BasicType_Int16, 2, 10, "Int16", false); testCases_["NTA_BasicType_UInt16"] = - ArrayTestParameters(NTA_BasicType_UInt16, 2, 10, "UInt16", false); + ArrayTestParameters(NTA_BasicType_UInt16, 2, 10, "UInt16", false); testCases_["NTA_BasicType_Int32"] = - ArrayTestParameters(NTA_BasicType_Int32, 4, 10, "Int32", false); + ArrayTestParameters(NTA_BasicType_Int32, 4, 10, "Int32", false); testCases_["NTA_BasicType_UInt32"] = - ArrayTestParameters(NTA_BasicType_UInt32, 4, 10, "UInt32", false); + ArrayTestParameters(NTA_BasicType_UInt32, 4, 10, "UInt32", false); testCases_["NTA_BasicType_Int64"] = - ArrayTestParameters(NTA_BasicType_Int64, 8, 10, "Int64", false); + ArrayTestParameters(NTA_BasicType_Int64, 8, 10, "Int64", false); testCases_["NTA_BasicType_UInt64"] = - ArrayTestParameters(NTA_BasicType_UInt64, 8, 10, "UInt64", false); + ArrayTestParameters(NTA_BasicType_UInt64, 8, 10, "UInt64", false); testCases_["NTA_BasicType_Real32"] = - ArrayTestParameters(NTA_BasicType_Real32, 4, 10, "Real32", false); + ArrayTestParameters(NTA_BasicType_Real32, 4, 10, "Real32", false); testCases_["NTA_BasicType_Real64"] = - ArrayTestParameters(NTA_BasicType_Real64, 8, 10, "Real64", false); + ArrayTestParameters(NTA_BasicType_Real64, 8, 10, "Real64", false); testCases_["NTA_BasicType_Bool"] = - ArrayTestParameters(NTA_BasicType_Bool, sizeof(bool), 10, "Bool", false); -#ifdef NTA_DOUBLE_PRECISION + ArrayTestParameters(NTA_BasicType_Bool, sizeof(bool), 10, "Bool", false); +#ifdef NTA_DOUBLE_PRECISION testCases_["NTA_BasicType_Real"] = - ArrayTestParameters(NTA_BasicType_Real, 8, 10, "Real64", false); -#else + ArrayTestParameters(NTA_BasicType_Real, 8, 10, "Real64", false); +#else testCases_["NTA_BasicType_Real"] = - ArrayTestParameters(NTA_BasicType_Real, 4, 10, "Real32", false); + ArrayTestParameters(NTA_BasicType_Real, 4, 10, "Real32", false); #endif testCases_["Non-existent NTA_BasicType"] = - ArrayTestParameters((NTA_BasicType) -1, 0, 10, "N/A", true); + ArrayTestParameters((NTA_BasicType)-1, 0, 10, "N/A", true); } diff --git a/src/test/unit/ntypes/BufferTest.cpp b/src/test/unit/ntypes/BufferTest.cpp index 98b856d70d..32d33828a7 100644 --- a/src/test/unit/ntypes/BufferTest.cpp +++ b/src/test/unit/ntypes/BufferTest.cpp @@ -24,11 +24,11 @@ * Implementation for Buffer unit tests */ +#include // strlen #include #include -#include // strlen -// This test accesses private methods. +// This test accesses private methods. #define private public #include #undef private @@ -38,143 +38,146 @@ using namespace nupic; -void testReadBytes_VariableSizeBufferHelper(Size buffSize) -{ +void testReadBytes_VariableSizeBufferHelper(Size buffSize) { std::vector in; std::vector out; - in.resize(buffSize+1); - out.resize(buffSize+1); - - std::fill(in.begin(), in.begin()+in.capacity(), 'I'); - std::fill(out.begin(), out.begin()+out.capacity(), 'O'); - - for (Size i = 0; i <= buffSize; ++i) - { + in.resize(buffSize + 1); + out.resize(buffSize + 1); + + std::fill(in.begin(), in.begin() + in.capacity(), 'I'); + std::fill(out.begin(), out.begin() + out.capacity(), 'O'); + + for (Size i = 0; i <= buffSize; ++i) { ASSERT_TRUE(in[i] == 'I'); ASSERT_TRUE(out[i] == 'O'); } - + // Populate the ReadBuffer with the input ReadBuffer rb(&in[0], buffSize); - + // Get the abstract interface - IReadBuffer & r = rb; - + IReadBuffer &r = rb; + // Prepare for reading from the read buffer in chunks const Size CHUNK_SIZE = 10; Size size = CHUNK_SIZE; - // Read chunks until the buffer is exhausted and write everything to out buffer + // Read chunks until the buffer is exhausted and write everything to out + // buffer Size index = 0; - while (size == CHUNK_SIZE) - { + while (size == CHUNK_SIZE) { Int32 res = r.read(&out[index], size); ASSERT_TRUE(res == 0); index += size; } - + // Verify that last index and last read size are correct ASSERT_TRUE(index == buffSize); ASSERT_TRUE(size == buffSize % CHUNK_SIZE); - + // Check corner cases ASSERT_TRUE(out[0] == 'I'); - ASSERT_TRUE(out[buffSize-1] == 'I'); + ASSERT_TRUE(out[buffSize - 1] == 'I'); ASSERT_TRUE(out[buffSize] == 'O'); - + // Check that all other values have been read correctly Size i; - for (i = 1; i < buffSize-1; ++i) + for (i = 1; i < buffSize - 1; ++i) ASSERT_TRUE(out[i] == 'I'); } -TEST(BufferTest, testReadBytes_VariableSizeBuffer) -{ - ASSERT_NO_FATAL_FAILURE( - testReadBytes_VariableSizeBufferHelper(5)); - -// testReadBytes_VariableSizeBufferHelpter(128); -// testReadBytes_VariableSizeBufferHelpter(227); -// testReadBytes_VariableSizeBufferHelpter(228); -// testReadBytes_VariableSizeBufferHelpter(229); -// testReadBytes_VariableSizeBufferHelpter(315); -// testReadBytes_VariableSizeBufferHelpter(482); -// testReadBytes_VariableSizeBufferHelpter(483); -// testReadBytes_VariableSizeBufferHelpter(484); -// testReadBytes_VariableSizeBufferHelpter(512); -// testReadBytes_VariableSizeBufferHelpter(2000); -// testReadBytes_VariableSizeBufferHelpter(20000); -} +TEST(BufferTest, testReadBytes_VariableSizeBuffer) { + ASSERT_NO_FATAL_FAILURE(testReadBytes_VariableSizeBufferHelper(5)); + // testReadBytes_VariableSizeBufferHelpter(128); + // testReadBytes_VariableSizeBufferHelpter(227); + // testReadBytes_VariableSizeBufferHelpter(228); + // testReadBytes_VariableSizeBufferHelpter(229); + // testReadBytes_VariableSizeBufferHelpter(315); + // testReadBytes_VariableSizeBufferHelpter(482); + // testReadBytes_VariableSizeBufferHelpter(483); + // testReadBytes_VariableSizeBufferHelpter(484); + // testReadBytes_VariableSizeBufferHelpter(512); + // testReadBytes_VariableSizeBufferHelpter(2000); + // testReadBytes_VariableSizeBufferHelpter(20000); +} -TEST(BufferTest, testReadBytes_SmallBuffer) -{ +TEST(BufferTest, testReadBytes_SmallBuffer) { ReadBuffer b((const Byte *)"123", 3); - IReadBuffer & reader = b; + IReadBuffer &reader = b; Byte out[5]; Size size = 0; Int32 res = 0; - + size = 2; res = reader.read(out, size); - ASSERT_TRUE(res == 0) << "BufferTest::testReadBuffer(), reader.read(2) failed"; - ASSERT_TRUE(size == 2) << "BufferTest::testReadBuffer(), reader.read(2) failed"; - ASSERT_TRUE(out[0] == '1') << "BufferTest::testReadBuffer(), out[0] should be 1 after reading 1,2"; - ASSERT_TRUE(out[1] == '2') << "BufferTest::testReadBuffer(), out[1] should be 2 after reading 1,2"; - + ASSERT_TRUE(res == 0) + << "BufferTest::testReadBuffer(), reader.read(2) failed"; + ASSERT_TRUE(size == 2) + << "BufferTest::testReadBuffer(), reader.read(2) failed"; + ASSERT_TRUE(out[0] == '1') + << "BufferTest::testReadBuffer(), out[0] should be 1 after reading 1,2"; + ASSERT_TRUE(out[1] == '2') + << "BufferTest::testReadBuffer(), out[1] should be 2 after reading 1,2"; + size = 2; - res = reader.read(out+2, size); - ASSERT_TRUE(res == 0) << "BufferTest::testReadBuffer(), reader.read(2) failed"; - ASSERT_TRUE(size == 1) << "BufferTest::testReadBuffer(), reader.read(2) failed"; - ASSERT_TRUE(out[0] == '1') << "BufferTest::testReadBuffer(), out[0] should be 1 after reading 3"; - ASSERT_TRUE(out[1] == '2') << "BufferTest::testReadBuffer(), out[1] should be 2 after reading 3"; - ASSERT_TRUE(out[2] == '3') << "BufferTest::testReadBuffer(), out[2] should be 3 after reading 3"; + res = reader.read(out + 2, size); + ASSERT_TRUE(res == 0) + << "BufferTest::testReadBuffer(), reader.read(2) failed"; + ASSERT_TRUE(size == 1) + << "BufferTest::testReadBuffer(), reader.read(2) failed"; + ASSERT_TRUE(out[0] == '1') + << "BufferTest::testReadBuffer(), out[0] should be 1 after reading 3"; + ASSERT_TRUE(out[1] == '2') + << "BufferTest::testReadBuffer(), out[1] should be 2 after reading 3"; + ASSERT_TRUE(out[2] == '3') + << "BufferTest::testReadBuffer(), out[2] should be 3 after reading 3"; } -TEST(BufferTest, testWriteBytes) -{ +TEST(BufferTest, testWriteBytes) { WriteBuffer b; - Byte out[5] = { 1, 2, 3, 4, 5 }; - IWriteBuffer & writer = b; - ASSERT_TRUE(writer.getSize() == 0) << "BufferTest::testWriteBuffer(), writer.getSize() should be 0 before putting anything in"; + Byte out[5] = {1, 2, 3, 4, 5}; + IWriteBuffer &writer = b; + ASSERT_TRUE(writer.getSize() == 0) + << "BufferTest::testWriteBuffer(), writer.getSize() should be 0 before " + "putting anything in"; Size size = 3; writer.write(out, size); - ASSERT_TRUE(writer.getSize() == 3) << "BufferTest::testWriteBuffer(), writer.getSize() should be 3 after writing 1,2,3"; + ASSERT_TRUE(writer.getSize() == 3) + << "BufferTest::testWriteBuffer(), writer.getSize() should be 3 after " + "writing 1,2,3"; size = 2; - writer.write(out+3, size); - ASSERT_TRUE(writer.getSize() == 5) << "BufferTest::testWriteBuffer(), writer.getSize() should be 5 after writing 4,5"; - const Byte * s = writer.getData(); + writer.write(out + 3, size); + ASSERT_TRUE(writer.getSize() == 5) + << "BufferTest::testWriteBuffer(), writer.getSize() should be 5 after " + "writing 4,5"; + const Byte *s = writer.getData(); size = writer.getSize(); - //NTA_INFO << "s=" << string(s, size) << ", size=" << size; - ASSERT_TRUE(std::string(s, size) == std::string("\1\2\3\4\5")) << "BufferTest::testWriteBuffer(), writer.str() == 12345"; + // NTA_INFO << "s=" << string(s, size) << ", size=" << size; + ASSERT_TRUE(std::string(s, size) == std::string("\1\2\3\4\5")) + << "BufferTest::testWriteBuffer(), writer.str() == 12345"; } -TEST(BufferTest, testEvenMoreComplicatedSerialization) -{ - struct X - { - X() : a((Real)3.4) - , b(6) - , c('c') - , e((Real)-0.04) - { +TEST(BufferTest, testEvenMoreComplicatedSerialization) { + struct X { + X() : a((Real)3.4), b(6), c('c'), e((Real)-0.04) { for (int i = 0; i < 4; ++i) d[i] = 'A' + i; for (int i = 0; i < 3; ++i) - f[i] = 100 + i; + f[i] = 100 + i; } - + Real a; UInt32 b; Byte c; Byte d[4]; Real e; - Int32 f[3]; + Int32 f[3]; }; - + X xi[2]; xi[0].a = (Real)8.8; @@ -182,17 +185,17 @@ TEST(BufferTest, testEvenMoreComplicatedSerialization) xi[1].c = 't'; xi[1].d[0] = 'X'; xi[1].e = (Real)3.14; - xi[1].f[0] = -999; + xi[1].f[0] = -999; // Write the two Xs to a buffer WriteBuffer wb; - ASSERT_TRUE(wb.getSize() == 0) << "BufferTest::testComplicatedSerialization(), empty WriteBuffer should have 0 size"; - + ASSERT_TRUE(wb.getSize() == 0) << "BufferTest::testComplicatedSerialization()" + ", empty WriteBuffer should have 0 size"; + // Write the number of Xs UInt32 size = 2; wb.write((UInt32 &)size); // Write all Xs. - for (UInt32 i = 0; i < size; ++i) - { + for (UInt32 i = 0; i < size; ++i) { wb.write(xi[i].a); wb.write(xi[i].b); wb.write(xi[i].c); @@ -200,74 +203,84 @@ TEST(BufferTest, testEvenMoreComplicatedSerialization) wb.write((const Byte *)xi[i].d, len); wb.write(xi[i].e); len = 3; - wb.write(xi[i].f, len); + wb.write(xi[i].f, len); } - + ReadBuffer rb(wb.getData(), wb.getSize()); // Read number of Xs rb.read(size); // Allocate array of Xs auto xo = new X[size]; - for (Size i = 0; i < size; ++i) - { + for (Size i = 0; i < size; ++i) { rb.read(xo[i].a); rb.read(xo[i].b); rb.read(xo[i].c); Size len = 4; Int32 res = rb.read(xo[i].d, len); - ASSERT_TRUE(res == 0) << "BufferTest::testComplicatedSerialization(), rb.read(xi[i].d, 4) failed"; - ASSERT_TRUE(len == 4) << "BufferTest::testComplicatedSerialization(), rb.read(xi[i].d, 4) == 4"; + ASSERT_TRUE(res == 0) << "BufferTest::testComplicatedSerialization(), " + "rb.read(xi[i].d, 4) failed"; + ASSERT_TRUE(len == 4) << "BufferTest::testComplicatedSerialization(), " + "rb.read(xi[i].d, 4) == 4"; rb.read(xo[i].e); len = 3; rb.read(xo[i].f, len); - NTA_INFO << "xo[" << i << "]={" << xo[i].a << " " - << xo[i].b << " " - << xo[i].c << " " - << "'" << std::string(xo[i].d, 4) << "'" + NTA_INFO << "xo[" << i << "]={" << xo[i].a << " " << xo[i].b << " " + << xo[i].c << " " + << "'" << std::string(xo[i].d, 4) << "'" << " " << xo[i].e << " " - << "'" << xo[i].f[0] << "," << xo[i].f[1] << "," << xo[i].f[2] << "'" - ; + << "'" << xo[i].f[0] << "," << xo[i].f[1] << "," << xo[i].f[2] + << "'"; } - - ASSERT_TRUE(nearlyEqual(xo[0].a, nupic::Real(8.8))) << "BufferTest::testComplicatedSerialization(), xo[0].a == 8.8"; - ASSERT_TRUE(xo[0].b == 6) << "BufferTest::testComplicatedSerialization(), xo[0].b == 6"; - ASSERT_TRUE(xo[0].c == 'c') << "BufferTest::testComplicatedSerialization(), xo[0].c == 'c'"; - ASSERT_TRUE(std::string(xo[0].d, 4) == std::string("ABCD")) << "BufferTest::testComplicatedSerialization(), xo[0].d == ABCD"; - ASSERT_TRUE(nearlyEqual(xo[0].e, nupic::Real(-0.04))) << "BufferTest::testComplicatedSerialization(), xo[0].e == -0.04"; - ASSERT_TRUE(xo[0].f[0] == 100) << "BufferTest::testComplicatedSerialization(), xo[0].f[0] == 100"; - ASSERT_TRUE(xo[0].f[1] == 101) << "BufferTest::testComplicatedSerialization(), xo[0].f[1] == 101"; - ASSERT_TRUE(xo[0].f[2] == 102) << "BufferTest::testComplicatedSerialization(), xo[0].f[2] == 102"; - - ASSERT_TRUE(xo[1].a == nupic::Real(4.5)) << "BufferTest::testComplicatedSerialization(), xo[1].a == 4.5"; - ASSERT_TRUE(xo[1].b == 6) << "BufferTest::testComplicatedSerialization(), xo[1].b == 6"; - ASSERT_TRUE(xo[1].c == 't') << "BufferTest::testComplicatedSerialization(), xo[1].c == 't'"; - ASSERT_TRUE(std::string(xo[1].d, 4) == std::string("XBCD")) << "BufferTest::testComplicatedSerialization(), xo[1].d == XBCD"; - ASSERT_TRUE(nearlyEqual(xo[1].e, nupic::Real(3.14))) << "BufferTest::testComplicatedSerialization(), xo[1].e == 3.14"; - ASSERT_TRUE(xo[1].f[0] == -999) << "BufferTest::testComplicatedSerialization(), xo[1].f[0] == -999"; - ASSERT_TRUE(xo[1].f[1] == 101) << "BufferTest::testComplicatedSerialization(), xo[1].f[1] == 101"; - ASSERT_TRUE(xo[1].f[2] == 102) << "BufferTest::testComplicatedSerialization(), xo[1].f[2] == 102"; + + ASSERT_TRUE(nearlyEqual(xo[0].a, nupic::Real(8.8))) + << "BufferTest::testComplicatedSerialization(), xo[0].a == 8.8"; + ASSERT_TRUE(xo[0].b == 6) + << "BufferTest::testComplicatedSerialization(), xo[0].b == 6"; + ASSERT_TRUE(xo[0].c == 'c') + << "BufferTest::testComplicatedSerialization(), xo[0].c == 'c'"; + ASSERT_TRUE(std::string(xo[0].d, 4) == std::string("ABCD")) + << "BufferTest::testComplicatedSerialization(), xo[0].d == ABCD"; + ASSERT_TRUE(nearlyEqual(xo[0].e, nupic::Real(-0.04))) + << "BufferTest::testComplicatedSerialization(), xo[0].e == -0.04"; + ASSERT_TRUE(xo[0].f[0] == 100) + << "BufferTest::testComplicatedSerialization(), xo[0].f[0] == 100"; + ASSERT_TRUE(xo[0].f[1] == 101) + << "BufferTest::testComplicatedSerialization(), xo[0].f[1] == 101"; + ASSERT_TRUE(xo[0].f[2] == 102) + << "BufferTest::testComplicatedSerialization(), xo[0].f[2] == 102"; + + ASSERT_TRUE(xo[1].a == nupic::Real(4.5)) + << "BufferTest::testComplicatedSerialization(), xo[1].a == 4.5"; + ASSERT_TRUE(xo[1].b == 6) + << "BufferTest::testComplicatedSerialization(), xo[1].b == 6"; + ASSERT_TRUE(xo[1].c == 't') + << "BufferTest::testComplicatedSerialization(), xo[1].c == 't'"; + ASSERT_TRUE(std::string(xo[1].d, 4) == std::string("XBCD")) + << "BufferTest::testComplicatedSerialization(), xo[1].d == XBCD"; + ASSERT_TRUE(nearlyEqual(xo[1].e, nupic::Real(3.14))) + << "BufferTest::testComplicatedSerialization(), xo[1].e == 3.14"; + ASSERT_TRUE(xo[1].f[0] == -999) + << "BufferTest::testComplicatedSerialization(), xo[1].f[0] == -999"; + ASSERT_TRUE(xo[1].f[1] == 101) + << "BufferTest::testComplicatedSerialization(), xo[1].f[1] == 101"; + ASSERT_TRUE(xo[1].f[2] == 102) + << "BufferTest::testComplicatedSerialization(), xo[1].f[2] == 102"; } -TEST(BufferTest, testComplicatedSerialization) -{ - struct X - { - X() : a((Real)3.4) - , b(6) - , c('c') - , e((Real)-0.04) - { +TEST(BufferTest, testComplicatedSerialization) { + struct X { + X() : a((Real)3.4), b(6), c('c'), e((Real)-0.04) { for (int i = 0; i < 4; ++i) d[i] = 'A' + i; } - + Real a; UInt32 b; Byte c; Byte d[4]; Real e; }; - + X xi[2]; xi[0].a = (Real)8.8; @@ -275,17 +288,17 @@ TEST(BufferTest, testComplicatedSerialization) xi[1].c = 't'; xi[1].d[0] = 'X'; xi[1].e = (Real)3.14; - + // Write the two Xs to a buffer WriteBuffer wb; - ASSERT_TRUE(wb.getSize() == 0) << "BufferTest::testComplicatedSerialization(), empty WriteBuffer should have 0 size"; - + ASSERT_TRUE(wb.getSize() == 0) << "BufferTest::testComplicatedSerialization()" + ", empty WriteBuffer should have 0 size"; + // Write the number of Xs UInt32 size = 2; wb.write((UInt32 &)size); // Write all Xs. - for (UInt32 i = 0; i < size; ++i) - { + for (UInt32 i = 0; i < size; ++i) { wb.write(xi[i].a); wb.write(xi[i].b); wb.write(xi[i].c); @@ -293,110 +306,112 @@ TEST(BufferTest, testComplicatedSerialization) wb.write((const Byte *)xi[i].d, len); wb.write(xi[i].e); } - + ReadBuffer rb(wb.getData(), wb.getSize()); // Read number of Xs rb.read(size); // Allocate array of Xs auto xo = new X[size]; - for (Size i = 0; i < size; ++i) - { + for (Size i = 0; i < size; ++i) { rb.read(xo[i].a); rb.read(xo[i].b); rb.read(xo[i].c); Size size = 4; Int32 res = rb.read(xo[i].d, size); - ASSERT_TRUE(res == 0) << "BufferTest::testComplicatedSerialization(), rb.read(xi[i].d, 4) failed"; - ASSERT_TRUE(size == 4) << "BufferTest::testComplicatedSerialization(), rb.read(xi[i].d, 4) == 4"; + ASSERT_TRUE(res == 0) << "BufferTest::testComplicatedSerialization(), " + "rb.read(xi[i].d, 4) failed"; + ASSERT_TRUE(size == 4) << "BufferTest::testComplicatedSerialization(), " + "rb.read(xi[i].d, 4) == 4"; rb.read(xo[i].e); - NTA_INFO << "xo[" << i << "]={" << xo[i].a << " " - << xo[i].b << " " - << xo[i].c << " " - << "'" << std::string(xo[i].d, 4) << "'" - << " " << xo[i].e - ; + NTA_INFO << "xo[" << i << "]={" << xo[i].a << " " << xo[i].b << " " + << xo[i].c << " " + << "'" << std::string(xo[i].d, 4) << "'" + << " " << xo[i].e; } - - ASSERT_TRUE(nearlyEqual(xo[0].a, nupic::Real(8.8))) << "BufferTest::testComplicatedSerialization(), xo[0].a == 8.8"; - ASSERT_TRUE(xo[0].b == 6) << "BufferTest::testComplicatedSerialization(), xo[0].b == 6"; - ASSERT_TRUE(xo[0].c == 'c') << "BufferTest::testComplicatedSerialization(), xo[0].c == 'c'"; - ASSERT_TRUE(std::string(xo[0].d, 4) == std::string("ABCD")) << "BufferTest::testComplicatedSerialization(), xo[0].d == ABCD"; - ASSERT_TRUE(nearlyEqual(xo[0].e, nupic::Real(-0.04))) << "BufferTest::testComplicatedSerialization(), xo[0].e == -0.04"; - - ASSERT_TRUE(xo[1].a == nupic::Real(4.5)) << "BufferTest::testComplicatedSerialization(), xo[1].a == 4.5"; - ASSERT_TRUE(xo[1].b == 6) << "BufferTest::testComplicatedSerialization(), xo[1].b == 6"; - ASSERT_TRUE(xo[1].c == 't') << "BufferTest::testComplicatedSerialization(), xo[1].c == 't'"; - ASSERT_TRUE(std::string(xo[1].d, 4) == std::string("XBCD")) << "BufferTest::testComplicatedSerialization(), xo[1].d == XBCD"; - ASSERT_TRUE(nearlyEqual(xo[1].e, nupic::Real(3.14))) << "BufferTest::testComplicatedSerialization(), xo[1].e == 3.14"; + + ASSERT_TRUE(nearlyEqual(xo[0].a, nupic::Real(8.8))) + << "BufferTest::testComplicatedSerialization(), xo[0].a == 8.8"; + ASSERT_TRUE(xo[0].b == 6) + << "BufferTest::testComplicatedSerialization(), xo[0].b == 6"; + ASSERT_TRUE(xo[0].c == 'c') + << "BufferTest::testComplicatedSerialization(), xo[0].c == 'c'"; + ASSERT_TRUE(std::string(xo[0].d, 4) == std::string("ABCD")) + << "BufferTest::testComplicatedSerialization(), xo[0].d == ABCD"; + ASSERT_TRUE(nearlyEqual(xo[0].e, nupic::Real(-0.04))) + << "BufferTest::testComplicatedSerialization(), xo[0].e == -0.04"; + + ASSERT_TRUE(xo[1].a == nupic::Real(4.5)) + << "BufferTest::testComplicatedSerialization(), xo[1].a == 4.5"; + ASSERT_TRUE(xo[1].b == 6) + << "BufferTest::testComplicatedSerialization(), xo[1].b == 6"; + ASSERT_TRUE(xo[1].c == 't') + << "BufferTest::testComplicatedSerialization(), xo[1].c == 't'"; + ASSERT_TRUE(std::string(xo[1].d, 4) == std::string("XBCD")) + << "BufferTest::testComplicatedSerialization(), xo[1].d == XBCD"; + ASSERT_TRUE(nearlyEqual(xo[1].e, nupic::Real(3.14))) + << "BufferTest::testComplicatedSerialization(), xo[1].e == 3.14"; } -TEST(BufferTest, testArrayMethods) -{ +TEST(BufferTest, testArrayMethods) { // Test read UInt32 array { - const Byte * s = "1 2 3 444"; + const Byte *s = "1 2 3 444"; ReadBuffer b(s, (Size)::strlen(s)); - IReadBuffer & reader = b; + IReadBuffer &reader = b; UInt32 result[4]; - std::fill(result, result+4, 0); - for (auto & elem : result) - { - ASSERT_TRUE(elem== 0); + std::fill(result, result + 4, 0); + for (auto &elem : result) { + ASSERT_TRUE(elem == 0); } - + reader.read((UInt32 *)result, 3); - for (UInt32 i = 0; i < 3; ++i) - { - ASSERT_TRUE(result[i] == i+1); + for (UInt32 i = 0; i < 3; ++i) { + ASSERT_TRUE(result[i] == i + 1); } UInt32 val = 0; reader.read(val); ASSERT_TRUE(val == 444); } - + // Test read Int32 array { - const Byte * s = "-1 -2 -3 444"; + const Byte *s = "-1 -2 -3 444"; ReadBuffer b(s, (Size)::strlen(s)); - IReadBuffer & reader = b; + IReadBuffer &reader = b; Int32 result[4]; - std::fill(result, result+4, 0); - for (auto & elem : result) - { - ASSERT_TRUE(elem== 0); + std::fill(result, result + 4, 0); + for (auto &elem : result) { + ASSERT_TRUE(elem == 0); } - + reader.read((Int32 *)result, 3); - for (Int32 i = 0; i < 3; ++i) - { - ASSERT_TRUE(result[i] == -i-1); + for (Int32 i = 0; i < 3; ++i) { + ASSERT_TRUE(result[i] == -i - 1); } Int32 val = 0; reader.read(val); ASSERT_TRUE(val == 444); } - + // Test read Real32 array { - const Byte * s = "1.5 2.5 3.5 444.555"; + const Byte *s = "1.5 2.5 3.5 444.555"; ReadBuffer b(s, (Size)::strlen(s)); - IReadBuffer & reader = b; + IReadBuffer &reader = b; Real32 result[4]; - std::fill(result, result+4, (Real32)0); - for (auto & elem : result) - { - ASSERT_TRUE(elem== 0); + std::fill(result, result + 4, (Real32)0); + for (auto &elem : result) { + ASSERT_TRUE(elem == 0); } - + reader.read((Real32 *)result, 3); - for (UInt32 i = 0; i < 3; ++i) - { - ASSERT_TRUE(result[i] == i+1.5); + for (UInt32 i = 0; i < 3; ++i) { + ASSERT_TRUE(result[i] == i + 1.5); } Real32 val = 0; @@ -404,4 +419,3 @@ TEST(BufferTest, testArrayMethods) ASSERT_TRUE(nearlyEqual(val, Real32(444.555))); } } - diff --git a/src/test/unit/ntypes/CollectionTest.cpp b/src/test/unit/ntypes/CollectionTest.cpp index 00c6a13e50..d546117b0f 100644 --- a/src/test/unit/ntypes/CollectionTest.cpp +++ b/src/test/unit/ntypes/CollectionTest.cpp @@ -24,42 +24,34 @@ * Implementation of Collection test */ -#include -#include #include #include +#include +#include // Collection implementation needed for explicit instantiation #include using namespace nupic; -struct CollectionTest : public ::testing::Test -{ - struct Item - { +struct CollectionTest : public ::testing::Test { + struct Item { int x; - Item() : x(-1) - { - } + Item() : x(-1) {} - Item(int x) : x(x) - { - } + Item(int x) : x(x) {} }; }; namespace nupic { - // The Collection class must be explicitly instantiated. - template class Collection; - template class Collection; - template class Collection; -} - +// The Collection class must be explicitly instantiated. +template class Collection; +template class Collection; +template class Collection; +} // namespace nupic -TEST_F(CollectionTest, testEmptyCollection) -{ +TEST_F(CollectionTest, testEmptyCollection) { Collection c; ASSERT_TRUE(c.getCount() == 0); ASSERT_TRUE(c.contains("blah") == false); @@ -67,8 +59,7 @@ TEST_F(CollectionTest, testEmptyCollection) ASSERT_ANY_THROW(c.getByName("blah")); } -TEST_F(CollectionTest, testCollectionWith_1_Item) -{ +TEST_F(CollectionTest, testCollectionWith_1_Item) { auto p = new Item(5); Collection c; ASSERT_TRUE(c.contains("x") == false); @@ -77,15 +68,14 @@ TEST_F(CollectionTest, testCollectionWith_1_Item) ASSERT_TRUE(c.getCount() == 1); ASSERT_TRUE(c.getByIndex(0).second->x == 5); ASSERT_TRUE(c.getByName("x")->x == 5); - + ASSERT_ANY_THROW(c.getByIndex(1)); ASSERT_ANY_THROW(c.getByName("blah")); delete p; } -TEST_F(CollectionTest, testCollectionWith_2_Items) -{ +TEST_F(CollectionTest, testCollectionWith_2_Items) { Collection c; c.add("x1", Item(1)); c.add("x2", Item(2)); @@ -98,20 +88,17 @@ TEST_F(CollectionTest, testCollectionWith_2_Items) ASSERT_TRUE(c.contains("no such item") == false); ASSERT_TRUE(c.contains("x1") == true); - ASSERT_TRUE(c.contains("x2") == true); + ASSERT_TRUE(c.contains("x2") == true); ASSERT_TRUE(c.getByName("x1").x == 1); ASSERT_TRUE(c.getByName("x2").x == 2); - + ASSERT_ANY_THROW(c.getByIndex(2)); - ASSERT_ANY_THROW(c.getByName("blah")); + ASSERT_ANY_THROW(c.getByName("blah")); } - -TEST_F(CollectionTest, testCollectionWith_137_Items) -{ +TEST_F(CollectionTest, testCollectionWith_137_Items) { Collection c; - for (int i = 0; i < 137; ++i) - { + for (int i = 0; i < 137; ++i) { std::stringstream ss; ss << i; c.add(ss.str(), i); @@ -119,17 +106,15 @@ TEST_F(CollectionTest, testCollectionWith_137_Items) ASSERT_TRUE(c.getCount() == 137); - for (int i = 0; i < 137; ++i) - { + for (int i = 0; i < 137; ++i) { ASSERT_TRUE(c.getByIndex(i).second == i); } ASSERT_ANY_THROW(c.getByIndex(137)); - ASSERT_ANY_THROW(c.getByName("blah")); + ASSERT_ANY_THROW(c.getByName("blah")); } -TEST_F(CollectionTest, testCollectionAddRemove) -{ +TEST_F(CollectionTest, testCollectionAddRemove) { Collection c; c.add("0", 0); c.add("1", 1); @@ -139,7 +124,7 @@ TEST_F(CollectionTest, testCollectionAddRemove) ASSERT_TRUE(c.contains("1")); ASSERT_TRUE(c.contains("2")); ASSERT_TRUE(!c.contains("3")); - + ASSERT_ANY_THROW(c.add("0", 0)); ASSERT_ANY_THROW(c.add("1", 1)); ASSERT_ANY_THROW(c.add("2", 2)); @@ -160,7 +145,7 @@ TEST_F(CollectionTest, testCollectionAddRemove) c.remove("1"); // c is now 0, 2 ASSERT_ANY_THROW(c.remove("1")); - + ASSERT_TRUE(c.getCount() == 2); ASSERT_TRUE(c.contains("0")); ASSERT_TRUE(!c.contains("1")); @@ -169,7 +154,7 @@ TEST_F(CollectionTest, testCollectionAddRemove) ASSERT_EQ(0, c.getByIndex(0).second); // item "2" has shifted into position 1 ASSERT_EQ(2, c.getByIndex(1).second); - + // should append to end of collection c.add("1", 1); // c is now 0, 2, 1 @@ -203,5 +188,4 @@ TEST_F(CollectionTest, testCollectionAddRemove) // c is now empty ASSERT_TRUE(c.getCount() == 0); ASSERT_TRUE(!c.contains("2")); - } diff --git a/src/test/unit/ntypes/DimensionsTest.cpp b/src/test/unit/ntypes/DimensionsTest.cpp index 7739e1ee74..2e6b1b3471 100644 --- a/src/test/unit/ntypes/DimensionsTest.cpp +++ b/src/test/unit/ntypes/DimensionsTest.cpp @@ -24,16 +24,14 @@ * Implementation of BasicType test */ -#include #include +#include using namespace nupic; -class DimensionsTest : public ::testing::Test -{ +class DimensionsTest : public ::testing::Test { public: - DimensionsTest() - { + DimensionsTest() { zero.push_back(0); one_two.push_back(1); @@ -43,19 +41,17 @@ class DimensionsTest : public ::testing::Test three_four.push_back(4); } - Coordinate zero; // [0]; - Coordinate one_two; // [1,2] + Coordinate zero; // [0]; + Coordinate one_two; // [1,2] Coordinate three_four; // [3,4] - //internal helper method - static std::string vecToString(std::vector vec) - { + // internal helper method + static std::string vecToString(std::vector vec) { std::stringstream ss; ss << "["; - for (size_t i = 0; i < vec.size(); i++) - { + for (size_t i = 0; i < vec.size(); i++) { ss << vec[i]; - if (i != vec.size()-1) + if (i != vec.size() - 1) ss << " "; } ss << "]"; @@ -63,9 +59,7 @@ class DimensionsTest : public ::testing::Test } }; - -TEST_F(DimensionsTest, EmptyDimensions) -{ +TEST_F(DimensionsTest, EmptyDimensions) { // empty dimensions (unspecified) Dimensions d; ASSERT_TRUE(d.isUnspecified()); @@ -80,8 +74,7 @@ TEST_F(DimensionsTest, EmptyDimensions) ASSERT_EQ((unsigned int)0, d.getDimensionCount()); } -TEST_F(DimensionsTest, DontCareDimensions) -{ +TEST_F(DimensionsTest, DontCareDimensions) { // dontcare dimensions [0] Dimensions d; d.push_back(0); @@ -95,8 +88,7 @@ TEST_F(DimensionsTest, DontCareDimensions) ASSERT_EQ((unsigned int)1, d.getDimensionCount()); } -TEST_F(DimensionsTest, InvalidDimensions) -{ +TEST_F(DimensionsTest, InvalidDimensions) { // invalid dimensions Dimensions d; d.push_back(1); @@ -113,8 +105,7 @@ TEST_F(DimensionsTest, InvalidDimensions) ASSERT_EQ((unsigned int)2, d.getDimensionCount()); } -TEST_F(DimensionsTest, ValidDimensions) -{ +TEST_F(DimensionsTest, ValidDimensions) { // valid dimensions [2,3] // two rows, three columns Dimensions d; @@ -132,33 +123,29 @@ TEST_F(DimensionsTest, ValidDimensions) ASSERT_EQ((unsigned int)2, d.getDimensionCount()); } -TEST_F(DimensionsTest, Check2DXMajor) -{ - //check a two dimensional matrix for proper x-major ordering +TEST_F(DimensionsTest, Check2DXMajor) { + // check a two dimensional matrix for proper x-major ordering std::vector x; x.push_back(4); x.push_back(5); Dimensions d(x); size_t testDim1 = 4; size_t testDim2 = 5; - for(size_t i = 0; i < testDim1; i++) - { - for(size_t j = 0; j < testDim2; j++) - { + for (size_t i = 0; i < testDim1; i++) { + for (size_t j = 0; j < testDim2; j++) { Coordinate testCoordinate; testCoordinate.push_back(i); testCoordinate.push_back(j); - ASSERT_EQ(i+j*testDim1, d.getIndex(testCoordinate)); + ASSERT_EQ(i + j * testDim1, d.getIndex(testCoordinate)); ASSERT_EQ(vecToString(testCoordinate), - vecToString(d.getCoordinate(i+j*testDim1))); + vecToString(d.getCoordinate(i + j * testDim1))); } } } -TEST_F(DimensionsTest, Check3DXMajor) -{ - //check a three dimensional matrix for proper x-major ordering +TEST_F(DimensionsTest, Check3DXMajor) { + // check a three dimensional matrix for proper x-major ordering std::vector x; x.push_back(3); x.push_back(4); @@ -167,32 +154,26 @@ TEST_F(DimensionsTest, Check3DXMajor) size_t testDim1 = 3; size_t testDim2 = 4; size_t testDim3 = 5; - for(size_t i = 0; i < testDim1; i++) - { - for(size_t j = 0; j < testDim2; j++) - { - for(size_t k = 0; k < testDim3; k++) - { + for (size_t i = 0; i < testDim1; i++) { + for (size_t j = 0; j < testDim2; j++) { + for (size_t k = 0; k < testDim3; k++) { Coordinate testCoordinate; testCoordinate.push_back(i); testCoordinate.push_back(j); testCoordinate.push_back(k); - ASSERT_EQ(i + - j*testDim1 + - k*testDim1*testDim2, d.getIndex(testCoordinate)); + ASSERT_EQ(i + j * testDim1 + k * testDim1 * testDim2, + d.getIndex(testCoordinate)); ASSERT_EQ(vecToString(testCoordinate), - vecToString(d.getCoordinate(i + - j*testDim1 + - k*testDim1*testDim2))); + vecToString(d.getCoordinate(i + j * testDim1 + + k * testDim1 * testDim2))); } } } } -TEST_F(DimensionsTest, AlternateConstructor) -{ +TEST_F(DimensionsTest, AlternateConstructor) { // alternate constructor std::vector x; x.push_back(2); @@ -201,7 +182,7 @@ TEST_F(DimensionsTest, AlternateConstructor) ASSERT_TRUE(!d.isUnspecified()); ASSERT_TRUE(!d.isDontcare()); ASSERT_TRUE(d.isValid()); - + ASSERT_EQ((unsigned int)2, d.getDimension(0)); ASSERT_EQ((unsigned int)5, d.getDimension(1)); ASSERT_ANY_THROW(d.getDimension(2)); diff --git a/src/test/unit/ntypes/MemParserTest.cpp b/src/test/unit/ntypes/MemParserTest.cpp index 183f6b18c4..14467d394b 100644 --- a/src/test/unit/ntypes/MemParserTest.cpp +++ b/src/test/unit/ntypes/MemParserTest.cpp @@ -22,140 +22,120 @@ /** @file * Notes - */ + */ #include +#include #include #include -#include #include using namespace nupic; - // ------------------------------------------------------- // Test using get methods // ------------------------------------------------------- -TEST(MemParserTest, getMethods) -{ - std::stringstream ss; - +TEST(MemParserTest, getMethods) { + std::stringstream ss; + // Write one of each type to the stream unsigned long a = 10; long b = -20; double c = 1.5; float d = 1.6f; std::string e = "hello"; - - ss << a << " " - << b << " " - << c << " " - << d << " " - << e << " "; - - // Read back + + ss << a << " " << b << " " << c << " " << d << " " << e << " "; + + // Read back MemParser in(ss, (UInt32)ss.str().size()); - + unsigned long test_a = 0; in.get(test_a); ASSERT_EQ(a, test_a) << "get ulong"; - + long test_b = 0; in.get(test_b); ASSERT_EQ(b, test_b) << "get long"; - + double test_c = 0; in.get(test_c); ASSERT_EQ(c, test_c) << "get double"; - + float test_d = 0; in.get(test_d); ASSERT_EQ(d, test_d) << "get float"; - + std::string test_e = ""; in.get(test_e); ASSERT_EQ(e, test_e) << "get string"; - // Test EOF ASSERT_ANY_THROW(in.get(test_e)); } - // ------------------------------------------------------- // Test passing in -1 for the size to read in entire stream // ------------------------------------------------------- -TEST(MemParserTest, PassInNegativeOne) -{ - std::stringstream ss; - +TEST(MemParserTest, PassInNegativeOne) { + std::stringstream ss; + // Write one of each type to the stream unsigned long a = 10; long b = -20; double c = 1.5; float d = 1.6f; std::string e = "hello"; - - ss << a << " " - << b << " " - << c << " " - << d << " " - << e << " "; - - // Read back + + ss << a << " " << b << " " << c << " " << d << " " << e << " "; + + // Read back MemParser in(ss); - + unsigned long test_a = 0; in.get(test_a); ASSERT_EQ(a, test_a) << "get ulong b"; - + long test_b = 0; in.get(test_b); ASSERT_EQ(b, test_b) << "get long b"; - + double test_c = 0; in.get(test_c); ASSERT_EQ(c, test_c) << "get double b"; - + float test_d = 0; in.get(test_d); ASSERT_EQ(d, test_d) << "get float b"; - + std::string test_e = ""; in.get(test_e); ASSERT_EQ(e, test_e) << "get string b"; - // Test EOF ASSERT_ANY_THROW(in.get(test_e)); } - // ------------------------------------------------------- // Test using >> operator // ------------------------------------------------------- -TEST(MemParserTest, RightShiftOperator) -{ - std::stringstream ss; - +TEST(MemParserTest, RightShiftOperator) { + std::stringstream ss; + // Write one of each type to the stream unsigned long a = 10; long b = -20; double c = 1.5; float d = 1.6f; std::string e = "hello"; - - ss << a << " " - << b << " " - << c << " " - << d << " " - << e << " "; - - // Read back + + ss << a << " " << b << " " << c << " " << d << " " << e << " "; + + // Read back MemParser in(ss, (UInt32)ss.str().size()); - + unsigned long test_a = 0; long test_b = 0; double test_c = 0; @@ -167,25 +147,22 @@ TEST(MemParserTest, RightShiftOperator) ASSERT_EQ(c, test_c) << ">> double"; ASSERT_EQ(d, test_d) << ">> float"; ASSERT_EQ(e, test_e) << ">> string"; - // Test EOF ASSERT_ANY_THROW(in >> test_e); } - // ------------------------------------------------------- // Test reading trying to read an int when we have a string // ------------------------------------------------------- -TEST(MemParserTest, ReadIntWhenStrig) -{ - std::stringstream ss; +TEST(MemParserTest, ReadIntWhenStrig) { + std::stringstream ss; ss << "hello"; - - // Read back + + // Read back MemParser in(ss, (UInt32)ss.str().size()); - + // Test EOF - long v; + long v; ASSERT_ANY_THROW(in.get(v)); -} +} diff --git a/src/test/unit/ntypes/MemStreamTest.cpp b/src/test/unit/ntypes/MemStreamTest.cpp index f948b53450..db2c2fb0c4 100644 --- a/src/test/unit/ntypes/MemStreamTest.cpp +++ b/src/test/unit/ntypes/MemStreamTest.cpp @@ -27,133 +27,117 @@ #include #include +#include +#include #include #include -#include -#include using namespace nupic; - -static size_t memLimitsTest(size_t max) -{ +static size_t memLimitsTest(size_t max) { OMemStream ms; // Create a large string to dump the stream - size_t chunkSize = 0x1000000; // 16 MByte - std::string test(chunkSize, 'M'); + size_t chunkSize = 0x1000000; // 16 MByte + std::string test(chunkSize, 'M'); /* - std::string test2 (0x10000000, '1'); - std::string test3 (0x10000000, '2'); - std::string test4 (0x10000000, '3'); - std::string test5 (0x10000000, '4'); - std::string test6 (0x10000000, '5'); - std::string test7 (0x10000000, '6'); - std::string test8 (0x10000000, '7'); - std::string test9 (0x10000000, '8'); + std::string test2 (0x10000000, '1'); + std::string test3 (0x10000000, '2'); + std::string test4 (0x10000000, '3'); + std::string test5 (0x10000000, '4'); + std::string test6 (0x10000000, '5'); + std::string test7 (0x10000000, '6'); + std::string test8 (0x10000000, '7'); + std::string test9 (0x10000000, '8'); */ - size_t count = 1; while (count * chunkSize <= max) { - //std::cout << hex << "0x" << count << "."; - //std::cout.flush(); + // std::cout << hex << "0x" << count << "."; + // std::cout.flush(); try { ms << test; - } catch (std::exception& /* unused */) { - NTA_DEBUG << "Exceeded memory limit at " << std::hex << "0x" << count * chunkSize << std::dec - << " bytes."; + } catch (std::exception & /* unused */) { + NTA_DEBUG << "Exceeded memory limit at " << std::hex << "0x" + << count * chunkSize << std::dec << " bytes."; break; } count++; } // Return largest size that worked - return (count-1) * chunkSize; + return (count - 1) * chunkSize; } - - - // ------------------------------------------------------- // Test input stream // ------------------------------------------------------- -TEST(MemStreamTest, InputStream) -{ - std::string test("hi there"); +TEST(MemStreamTest, InputStream) { + std::string test("hi there"); - IMemStream ms((char*)(test.data()), test.size()); + IMemStream ms((char *)(test.data()), test.size()); std::stringstream ss(test); - for (int i=0; i<5; i++) - { + for (int i = 0; i < 5; i++) { std::string s1, s2; ms >> s1; ss >> s2; ASSERT_EQ(s2, s1) << "in"; ASSERT_EQ(ss.fail(), ms.fail()) << "in fail"; ASSERT_EQ(ss.eof(), ms.eof()) << "in eof"; - } - + } - // Test changing the buffer - std::string test2("bye now"); - ms.str((char*)(test2.data()), test2.size()); + // Test changing the buffer + std::string test2("bye now"); + ms.str((char *)(test2.data()), test2.size()); ms.seekg(0); ms.clear(); std::stringstream ss2(test2); - for (int i=0; i<5; i++) - { + for (int i = 0; i < 5; i++) { std::string s1, s2; ms >> s1; ss2 >> s2; ASSERT_EQ(s2, s1) << "in2"; ASSERT_EQ(ss2.fail(), ms.fail()) << "in2 fail"; ASSERT_EQ(ss2.eof(), ms.eof()) << "in2 eof"; - } + } } - // ------------------------------------------------------- // Test setting the buffer on a default input stream // ------------------------------------------------------- -TEST(MemStreamTest, BufferDefaultInputStream) -{ - std::string test("third test"); +TEST(MemStreamTest, BufferDefaultInputStream) { + std::string test("third test"); IMemStream ms; - ms.str((char*)(test.data()), test.size()); + ms.str((char *)(test.data()), test.size()); std::stringstream ss(test); - for (int i=0; i<5; i++) - { + for (int i = 0; i < 5; i++) { std::string s1, s2; ms >> s1; ss >> s2; ASSERT_EQ(s2, s1) << "in2"; ASSERT_EQ(ss.fail(), ms.fail()) << "in2 fail"; ASSERT_EQ(ss.eof(), ms.eof()) << "in2 eof"; - } + } } - // ------------------------------------------------------- // Test output stream // ------------------------------------------------------- -TEST(MemStreamTest, OutputStream) -{ +TEST(MemStreamTest, OutputStream) { OMemStream ms; std::stringstream ss; - for (int i=0; i<500; i++) - { + for (int i = 0; i < 500; i++) { ms << i << " "; ss << i << " "; } - - const char* dataP = ms.str(); + + const char *dataP = ms.str(); size_t size = ms.pcount(); std::string msStr(dataP, size); std::string ssStr = ss.str(); @@ -165,12 +149,11 @@ TEST(MemStreamTest, OutputStream) // ------------------------------------------------------- // Test memory limits // ------------------------------------------------------- -// Set max at 0x10000000 for day to day testing so that test doesn't take too long. -// To determine the actual memory limits, change this max to something very large and -// see where we break. -TEST(MemStreamTest, MemoryLimits) -{ +// Set max at 0x10000000 for day to day testing so that test doesn't take too +// long. To determine the actual memory limits, change this max to something +// very large and see where we break. +TEST(MemStreamTest, MemoryLimits) { size_t max = 0x10000000L; size_t sizeLimit = memLimitsTest(max); - ASSERT_EQ(sizeLimit >= max, true) << "maximum stream size"; + ASSERT_EQ(sizeLimit >= max, true) << "maximum stream size"; } diff --git a/src/test/unit/ntypes/NodeSetTest.cpp b/src/test/unit/ntypes/NodeSetTest.cpp index 97d78f12a4..e5191bea8e 100644 --- a/src/test/unit/ntypes/NodeSetTest.cpp +++ b/src/test/unit/ntypes/NodeSetTest.cpp @@ -24,17 +24,15 @@ * Implementation of BasicType test */ -#include // Only required because of issue #802 -#include #include - +#include +#include // Only required because of issue #802 using namespace nupic; -TEST(NodeSetTest, Basic) -{ +TEST(NodeSetTest, Basic) { NodeSet ns(4); - + ASSERT_TRUE(ns.begin() == ns.end()); ns.allOn(); auto i = ns.begin(); @@ -47,10 +45,10 @@ TEST(NodeSetTest, Basic) ASSERT_TRUE(*i == 3); ++i; ASSERT_TRUE(i == ns.end()); - + ns.allOff(); ASSERT_TRUE(ns.begin() == ns.end()); - + ns.add(1); ns.add(3); i = ns.begin(); @@ -69,9 +67,9 @@ TEST(NodeSetTest, Basic) ASSERT_TRUE(*i == 4); ++i; ASSERT_TRUE(i == ns.end()); - + ASSERT_ANY_THROW(ns.add(5)); - + ns.remove(3); i = ns.begin(); ASSERT_TRUE(*i == 1); @@ -88,5 +86,4 @@ TEST(NodeSetTest, Basic) ASSERT_TRUE(*i == 4); ++i; ASSERT_TRUE(i == ns.end()); - } diff --git a/src/test/unit/ntypes/ScalarTest.cpp b/src/test/unit/ntypes/ScalarTest.cpp index 32f48557c6..f2c4a3f5b0 100644 --- a/src/test/unit/ntypes/ScalarTest.cpp +++ b/src/test/unit/ntypes/ScalarTest.cpp @@ -24,24 +24,23 @@ * Implementation of Scalar test */ -#include #include +#include using namespace nupic; -TEST(ScalarTest, All) -{ +TEST(ScalarTest, All) { Scalar a(NTA_BasicType_UInt16); - //Test UInt16 + // Test UInt16 a = Scalar(NTA_BasicType_UInt16); ASSERT_ANY_THROW(a.getValue()); ASSERT_EQ((UInt16)0, a.getValue()); ASSERT_EQ(NTA_BasicType_UInt16, a.getType()); a.value.uint16 = 10; ASSERT_EQ((UInt16)10, a.getValue()); - - //Test UInt32 + + // Test UInt32 a = Scalar(NTA_BasicType_UInt32); ASSERT_ANY_THROW(a.getValue()); ASSERT_EQ((UInt32)0, a.getValue()); @@ -49,15 +48,15 @@ TEST(ScalarTest, All) a.value.uint32 = 10; ASSERT_EQ((UInt32)10, a.getValue()); - //Test UInt64 + // Test UInt64 a = Scalar(NTA_BasicType_UInt64); ASSERT_ANY_THROW(a.getValue()); ASSERT_EQ((UInt64)0, a.getValue()); ASSERT_EQ(NTA_BasicType_UInt64, a.getType()); a.value.uint64 = 10; ASSERT_EQ((UInt64)10, a.getValue()); - - //Test Int16 + + // Test Int16 a = Scalar(NTA_BasicType_Int16); ASSERT_ANY_THROW(a.getValue()); ASSERT_EQ((Int16)0, a.getValue()); @@ -65,7 +64,7 @@ TEST(ScalarTest, All) a.value.int16 = 10; ASSERT_EQ((Int16)10, a.getValue()); - //Test Int32 + // Test Int32 a = Scalar(NTA_BasicType_Int32); ASSERT_ANY_THROW(a.getValue()); ASSERT_EQ((Int32)0, a.getValue()); @@ -73,7 +72,7 @@ TEST(ScalarTest, All) a.value.int32 = 10; ASSERT_EQ((Int32)10, a.getValue()); - //Test Int64 + // Test Int64 a = Scalar(NTA_BasicType_Int64); ASSERT_ANY_THROW(a.getValue()); ASSERT_EQ((Int64)0, a.getValue()); @@ -81,7 +80,7 @@ TEST(ScalarTest, All) a.value.int64 = 10; ASSERT_EQ((Int64)10, a.getValue()); - //Test Real32 + // Test Real32 a = Scalar(NTA_BasicType_Real32); ASSERT_ANY_THROW(a.getValue()); ASSERT_EQ((Real32)0, a.getValue()); @@ -89,7 +88,7 @@ TEST(ScalarTest, All) a.value.real32 = 10; ASSERT_EQ((Real32)10, a.getValue()); - //Test Real64 + // Test Real64 a = Scalar(NTA_BasicType_Real64); ASSERT_ANY_THROW(a.getValue()); ASSERT_EQ((Real64)0, a.getValue()); @@ -97,20 +96,20 @@ TEST(ScalarTest, All) a.value.real64 = 10; ASSERT_EQ((Real64)10, a.getValue()); - //Test Handle + // Test Handle a = Scalar(NTA_BasicType_Handle); ASSERT_ANY_THROW(a.getValue()); - ASSERT_EQ((Handle)nullptr, a.getValue()); + ASSERT_EQ((Handle) nullptr, a.getValue()); ASSERT_EQ(NTA_BasicType_Handle, a.getType()); int x = 10; a.value.handle = &x; - int* p = (int*)(a.getValue()); + int *p = (int *)(a.getValue()); ASSERT_EQ(&x, a.getValue()); ASSERT_EQ(x, *p); (*p)++; ASSERT_EQ(11, *p); - - //Test Byte + + // Test Byte a = Scalar(NTA_BasicType_Byte); ASSERT_ANY_THROW(a.getValue()); ASSERT_EQ((Byte)0, a.getValue()); @@ -120,7 +119,7 @@ TEST(ScalarTest, All) a.value.byte++; ASSERT_EQ('b', a.getValue()); - //Test Bool + // Test Bool a = Scalar(NTA_BasicType_Bool); ASSERT_ANY_THROW(a.getValue()); ASSERT_EQ(false, a.getValue()); diff --git a/src/test/unit/ntypes/ValueTest.cpp b/src/test/unit/ntypes/ValueTest.cpp index 5e10f776df..d5f679db24 100644 --- a/src/test/unit/ntypes/ValueTest.cpp +++ b/src/test/unit/ntypes/ValueTest.cpp @@ -24,52 +24,48 @@ * Implementation of Value test */ -#include #include +#include using namespace nupic; -TEST(ValueTest, Scalar) -{ - boost::shared_ptr s(new Scalar(NTA_BasicType_Int32)); +TEST(ValueTest, Scalar) { + boost::shared_ptr s(new Scalar(NTA_BasicType_Int32)); s->value.int32 = 10; Value v(s); ASSERT_TRUE(v.isScalar()); - ASSERT_TRUE(! v.isString()); - ASSERT_TRUE(! v.isArray()); + ASSERT_TRUE(!v.isString()); + ASSERT_TRUE(!v.isArray()); ASSERT_EQ(Value::scalarCategory, v.getCategory()); ASSERT_EQ(NTA_BasicType_Int32, v.getType()); - + boost::shared_ptr s1 = v.getScalar(); ASSERT_TRUE(s1 == s); - + ASSERT_ANY_THROW(v.getArray()); ASSERT_ANY_THROW(v.getString()); - + EXPECT_STREQ("Scalar of type Int32", v.getDescription().c_str()); - - + Int32 x = v.getScalarT(); ASSERT_EQ(10, x); - - ASSERT_ANY_THROW(v.getScalarT()); + ASSERT_ANY_THROW(v.getScalarT()); } -TEST(ValueTest, Array) -{ - boost::shared_ptr s(new Array(NTA_BasicType_Int32)); +TEST(ValueTest, Array) { + boost::shared_ptr s(new Array(NTA_BasicType_Int32)); s->allocateBuffer(10); Value v(s); ASSERT_TRUE(v.isArray()); - ASSERT_TRUE(! v.isString()); - ASSERT_TRUE(! v.isScalar()); + ASSERT_TRUE(!v.isString()); + ASSERT_TRUE(!v.isScalar()); ASSERT_EQ(Value::arrayCategory, v.getCategory()); ASSERT_EQ(NTA_BasicType_Int32, v.getType()); - + boost::shared_ptr s1 = v.getArray(); ASSERT_TRUE(s1 == s); - + ASSERT_ANY_THROW(v.getScalar()); ASSERT_ANY_THROW(v.getString()); ASSERT_ANY_THROW(v.getScalarT()); @@ -77,33 +73,31 @@ TEST(ValueTest, Array) EXPECT_STREQ("Array of type Int32", v.getDescription().c_str()); } -TEST(ValueTest, String) -{ +TEST(ValueTest, String) { boost::shared_ptr s(new std::string("hello world")); Value v(s); - ASSERT_TRUE(! v.isArray()); + ASSERT_TRUE(!v.isArray()); ASSERT_TRUE(v.isString()); - ASSERT_TRUE(! v.isScalar()); + ASSERT_TRUE(!v.isScalar()); ASSERT_EQ(Value::stringCategory, v.getCategory()); ASSERT_EQ(NTA_BasicType_Byte, v.getType()); - + boost::shared_ptr s1 = v.getString(); EXPECT_STREQ("hello world", s1->c_str()); - + ASSERT_ANY_THROW(v.getScalar()); ASSERT_ANY_THROW(v.getArray()); ASSERT_ANY_THROW(v.getScalarT()); - + EXPECT_STREQ("string (hello world)", v.getDescription().c_str()); } -TEST(ValueTest, ValueMap) -{ +TEST(ValueTest, ValueMap) { boost::shared_ptr s(new Scalar(NTA_BasicType_Int32)); s->value.int32 = 10; boost::shared_ptr a(new Array(NTA_BasicType_Real32)); boost::shared_ptr str(new std::string("hello world")); - + ValueMap vm; vm.add("scalar", s); vm.add("array", a); @@ -116,17 +110,17 @@ TEST(ValueTest, ValueMap) ASSERT_TRUE(!vm.contains("foo")); ASSERT_TRUE(!vm.contains("scalar2")); ASSERT_TRUE(!vm.contains("xscalar")); - + boost::shared_ptr s1 = vm.getScalar("scalar"); ASSERT_TRUE(s1 == s); - + boost::shared_ptr a1 = vm.getArray("array"); ASSERT_TRUE(a1 == a); - + boost::shared_ptr def(new Scalar(NTA_BasicType_Int32)); Int32 x = vm.getScalarT("scalar", (Int32)20); ASSERT_EQ((Int32)10, x); - + x = vm.getScalarT("scalar2", (Int32)20); ASSERT_EQ((Int32)20, x); diff --git a/src/test/unit/os/DirectoryTest.cpp b/src/test/unit/os/DirectoryTest.cpp index d2fd807fda..d101e19966 100644 --- a/src/test/unit/os/DirectoryTest.cpp +++ b/src/test/unit/os/DirectoryTest.cpp @@ -24,46 +24,41 @@ * Implementation for Directory test */ - -#include -#include -#include -#include #include #include +#include +#include +#include +#include #include - #if defined(NTA_OS_WINDOWS) - #include +#include #else - #include +#include #endif #include // sort using namespace std; using namespace nupic; -static std::string getCurrDir() -{ - char buff[APR_PATH_MAX+1]; +static std::string getCurrDir() { + char buff[APR_PATH_MAX + 1]; #if defined(NTA_OS_WINDOWS) DWORD res = ::GetCurrentDirectoryA(APR_PATH_MAX, (LPSTR)buff); NTA_CHECK(res > 0) << OS::getErrorMessage(); #else - char * s = ::getcwd(buff, APR_PATH_MAX); + char *s = ::getcwd(buff, APR_PATH_MAX); NTA_CHECK(s != nullptr) << OS::getErrorMessage(); #endif - return buff; + return buff; } - std::string sep(Path::sep); -TEST(DirectoryTest, Existence) -{ - +TEST(DirectoryTest, Existence) { + ASSERT_TRUE(!Directory::exists("No such dir")); if (Directory::exists("dir_0")) Directory::removeTree("dir_0"); @@ -73,36 +68,33 @@ TEST(DirectoryTest, Existence) Directory::removeTree("dir_0"); } -TEST(DirectoryTest, setCWD) -{ +TEST(DirectoryTest, setCWD) { Directory::create("dir_1"); - + std::string baseDir = Path::makeAbsolute(getCurrDir()); Directory::setCWD("dir_1"); - std::string cwd1 = Path::makeAbsolute(getCurrDir()); + std::string cwd1 = Path::makeAbsolute(getCurrDir()); - std::string cwd2 = Path::makeAbsolute(baseDir + Path::sep + std::string("dir_1")); + std::string cwd2 = + Path::makeAbsolute(baseDir + Path::sep + std::string("dir_1")); ASSERT_EQ(cwd1, cwd2); - + Directory::setCWD(baseDir); ASSERT_EQ(baseDir, getCurrDir()); Directory::removeTree("dir_1"); } -TEST(DirectoryTest, getCWD) -{ - ASSERT_EQ(getCurrDir(), Directory::getCWD()); -} +TEST(DirectoryTest, getCWD) { ASSERT_EQ(getCurrDir(), Directory::getCWD()); } -TEST(DirectoryTest, RemoveTreeAndCreate) -{ +TEST(DirectoryTest, RemoveTreeAndCreate) { std::string p = Path::makeAbsolute(std::string("someDir")); std::string d = Path::join(p, "someSubDir"); if (Path::exists(p)) Directory::removeTree(p); ASSERT_TRUE(!Path::exists(p)); - ASSERT_THROW(Directory::create(d), exception); // nonrecursive create should fail + ASSERT_THROW(Directory::create(d), + exception); // nonrecursive create should fail Directory::create(d, false, true /* recursive */); ASSERT_TRUE(Path::exists(d)); Directory::removeTree(p); @@ -110,13 +102,11 @@ TEST(DirectoryTest, RemoveTreeAndCreate) ASSERT_TRUE(!Path::exists(p)); } - -TEST(DirectoryTest, CopyTree) -{ +TEST(DirectoryTest, CopyTree) { std::string p = Path::makeAbsolute(std::string("someDir")); std::string a = Path::join(p, "A"); std::string b = Path::join(p, "B"); - + if (Path::exists(p)) Directory::removeTree(p); ASSERT_TRUE(!Path::exists(p)); @@ -141,7 +131,7 @@ TEST(DirectoryTest, CopyTree) std::string dest = Path::join(a, "B", "1.txt"); ASSERT_TRUE(!Directory::exists(Path::normalize(Path::join(a, "B")))); - + Directory::copyTree(b, a); ASSERT_TRUE(Directory::exists(Path::normalize(Path::join(a, "B")))); @@ -157,21 +147,19 @@ TEST(DirectoryTest, CopyTree) Directory::removeTree(p); ASSERT_TRUE(!Path::exists(p)); - } -TEST(DirectoryTest, Iterator) -{ +TEST(DirectoryTest, Iterator) { if (Directory::exists("A")) Directory::removeTree("A"); Directory::create("A"); Directory::create("A" + sep + "B"); Directory::create("A" + sep + "C"); - + { Directory::Iterator di("A"); Directory::Entry entry; - Directory::Entry * e = nullptr; + Directory::Entry *e = nullptr; vector subdirs; e = di.next(entry); @@ -179,31 +167,30 @@ TEST(DirectoryTest, Iterator) ASSERT_TRUE(e->type == Directory::Entry::DIRECTORY); subdirs.push_back(e->path); string first = e->path; - + e = di.next(entry); ASSERT_TRUE(e != nullptr); ASSERT_TRUE(e->type == Directory::Entry::DIRECTORY); subdirs.push_back(e->path); - + e = di.next(entry); ASSERT_TRUE(e == nullptr); - + // Get around different directory iteration orders on different platforms std::sort(subdirs.begin(), subdirs.end()); ASSERT_TRUE(subdirs[0] == "B"); ASSERT_TRUE(subdirs[1] == "C"); - + // check that after reset first entry is returned again di.reset(); e = di.next(entry); ASSERT_TRUE(e != nullptr); ASSERT_TRUE(e->type == Directory::Entry::DIRECTORY); ASSERT_TRUE(e->path == first); - } + } // Cleanup test dirs ASSERT_TRUE(Path::exists("A")); Directory::removeTree("A"); ASSERT_TRUE(!Path::exists("A")); } - diff --git a/src/test/unit/os/EnvTest.cpp b/src/test/unit/os/EnvTest.cpp index c6b1038685..8a71daa867 100644 --- a/src/test/unit/os/EnvTest.cpp +++ b/src/test/unit/os/EnvTest.cpp @@ -24,62 +24,58 @@ * @file */ - -#include #include +#include using namespace nupic; - -TEST(EnvTest, Basic) -{ +TEST(EnvTest, Basic) { std::string name; std::string value; bool result; - + // get value that is not set value = "DONTCHANGEME"; result = Env::get("NOTDEFINED", value); ASSERT_FALSE(result); EXPECT_STREQ("DONTCHANGEME", value.c_str()); - + // get value that should be set value = ""; result = Env::get("PATH", value); ASSERT_TRUE(result); ASSERT_TRUE(value.length() > 0) << "get path value"; - + // set a value name = "myname"; value = "myvalue"; Env::set(name, value); - + // retrieve it value = ""; result = Env::get(name, value); ASSERT_TRUE(result); EXPECT_STREQ("myvalue", value.c_str()); - + // set it to something different value = "mynewvalue"; Env::set(name, value); - + // retrieve the new value result = Env::get(name, value); ASSERT_TRUE(result); EXPECT_STREQ("mynewvalue", value.c_str()); - + // delete the value value = "DONTCHANGEME"; Env::unset(name); result = Env::get(name, value); ASSERT_FALSE(result); EXPECT_STREQ("DONTCHANGEME", value.c_str()); - + // delete a value that is not set - // APR response is not documented. Will see a warning if - // APR reports an error. - // Is there any way to do an actual test here? + // APR response is not documented. Will see a warning if + // APR reports an error. + // Is there any way to do an actual test here? Env::unset(name); } - diff --git a/src/test/unit/os/OSTest.cpp b/src/test/unit/os/OSTest.cpp index e722a1315d..0d219e292c 100644 --- a/src/test/unit/os/OSTest.cpp +++ b/src/test/unit/os/OSTest.cpp @@ -24,28 +24,26 @@ * @file */ -#include +#include +#include #include +#include #include -#include -#include using namespace nupic; - -TEST(OSTest, Basic) -{ +TEST(OSTest, Basic) { #if defined(NTA_OS_WINDOWS) #else // save the parts of the environment we'll be changing std::string savedHOME; bool isHomeSet = Env::get("HOME", savedHOME); - + Env::set("HOME", "/home1/myhome"); Env::set("USER", "user1"); Env::set("LOGNAME", "logname1"); - + EXPECT_STREQ("/home1/myhome", OS::getHomeDir().c_str()) << "OS::getHomeDir"; bool caughtException = false; Env::unset("HOME"); @@ -61,14 +59,13 @@ TEST(OSTest, Basic) Env::set("HOME", savedHOME); } - #endif // Test getUserName() { #if defined(NTA_OS_WINDOWS) Env::set("USERNAME", "123"); - ASSERT_TRUE(OS::getUserName() == "123"); + ASSERT_TRUE(OS::getUserName() == "123"); #else // case 1 - USER defined Env::set("USER", "123"); @@ -81,23 +78,22 @@ TEST(OSTest, Basic) // case 3 - USER and LOGNAME not defined Env::unset("LOGNAME"); - + std::stringstream ss(""); ss << getuid(); ASSERT_TRUE(OS::getUserName() == ss.str()); #endif } - // Test getStackTrace() { #if defined(NTA_OS_WINDOWS) // std::string stackTrace = OS::getStackTrace(); -// ASSERT_TRUE(!stackTrace.empty()); +// ASSERT_TRUE(!stackTrace.empty()); // // stackTrace = OS::getStackTrace(); // ASSERT_TRUE(!stackTrace.empty()); -#endif +#endif } // Test executeCommand() @@ -106,7 +102,4 @@ TEST(OSTest, Basic) ASSERT_TRUE(output == "ABCDefg\n"); } - - } - diff --git a/src/test/unit/os/PathTest.cpp b/src/test/unit/os/PathTest.cpp index e939677683..a11724ed45 100644 --- a/src/test/unit/os/PathTest.cpp +++ b/src/test/unit/os/PathTest.cpp @@ -24,17 +24,16 @@ * Implementation for Path test */ +#include #include -#include -#include #include +#include +#include #include -#include using namespace std; using namespace nupic; - class PathTest : public ::testing::Test { public: string testOutputDir_ = Path::makeAbsolute("TestEverything.out"); @@ -42,28 +41,25 @@ class PathTest : public ::testing::Test { PathTest() { // Create if it doesn't exist if (!Path::exists(testOutputDir_)) { - std::cout << "Tester -- creating output directory " << std::string(testOutputDir_) << "\n"; - // will throw if unsuccessful. + std::cout << "Tester -- creating output directory " + << std::string(testOutputDir_) << "\n"; + // will throw if unsuccessful. Directory::create(string(testOutputDir_)); - } + } } - string fromTestOutputDir(const string& path) { + string fromTestOutputDir(const string &path) { Path testoutputpath(testOutputDir_); if (path != "") testoutputpath += path; return string(testoutputpath); } - }; // test static exists() -TEST_F(PathTest, exists) -{ -} +TEST_F(PathTest, exists) {} -TEST_F(PathTest, getParent) -{ +TEST_F(PathTest, getParent) { #if defined(NTA_OS_WINDOWS) // no tests defined #else @@ -82,7 +78,7 @@ TEST_F(PathTest, getParent) g = Path::getParent(g); EXPECT_STREQ("/", g.c_str()) << "getParent5"; - + // Parent should normalize first, to avoid parent(a/b/..)->(a/b) g = "/a/b/.."; EXPECT_STREQ("/", Path::getParent(g).c_str()) << "getParent6"; @@ -93,81 +89,73 @@ TEST_F(PathTest, getParent) g = "a"; EXPECT_STREQ(".", Path::getParent(g).c_str()) << "getParent8"; - + // getParent() of a relative directory above us should work g = "../../a"; EXPECT_STREQ("../..", Path::getParent(g).c_str()) << "getParent9"; g = "."; EXPECT_STREQ("..", Path::getParent(g).c_str()) << "getParent10"; - + #endif - std::string x = Path::join("someDir", "X"); x = Path::makeAbsolute(x); std::string y = Path::join(x, "Y"); - std::string parent = Path::getParent(y); ASSERT_TRUE(x == parent); - } // test static getFilename() TEST_F(PathTest, getFilename) -{ -} +{} // test static getBasename() -TEST_F(PathTest, getBasename) -{ +TEST_F(PathTest, getBasename) { #if defined(NTA_OS_WINDOWS) // no tests defined #else EXPECT_STREQ("bar", Path::getBasename("/foo/bar").c_str()) << "basename1"; EXPECT_STREQ("", Path::getBasename("/foo/bar/").c_str()) << "basename2"; EXPECT_STREQ("bar.ext", - Path::getBasename("/this is a long dir / foo$/bar.ext").c_str()) - << "basename3"; + Path::getBasename("/this is a long dir / foo$/bar.ext").c_str()) + << "basename3"; #endif } // test static getExtension() -TEST_F(PathTest, getExtension) -{ +TEST_F(PathTest, getExtension) { string sep(Path::sep); std::string ext = Path::getExtension("abc" + sep + "def.ext"); ASSERT_TRUE(ext == "ext"); } // test static normalize() -TEST_F(PathTest, normalize) -{ +TEST_F(PathTest, normalize) { #if defined(NTA_OS_WINDOWS) // no tests defined #else EXPECT_STREQ("/foo/bar", Path::normalize("//foo/quux/..//bar").c_str()) - << "normalize1"; - EXPECT_STREQ("/foo/contains a lot of spaces", - Path::normalize("///foo/a/b/c/../../d/../../contains a lot of spaces/g.tgz/..").c_str()) - << "normalize2"; - EXPECT_STREQ("../..", Path::normalize("../foo/../..").c_str()) - << "normalize3"; + << "normalize1"; + EXPECT_STREQ( + "/foo/contains a lot of spaces", + Path::normalize( + "///foo/a/b/c/../../d/../../contains a lot of spaces/g.tgz/..") + .c_str()) + << "normalize2"; + EXPECT_STREQ("../..", Path::normalize("../foo/../..").c_str()) + << "normalize3"; EXPECT_STREQ("/", Path::normalize("/../..").c_str()) << "normalize4"; -#endif - +#endif } // test static makeAbsolute() -TEST_F(PathTest, makeAbsolute) -{ -} +TEST_F(PathTest, makeAbsolute) {} // test static split() -TEST_F(PathTest, split) -{ +TEST_F(PathTest, split) { #if defined(NTA_OS_WINDOWS) // no tests defined #else @@ -179,51 +167,41 @@ TEST_F(PathTest, split) ASSERT_EQ(sv[1], "foo") << "split1.2"; ASSERT_EQ(sv[2], "bar") << "split1.3"; } - EXPECT_STREQ("/foo/bar", Path::join(sv.begin(), sv.end()).c_str()) << "split1.4"; + EXPECT_STREQ("/foo/bar", Path::join(sv.begin(), sv.end()).c_str()) + << "split1.4"; sv = Path::split("foo/bar"); ASSERT_EQ(2U, sv.size()) << "split2 size"; - if (sv.size() == 2) - { + if (sv.size() == 2) { ASSERT_EQ(sv[0], "foo") << "split2.2"; ASSERT_EQ(sv[1], "bar") << "split2.3"; } EXPECT_STREQ("foo/bar", Path::join(sv.begin(), sv.end()).c_str()) - << "split2.3"; + << "split2.3"; sv = Path::split("foo//bar/"); ASSERT_EQ(2U, sv.size()) << "split3 size"; - if (sv.size() == 2) - { + if (sv.size() == 2) { ASSERT_EQ(sv[0], "foo") << "split3.2"; ASSERT_EQ(sv[1], "bar") << "split3.3"; } EXPECT_STREQ("foo/bar", Path::join(sv.begin(), sv.end()).c_str()) - << "split3.4"; - -#endif - + << "split3.4"; +#endif } // test static join() -TEST_F(PathTest, join) -{ -} +TEST_F(PathTest, join) {} // test static remove() -TEST_F(PathTest, remove) -{ -} +TEST_F(PathTest, remove) {} // test static rename() -TEST_F(PathTest, rename) -{ -} +TEST_F(PathTest, rename) {} // test static copy() -TEST_F(PathTest, copy) -{ +TEST_F(PathTest, copy) { { OFStream f("a.txt"); f << "12345"; @@ -235,7 +213,7 @@ TEST_F(PathTest, copy) f >> s; ASSERT_TRUE(s == "12345"); } - + { if (Path::exists("b.txt")) Path::remove("b.txt"); @@ -247,16 +225,15 @@ TEST_F(PathTest, copy) f >> s; ASSERT_TRUE(s == "12345"); } - + Path::remove("a.txt"); Path::remove("b.txt"); ASSERT_TRUE(!Path::exists("a.txt")); ASSERT_TRUE(!Path::exists("b.txt")); -} +} // test static copy() in temp directory -TEST_F(PathTest, copyInTemp) -{ +TEST_F(PathTest, copyInTemp) { { OFStream f("a.txt"); f << "12345"; @@ -268,7 +245,7 @@ TEST_F(PathTest, copyInTemp) f >> s; ASSERT_TRUE(s == "12345"); } - + string destination = fromTestOutputDir("pathtest_dir"); { destination += "b.txt"; @@ -282,34 +259,31 @@ TEST_F(PathTest, copyInTemp) f >> s; ASSERT_TRUE(s == "12345"); } - + Path::remove("a.txt"); Path::remove(destination); ASSERT_TRUE(!Path::exists("a.txt")); ASSERT_TRUE(!Path::exists(destination)); -} - -//test static isRootdir() -TEST_F(PathTest, isRootDir) -{ } -//test static isAbsolute() -TEST_F(PathTest, isAbsolute) -{ +// test static isRootdir() +TEST_F(PathTest, isRootDir) {} + +// test static isAbsolute() +TEST_F(PathTest, isAbsolute) { #if defined(NTA_OS_WINDOWS) ASSERT_TRUE(Path::isAbsolute("c:")); ASSERT_TRUE(Path::isAbsolute("c:\\")); ASSERT_TRUE(Path::isAbsolute("c:\\foo\\")); - ASSERT_TRUE(Path::isAbsolute("c:\\foo\\bar")); - - ASSERT_TRUE(Path::isAbsolute("\\\\foo")); - ASSERT_TRUE(Path::isAbsolute("\\\\foo\\")); + ASSERT_TRUE(Path::isAbsolute("c:\\foo\\bar")); + + ASSERT_TRUE(Path::isAbsolute("\\\\foo")); + ASSERT_TRUE(Path::isAbsolute("\\\\foo\\")); ASSERT_TRUE(Path::isAbsolute("\\\\foo\\bar")); ASSERT_TRUE(Path::isAbsolute("\\\\foo\\bar\\baz")); - - ASSERT_TRUE(!Path::isAbsolute("foo")); - ASSERT_TRUE(!Path::isAbsolute("foo\\bar")); + + ASSERT_TRUE(!Path::isAbsolute("foo")); + ASSERT_TRUE(!Path::isAbsolute("foo\\bar")); ASSERT_TRUE(!Path::isAbsolute("\\")); ASSERT_TRUE(!Path::isAbsolute("\\\\")); ASSERT_TRUE(!Path::isAbsolute("\\foo")); @@ -317,15 +291,14 @@ TEST_F(PathTest, isAbsolute) ASSERT_TRUE(Path::isAbsolute("/")); ASSERT_TRUE(Path::isAbsolute("/foo")); ASSERT_TRUE(Path::isAbsolute("/foo/")); - ASSERT_TRUE(Path::isAbsolute("/foo/bar")); - - ASSERT_TRUE(!Path::isAbsolute("foo")); - ASSERT_TRUE(!Path::isAbsolute("foo/bar")); -#endif + ASSERT_TRUE(Path::isAbsolute("/foo/bar")); + + ASSERT_TRUE(!Path::isAbsolute("foo")); + ASSERT_TRUE(!Path::isAbsolute("foo/bar")); +#endif } -TEST_F(PathTest, getExecutablePath) -{ +TEST_F(PathTest, getExecutablePath) { // test static getExecutablePath std::string path = Path::getExecutablePath(); std::cout << "Executable path: '" << path << "'\n"; @@ -334,9 +307,9 @@ TEST_F(PathTest, getExecutablePath) std::string basename = Path::getBasename(path); #if defined(NTA_OS_WINDOWS) EXPECT_STREQ(basename.c_str(), "unit_tests.exe") - << "basename should be unit_tests"; + << "basename should be unit_tests"; #else EXPECT_STREQ(basename.c_str(), "unit_tests") - << "basename should be unit_tests"; + << "basename should be unit_tests"; #endif -} +} diff --git a/src/test/unit/os/RegexTest.cpp b/src/test/unit/os/RegexTest.cpp index 951d5f675f..c51154b4b7 100644 --- a/src/test/unit/os/RegexTest.cpp +++ b/src/test/unit/os/RegexTest.cpp @@ -24,54 +24,43 @@ * Implementation for Directory test */ - -#include #include +#include using namespace std; using namespace nupic; +TEST(RegexTest, Basic) { -TEST(RegexTest, Basic) -{ - ASSERT_TRUE(regex::match(".*", "")); ASSERT_TRUE(regex::match(".*", "dddddfsdsgregegr")); - ASSERT_TRUE(regex::match("d.*", "d")); + ASSERT_TRUE(regex::match("d.*", "d")); ASSERT_TRUE(regex::match("^d.*", "ddsfffdg")); ASSERT_TRUE(!regex::match("d.*", "")); ASSERT_TRUE(!regex::match("d.*", "a")); ASSERT_TRUE(!regex::match("^d.*", "ad")); ASSERT_TRUE(!regex::match("Sensor", "CategorySensor")); - - - ASSERT_TRUE(regex::match("\\\\", "\\")); - -// ASSERT_TRUE(regex::match("\\w", "a")); -// ASSERT_TRUE(regex::match("\\d", "3")); -// ASSERT_TRUE(regex::match("\\w{3}", "abc")); -// ASSERT_TRUE(regex::match("^\\w{3}$", "abc")); -// ASSERT_TRUE(regex::match("[\\w]{3}", "abc")); - + + ASSERT_TRUE(regex::match("\\\\", "\\")); + + // ASSERT_TRUE(regex::match("\\w", "a")); + // ASSERT_TRUE(regex::match("\\d", "3")); + // ASSERT_TRUE(regex::match("\\w{3}", "abc")); + // ASSERT_TRUE(regex::match("^\\w{3}$", "abc")); + // ASSERT_TRUE(regex::match("[\\w]{3}", "abc")); + ASSERT_TRUE(regex::match("[A-Za-z0-9_]{3}", "abc")); - + // Invalid expression tests (should throw) - try - { + try { ASSERT_TRUE(regex::match("", "")); ASSERT_TRUE(false); + } catch (...) { } - catch (...) - { - } - - try - { + + try { ASSERT_TRUE(regex::match("xyz[", "")); ASSERT_TRUE(false); - } - catch (...) - { + } catch (...) { } } - diff --git a/src/test/unit/os/TimerTest.cpp b/src/test/unit/os/TimerTest.cpp index 431604fd39..f9bb4d166d 100644 --- a/src/test/unit/os/TimerTest.cpp +++ b/src/test/unit/os/TimerTest.cpp @@ -26,18 +26,17 @@ #define SLEEP_MICROSECONDS (100 * 1000) -#include -#include -#include // fabs #include #include +#include // fabs +#include +#include using namespace nupic; -TEST(TimerTest, Basic) -{ -// Tests are minimal because we have no way to run performance-sensitive tests in a controlled -// environment. +TEST(TimerTest, Basic) { + // Tests are minimal because we have no way to run performance-sensitive tests + // in a controlled environment. Timer t1; Timer t2(/* startme= */ true); @@ -69,15 +68,14 @@ TEST(TimerTest, Basic) ASSERT_EQ(t1.getStartCount(), 2); } -TEST(TimerTest, Drift) -{ -// Test start/stop delay accumulation +TEST(TimerTest, Drift) { + // Test start/stop delay accumulation Timer t; const UInt EPOCHS = 1000000; // 1M - const UInt EPSILON = 5; // tolerate 5us drift on 1M restarts - for(UInt i=0; i -#include #include +#include +#include using namespace nupic; -class PyHelpersTest : public ::testing::Test -{ - public: - PyHelpersTest() - { +class PyHelpersTest : public ::testing::Test { +public: + PyHelpersTest() { NTA_DEBUG << "Py_Initialize()"; Py_Initialize(); } - ~PyHelpersTest() - { + ~PyHelpersTest() { NTA_DEBUG << "Py_Finalize()"; Py_Finalize(); } }; +TEST_F(PyHelpersTest, pyPtrConstructionNULL) { + PyObject *p = NULL; + EXPECT_THROW(py::Ptr(p, /* allowNULL: */ false), std::exception); -TEST_F(PyHelpersTest, pyPtrConstructionNULL) -{ - PyObject * p = NULL; - EXPECT_THROW(py::Ptr(p, /* allowNULL: */false), std::exception); - - py::Ptr pp1(p, /* allowNULL: */true); + py::Ptr pp1(p, /* allowNULL: */ true); ASSERT_TRUE((PyObject *)pp1 == NULL); ASSERT_TRUE(pp1.isNULL()); } -TEST_F(PyHelpersTest, pyPtrConstructionNonNULL) -{ - PyObject * p = PyTuple_New(1); +TEST_F(PyHelpersTest, pyPtrConstructionNonNULL) { + PyObject *p = PyTuple_New(1); py::Ptr pp2(p); ASSERT_TRUE(!pp2.isNULL()); ASSERT_TRUE((PyObject *)pp2 == p); @@ -67,10 +61,9 @@ TEST_F(PyHelpersTest, pyPtrConstructionNonNULL) ASSERT_TRUE(pp2.isNULL()); Py_DECREF(p); } - -TEST_F(PyHelpersTest, pyPtrConstructionAssign) -{ - PyObject * p = PyTuple_New(1); + +TEST_F(PyHelpersTest, pyPtrConstructionAssign) { + PyObject *p = PyTuple_New(1); ASSERT_TRUE(p->ob_refcnt == 1); py::Ptr pp(NULL, /* allowNULL */ true); ASSERT_TRUE(pp.isNULL()); @@ -83,8 +76,7 @@ TEST_F(PyHelpersTest, pyPtrConstructionAssign) ASSERT_TRUE(p->ob_refcnt == 1); } -TEST_F(PyHelpersTest, pyString) -{ +TEST_F(PyHelpersTest, pyString) { py::String ps1(std::string("123")); ASSERT_TRUE(PyString_Check(ps1) != 0); @@ -106,21 +98,20 @@ TEST_F(PyHelpersTest, pyString) ASSERT_TRUE(std::string(ps2) == expected); ASSERT_TRUE(std::string(ps3) == expected); - PyObject * p = PyString_FromString("777"); + PyObject *p = PyString_FromString("777"); py::String ps4(p); ASSERT_TRUE(std::string(ps4) == std::string("777")); } -TEST_F(PyHelpersTest, pyInt) -{ +TEST_F(PyHelpersTest, pyInt) { py::Int n1(-5); py::Int n2(-6666); py::Int n3(long(0)); py::Int n4(555); py::Int n5(6666); - + ASSERT_TRUE(n1 == -5); - int x = n2; + int x = n2; int expected = -6666; ASSERT_TRUE(x == expected); ASSERT_TRUE(n3 == 0); @@ -130,16 +121,15 @@ TEST_F(PyHelpersTest, pyInt) ASSERT_TRUE(x == expected); } -TEST_F(PyHelpersTest, pyLong) -{ +TEST_F(PyHelpersTest, pyLong) { py::Long n1(-5); py::Long n2(-66666666); py::Long n3(long(0)); py::Long n4(555); py::Long n5(66666666); - + ASSERT_TRUE(n1 == -5); - long x = n2; + long x = n2; long expected = -66666666; ASSERT_TRUE(x == expected); ASSERT_TRUE(n3 == 0); @@ -149,14 +139,13 @@ TEST_F(PyHelpersTest, pyLong) ASSERT_TRUE(x == expected); } -TEST_F(PyHelpersTest, pyUnsignedLong) -{ +TEST_F(PyHelpersTest, pyUnsignedLong) { py::UnsignedLong n1((unsigned long)(-5)); py::UnsignedLong n2((unsigned long)(-66666666)); py::UnsignedLong n3((unsigned long)(0)); py::UnsignedLong n4(555); py::UnsignedLong n5(66666666); - + ASSERT_TRUE(n1 == (unsigned long)(-5)); ASSERT_TRUE(n2 == (unsigned long)(-66666666)); ASSERT_TRUE(n3 == 0); @@ -164,8 +153,7 @@ TEST_F(PyHelpersTest, pyUnsignedLong) ASSERT_TRUE(n5 == 66666666); } -TEST_F(PyHelpersTest, pyFloat) -{ +TEST_F(PyHelpersTest, pyFloat) { ASSERT_TRUE(py::Float::getMax() == std::numeric_limits::max()); ASSERT_TRUE(py::Float::getMin() == std::numeric_limits::min()); @@ -176,7 +164,7 @@ TEST_F(PyHelpersTest, pyFloat) py::Float n3(333.555); py::Float n4(0.02); py::Float n5("0.02"); - + ASSERT_TRUE(max == py::Float::getMax()); ASSERT_TRUE(min == py::Float::getMin()); ASSERT_TRUE(n1 == -0.5); @@ -186,8 +174,7 @@ TEST_F(PyHelpersTest, pyFloat) ASSERT_TRUE(n5 == 0.02); } -TEST_F(PyHelpersTest, pyBool) -{ +TEST_F(PyHelpersTest, pyBool) { const auto trueRefcount = Py_REFCNT(Py_True); const auto falseRefcount = Py_REFCNT(Py_False); @@ -223,21 +210,19 @@ TEST_F(PyHelpersTest, pyBool) ASSERT_EQ(falseRefcount, Py_REFCNT(Py_False)); } -TEST_F(PyHelpersTest, pyTupleEmpty) -{ +TEST_F(PyHelpersTest, pyTupleEmpty) { py::String s1("item_1"); py::String s2("item_2"); - + py::Tuple empty; ASSERT_TRUE(PyTuple_Check(empty) != 0); ASSERT_TRUE(empty.getCount() == 0); - + EXPECT_THROW(empty.setItem(0, s1), std::exception); EXPECT_THROW(empty.getItem(0), std::exception); } -TEST_F(PyHelpersTest, pyTupleOneItem) -{ +TEST_F(PyHelpersTest, pyTupleOneItem) { py::String s1("item_1"); py::String s2("item_2"); @@ -248,19 +233,18 @@ TEST_F(PyHelpersTest, pyTupleOneItem) t1.setItem(0, s1); py::String item1(t1.getItem(0)); ASSERT_TRUE(std::string(item1) == std::string(s1)); - + py::String fastItem1(t1.fastGetItem(0)); ASSERT_TRUE(std::string(fastItem1) == std::string(s1)); fastItem1.release(); - + EXPECT_THROW(t1.setItem(1, s2), std::exception); EXPECT_THROW(t1.getItem(1), std::exception); ASSERT_TRUE(t1.getCount() == 1); } -TEST_F(PyHelpersTest, pyTupleTwoItems) -{ +TEST_F(PyHelpersTest, pyTupleTwoItems) { py::String s1("item_1"); py::String s2("item_2"); @@ -282,29 +266,25 @@ TEST_F(PyHelpersTest, pyTupleTwoItems) ASSERT_TRUE(std::string(fastItem2) == std::string(s2)); fastItem2.release(); - EXPECT_THROW(t2.setItem(2, s2), std::exception); EXPECT_THROW(t2.getItem(2), std::exception); ASSERT_TRUE(t2.getCount() == 2); } - -TEST_F(PyHelpersTest, pyListEmpty) -{ +TEST_F(PyHelpersTest, pyListEmpty) { py::String s1("item_1"); py::String s2("item_2"); py::List empty; ASSERT_TRUE(PyList_Check(empty) != 0); ASSERT_TRUE(empty.getCount() == 0); - + EXPECT_THROW(empty.setItem(0, s1), std::exception); EXPECT_THROW(empty.getItem(0), std::exception); } -TEST_F(PyHelpersTest, pyListOneItem) -{ +TEST_F(PyHelpersTest, pyListOneItem) { py::String s1("item_1"); py::String s2("item_2"); @@ -321,12 +301,11 @@ TEST_F(PyHelpersTest, pyListOneItem) ASSERT_TRUE(t1.getCount() == 1); ASSERT_TRUE(std::string(item1) == std::string(s1)); - + EXPECT_THROW(t1.getItem(1), std::exception); } -TEST_F(PyHelpersTest, pyListTwoItems) -{ +TEST_F(PyHelpersTest, pyListTwoItems) { py::String s1("item_1"); py::String s2("item_2"); @@ -343,49 +322,42 @@ TEST_F(PyHelpersTest, pyListTwoItems) t2.append(s2); ASSERT_TRUE(t2.getCount() == 2); - + py::String item2(t2.getItem(1)); ASSERT_TRUE(std::string(item2) == std::string(s2)); py::String fastItem2(t2.fastGetItem(1)); ASSERT_TRUE(std::string(fastItem2) == std::string(s2)); fastItem2.release(); - EXPECT_THROW(t2.getItem(2), std::exception); } -TEST_F(PyHelpersTest, pyDictEmpty) -{ +TEST_F(PyHelpersTest, pyDictEmpty) { py::Dict d; ASSERT_EQ(PyDict_Size(d), 0); ASSERT_TRUE(d.getItem("blah") == NULL); } -TEST_F(PyHelpersTest, pyDictExternalPyObjectFailed) -{ +TEST_F(PyHelpersTest, pyDictExternalPyObjectFailed) { // NULL object EXPECT_THROW(py::Dict(NULL), std::exception); // Wrong type (must be a dictionary) py::String s("1234"); - try - { + try { py::Dict d(s.release()); NTA_THROW << "py::Dict d(s) Should fail!!!"; - } - catch(...) - { + } catch (...) { } // SHOULDFAIL fails to fail :-) - //SHOULDFAIL(py::Dict(s)); + // SHOULDFAIL(py::Dict(s)); } -TEST_F(PyHelpersTest, pyDictExternalPyObjectSuccessful) -{ - PyObject * p = PyDict_New(); +TEST_F(PyHelpersTest, pyDictExternalPyObjectSuccessful) { + PyObject *p = PyDict_New(); PyDict_SetItem(p, py::String("1234"), py::String("5678")); - + py::Dict d(p); ASSERT_TRUE(PyDict_Contains(d, py::String("1234")) == 1); @@ -394,41 +366,34 @@ TEST_F(PyHelpersTest, pyDictExternalPyObjectSuccessful) ASSERT_TRUE(PyDict_Contains(d, py::String("777")) == 1); } - + // getItem with default (exisiting and non-exisitng key) -TEST_F(PyHelpersTest, pyDictGetItem) -{ +TEST_F(PyHelpersTest, pyDictGetItem) { py::Dict d; d.setItem("A", py::String("AAA")); - PyObject * defaultItem = (PyObject *)123; - - py::String A(d.getItem("A")); + PyObject *defaultItem = (PyObject *)123; + + py::String A(d.getItem("A")); ASSERT_TRUE(std::string(A) == std::string("AAA")); // No "B" in the dict, so expect to get the default item - PyObject * B = (d.getItem("B", defaultItem)); + PyObject *B = (d.getItem("B", defaultItem)); ASSERT_TRUE(B == defaultItem); PyDict_SetItem(d, py::String("777"), py::String("999")); ASSERT_TRUE(PyDict_Contains(d, py::String("777")) == 1); } - -TEST_F(PyHelpersTest, pyModule) -{ +TEST_F(PyHelpersTest, pyModule) { py::Module module("sys"); ASSERT_TRUE(std::string(PyModule_GetName(module)) == std::string("sys")); } -TEST_F(PyHelpersTest, pyClass) -{ - py::Class c("datetime", "date"); -} +TEST_F(PyHelpersTest, pyClass) { py::Class c("datetime", "date"); } + +TEST_F(PyHelpersTest, pyInstance) { -TEST_F(PyHelpersTest, pyInstance) -{ - py::Tuple args(3); args.setItem(0, py::Long(2000)); args.setItem(1, py::Long(11)); @@ -466,8 +431,7 @@ TEST_F(PyHelpersTest, pyInstance) } } -TEST_F(PyHelpersTest, pyCustomException) -{ +TEST_F(PyHelpersTest, pyCustomException) { py::Tuple args(1); args.setItem(0, py::String("error message!")); py::Instance e(PyExc_RuntimeError, args); @@ -475,12 +439,9 @@ TEST_F(PyHelpersTest, pyCustomException) PyErr_SetObject(PyExc_RuntimeError, e); - try - { + try { py::checkPyError(0); - } - catch (const nupic::Exception & e) - { + } catch (const nupic::Exception &e) { NTA_DEBUG << e.getMessage(); } } diff --git a/src/test/unit/types/BasicTypeTest.cpp b/src/test/unit/types/BasicTypeTest.cpp index 7ef7ef06a6..8b421ff369 100644 --- a/src/test/unit/types/BasicTypeTest.cpp +++ b/src/test/unit/types/BasicTypeTest.cpp @@ -29,8 +29,7 @@ using namespace nupic; -TEST(BasicTypeTest, isValid) -{ +TEST(BasicTypeTest, isValid) { ASSERT_TRUE(BasicType::isValid(NTA_BasicType_Byte)); ASSERT_TRUE(BasicType::isValid(NTA_BasicType_Int16)); ASSERT_TRUE(BasicType::isValid(NTA_BasicType_UInt16)); @@ -44,14 +43,12 @@ TEST(BasicTypeTest, isValid) ASSERT_TRUE(BasicType::isValid(NTA_BasicType_Handle)); ASSERT_TRUE(BasicType::isValid(NTA_BasicType_Bool)); - ASSERT_TRUE(!BasicType::isValid(NTA_BasicType_Last)); ASSERT_TRUE(!(BasicType::isValid(NTA_BasicType(NTA_BasicType_Last + 777)))); ASSERT_TRUE(!(BasicType::isValid(NTA_BasicType(-1)))); } -TEST(BasicTypeTest, getSize) -{ +TEST(BasicTypeTest, getSize) { ASSERT_TRUE(BasicType::getSize(NTA_BasicType_Byte) == 1); ASSERT_TRUE(BasicType::getSize(NTA_BasicType_Int16) == 2); ASSERT_TRUE(BasicType::getSize(NTA_BasicType_UInt16) == 2); @@ -62,36 +59,40 @@ TEST(BasicTypeTest, getSize) ASSERT_TRUE(BasicType::getSize(NTA_BasicType_Real32) == 4); ASSERT_TRUE(BasicType::getSize(NTA_BasicType_Real64) == 8); ASSERT_TRUE(BasicType::getSize(NTA_BasicType_Bool) == sizeof(bool)); - #ifdef NTA_DOUBLE_PRECISION - ASSERT_TRUE(BasicType::getSize(NTA_BasicType_Real) == 8); // Real64 - #else - ASSERT_TRUE(BasicType::getSize(NTA_BasicType_Real) == 4); // Real32 - #endif +#ifdef NTA_DOUBLE_PRECISION + ASSERT_TRUE(BasicType::getSize(NTA_BasicType_Real) == 8); // Real64 +#else + ASSERT_TRUE(BasicType::getSize(NTA_BasicType_Real) == 4); // Real32 +#endif ASSERT_TRUE(BasicType::getSize(NTA_BasicType_Handle) == sizeof(void *)); } -TEST(BasicTypeTest, getName) -{ +TEST(BasicTypeTest, getName) { ASSERT_TRUE(BasicType::getName(NTA_BasicType_Byte) == std::string("Byte")); ASSERT_TRUE(BasicType::getName(NTA_BasicType_Int16) == std::string("Int16")); - ASSERT_TRUE(BasicType::getName(NTA_BasicType_UInt16) == std::string("UInt16")); + ASSERT_TRUE(BasicType::getName(NTA_BasicType_UInt16) == + std::string("UInt16")); ASSERT_TRUE(BasicType::getName(NTA_BasicType_Int32) == std::string("Int32")); - ASSERT_TRUE(BasicType::getName(NTA_BasicType_UInt32) == std::string("UInt32")); + ASSERT_TRUE(BasicType::getName(NTA_BasicType_UInt32) == + std::string("UInt32")); ASSERT_TRUE(BasicType::getName(NTA_BasicType_Int64) == std::string("Int64")); - ASSERT_TRUE(BasicType::getName(NTA_BasicType_UInt64) == std::string("UInt64")); - ASSERT_TRUE(BasicType::getName(NTA_BasicType_Real32) == std::string("Real32")); - ASSERT_TRUE(BasicType::getName(NTA_BasicType_Real64) == std::string("Real64")); - #ifdef NTA_DOUBLE_PRECISION - ASSERT_TRUE(BasicType::getName(NTA_BasicType_Real) == std::string("Real64")); - #else - ASSERT_TRUE(BasicType::getName(NTA_BasicType_Real) == std::string("Real32")); - #endif - ASSERT_TRUE(BasicType::getName(NTA_BasicType_Handle) == std::string("Handle")); + ASSERT_TRUE(BasicType::getName(NTA_BasicType_UInt64) == + std::string("UInt64")); + ASSERT_TRUE(BasicType::getName(NTA_BasicType_Real32) == + std::string("Real32")); + ASSERT_TRUE(BasicType::getName(NTA_BasicType_Real64) == + std::string("Real64")); +#ifdef NTA_DOUBLE_PRECISION + ASSERT_TRUE(BasicType::getName(NTA_BasicType_Real) == std::string("Real64")); +#else + ASSERT_TRUE(BasicType::getName(NTA_BasicType_Real) == std::string("Real32")); +#endif + ASSERT_TRUE(BasicType::getName(NTA_BasicType_Handle) == + std::string("Handle")); ASSERT_TRUE(BasicType::getName(NTA_BasicType_Bool) == std::string("Bool")); } -TEST(BasicTypeTest, parse) -{ +TEST(BasicTypeTest, parse) { ASSERT_TRUE(BasicType::parse("Byte") == NTA_BasicType_Byte); ASSERT_TRUE(BasicType::parse("Int16") == NTA_BasicType_Int16); ASSERT_TRUE(BasicType::parse("UInt16") == NTA_BasicType_UInt16); diff --git a/src/test/unit/types/ExceptionTest.cpp b/src/test/unit/types/ExceptionTest.cpp index e2ff51fefb..fb01b2c83f 100644 --- a/src/test/unit/types/ExceptionTest.cpp +++ b/src/test/unit/types/ExceptionTest.cpp @@ -24,35 +24,27 @@ * Implementation of Fraction test */ -#include #include +#include using namespace nupic; -TEST(ExceptionTest, Basic) -{ - try - { +TEST(ExceptionTest, Basic) { + try { throw nupic::Exception("FFF", 123, "MMM"); - } - catch (const Exception & e) - { + } catch (const Exception &e) { ASSERT_EQ(std::string(e.getFilename()), std::string("FFF")); ASSERT_EQ(e.getLineNumber(), 123); ASSERT_EQ(std::string(e.getMessage()), std::string("MMM")); ASSERT_EQ(std::string(e.getStackTrace()), std::string("")); } - try - { + try { throw nupic::Exception("FFF", 123, "MMM", "TB"); - } - catch (const Exception & e) - { + } catch (const Exception &e) { ASSERT_EQ(std::string(e.getFilename()), std::string("FFF")); ASSERT_EQ(e.getLineNumber(), 123); ASSERT_EQ(std::string(e.getMessage()), std::string("MMM")); ASSERT_EQ(std::string(e.getStackTrace()), std::string("TB")); - } + } } - diff --git a/src/test/unit/types/FractionTest.cpp b/src/test/unit/types/FractionTest.cpp index cc0b2b123f..fab44f0593 100644 --- a/src/test/unit/types/FractionTest.cpp +++ b/src/test/unit/types/FractionTest.cpp @@ -24,17 +24,16 @@ * Implementation of Fraction test */ -#include -#include +#include #include +#include #include -#include +#include using namespace nupic; -TEST(FractionTest, All) -{ - //create fractions +TEST(FractionTest, All) { + // create fractions Fraction(1); Fraction(0); Fraction(-1); @@ -53,9 +52,9 @@ TEST(FractionTest, All) ASSERT_ANY_THROW(Fraction(tooLarge, 0)); ASSERT_ANY_THROW(Fraction(tooLarge, 1)); ASSERT_ANY_THROW(Fraction(0, tooLarge)); - // There is some strange interaction with the SHOULDFAIL macro here. + // There is some strange interaction with the SHOULDFAIL macro here. // Without this syntax, the compiler thinks we're declaring a new variable - // tooLarge of type Fraction (which masks the old tooLarge). + // tooLarge of type Fraction (which masks the old tooLarge). ASSERT_ANY_THROW(Fraction x(tooLarge)); ASSERT_ANY_THROW(Fraction(20000000)); ASSERT_ANY_THROW(Fraction(-tooLarge)); @@ -64,8 +63,8 @@ TEST(FractionTest, All) ASSERT_ANY_THROW(Fraction(-tooLarge)); ASSERT_ANY_THROW(new Fraction(std::numeric_limits::max())); ASSERT_ANY_THROW(new Fraction(std::numeric_limits::min())); - - //Test isNaturalNumber() (natural numbers must be nonnegative) + + // Test isNaturalNumber() (natural numbers must be nonnegative) ASSERT_TRUE(Fraction(1).isNaturalNumber()); ASSERT_TRUE(Fraction(0).isNaturalNumber()); ASSERT_TRUE(!Fraction(-1).isNaturalNumber()); @@ -76,7 +75,7 @@ TEST(FractionTest, All) ASSERT_TRUE(!Fraction(3, 2).isNaturalNumber()); ASSERT_TRUE(!Fraction(-3, 2).isNaturalNumber()); - //Test getNumerator() + // Test getNumerator() ASSERT_EQ(2, Fraction(2, 1).getNumerator()); ASSERT_EQ(0, Fraction(0, 1).getNumerator()); ASSERT_EQ(-2, Fraction(-2, 1).getNumerator()); @@ -84,13 +83,13 @@ TEST(FractionTest, All) ASSERT_EQ(0, Fraction(0, -2).getNumerator()); ASSERT_EQ(-2, Fraction(-2, -2).getNumerator()); - //Test getDenominator() + // Test getDenominator() ASSERT_EQ(1, Fraction(0).getDenominator()); ASSERT_EQ(1, Fraction(2).getDenominator()); ASSERT_EQ(-2, Fraction(0, -2).getDenominator()); ASSERT_EQ(-2, Fraction(-2, -2).getDenominator()); - - //Test setNumerator() + + // Test setNumerator() Fraction b(1); b.setNumerator(0); ASSERT_EQ(0, b.getNumerator()); @@ -101,7 +100,7 @@ TEST(FractionTest, All) b.setNumerator(2); ASSERT_EQ(2, b.getNumerator()); - //Test setDenominator() + // Test setDenominator() ASSERT_ANY_THROW(Fraction(1).setDenominator(0)); b = Fraction(1); b.setDenominator(2); @@ -110,7 +109,7 @@ TEST(FractionTest, All) b.setDenominator(5); ASSERT_EQ(5, b.getDenominator()); - //Test setFraction() + // Test setFraction() ASSERT_ANY_THROW(Fraction(1).setFraction(1, 0)); ASSERT_ANY_THROW(Fraction(-2).setFraction(-3, 0)); b = Fraction(2); @@ -122,8 +121,8 @@ TEST(FractionTest, All) b = Fraction(0); b.setFraction(-6, 4); ASSERT_TRUE(Fraction(-6, 4) == b); - - //Test computeGCD() + + // Test computeGCD() ASSERT_EQ((UInt32)5, Fraction::computeGCD(5, 10)); ASSERT_EQ((UInt32)1, Fraction::computeGCD(1, 1)); ASSERT_EQ((UInt32)1, Fraction::computeGCD(0, 1)); @@ -131,14 +130,14 @@ TEST(FractionTest, All) ASSERT_EQ((UInt32)1, Fraction::computeGCD(1, 0)); ASSERT_EQ((UInt32)1, Fraction::computeGCD(1, -1)); - //Test computeLCM + // Test computeLCM ASSERT_EQ((UInt32)10, Fraction::computeLCM(5, 2)); ASSERT_EQ((UInt32)1, Fraction::computeLCM(1, 1)); ASSERT_EQ((UInt32)0, Fraction::computeLCM(0, 0)); ASSERT_EQ((UInt32)0, Fraction::computeLCM(0, -1)); - ASSERT_EQ((UInt32)0 , Fraction::computeLCM(-1, 2)); - - //Test reduce() + ASSERT_EQ((UInt32)0, Fraction::computeLCM(-1, 2)); + + // Test reduce() Fraction a = Fraction(1); a.reduce(); ASSERT_EQ(1, a.getNumerator()); @@ -172,40 +171,40 @@ TEST(FractionTest, All) ASSERT_EQ(-1, a.getNumerator()); ASSERT_EQ(3, a.getDenominator()); - //Test * + // Test * Fraction one = Fraction(1); Fraction zero = Fraction(0); Fraction neg_one = Fraction(-1); - ASSERT_TRUE(one == one*one); - ASSERT_TRUE(one == neg_one*neg_one); - ASSERT_TRUE(zero == zero*one); - ASSERT_TRUE(zero == zero*zero); - ASSERT_TRUE(zero == zero*neg_one); - ASSERT_TRUE(neg_one == one*neg_one); - ASSERT_TRUE(neg_one == neg_one*one); - ASSERT_TRUE(Fraction(10) == one*Fraction(20, 2)); + ASSERT_TRUE(one == one * one); + ASSERT_TRUE(one == neg_one * neg_one); + ASSERT_TRUE(zero == zero * one); + ASSERT_TRUE(zero == zero * zero); + ASSERT_TRUE(zero == zero * neg_one); + ASSERT_TRUE(neg_one == one * neg_one); + ASSERT_TRUE(neg_one == neg_one * one); + ASSERT_TRUE(Fraction(10) == one * Fraction(20, 2)); + + ASSERT_TRUE(one == one * 1); + ASSERT_TRUE(one == one * 1); + ASSERT_TRUE(zero == zero * 1); + ASSERT_TRUE(zero == zero * 1); + ASSERT_TRUE(zero == zero * -1); + ASSERT_TRUE(zero == zero * -1); + ASSERT_TRUE(-1 == one * -1); + ASSERT_TRUE(-1 == neg_one * 1); + ASSERT_TRUE(Fraction(10) == one * 10); + ASSERT_TRUE(Fraction(10) == neg_one * -10); - ASSERT_TRUE(one == one*1); - ASSERT_TRUE(one == one*1); - ASSERT_TRUE(zero == zero*1); - ASSERT_TRUE(zero == zero*1); - ASSERT_TRUE(zero == zero*-1); - ASSERT_TRUE(zero == zero*-1); - ASSERT_TRUE(-1 == one*-1); - ASSERT_TRUE(-1 == neg_one*1); - ASSERT_TRUE(Fraction(10) == one*10); - ASSERT_TRUE(Fraction(10) == neg_one*-10); - - //Test / - ASSERT_TRUE(one == one/one); - ASSERT_TRUE(zero == zero/one); - ASSERT_TRUE(zero == zero/neg_one); - ASSERT_TRUE(Fraction(-0) == zero/neg_one); - ASSERT_ANY_THROW(one/zero); - ASSERT_TRUE(Fraction(3, 2) == Fraction(3)/Fraction(2)); - ASSERT_TRUE(Fraction(2, -3) == Fraction(2)/Fraction(-3)); + // Test / + ASSERT_TRUE(one == one / one); + ASSERT_TRUE(zero == zero / one); + ASSERT_TRUE(zero == zero / neg_one); + ASSERT_TRUE(Fraction(-0) == zero / neg_one); + ASSERT_ANY_THROW(one / zero); + ASSERT_TRUE(Fraction(3, 2) == Fraction(3) / Fraction(2)); + ASSERT_TRUE(Fraction(2, -3) == Fraction(2) / Fraction(-3)); - //Test - + // Test - ASSERT_TRUE(zero == one - one); ASSERT_TRUE(neg_one == zero - one); ASSERT_TRUE(one == zero - neg_one); @@ -213,7 +212,7 @@ TEST(FractionTest, All) ASSERT_TRUE(Fraction(1, 2) == Fraction(3, 2) - one); ASSERT_TRUE(Fraction(-1, 2) == Fraction(-3, 2) - neg_one); - //Test + + // Test + ASSERT_TRUE(zero == neg_one + one); ASSERT_TRUE(one == zero + one); ASSERT_TRUE(one == (neg_one + one) + one); @@ -222,7 +221,7 @@ TEST(FractionTest, All) ASSERT_TRUE(Fraction(1, 2) == Fraction(-1, 2) + one); ASSERT_TRUE(Fraction(-3, 2) == neg_one + Fraction(-1, 2)); - //Test % + // Test % ASSERT_TRUE(Fraction(1, 2) == Fraction(3, 2) % one); ASSERT_TRUE(Fraction(-1, 2) == Fraction(-1, 2) % one); ASSERT_TRUE(Fraction(3, 2) == Fraction(7, 2) % Fraction(2)); @@ -230,9 +229,9 @@ TEST(FractionTest, All) ASSERT_TRUE(Fraction(-1, 2) == Fraction(-3, 2) % neg_one); ASSERT_TRUE(Fraction(1, 2) == Fraction(3, 2) % neg_one); ASSERT_ANY_THROW(Fraction(1, 2) % Fraction(0)); - ASSERT_ANY_THROW(Fraction(-3,2) % Fraction(0, -2)); + ASSERT_ANY_THROW(Fraction(-3, 2) % Fraction(0, -2)); - //Test < + // Test < ASSERT_TRUE(zero < one); ASSERT_TRUE(!(one < zero)); ASSERT_TRUE(!(zero < zero)); @@ -241,7 +240,7 @@ TEST(FractionTest, All) ASSERT_TRUE(Fraction(-3, 2) < Fraction(1, -2)); ASSERT_TRUE(Fraction(-1, 2) < Fraction(3, 2)); - //Test > + // Test > ASSERT_TRUE(one > zero); ASSERT_TRUE(!(zero > zero)); ASSERT_TRUE(!(one > one)); @@ -250,7 +249,7 @@ TEST(FractionTest, All) ASSERT_TRUE(Fraction(1, -2) > Fraction(-3, 2)); ASSERT_TRUE(Fraction(1, 2) > Fraction(-3, 2)); - //Test <= + // Test <= ASSERT_TRUE(zero <= one); ASSERT_TRUE(!(one <= zero)); ASSERT_TRUE(Fraction(1, 2) <= one); @@ -261,7 +260,7 @@ TEST(FractionTest, All) ASSERT_TRUE(neg_one <= neg_one); ASSERT_TRUE(Fraction(-7, 4) <= Fraction(14, -8)); - //Test >= + // Test >= ASSERT_TRUE(one >= zero); ASSERT_TRUE(!(zero >= one)); ASSERT_TRUE(one >= Fraction(1, 2)); @@ -272,7 +271,7 @@ TEST(FractionTest, All) ASSERT_TRUE(neg_one >= neg_one); ASSERT_TRUE(Fraction(-7, 4) >= Fraction(14, -8)); - //Test == + // Test == ASSERT_TRUE(one == one); ASSERT_TRUE(zero == zero); ASSERT_TRUE(!(one == zero)); @@ -281,7 +280,7 @@ TEST(FractionTest, All) ASSERT_TRUE(Fraction(0, 1) == Fraction(0, -1)); ASSERT_TRUE(Fraction(0, 1) == Fraction(0, 2)); - //Test << + // Test << std::stringstream ss; ss << Fraction(3, 4); EXPECT_STREQ("3/4", ss.str().c_str()); @@ -317,7 +316,7 @@ TEST(FractionTest, All) EXPECT_STREQ("1", ss.str().c_str()); ss.str(""); - //Test fromDouble() + // Test fromDouble() ASSERT_TRUE(one == Fraction::fromDouble(1.0)); ASSERT_TRUE(zero == Fraction::fromDouble(0.0)); ASSERT_TRUE(Fraction(1, 2) == Fraction::fromDouble(0.5)); @@ -326,15 +325,15 @@ TEST(FractionTest, All) ASSERT_TRUE(Fraction(1, 3) == Fraction::fromDouble(.3333333)); ASSERT_TRUE(Fraction(1, -3) == Fraction::fromDouble(-.33333333)); ASSERT_ANY_THROW(Fraction::fromDouble((double)(tooLarge))); - ASSERT_ANY_THROW(Fraction::fromDouble(1.0/(double)(tooLarge))); + ASSERT_ANY_THROW(Fraction::fromDouble(1.0 / (double)(tooLarge))); ASSERT_ANY_THROW(Fraction::fromDouble(-(double)tooLarge)); - ASSERT_ANY_THROW(Fraction::fromDouble(-1.0/(double)(tooLarge))); + ASSERT_ANY_THROW(Fraction::fromDouble(-1.0 / (double)(tooLarge))); ASSERT_ANY_THROW(Fraction::fromDouble(std::numeric_limits::max())); ASSERT_ANY_THROW(Fraction::fromDouble(std::numeric_limits::min())); ASSERT_ANY_THROW(Fraction::fromDouble(-std::numeric_limits::max())); ASSERT_ANY_THROW(Fraction::fromDouble(-std::numeric_limits::min())); - //Test toDouble() + // Test toDouble() ASSERT_EQ(0.0, Fraction(0).toDouble()); ASSERT_EQ(0.0, Fraction(-0).toDouble()); ASSERT_EQ(0.0, Fraction(0, 1).toDouble()); @@ -342,5 +341,4 @@ TEST(FractionTest, All) ASSERT_EQ(-0.5, Fraction(-1, 2).toDouble()); ASSERT_EQ(-0.5, Fraction(1, -2).toDouble()); ASSERT_EQ(0.5, Fraction(-1, -2).toDouble()); - } diff --git a/src/test/unit/utils/GroupByTest.cpp b/src/test/unit/utils/GroupByTest.cpp index 013afde63d..a93cb90c2c 100644 --- a/src/test/unit/utils/GroupByTest.cpp +++ b/src/test/unit/utils/GroupByTest.cpp @@ -24,11 +24,11 @@ * Implementation of unit tests for groupBy */ -#include #include "gtest/gtest.h" +#include -using std::tuple; using std::tie; +using std::tuple; using std::vector; using nupic::groupBy; @@ -36,372 +36,288 @@ using nupic::iterGroupBy; namespace { - struct ReturnValue1 { +struct ReturnValue1 { + int key; + vector results0; +}; + +TEST(GroupByTest, OneSequence) { + const vector sequence0 = {7, 12, 12, 16}; + + auto identity = [](int a) { return a; }; + + const vector expectedValues = { + {7, {7}}, {12, {12, 12}}, {16, {16}}}; + + // + // groupBy + // + size_t i = 0; + for (auto data : groupBy(sequence0, identity)) { int key; - vector results0; - }; - - TEST(GroupByTest, OneSequence) - { - const vector sequence0 = {7, 12, 12, 16}; - - auto identity = [](int a) { return a; }; - - const vector expectedValues = { - {7, {7}}, - {12, {12, 12}}, - {16, {16}} - }; - - // - // groupBy - // - size_t i = 0; - for (auto data : groupBy(sequence0, identity)) - { - int key; - vector::const_iterator - begin0, end0; - - tie(key, - begin0, end0) = data; - - const ReturnValue1 actualValue = - {key, {begin0, end0}}; - - EXPECT_EQ(expectedValues[i].key, actualValue.key); - EXPECT_EQ(expectedValues[i].results0, actualValue.results0); - - i++; - } - - // - // iterGroupBy - // - i = 0; - for (auto data : iterGroupBy( - sequence0.begin(), sequence0.end(), identity)) - { - int key; - vector::const_iterator - begin0, end0; - - tie(key, - begin0, end0) = data; - - const ReturnValue1 actualValue = - {key, {begin0, end0}}; - - EXPECT_EQ(expectedValues[i].key, actualValue.key); - EXPECT_EQ(expectedValues[i].results0, actualValue.results0); - - i++; - } - - EXPECT_EQ(expectedValues.size(), i); - } + vector::const_iterator begin0, end0; + + tie(key, begin0, end0) = data; + const ReturnValue1 actualValue = {key, {begin0, end0}}; + EXPECT_EQ(expectedValues[i].key, actualValue.key); + EXPECT_EQ(expectedValues[i].results0, actualValue.results0); + + i++; + } - struct ReturnValue2 { + // + // iterGroupBy + // + i = 0; + for (auto data : iterGroupBy(sequence0.begin(), sequence0.end(), identity)) { int key; - vector results0; - vector results1; - }; - - TEST(GroupByTest, TwoSequences) - { - const vector sequence0 = {7, 12, 16}; - const vector sequence1 = {3, 4, 5}; - - auto identity = [](int a) { return a; }; - auto times3 = [](int a) { return a*3; }; - - const vector expectedValues = { - {7, {7}, {}}, - {9, {}, {3}}, - {12, {12}, {4}}, - {15, {}, {5}}, - {16, {16}, {}} - }; - - // - // groupBy - // - size_t i = 0; - for (auto data : groupBy(sequence0, identity, - sequence1, times3)) - { - int key; - vector::const_iterator - begin0, end0, - begin1, end1; - - tie(key, - begin0, end0, - begin1, end1) = data; - - const ReturnValue2 actualValue = - {key, {begin0, end0}, {begin1, end1}}; - - EXPECT_EQ(expectedValues[i].key, actualValue.key); - EXPECT_EQ(expectedValues[i].results0, actualValue.results0); - EXPECT_EQ(expectedValues[i].results1, actualValue.results1); - - i++; - } - - // - // iterGroupBy - // - i = 0; - for (auto data : iterGroupBy( - sequence0.begin(), sequence0.end(), identity, - sequence1.begin(), sequence1.end(), times3)) - { - int key; - vector::const_iterator - begin0, end0, - begin1, end1; - - tie(key, - begin0, end0, - begin1, end1) = data; - - const ReturnValue2 actualValue = - {key, {begin0, end0}, {begin1, end1}}; - - EXPECT_EQ(expectedValues[i].key, actualValue.key); - EXPECT_EQ(expectedValues[i].results0, actualValue.results0); - EXPECT_EQ(expectedValues[i].results1, actualValue.results1); - - i++; - } - - EXPECT_EQ(expectedValues.size(), i); + vector::const_iterator begin0, end0; + + tie(key, begin0, end0) = data; + + const ReturnValue1 actualValue = {key, {begin0, end0}}; + + EXPECT_EQ(expectedValues[i].key, actualValue.key); + EXPECT_EQ(expectedValues[i].results0, actualValue.results0); + + i++; } + EXPECT_EQ(expectedValues.size(), i); +} + +struct ReturnValue2 { + int key; + vector results0; + vector results1; +}; + +TEST(GroupByTest, TwoSequences) { + const vector sequence0 = {7, 12, 16}; + const vector sequence1 = {3, 4, 5}; + + auto identity = [](int a) { return a; }; + auto times3 = [](int a) { return a * 3; }; + + const vector expectedValues = {{7, {7}, {}}, + {9, {}, {3}}, + {12, {12}, {4}}, + {15, {}, {5}}, + {16, {16}, {}}}; + + // + // groupBy + // + size_t i = 0; + for (auto data : groupBy(sequence0, identity, sequence1, times3)) { + int key; + vector::const_iterator begin0, end0, begin1, end1; + + tie(key, begin0, end0, begin1, end1) = data; + + const ReturnValue2 actualValue = {key, {begin0, end0}, {begin1, end1}}; + + EXPECT_EQ(expectedValues[i].key, actualValue.key); + EXPECT_EQ(expectedValues[i].results0, actualValue.results0); + EXPECT_EQ(expectedValues[i].results1, actualValue.results1); + i++; + } - struct ReturnValue3 { + // + // iterGroupBy + // + i = 0; + for (auto data : iterGroupBy(sequence0.begin(), sequence0.end(), identity, + sequence1.begin(), sequence1.end(), times3)) { int key; - vector results0; - vector results1; - vector results2; - }; - - TEST(GroupByTest, ThreeSequences) - { - const vector sequence0 = {7, 12, 16}; - const vector sequence1 = {3, 4, 5}; - const vector sequence2 = {3, 3, 4, 5}; - - auto identity = [](int a) { return a; }; - auto times3 = [](int a) { return a*3; }; - auto times4 = [](int a) { return a*4; }; - - const vector expectedValues = { - {7, {7}, {}, {}}, - {9, {}, {3}, {}}, - {12, {12}, {4}, {3, 3}}, - {15, {}, {5}, {}}, - {16, {16}, {}, {4}}, - {20, {}, {}, {5}} - }; - - // - // groupBy - // - size_t i = 0; - for (auto data : groupBy(sequence0, identity, - sequence1, times3, - sequence2, times4)) - { - int key; - vector::const_iterator - begin0, end0, - begin1, end1, - begin2, end2; - - tie(key, - begin0, end0, - begin1, end1, - begin2, end2) = data; - - const ReturnValue3 actualValue = - {key, {begin0, end0}, {begin1, end1}, {begin2, end2}}; - - EXPECT_EQ(expectedValues[i].key, actualValue.key); - EXPECT_EQ(expectedValues[i].results0, actualValue.results0); - EXPECT_EQ(expectedValues[i].results1, actualValue.results1); - EXPECT_EQ(expectedValues[i].results2, actualValue.results2); - - i++; - } - - // - // iterGroupBy - // - i = 0; - for (auto data : iterGroupBy( - sequence0.begin(), sequence0.end(), identity, - sequence1.begin(), sequence1.end(), times3, - sequence2.begin(), sequence2.end(), times4)) - { - int key; - vector::const_iterator - begin0, end0, - begin1, end1, - begin2, end2; - - tie(key, - begin0, end0, - begin1, end1, - begin2, end2) = data; - - const ReturnValue3 actualValue = - {key, {begin0, end0}, {begin1, end1}, {begin2, end2}}; - - EXPECT_EQ(expectedValues[i].key, actualValue.key); - EXPECT_EQ(expectedValues[i].results0, actualValue.results0); - EXPECT_EQ(expectedValues[i].results1, actualValue.results1); - EXPECT_EQ(expectedValues[i].results2, actualValue.results2); - - i++; - } - - EXPECT_EQ(expectedValues.size(), i); + vector::const_iterator begin0, end0, begin1, end1; + + tie(key, begin0, end0, begin1, end1) = data; + + const ReturnValue2 actualValue = {key, {begin0, end0}, {begin1, end1}}; + + EXPECT_EQ(expectedValues[i].key, actualValue.key); + EXPECT_EQ(expectedValues[i].results0, actualValue.results0); + EXPECT_EQ(expectedValues[i].results1, actualValue.results1); + + i++; } + EXPECT_EQ(expectedValues.size(), i); +} - struct ReturnValue4 { +struct ReturnValue3 { + int key; + vector results0; + vector results1; + vector results2; +}; + +TEST(GroupByTest, ThreeSequences) { + const vector sequence0 = {7, 12, 16}; + const vector sequence1 = {3, 4, 5}; + const vector sequence2 = {3, 3, 4, 5}; + + auto identity = [](int a) { return a; }; + auto times3 = [](int a) { return a * 3; }; + auto times4 = [](int a) { return a * 4; }; + + const vector expectedValues = { + {7, {7}, {}, {}}, {9, {}, {3}, {}}, {12, {12}, {4}, {3, 3}}, + {15, {}, {5}, {}}, {16, {16}, {}, {4}}, {20, {}, {}, {5}}}; + + // + // groupBy + // + size_t i = 0; + for (auto data : + groupBy(sequence0, identity, sequence1, times3, sequence2, times4)) { int key; - vector results0; - vector results1; - vector results2; - vector results3; - }; - - TEST(GroupByTest, FourSequences) - { - const vector sequence0 = {7, 12, 16}; - const vector sequence1 = {3, 4, 5}; - const vector sequence2 = {3, 3, 4, 5}; - const vector sequence3 = {3, 3, 4, 5}; - - auto identity = [](int a) { return a; }; - auto times3 = [](int a) { return a*3; }; - auto times4 = [](int a) { return a*4; }; - auto times5 = [](int a) { return a*5; }; - - const vector expectedValues = { - {7, {7}, {}, {}, {}}, - {9, {}, {3}, {}, {}}, - {12, {12}, {4}, {3, 3}, {}}, - {15, {}, {5}, {}, {3, 3}}, - {16, {16}, {}, {4}, {}}, - {20, {}, {}, {5}, {4}}, - {25, {}, {}, {}, {5}} - }; - - // - // groupBy - // - size_t i = 0; - for (auto data : groupBy(sequence0, identity, - sequence1, times3, - sequence2, times4, - sequence3, times5)) - { - int key; - vector::const_iterator - begin0, end0, - begin1, end1, - begin2, end2, - begin3, end3; + vector::const_iterator begin0, end0, begin1, end1, begin2, end2; - tie(key, - begin0, end0, - begin1, end1, - begin2, end2, - begin3, end3) = data; - - const ReturnValue4 actualValue = - {key, {begin0, end0}, {begin1, end1}, {begin2, end2}, {begin3, end3}}; - - EXPECT_EQ(expectedValues[i].key, actualValue.key); - EXPECT_EQ(expectedValues[i].results0, actualValue.results0); - EXPECT_EQ(expectedValues[i].results1, actualValue.results1); - EXPECT_EQ(expectedValues[i].results2, actualValue.results2); - EXPECT_EQ(expectedValues[i].results3, actualValue.results3); - - i++; - } - - // - // iterGroupBy - // - i = 0; - for (auto data : iterGroupBy( - sequence0.begin(), sequence0.end(), identity, - sequence1.begin(), sequence1.end(), times3, - sequence2.begin(), sequence2.end(), times4, - sequence3.begin(), sequence3.end(), times5)) - { - int key; - vector::const_iterator - begin0, end0, - begin1, end1, - begin2, end2, - begin3, end3; + tie(key, begin0, end0, begin1, end1, begin2, end2) = data; + + const ReturnValue3 actualValue = { + key, {begin0, end0}, {begin1, end1}, {begin2, end2}}; - tie(key, - begin0, end0, - begin1, end1, - begin2, end2, - begin3, end3) = data; + EXPECT_EQ(expectedValues[i].key, actualValue.key); + EXPECT_EQ(expectedValues[i].results0, actualValue.results0); + EXPECT_EQ(expectedValues[i].results1, actualValue.results1); + EXPECT_EQ(expectedValues[i].results2, actualValue.results2); - const ReturnValue4 actualValue = - {key, {begin0, end0}, {begin1, end1}, {begin2, end2}, {begin3, end3}}; + i++; + } - EXPECT_EQ(expectedValues[i].key, actualValue.key); - EXPECT_EQ(expectedValues[i].results0, actualValue.results0); - EXPECT_EQ(expectedValues[i].results1, actualValue.results1); - EXPECT_EQ(expectedValues[i].results2, actualValue.results2); - EXPECT_EQ(expectedValues[i].results3, actualValue.results3); + // + // iterGroupBy + // + i = 0; + for (auto data : iterGroupBy(sequence0.begin(), sequence0.end(), identity, + sequence1.begin(), sequence1.end(), times3, + sequence2.begin(), sequence2.end(), times4)) { + int key; + vector::const_iterator begin0, end0, begin1, end1, begin2, end2; - i++; - } + tie(key, begin0, end0, begin1, end1, begin2, end2) = data; - EXPECT_EQ(expectedValues.size(), i); + const ReturnValue3 actualValue = { + key, {begin0, end0}, {begin1, end1}, {begin2, end2}}; + + EXPECT_EQ(expectedValues[i].key, actualValue.key); + EXPECT_EQ(expectedValues[i].results0, actualValue.results0); + EXPECT_EQ(expectedValues[i].results1, actualValue.results1); + EXPECT_EQ(expectedValues[i].results2, actualValue.results2); + + i++; } + EXPECT_EQ(expectedValues.size(), i); +} + +struct ReturnValue4 { + int key; + vector results0; + vector results1; + vector results2; + vector results3; +}; + +TEST(GroupByTest, FourSequences) { + const vector sequence0 = {7, 12, 16}; + const vector sequence1 = {3, 4, 5}; + const vector sequence2 = {3, 3, 4, 5}; + const vector sequence3 = {3, 3, 4, 5}; + + auto identity = [](int a) { return a; }; + auto times3 = [](int a) { return a * 3; }; + auto times4 = [](int a) { return a * 4; }; + auto times5 = [](int a) { return a * 5; }; + + const vector expectedValues = { + {7, {7}, {}, {}, {}}, {9, {}, {3}, {}, {}}, + {12, {12}, {4}, {3, 3}, {}}, {15, {}, {5}, {}, {3, 3}}, + {16, {16}, {}, {4}, {}}, {20, {}, {}, {5}, {4}}, + {25, {}, {}, {}, {5}}}; + + // + // groupBy + // + size_t i = 0; + for (auto data : groupBy(sequence0, identity, sequence1, times3, sequence2, + times4, sequence3, times5)) { + int key; + vector::const_iterator begin0, end0, begin1, end1, begin2, end2, + begin3, end3; + + tie(key, begin0, end0, begin1, end1, begin2, end2, begin3, end3) = data; + const ReturnValue4 actualValue = { + key, {begin0, end0}, {begin1, end1}, {begin2, end2}, {begin3, end3}}; + + EXPECT_EQ(expectedValues[i].key, actualValue.key); + EXPECT_EQ(expectedValues[i].results0, actualValue.results0); + EXPECT_EQ(expectedValues[i].results1, actualValue.results1); + EXPECT_EQ(expectedValues[i].results2, actualValue.results2); + EXPECT_EQ(expectedValues[i].results3, actualValue.results3); + + i++; + } - struct ReturnValue5 { + // + // iterGroupBy + // + i = 0; + for (auto data : iterGroupBy(sequence0.begin(), sequence0.end(), identity, + sequence1.begin(), sequence1.end(), times3, + sequence2.begin(), sequence2.end(), times4, + sequence3.begin(), sequence3.end(), times5)) { int key; - vector results0; - vector results1; - vector results2; - vector results3; - vector results4; - }; - - TEST(GroupByTest, FiveSequences) - { - const vector sequence0 = {7, 12, 16}; - const vector sequence1 = {3, 4, 5}; - const vector sequence2 = {3, 3, 4, 5}; - const vector sequence3 = {3, 3, 4, 5}; - const vector sequence4 = {2, 2, 3}; - - auto identity = [](int a) { return a; }; - auto times3 = [](int a) { return a*3; }; - auto times4 = [](int a) { return a*4; }; - auto times5 = [](int a) { return a*5; }; - auto times6 = [](int a) { return a*6; }; - - const vector expectedValues = { + vector::const_iterator begin0, end0, begin1, end1, begin2, end2, + begin3, end3; + + tie(key, begin0, end0, begin1, end1, begin2, end2, begin3, end3) = data; + + const ReturnValue4 actualValue = { + key, {begin0, end0}, {begin1, end1}, {begin2, end2}, {begin3, end3}}; + + EXPECT_EQ(expectedValues[i].key, actualValue.key); + EXPECT_EQ(expectedValues[i].results0, actualValue.results0); + EXPECT_EQ(expectedValues[i].results1, actualValue.results1); + EXPECT_EQ(expectedValues[i].results2, actualValue.results2); + EXPECT_EQ(expectedValues[i].results3, actualValue.results3); + + i++; + } + + EXPECT_EQ(expectedValues.size(), i); +} + +struct ReturnValue5 { + int key; + vector results0; + vector results1; + vector results2; + vector results3; + vector results4; +}; + +TEST(GroupByTest, FiveSequences) { + const vector sequence0 = {7, 12, 16}; + const vector sequence1 = {3, 4, 5}; + const vector sequence2 = {3, 3, 4, 5}; + const vector sequence3 = {3, 3, 4, 5}; + const vector sequence4 = {2, 2, 3}; + + auto identity = [](int a) { return a; }; + auto times3 = [](int a) { return a * 3; }; + auto times4 = [](int a) { return a * 4; }; + auto times5 = [](int a) { return a * 5; }; + auto times6 = [](int a) { return a * 6; }; + + const vector expectedValues = { {7, {7}, {}, {}, {}, {}}, {9, {}, {3}, {}, {}, {}}, {12, {12}, {4}, {3, 3}, {}, {2, 2}}, @@ -409,92 +325,71 @@ namespace { {16, {16}, {}, {4}, {}, {}}, {18, {}, {}, {}, {}, {3}}, {20, {}, {}, {5}, {4}, {}}, - {25, {}, {}, {}, {5}, {}} - }; - - // - // groupBy - // - size_t i = 0; - for (auto data : groupBy(sequence0, identity, - sequence1, times3, - sequence2, times4, - sequence3, times5, - sequence4, times6)) - { - int key; - vector::const_iterator - begin0, end0, - begin1, end1, - begin2, end2, - begin3, end3, - begin4, end4; - - tie(key, - begin0, end0, - begin1, end1, - begin2, end2, - begin3, end3, - begin4, end4) = data; - - const ReturnValue5 actualValue = - {key, - {begin0, end0}, {begin1, end1}, - {begin2, end2}, {begin3, end3}, - {begin4, end4}}; - - EXPECT_EQ(expectedValues[i].key, actualValue.key); - EXPECT_EQ(expectedValues[i].results0, actualValue.results0); - EXPECT_EQ(expectedValues[i].results1, actualValue.results1); - EXPECT_EQ(expectedValues[i].results2, actualValue.results2); - EXPECT_EQ(expectedValues[i].results3, actualValue.results3); - EXPECT_EQ(expectedValues[i].results4, actualValue.results4); - - i++; - } - - // - // iterGroupBy - // - i = 0; - for (auto data : iterGroupBy( - sequence0.begin(), sequence0.end(), identity, - sequence1.begin(), sequence1.end(), times3, - sequence2.begin(), sequence2.end(), times4, - sequence3.begin(), sequence3.end(), times5, - sequence4.begin(), sequence4.end(), times6)) - { - int key; - vector::const_iterator - begin0, end0, - begin1, end1, - begin2, end2, - begin3, end3, - begin4, end4; - - tie(key, - begin0, end0, - begin1, end1, - begin2, end2, - begin3, end3, - begin4, end4) = data; - - const ReturnValue5 actualValue = - {key, - {begin0, end0}, {begin1, end1}, - {begin2, end2}, {begin3, end3}, - {begin4, end4}}; - - EXPECT_EQ(expectedValues[i].key, actualValue.key); - EXPECT_EQ(expectedValues[i].results0, actualValue.results0); - EXPECT_EQ(expectedValues[i].results1, actualValue.results1); - EXPECT_EQ(expectedValues[i].results2, actualValue.results2); - EXPECT_EQ(expectedValues[i].results3, actualValue.results3); - EXPECT_EQ(expectedValues[i].results4, actualValue.results4); - - i++; - } - - EXPECT_EQ(expectedValues.size(), i); + {25, {}, {}, {}, {5}, {}}}; + + // + // groupBy + // + size_t i = 0; + for (auto data : groupBy(sequence0, identity, sequence1, times3, sequence2, + times4, sequence3, times5, sequence4, times6)) { + int key; + vector::const_iterator begin0, end0, begin1, end1, begin2, end2, + begin3, end3, begin4, end4; + + tie(key, begin0, end0, begin1, end1, begin2, end2, begin3, end3, begin4, + end4) = data; + + const ReturnValue5 actualValue = {key, + {begin0, end0}, + {begin1, end1}, + {begin2, end2}, + {begin3, end3}, + {begin4, end4}}; + + EXPECT_EQ(expectedValues[i].key, actualValue.key); + EXPECT_EQ(expectedValues[i].results0, actualValue.results0); + EXPECT_EQ(expectedValues[i].results1, actualValue.results1); + EXPECT_EQ(expectedValues[i].results2, actualValue.results2); + EXPECT_EQ(expectedValues[i].results3, actualValue.results3); + EXPECT_EQ(expectedValues[i].results4, actualValue.results4); + + i++; } + + // + // iterGroupBy + // + i = 0; + for (auto data : iterGroupBy(sequence0.begin(), sequence0.end(), identity, + sequence1.begin(), sequence1.end(), times3, + sequence2.begin(), sequence2.end(), times4, + sequence3.begin(), sequence3.end(), times5, + sequence4.begin(), sequence4.end(), times6)) { + int key; + vector::const_iterator begin0, end0, begin1, end1, begin2, end2, + begin3, end3, begin4, end4; + + tie(key, begin0, end0, begin1, end1, begin2, end2, begin3, end3, begin4, + end4) = data; + + const ReturnValue5 actualValue = {key, + {begin0, end0}, + {begin1, end1}, + {begin2, end2}, + {begin3, end3}, + {begin4, end4}}; + + EXPECT_EQ(expectedValues[i].key, actualValue.key); + EXPECT_EQ(expectedValues[i].results0, actualValue.results0); + EXPECT_EQ(expectedValues[i].results1, actualValue.results1); + EXPECT_EQ(expectedValues[i].results2, actualValue.results2); + EXPECT_EQ(expectedValues[i].results3, actualValue.results3); + EXPECT_EQ(expectedValues[i].results4, actualValue.results4); + + i++; + } + + EXPECT_EQ(expectedValues.size(), i); } +} // namespace diff --git a/src/test/unit/utils/MovingAverageTest.cpp b/src/test/unit/utils/MovingAverageTest.cpp index 0eb2ef363b..137937d9c8 100644 --- a/src/test/unit/utils/MovingAverageTest.cpp +++ b/src/test/unit/utils/MovingAverageTest.cpp @@ -30,9 +30,7 @@ using namespace nupic; using namespace nupic::util; - -TEST(MovingAverage, Instance) -{ +TEST(MovingAverage, Instance) { MovingAverage m{3}; Real32 newAverage; @@ -73,9 +71,7 @@ TEST(MovingAverage, Instance) } }; - -TEST(MovingAverage, SlidingWindowInit) -{ +TEST(MovingAverage, SlidingWindowInit) { std::vector existingHistorical = {3.0, 4.0, 5.0}; MovingAverage m{3, existingHistorical}; ASSERT_EQ(m.getSlidingWindow(), existingHistorical); @@ -85,9 +81,7 @@ TEST(MovingAverage, SlidingWindowInit) ASSERT_EQ(m2.getSlidingWindow(), emptyVector); } - -TEST(MovingAverage, EqualsOperator) -{ +TEST(MovingAverage, EqualsOperator) { MovingAverage ma{3}; MovingAverage maP{3}; ASSERT_EQ(ma, maP); diff --git a/src/test/unit/utils/RandomPrivateOrig.c b/src/test/unit/utils/RandomPrivateOrig.c index 5a67430f13..7280a970a6 100644 --- a/src/test/unit/utils/RandomPrivateOrig.c +++ b/src/test/unit/utils/RandomPrivateOrig.c @@ -18,24 +18,23 @@ * * http://numenta.org/licenses/ * --------------------------------------------------------------------- - * - * Numenta note: this source code is from OpenBSD 2.0. + * + * Numenta note: this source code is from OpenBSD 2.0. * Went back this far to avoid G P L pollution as this - * code was absorbed into gcc. - * + * code was absorbed into gcc. + * * Small modifications have been made to allow it to compile - * with a modern C compiler and to avoid name clashes with the + * with a modern C compiler and to avoid name clashes with the * system-supplied random() functions - * - * Numenta Random number generation code in Random.cpp is based on this code. + * + * Numenta Random number generation code in Random.cpp is based on this code. * We retain the code here for unit testing. - * - * Build system note: this code is #included into RandomTest.cpp -- it does + * + * Build system note: this code is #included into RandomTest.cpp -- it does * not appear in build system files because the filename contains _private_ - * + * */ - /* * Copyright (c) 1983 Regents of the University of California. * All rights reserved. @@ -79,7 +78,6 @@ static char *rcsid = "$OpenBSD: random.c,v 1.3 1996/08/19 08:33:46 tholo Exp $"; // NTA -- declare to avoid implicit typing at first use long myrandom(); - /* * random.c: * @@ -103,7 +101,7 @@ long myrandom(); * state information, which will allow a degree seven polynomial. (Note: * the zeroeth word of state information also has some other information * stored in it -- see setstate() for details). - * + * * The random number generation technique is a linear feedback shift register * approach, employing trinomials (since there are fewer terms to sum up that * way). In this approach, the least significant bit of all the numbers in @@ -127,39 +125,39 @@ long myrandom(); * for the polynomial (actually a trinomial) that the R.N.G. is based on, and * the separation between the two lower order coefficients of the trinomial. */ -#define TYPE_0 0 /* linear congruential */ -#define BREAK_0 8 -#define DEG_0 0 -#define SEP_0 0 +#define TYPE_0 0 /* linear congruential */ +#define BREAK_0 8 +#define DEG_0 0 +#define SEP_0 0 -#define TYPE_1 1 /* x**7 + x**3 + 1 */ -#define BREAK_1 32 -#define DEG_1 7 -#define SEP_1 3 +#define TYPE_1 1 /* x**7 + x**3 + 1 */ +#define BREAK_1 32 +#define DEG_1 7 +#define SEP_1 3 -#define TYPE_2 2 /* x**15 + x + 1 */ -#define BREAK_2 64 -#define DEG_2 15 -#define SEP_2 1 +#define TYPE_2 2 /* x**15 + x + 1 */ +#define BREAK_2 64 +#define DEG_2 15 +#define SEP_2 1 -#define TYPE_3 3 /* x**31 + x**3 + 1 */ -#define BREAK_3 128 -#define DEG_3 31 -#define SEP_3 3 +#define TYPE_3 3 /* x**31 + x**3 + 1 */ +#define BREAK_3 128 +#define DEG_3 31 +#define SEP_3 3 -#define TYPE_4 4 /* x**63 + x + 1 */ -#define BREAK_4 256 -#define DEG_4 63 -#define SEP_4 1 +#define TYPE_4 4 /* x**63 + x + 1 */ +#define BREAK_4 256 +#define DEG_4 63 +#define SEP_4 1 /* * Array versions of the above information to make code run faster -- * relies on fact that TYPE_i == i. */ -#define MAX_TYPES 5 /* max number of types above */ +#define MAX_TYPES 5 /* max number of types above */ -static int degrees[MAX_TYPES] = { DEG_0, DEG_1, DEG_2, DEG_3, DEG_4 }; -static int seps [MAX_TYPES] = { SEP_0, SEP_1, SEP_2, SEP_3, SEP_4 }; +static int degrees[MAX_TYPES] = {DEG_0, DEG_1, DEG_2, DEG_3, DEG_4}; +static int seps[MAX_TYPES] = {SEP_0, SEP_1, SEP_2, SEP_3, SEP_4}; /* * Initially, everything is set up as if from: @@ -176,18 +174,14 @@ static int seps [MAX_TYPES] = { SEP_0, SEP_1, SEP_2, SEP_3, SEP_4 }; */ static long randtbl[DEG_3 + 1] = { - TYPE_3, - (long)0x991539b1, (long)0x16a5bce3, (long)0x6774a4cd, - (long)0x3e01511e, (long)0x4e508aaa, (long)0x61048c05, - (long)0xf5500617, (long)0x846b7115, (long)0x6a19892c, - (long)0x896a97af, (long)0xdb48f936, (long)0x14898454, - (long)0x37ffd106, (long)0xb58bff9c, (long)0x59e17104, - (long)0xcf918a49, (long)0x09378c83, (long)0x52c7a471, - (long)0x8d293ea9, (long)0x1f4fc301, (long)0xc3db71be, - (long)0x39b44e1c, (long)0xf8a44ef9, (long)0x4c8b80b1, - (long)0x19edc328, (long)0x87bf4bdd, (long)0xc9b240e5, - (long)0xe9ee4b1b, (long)0x4382aee7, (long)0x535b6b41, - (long)0xf3bec5da, + TYPE_3, (long)0x991539b1, (long)0x16a5bce3, (long)0x6774a4cd, + (long)0x3e01511e, (long)0x4e508aaa, (long)0x61048c05, (long)0xf5500617, + (long)0x846b7115, (long)0x6a19892c, (long)0x896a97af, (long)0xdb48f936, + (long)0x14898454, (long)0x37ffd106, (long)0xb58bff9c, (long)0x59e17104, + (long)0xcf918a49, (long)0x09378c83, (long)0x52c7a471, (long)0x8d293ea9, + (long)0x1f4fc301, (long)0xc3db71be, (long)0x39b44e1c, (long)0xf8a44ef9, + (long)0x4c8b80b1, (long)0x19edc328, (long)0x87bf4bdd, (long)0xc9b240e5, + (long)0xe9ee4b1b, (long)0x4382aee7, (long)0x535b6b41, (long)0xf3bec5da, }; /* @@ -235,31 +229,29 @@ static long *end_ptr = &randtbl[DEG_3 + 1]; * introduced by the L.C.R.N.G. Note that the initialization of randtbl[] * for default usage relies on values produced by this routine. */ -void -mysrandom(unsigned long x) -{ - long int test; - int i; - ldiv_t val; +void mysrandom(unsigned long x) { + long int test; + int i; + ldiv_t val; - if (rand_type == TYPE_0) - state[0] = x; - else { - state[0] = x; - for (i = 1; i < rand_deg; i++) { - /* - * Implement the following, without overflowing 31 bits: - * - * state[i] = (16807 * state[i - 1]) % 2147483647; - * - * 2^31-1 (prime) = 2147483647 = 127773*16807+2836 - */ - val = ldiv(state[i-1], 127773); - test = 16807 * val.rem - 2836 * val.quot; - state[i] = test + (test < 0 ? 2147483647 : 0); - } - fptr = &state[rand_sep]; - rptr = &state[0]; + if (rand_type == TYPE_0) + state[0] = x; + else { + state[0] = x; + for (i = 1; i < rand_deg; i++) { + /* + * Implement the following, without overflowing 31 bits: + * + * state[i] = (16807 * state[i - 1]) % 2147483647; + * + * 2^31-1 (prime) = 2147483647 = 127773*16807+2836 + */ + val = ldiv(state[i - 1], 127773); + test = 16807 * val.rem - 2836 * val.quot; + state[i] = test + (test < 0 ? 2147483647 : 0); + } + fptr = &state[rand_sep]; + rptr = &state[0]; #ifdef RANDOM_SUPERDEBUG printf("srandom: init for seed = %ld\n", x); for (i = 0; i < rand_deg; i++) { @@ -270,13 +262,14 @@ mysrandom(unsigned long x) (void)myrandom(); #ifdef RANDOM_SUPERDEBUG printf("srandom: after init for seed = %ld\n", x); - printf("srandom: *fptr = %ld; *rptr = %ld fptr = %p rptr = %p end_ptr = %p\n", *fptr, *rptr, fptr, rptr, end_ptr); + printf( + "srandom: *fptr = %ld; *rptr = %ld fptr = %p rptr = %p end_ptr = %p\n", + *fptr, *rptr, fptr, rptr, end_ptr); for (i = 0; i < rand_deg; i++) { printf("srandom %d %ld\n", i, state[i]); } #endif - - } + } } /* @@ -287,65 +280,63 @@ mysrandom(unsigned long x) * the break values for the different R.N.G.'s, we choose the best (largest) * one we can and set things up for it. srandom() is then called to * initialize the state information. - * + * * Note that on return from srandom(), we set state[-1] to be the type * multiplexed with the current value of the rear pointer; this is so * successive calls to initstate() won't lose this information and will be * able to restart with setstate(). - * + * * Note: the first thing we do is save the current state, if any, just like * setstate() so that it doesn't matter when initstate is called. * * Returns a pointer to the old state. */ -char * -myinitstate(unsigned long seed, char *arg_state, int n) +char *myinitstate(unsigned long seed, char *arg_state, int n) // NTA // unsigned long seed; /* seed for R.N.G. */ // char *arg_state; /* pointer to state array */ // int n; /* # bytes of state info */ // { - char *ostate = (char *)(&state[-1]); + char *ostate = (char *)(&state[-1]); - if (rand_type == TYPE_0) - state[-1] = rand_type; - else - state[-1] = MAX_TYPES * (rptr - state) + rand_type; - if (n < BREAK_0) { - (void)fprintf(stderr, - "random: not enough state (%d bytes); ignored.\n", n); - return(nullptr); - } - if (n < BREAK_1) { - rand_type = TYPE_0; - rand_deg = DEG_0; - rand_sep = SEP_0; - } else if (n < BREAK_2) { - rand_type = TYPE_1; - rand_deg = DEG_1; - rand_sep = SEP_1; - } else if (n < BREAK_3) { - rand_type = TYPE_2; - rand_deg = DEG_2; - rand_sep = SEP_2; - } else if (n < BREAK_4) { - rand_type = TYPE_3; - rand_deg = DEG_3; - rand_sep = SEP_3; - } else { - rand_type = TYPE_4; - rand_deg = DEG_4; - rand_sep = SEP_4; - } - state = &(((long *)arg_state)[1]); /* first location */ - end_ptr = &state[rand_deg]; /* must set end_ptr before srandom */ - mysrandom(seed); - if (rand_type == TYPE_0) - state[-1] = rand_type; - else - state[-1] = MAX_TYPES*(rptr - state) + rand_type; - return(ostate); + if (rand_type == TYPE_0) + state[-1] = rand_type; + else + state[-1] = MAX_TYPES * (rptr - state) + rand_type; + if (n < BREAK_0) { + (void)fprintf(stderr, "random: not enough state (%d bytes); ignored.\n", n); + return (nullptr); + } + if (n < BREAK_1) { + rand_type = TYPE_0; + rand_deg = DEG_0; + rand_sep = SEP_0; + } else if (n < BREAK_2) { + rand_type = TYPE_1; + rand_deg = DEG_1; + rand_sep = SEP_1; + } else if (n < BREAK_3) { + rand_type = TYPE_2; + rand_deg = DEG_2; + rand_sep = SEP_2; + } else if (n < BREAK_4) { + rand_type = TYPE_3; + rand_deg = DEG_3; + rand_sep = SEP_3; + } else { + rand_type = TYPE_4; + rand_deg = DEG_4; + rand_sep = SEP_4; + } + state = &(((long *)arg_state)[1]); /* first location */ + end_ptr = &state[rand_deg]; /* must set end_ptr before srandom */ + mysrandom(seed); + if (rand_type == TYPE_0) + state[-1] = rand_type; + else + state[-1] = MAX_TYPES * (rptr - state) + rand_type; + return (ostate); } /* @@ -363,39 +354,36 @@ myinitstate(unsigned long seed, char *arg_state, int n) * * Returns a pointer to the old state information. */ -char * -mysetstate(char *arg_state) -{ - long *new_state = (long *)arg_state; - int type = new_state[0] % MAX_TYPES; - int rear = new_state[0] / MAX_TYPES; - char *ostate = (char *)(&state[-1]); +char *mysetstate(char *arg_state) { + long *new_state = (long *)arg_state; + int type = new_state[0] % MAX_TYPES; + int rear = new_state[0] / MAX_TYPES; + char *ostate = (char *)(&state[-1]); - if (rand_type == TYPE_0) - state[-1] = rand_type; - else - state[-1] = MAX_TYPES * (rptr - state) + rand_type; - switch(type) { - case TYPE_0: - case TYPE_1: - case TYPE_2: - case TYPE_3: - case TYPE_4: - rand_type = type; - rand_deg = degrees[type]; - rand_sep = seps[type]; - break; - default: - (void)fprintf(stderr, - "random: state info corrupted; not changed.\n"); - } - state = &new_state[1]; - if (rand_type != TYPE_0) { - rptr = &state[rear]; - fptr = &state[(rear + rand_sep) % rand_deg]; - } - end_ptr = &state[rand_deg]; /* set end_ptr too */ - return(ostate); + if (rand_type == TYPE_0) + state[-1] = rand_type; + else + state[-1] = MAX_TYPES * (rptr - state) + rand_type; + switch (type) { + case TYPE_0: + case TYPE_1: + case TYPE_2: + case TYPE_3: + case TYPE_4: + rand_type = type; + rand_deg = degrees[type]; + rand_sep = seps[type]; + break; + default: + (void)fprintf(stderr, "random: state info corrupted; not changed.\n"); + } + state = &new_state[1]; + if (rand_type != TYPE_0) { + rptr = &state[rear]; + fptr = &state[(rear + rand_sep) % rand_deg]; + } + end_ptr = &state[rand_deg]; /* set end_ptr too */ + return (ostate); } /* @@ -415,30 +403,30 @@ mysetstate(char *arg_state) * * Returns a 31-bit random number. */ -long -myrandom() -{ - long i; +long myrandom() { + long i; - if (rand_type == TYPE_0) - i = state[0] = (state[0] * 1103515245 + 12345) & 0x7fffffff; - else { + if (rand_type == TYPE_0) + i = state[0] = (state[0] * 1103515245 + 12345) & 0x7fffffff; + else { #ifdef RANDOM_SUPERDEBUG - printf("random: *fptr = %ld; *rptr = %ld fptr = %p rptr = %p end_ptr = %p\n", *fptr, *rptr, fptr, rptr, end_ptr); + printf( + "random: *fptr = %ld; *rptr = %ld fptr = %p rptr = %p end_ptr = %p\n", + *fptr, *rptr, fptr, rptr, end_ptr); #endif - *fptr += *rptr; - i = (*fptr >> 1) & 0x7fffffff; /* chucking least random bit */ - if (++fptr >= end_ptr) { - fptr = state; - ++rptr; - } else if (++rptr >= end_ptr) - rptr = state; - } + *fptr += *rptr; + i = (*fptr >> 1) & 0x7fffffff; /* chucking least random bit */ + if (++fptr >= end_ptr) { + fptr = state; + ++rptr; + } else if (++rptr >= end_ptr) + rptr = state; + } #ifdef RANDOM_SUPERDEBUG printf("random: returning %ld\n", i); for (int j = 0; j < rand_deg; j++) { printf("srandom %d %ld\n", j, state[j]); } #endif - return(i); + return (i); } diff --git a/src/test/unit/utils/RandomTest.cpp b/src/test/unit/utils/RandomTest.cpp index 4ab1fd6370..ab589064b0 100644 --- a/src/test/unit/utils/RandomTest.cpp +++ b/src/test/unit/utils/RandomTest.cpp @@ -24,16 +24,15 @@ * @file */ - -#include +#include +#include #include +#include #include #include -#include +#include #include #include -#include -#include using namespace nupic; @@ -42,218 +41,181 @@ using namespace nupic; // Expected values with seed of 148 // Comparing against expected values ensures the same result // on all platforms. -UInt32 expected[1000] = -{ -33267067, 1308471064, 567525506, 744151466, 1514731226, -320263095, 850223971, 272111712, 1447892216, 1051935310, -767613522, 1966421570, 740590150, 238534644, 439141684, -1006390399, 683226375, 105163871, 1148420320, 1897583466, -1364189210, 2056512848, 2012249758, 724656770, 862951273, -739791114, 2018040533, 733864322, 181139176, 1764207967, -1203036433, 214406243, 925195384, 1770561940, 958557709, -292442962, 2090825035, 1808781680, 564554675, 1391233604, -713233342, 1332168197, 1210171526, 1453823492, 1570702841, -1649313210, 312730244, 106445569, 1754477081, 1461150564, -2004029035, 971182643, 1370179764, 1868795145, 1695839414, -85647389, 461102611, 1566396299, 819511712, 642241788, -1183120618, 2022548145, 856648031, 2108316002, 1645626437, -1815205741, 253275317, 1588967825, 1476503773, 817829992, -832717781, 42253468, 2514541, 2042889307, 1496076960, -1573217382, 1544718869, 1808807204, 1679662951, 1151712303, -1122474120, 1536208338, 2122894946, 345170236, 1257519835, -1671250712, 430817626, 1718622447, 1090163363, 1250329338, -213380587, 125800334, 1125393835, 1070028618, 86632688, -623536625, 737750711, 339908005, 65020802, 66770837, -1157737997, 897738583, 109024305, 1160252538, 793144242, -1605101265, 585986273, 190379463, 1266424822, 118165576, -1342091766, 241415294, 1654373915, 1317503065, 586585531, -764410102, 841270129, 1017403157, 335548901, 1931433493, -120248847, 548929488, 2057233827, 1245642682, 1618958107, -2143866515, 1869179307, 209225170, 336290873, 1934200109, -275996007, 1494028870, 684455044, 385020312, 506797761, -1477599286, 1990121578, 1092784034, 1667978750, 1109062752, -1210949610, 862586868, 1350478046, 717839877, 32606285, -1937063577, 1482249980, 873876415, 806983086, 1817798881, -657826260, 927231933, 219244722, 567576439, 25390968, -1838202829, 563959306, 1894570275, 2047427999, 900250179, -1681286737, 175940359, 246795402, 218258133, 560960671, -753593163, 1695857420, 403598601, 1846377197, 1216352522, -1512661353, 909843159, 2078939390, 715655752, 1627683037, -2111545676, 505235681, 962449369, 837938443, 1312218768, -632764602, 1495764703, 91967053, 852009324, 2063341142, -117358021, 542728505, 479816800, 2011928297, 442672857, -1380066980, 1545731386, 618613216, 1626862382, 1763989519, -1179573887, 232971897, 1312363291, 1583172489, 2079349094, -381232165, 948350194, 841708605, 312687908, 1664005946, -321907994, 276749936, 21757980, 1284357363, 1114688379, -1333976748, 1917121966, 462969434, 1425943801, 621647642, -378826928, 1543301823, 1164376148, 858643728, 1407746472, -1607049005, 91227060, 805994210, 78178573, 1718089442, -422500081, 1257752460, 1951061339, 1734863373, 693441301, -1882926785, 2116095538, 1641791496, 577151743, 281299798, -1158313794, 899059737, 558049734, 1180071774, 35933453, -1672738113, 366564874, 1953055419, 2135707547, 1792508676, -427219413, 367050827, 1188326851, 1591595561, 1225694556, -448589675, 1051160918, 1316921616, 1254583885, 1129339491, -887527411, 1677083966, 239608304, 691105102, 1264463691, -933049605, 426548240, 1233075582, 427357453, 1003699983, -1514375380, 1585671248, 1902759720, 2072425115, 618259374, -1938693173, 1597679580, 984824249, 1744264944, 1585903480, -629849277, 24000710, 1952954307, 1818176128, 1615596271, -1031165215, 119282155, 519273542, 200603184, 1373866040, -1648613033, 1088130595, 903466358, 1888221337, 1779235697, -20446402, 673787295, 58300289, 1253521984, 1101144748, -1062000272, 620413716, 539332348, 817276345, 545355183, -1157591723, 608485870, 2143034764, 2142415972, 205267167, -1581454596, 624781601, 229267877, 1386925255, 295474081, -1844864148, 270606823, 414756236, 216654042, 471210007, -1788622276, 1865267076, 1559340602, 544604986, 1606004765, -1191092651, 565051388, 132308412, 1249392941, 1818573372, -1233453161, 163909565, 291503441, 1772785509, 981185910, -836858624, 782893584, 1589671781, 832409740, 777825908, -1794938948, 266380688, 1402607509, 2024206825, 1653305944, -1698081590, 1721587325, 1923912767, 2112837826, 1938241368, -247639126, 1753976454, 1656024796, 1806979728, 151097793, -1114545913, 850588731, 716149181, 1246854326, 2099981672, -387238906, 332823839, 116407590, 678742347, 2105609348, -1097593500, 1515600971, 741019285, 539781633, 200527064, -1518845193, 187236933, 466907752, 773969055, 63960110, -2120213696, 324566997, 1785547436, 1896642815, 289921176, -1576305156, 2144281941, 2043897630, 1084846304, 1803778021, -47511775, 51908569, 506883105, 763660957, 1298762895, -459381129, 1150899863, 1631586734, 575788719, 1829642210, -1589712435, 1673382220, 1197759533, 183248072, 65680205, -1398286597, 1702093265, 252917139, 1865194350, 328578672, -316877249, 1837924398, 653145670, 2102424685, 1587083566, -943066846, 1531246193, 1583881859, 839480828, 468608849, -1240176233, 886992604, 520517419, 1747059338, 1650653561, -1819280314, 58956819, 654069776, 1303383401, 634745539, -336228338, 745612188, 160644111, 1533987871, 928860260, -226324316, 784790821, 483469877, 479241455, 502501523, -812048550, 796118705, 192942273, 1465194220, 751059742, -1780025839, 260777418, 134822288, 1216424051, 1100258246, -603431137, 309116636, 1987250850, 1123948556, 2056175974, -1490420763, 795745223, 2115132793, 2144490539, 2099128624, -602394684, 333235229, 697257164, 763038795, 1867223101, -1626117424, 989363112, 504530274, 2109587301, 1468604567, -1007031797, 774152203, 117239624, 1199974070, 91862775, -868299367, 832516262, 352640193, 1003121655, 2048940313, -1452898440, 1606552792, 210573301, 1292665642, 583017701, -119265627, 635602758, 1378762924, 86914772, 632609649, -1330407900, 689309457, 965844879, 2027665064, 1452348252, -685584332, 1506298840, 294227716, 1190114606, 1468402493, -1762832284, 49662755, 95071049, 1880071908, 1249636825, -186933824, 600887627, 2082153087, 539574018, 1604009282, -1983609752, 1992472458, 1063078427, 46699405, 1137654452, -1646096128, 165965032, 1773257210, 877375404, 252879805, -258383212, 60299656, 942189262, 1224228091, 2087964720, -247053866, 1909812423, 1446779912, 541281583, 952443381, -767698757, 156630219, 1002106136, 862769806, 2036702127, -104259313, 1049703631, 490106107, 38928753, 1589277649, -2094115389, 2022538505, 1434266459, 1009710168, 2069237911, -424437263, 508322648, 87719295, 50210826, 1385698052, -340599100, 308594038, 1445997708, 1282788362, 1532822129, -1386478780, 1529842229, 1295150904, 685775044, 2071123812, -100110637, 1453473802, 80270383, 1102216773, 168759960, -2116972510, 1206476086, 1218463591, 459594969, 1245404839, -660257592, 406226711, 1120459697, 2094524051, 1415936879, -1042213960, 371477667, 1924259528, 1129933255, 421688493, -1162473932, 1470532356, 730282531, 460987993, 605837070, -115621012, 1847466773, 2135679299, 1410771916, 385758170, -2059319463, 1510882553, 1839231972, 2139589846, 465615678, -2007991932, 2109078709, 1672091764, 1078971876, 421190030, -770012956, 1739229468, 827416741, 1890472653, 1686269872, -95869973, 785202965, 2057747539, 2020129501, 1915136220, -331952384, 1035119785, 1238184928, 1062234915, 1496107778, -1844021999, 1177855927, 1196090904, 1832217650, 441144195, -1581849074, 1744053466, 1952026748, 1273597398, 1736159664, -270158778, 1134105682, 1697754725, 1942250542, 65593910, -2118944756, 564779850, 1804823379, 798877849, 307768855, -1343609603, 894747822, 1092971820, 1253873494, 767393675, -860624393, 1585825878, 1802513461, 2098809321, 500577145, -1151137591, 1795347672, 1678433072, 199744847, 1480081675, -2119577267, 1781593921, 1076651493, 1924120367, 907707671, -665327509, 46795497, 2041813354, 215598587, 1989046039, -2107407264, 187059695, 406342242, 1764746995, 985937544, -714111097, 960872950, 1880685367, 1807082918, 67262796, -500595394, 520223663, 1653088674, 155625207, 471549336, -6182171, 1306762799, 119413361, 1684615243, 1506507646, -1599495036, 1656708862, 1140617920, 528662881, 1433345581, -2048325591, 1193990390, 1480141078, 1942655297, 1409588977, -1321703470, 1902578914, 1596648672, 1728045712, 1519842261, -435102569, 294673161, 333231564, 168304288, 2101756079, -400494360, 668899682, 474496094, 2053583035, 824524890, -946045431, 2059765206, 2131287689, 1065458792, 1596896802, -1490311687, 517470180, 1106122016, 483445959, 1046133061, -391983950, 384287903, 92639803, 1872125028, 179459552, -1502228781, 1046344850, 2082038466, 951393805, 626906914, -1454397080, 1386496374, 921580076, 1787628644, 1554800662, -875852507, 40639356, 76216697, 1350348602, 2094222391, -900741587, 148910385, 2006503950, 884545628, 1214369177, -1455917104, 227373667, 1731839357, 414555472, 710819627, -630488770, 806539422, 1095107530, 723128573, 531180803, -1274567082, 77873706, 1577525653, 1209121901, 1029267512, -56948920, 516035333, 268280238, 978528996, 156180329, -1823080901, 1854381503, 196819685, 1899297598, 1057246457, -143558429, 652555537, 1206156842, 2578731, 1537101165, -273042371, 1458495835, 1764474832, 2004881728, 1873051307, -327810811, 487886850, 532107082, 1422918341, 1211015424, -1063287885, 550001776, 1288889130, 493329890, 1759123677, -170672994, 550278810, 127675362, 438953233, 1528807806, -283855691, 114550486, 1235705662, 480675376, 2013848084, -145468471, 624233805, 518919973, 1351625314, 626812536, -2056021138, 1624667685, 2085308371, 1673012322, 1482065766, -1810876031, 2000823134, 1969952616, 195499465, 1276257827, -1033484392, 1258787350, 1826259603, 174889875, 1752117240, -1437899632, 345562869, 154912403, 1565574994, 784516102, -1683720209, 1849430685, 899066588, 771942223, 182622414, -765431024, 917410695, 806856219, 1284350997, 121552361, -1433668756, 1192888487, 1746220046, 1371493479, 718417162, -1080802164, 1034885862, 571756648, 903271133, 1230385327, -1848014475, 1936755525, 341689029, 1526790431, 2111645400, -2093806270, 817206415, 309724622, 101235025, 235297762, -1094240724, 1784955234, 2084728447, 1993307313, 409413810, -119867213, 611254689, 1326824505, 926723433, 1895605687, -1448376866, 212908541, 941010526, 1047113264, 1584402020, -1659427688, 2127915429, 471804235, 83700688, 883702914, -1702189562, 1931715164, 672974791, 2043878592, 1311021947, -637136544, 1990201214, 2128228362, 946861166, 2091436239, -216042476, 2041101890, 1728907825, 153287276, 1886925555, -2138321635, 273154489, 350696597, 1317662492, 1199877922, -98818636, 618555710, 1412786463, 1039829162, 1665668975, -849704836, 551773203, 1646100756, 1321509071, 635473891, -382320022, 876214985, 419705407, 1055294813, 772609929, -1730727354, 1692431357, 615327495, 1711472069, 491808875, -559280086, 1927514545, 385427118, 140704264, 2080801821, -124869025, 131542251, 206472663, 475565622, 1449204744, -1406350585, 574384258, 2067760454, 671653401, 1614213421, -1585945781, 1521358237, 18502976, 1084562889, 695383660, -653976867, 1466882911, 1571598645, 1073682275, 374694077, -196724927, 656925981, 2067125434, 812052422, 220914402, -411450662, 1371332509, 945300, 796877780, 1512036773, -2081747121, 921746805, 1643579024, 140736136, 1397312428, -945300120, 1547086722, 1971696686, 865576927, 71256475, -1438426459, 304039060, 1592614712, 1456929435, 1388601950, -140514724, 2110906303, 708001213, 1712113369, 1037104930, -1082695290, 1908838296, 1694030911, 1002337077, 573407071, -1914945314, 1413787739, 1944739580, 1915890614, 63181871, -1309292705, 1850154087, 984928676, 805388081, 1990890224, -234757456, 1750688202, 1390493298, 58970495, 468781481, -1461749773, 1497396954, 772820541, 906880837, 806842742, -13938843, 1047395561, 770265397, 721940057, 612025282, -1807370327, 1804635347, 373379931, 1353917590, 659488776, -946787002, 1121379256, 2073276515, 744042934, 889786222, -2136458386, 2053335639, 592456662, 973903415, 711240072 -}; - - -TEST(RandomTest, Seeding) -{ +UInt32 expected[1000] = { + 33267067, 1308471064, 567525506, 744151466, 1514731226, 320263095, + 850223971, 272111712, 1447892216, 1051935310, 767613522, 1966421570, + 740590150, 238534644, 439141684, 1006390399, 683226375, 105163871, + 1148420320, 1897583466, 1364189210, 2056512848, 2012249758, 724656770, + 862951273, 739791114, 2018040533, 733864322, 181139176, 1764207967, + 1203036433, 214406243, 925195384, 1770561940, 958557709, 292442962, + 2090825035, 1808781680, 564554675, 1391233604, 713233342, 1332168197, + 1210171526, 1453823492, 1570702841, 1649313210, 312730244, 106445569, + 1754477081, 1461150564, 2004029035, 971182643, 1370179764, 1868795145, + 1695839414, 85647389, 461102611, 1566396299, 819511712, 642241788, + 1183120618, 2022548145, 856648031, 2108316002, 1645626437, 1815205741, + 253275317, 1588967825, 1476503773, 817829992, 832717781, 42253468, + 2514541, 2042889307, 1496076960, 1573217382, 1544718869, 1808807204, + 1679662951, 1151712303, 1122474120, 1536208338, 2122894946, 345170236, + 1257519835, 1671250712, 430817626, 1718622447, 1090163363, 1250329338, + 213380587, 125800334, 1125393835, 1070028618, 86632688, 623536625, + 737750711, 339908005, 65020802, 66770837, 1157737997, 897738583, + 109024305, 1160252538, 793144242, 1605101265, 585986273, 190379463, + 1266424822, 118165576, 1342091766, 241415294, 1654373915, 1317503065, + 586585531, 764410102, 841270129, 1017403157, 335548901, 1931433493, + 120248847, 548929488, 2057233827, 1245642682, 1618958107, 2143866515, + 1869179307, 209225170, 336290873, 1934200109, 275996007, 1494028870, + 684455044, 385020312, 506797761, 1477599286, 1990121578, 1092784034, + 1667978750, 1109062752, 1210949610, 862586868, 1350478046, 717839877, + 32606285, 1937063577, 1482249980, 873876415, 806983086, 1817798881, + 657826260, 927231933, 219244722, 567576439, 25390968, 1838202829, + 563959306, 1894570275, 2047427999, 900250179, 1681286737, 175940359, + 246795402, 218258133, 560960671, 753593163, 1695857420, 403598601, + 1846377197, 1216352522, 1512661353, 909843159, 2078939390, 715655752, + 1627683037, 2111545676, 505235681, 962449369, 837938443, 1312218768, + 632764602, 1495764703, 91967053, 852009324, 2063341142, 117358021, + 542728505, 479816800, 2011928297, 442672857, 1380066980, 1545731386, + 618613216, 1626862382, 1763989519, 1179573887, 232971897, 1312363291, + 1583172489, 2079349094, 381232165, 948350194, 841708605, 312687908, + 1664005946, 321907994, 276749936, 21757980, 1284357363, 1114688379, + 1333976748, 1917121966, 462969434, 1425943801, 621647642, 378826928, + 1543301823, 1164376148, 858643728, 1407746472, 1607049005, 91227060, + 805994210, 78178573, 1718089442, 422500081, 1257752460, 1951061339, + 1734863373, 693441301, 1882926785, 2116095538, 1641791496, 577151743, + 281299798, 1158313794, 899059737, 558049734, 1180071774, 35933453, + 1672738113, 366564874, 1953055419, 2135707547, 1792508676, 427219413, + 367050827, 1188326851, 1591595561, 1225694556, 448589675, 1051160918, + 1316921616, 1254583885, 1129339491, 887527411, 1677083966, 239608304, + 691105102, 1264463691, 933049605, 426548240, 1233075582, 427357453, + 1003699983, 1514375380, 1585671248, 1902759720, 2072425115, 618259374, + 1938693173, 1597679580, 984824249, 1744264944, 1585903480, 629849277, + 24000710, 1952954307, 1818176128, 1615596271, 1031165215, 119282155, + 519273542, 200603184, 1373866040, 1648613033, 1088130595, 903466358, + 1888221337, 1779235697, 20446402, 673787295, 58300289, 1253521984, + 1101144748, 1062000272, 620413716, 539332348, 817276345, 545355183, + 1157591723, 608485870, 2143034764, 2142415972, 205267167, 1581454596, + 624781601, 229267877, 1386925255, 295474081, 1844864148, 270606823, + 414756236, 216654042, 471210007, 1788622276, 1865267076, 1559340602, + 544604986, 1606004765, 1191092651, 565051388, 132308412, 1249392941, + 1818573372, 1233453161, 163909565, 291503441, 1772785509, 981185910, + 836858624, 782893584, 1589671781, 832409740, 777825908, 1794938948, + 266380688, 1402607509, 2024206825, 1653305944, 1698081590, 1721587325, + 1923912767, 2112837826, 1938241368, 247639126, 1753976454, 1656024796, + 1806979728, 151097793, 1114545913, 850588731, 716149181, 1246854326, + 2099981672, 387238906, 332823839, 116407590, 678742347, 2105609348, + 1097593500, 1515600971, 741019285, 539781633, 200527064, 1518845193, + 187236933, 466907752, 773969055, 63960110, 2120213696, 324566997, + 1785547436, 1896642815, 289921176, 1576305156, 2144281941, 2043897630, + 1084846304, 1803778021, 47511775, 51908569, 506883105, 763660957, + 1298762895, 459381129, 1150899863, 1631586734, 575788719, 1829642210, + 1589712435, 1673382220, 1197759533, 183248072, 65680205, 1398286597, + 1702093265, 252917139, 1865194350, 328578672, 316877249, 1837924398, + 653145670, 2102424685, 1587083566, 943066846, 1531246193, 1583881859, + 839480828, 468608849, 1240176233, 886992604, 520517419, 1747059338, + 1650653561, 1819280314, 58956819, 654069776, 1303383401, 634745539, + 336228338, 745612188, 160644111, 1533987871, 928860260, 226324316, + 784790821, 483469877, 479241455, 502501523, 812048550, 796118705, + 192942273, 1465194220, 751059742, 1780025839, 260777418, 134822288, + 1216424051, 1100258246, 603431137, 309116636, 1987250850, 1123948556, + 2056175974, 1490420763, 795745223, 2115132793, 2144490539, 2099128624, + 602394684, 333235229, 697257164, 763038795, 1867223101, 1626117424, + 989363112, 504530274, 2109587301, 1468604567, 1007031797, 774152203, + 117239624, 1199974070, 91862775, 868299367, 832516262, 352640193, + 1003121655, 2048940313, 1452898440, 1606552792, 210573301, 1292665642, + 583017701, 119265627, 635602758, 1378762924, 86914772, 632609649, + 1330407900, 689309457, 965844879, 2027665064, 1452348252, 685584332, + 1506298840, 294227716, 1190114606, 1468402493, 1762832284, 49662755, + 95071049, 1880071908, 1249636825, 186933824, 600887627, 2082153087, + 539574018, 1604009282, 1983609752, 1992472458, 1063078427, 46699405, + 1137654452, 1646096128, 165965032, 1773257210, 877375404, 252879805, + 258383212, 60299656, 942189262, 1224228091, 2087964720, 247053866, + 1909812423, 1446779912, 541281583, 952443381, 767698757, 156630219, + 1002106136, 862769806, 2036702127, 104259313, 1049703631, 490106107, + 38928753, 1589277649, 2094115389, 2022538505, 1434266459, 1009710168, + 2069237911, 424437263, 508322648, 87719295, 50210826, 1385698052, + 340599100, 308594038, 1445997708, 1282788362, 1532822129, 1386478780, + 1529842229, 1295150904, 685775044, 2071123812, 100110637, 1453473802, + 80270383, 1102216773, 168759960, 2116972510, 1206476086, 1218463591, + 459594969, 1245404839, 660257592, 406226711, 1120459697, 2094524051, + 1415936879, 1042213960, 371477667, 1924259528, 1129933255, 421688493, + 1162473932, 1470532356, 730282531, 460987993, 605837070, 115621012, + 1847466773, 2135679299, 1410771916, 385758170, 2059319463, 1510882553, + 1839231972, 2139589846, 465615678, 2007991932, 2109078709, 1672091764, + 1078971876, 421190030, 770012956, 1739229468, 827416741, 1890472653, + 1686269872, 95869973, 785202965, 2057747539, 2020129501, 1915136220, + 331952384, 1035119785, 1238184928, 1062234915, 1496107778, 1844021999, + 1177855927, 1196090904, 1832217650, 441144195, 1581849074, 1744053466, + 1952026748, 1273597398, 1736159664, 270158778, 1134105682, 1697754725, + 1942250542, 65593910, 2118944756, 564779850, 1804823379, 798877849, + 307768855, 1343609603, 894747822, 1092971820, 1253873494, 767393675, + 860624393, 1585825878, 1802513461, 2098809321, 500577145, 1151137591, + 1795347672, 1678433072, 199744847, 1480081675, 2119577267, 1781593921, + 1076651493, 1924120367, 907707671, 665327509, 46795497, 2041813354, + 215598587, 1989046039, 2107407264, 187059695, 406342242, 1764746995, + 985937544, 714111097, 960872950, 1880685367, 1807082918, 67262796, + 500595394, 520223663, 1653088674, 155625207, 471549336, 6182171, + 1306762799, 119413361, 1684615243, 1506507646, 1599495036, 1656708862, + 1140617920, 528662881, 1433345581, 2048325591, 1193990390, 1480141078, + 1942655297, 1409588977, 1321703470, 1902578914, 1596648672, 1728045712, + 1519842261, 435102569, 294673161, 333231564, 168304288, 2101756079, + 400494360, 668899682, 474496094, 2053583035, 824524890, 946045431, + 2059765206, 2131287689, 1065458792, 1596896802, 1490311687, 517470180, + 1106122016, 483445959, 1046133061, 391983950, 384287903, 92639803, + 1872125028, 179459552, 1502228781, 1046344850, 2082038466, 951393805, + 626906914, 1454397080, 1386496374, 921580076, 1787628644, 1554800662, + 875852507, 40639356, 76216697, 1350348602, 2094222391, 900741587, + 148910385, 2006503950, 884545628, 1214369177, 1455917104, 227373667, + 1731839357, 414555472, 710819627, 630488770, 806539422, 1095107530, + 723128573, 531180803, 1274567082, 77873706, 1577525653, 1209121901, + 1029267512, 56948920, 516035333, 268280238, 978528996, 156180329, + 1823080901, 1854381503, 196819685, 1899297598, 1057246457, 143558429, + 652555537, 1206156842, 2578731, 1537101165, 273042371, 1458495835, + 1764474832, 2004881728, 1873051307, 327810811, 487886850, 532107082, + 1422918341, 1211015424, 1063287885, 550001776, 1288889130, 493329890, + 1759123677, 170672994, 550278810, 127675362, 438953233, 1528807806, + 283855691, 114550486, 1235705662, 480675376, 2013848084, 145468471, + 624233805, 518919973, 1351625314, 626812536, 2056021138, 1624667685, + 2085308371, 1673012322, 1482065766, 1810876031, 2000823134, 1969952616, + 195499465, 1276257827, 1033484392, 1258787350, 1826259603, 174889875, + 1752117240, 1437899632, 345562869, 154912403, 1565574994, 784516102, + 1683720209, 1849430685, 899066588, 771942223, 182622414, 765431024, + 917410695, 806856219, 1284350997, 121552361, 1433668756, 1192888487, + 1746220046, 1371493479, 718417162, 1080802164, 1034885862, 571756648, + 903271133, 1230385327, 1848014475, 1936755525, 341689029, 1526790431, + 2111645400, 2093806270, 817206415, 309724622, 101235025, 235297762, + 1094240724, 1784955234, 2084728447, 1993307313, 409413810, 119867213, + 611254689, 1326824505, 926723433, 1895605687, 1448376866, 212908541, + 941010526, 1047113264, 1584402020, 1659427688, 2127915429, 471804235, + 83700688, 883702914, 1702189562, 1931715164, 672974791, 2043878592, + 1311021947, 637136544, 1990201214, 2128228362, 946861166, 2091436239, + 216042476, 2041101890, 1728907825, 153287276, 1886925555, 2138321635, + 273154489, 350696597, 1317662492, 1199877922, 98818636, 618555710, + 1412786463, 1039829162, 1665668975, 849704836, 551773203, 1646100756, + 1321509071, 635473891, 382320022, 876214985, 419705407, 1055294813, + 772609929, 1730727354, 1692431357, 615327495, 1711472069, 491808875, + 559280086, 1927514545, 385427118, 140704264, 2080801821, 124869025, + 131542251, 206472663, 475565622, 1449204744, 1406350585, 574384258, + 2067760454, 671653401, 1614213421, 1585945781, 1521358237, 18502976, + 1084562889, 695383660, 653976867, 1466882911, 1571598645, 1073682275, + 374694077, 196724927, 656925981, 2067125434, 812052422, 220914402, + 411450662, 1371332509, 945300, 796877780, 1512036773, 2081747121, + 921746805, 1643579024, 140736136, 1397312428, 945300120, 1547086722, + 1971696686, 865576927, 71256475, 1438426459, 304039060, 1592614712, + 1456929435, 1388601950, 140514724, 2110906303, 708001213, 1712113369, + 1037104930, 1082695290, 1908838296, 1694030911, 1002337077, 573407071, + 1914945314, 1413787739, 1944739580, 1915890614, 63181871, 1309292705, + 1850154087, 984928676, 805388081, 1990890224, 234757456, 1750688202, + 1390493298, 58970495, 468781481, 1461749773, 1497396954, 772820541, + 906880837, 806842742, 13938843, 1047395561, 770265397, 721940057, + 612025282, 1807370327, 1804635347, 373379931, 1353917590, 659488776, + 946787002, 1121379256, 2073276515, 744042934, 889786222, 2136458386, + 2053335639, 592456662, 973903415, 711240072}; + +TEST(RandomTest, Seeding) { // make sure the global instance is seeded from time() - // in the test situation, we can be sure we were seeded less than 100000 seconds ago - // make sure random number system is initialized by creating a random - // object. Use the object to make sure the compiler doesn't complain about - // an unused variable + // in the test situation, we can be sure we were seeded less than 100000 + // seconds ago make sure random number system is initialized by creating a + // random object. Use the object to make sure the compiler doesn't complain + // about an unused variable Random r; UInt64 x = r.getUInt64(); ASSERT_TRUE(x != 0); @@ -265,8 +227,7 @@ TEST(RandomTest, Seeding) } } -TEST(RandomTest, CopyConstructor) -{ +TEST(RandomTest, CopyConstructor) { // test copy constructor. Random r1(289436); int i; @@ -275,8 +236,7 @@ TEST(RandomTest, CopyConstructor) Random r2(r1); UInt32 v1, v2; - for (i = 0; i < 100; i++) - { + for (i = 0; i < 100; i++) { v1 = r1.getUInt32(); v2 = r2.getUInt32(); if (v1 != v2) @@ -285,8 +245,7 @@ TEST(RandomTest, CopyConstructor) ASSERT_TRUE(v1 == v2) << "copy constructor"; } -TEST(RandomTest, OperatorEquals) -{ +TEST(RandomTest, OperatorEquals) { // test operator= Random r1(289436); int i; @@ -298,8 +257,7 @@ TEST(RandomTest, OperatorEquals) r2 = r1; UInt32 v1, v2; - for (i = 0; i < 100; i++) - { + for (i = 0; i < 100; i++) { v1 = r1.getUInt32(); v2 = r2.getUInt32(); if (v1 != v2) @@ -308,8 +266,7 @@ TEST(RandomTest, OperatorEquals) ASSERT_TRUE(v1 == v2) << "operator="; } -TEST(RandomTest, SerializationDeserialization) -{ +TEST(RandomTest, SerializationDeserialization) { // test serialization/deserialization Random r1(862973); int i; @@ -324,10 +281,15 @@ TEST(RandomTest, SerializationDeserialization) std::string x(ostream.str(), ostream.pcount()); NTA_INFO << "random serialize string: '" << x << "'"; // Serialization should be deterministic and platform independent - std::string expectedString = "random-v1 862973 RandomImpl 2 31 4241808047 927171440 115246761 3188485113 2358188524 3270869522 282383075 1082613868 441984109 995051899 2794036324 2239422562 898636415 2372250535 3369849014 2122900843 1895341779 2450525880 394177447 3199534303 3887683026 656347524 48907782 1135809043 334191338 2900562231 197628021 1265227140 569581351 466443697 843098206 7 10 endrandom-v1"; + std::string expectedString = + "random-v1 862973 RandomImpl 2 31 4241808047 927171440 115246761 " + "3188485113 2358188524 3270869522 282383075 1082613868 441984109 " + "995051899 2794036324 2239422562 898636415 2372250535 3369849014 " + "2122900843 1895341779 2450525880 394177447 3199534303 3887683026 " + "656347524 48907782 1135809043 334191338 2900562231 197628021 1265227140 " + "569581351 466443697 843098206 7 10 endrandom-v1"; ASSERT_EQ(expectedString, x); - // deserialize into r2 std::string s(ostream.str(), ostream.pcount()); std::stringstream ss(s); @@ -338,8 +300,7 @@ TEST(RandomTest, SerializationDeserialization) ASSERT_EQ(r1.getSeed(), r2.getSeed()); UInt32 v1, v2; - for (i = 0; i < 100; i++) - { + for (i = 0; i < 100; i++) { v1 = r1.getUInt32(); v2 = r2.getUInt32(); NTA_CHECK(v1 == v2); @@ -347,8 +308,7 @@ TEST(RandomTest, SerializationDeserialization) ASSERT_EQ(v1, v2) << "serialization"; } -TEST(RandomTest, ReturnInCorrectRange) -{ +TEST(RandomTest, ReturnInCorrectRange) { // make sure that we are returning values in the correct range // @todo perform statistical tests Random r; @@ -357,8 +317,7 @@ TEST(RandomTest, ReturnInCorrectRange) int i; UInt32 max32 = 10000000; UInt64 max64 = (UInt64)max32 * (UInt64)max32; - for (i = 0; i < 200; i++) - { + for (i = 0; i < 200; i++) { UInt32 i32 = r.getUInt32(max32); ASSERT_TRUE(i32 < max32) << "UInt32"; UInt64 i64 = r.getUInt64(max64); @@ -368,30 +327,28 @@ TEST(RandomTest, ReturnInCorrectRange) } } -TEST(RandomTest, getUInt64) -{ +TEST(RandomTest, getUInt64) { // tests for getUInt64 Random r1(1); - ASSERT_EQ(3723745761376425000, r1.getUInt64()) - << "check getUInt64, seed 1, first call"; + ASSERT_EQ(3723745761376425000, r1.getUInt64()) + << "check getUInt64, seed 1, first call"; ASSERT_EQ(7464235991977222558, r1.getUInt64()) - << "check getUInt64, seed 1, second call"; + << "check getUInt64, seed 1, second call"; Random r2(2); ASSERT_EQ(7543924162171776743, r2.getUInt64()) - << "check getUInt64, seed 2, first call"; + << "check getUInt64, seed 2, first call"; ASSERT_EQ(1206857364816002550, r2.getUInt64()) - << "check getUInt64, seed 2, second call"; + << "check getUInt64, seed 2, second call"; Random r3(7464235991977222558); ASSERT_EQ(3609339244249306794, r3.getUInt64()) - << "check getUInt64, big seed, first call"; + << "check getUInt64, big seed, first call"; ASSERT_EQ(4084830275585779078, r3.getUInt64()) - << "check getUInt64, big seed, second call"; + << "check getUInt64, big seed, second call"; } -TEST(RandomTest, getReal64) -{ +TEST(RandomTest, getReal64) { // tests for getReal64 Random r1(1); ASSERT_FLOAT_EQ(0.40250281741114691, r1.getReal64()); @@ -406,8 +363,7 @@ TEST(RandomTest, getReal64) ASSERT_FLOAT_EQ(0.23239565201722456, r3.getReal64()); } -TEST(RandomTest, Sampling) -{ +TEST(RandomTest, Sampling) { // tests for sampling UInt32 population[] = {1, 2, 3, 4}; @@ -435,27 +391,22 @@ TEST(RandomTest, Sampling) // nChoices > nPopulation UInt32 choices[5]; bool caught = false; - try - { + try { r.sample(population, 4, choices, 5); - } - catch (LoggingException& exc) - { + } catch (LoggingException &exc) { caught = true; } - ASSERT_TRUE(caught) - << "checking for exception from population too small"; + ASSERT_TRUE(caught) << "checking for exception from population too small"; } } -TEST(RandomTest, Shuffling) -{ +TEST(RandomTest, Shuffling) { // tests for shuffling Random r(42); UInt32 arr[] = {1, 2, 3, 4}; - UInt32* start = arr; - UInt32* end = start + 4; + UInt32 *start = arr; + UInt32 *end = start + 4; r.shuffle(start, end); ASSERT_EQ(1, arr[0]) << "check element 0"; @@ -464,13 +415,12 @@ TEST(RandomTest, Shuffling) ASSERT_EQ(2, arr[3]) << "check element 3"; } -TEST(RandomTest, CapnpSerialization) -{ +TEST(RandomTest, CapnpSerialization) { // tests for Cap'n Proto serialization Random r1, r2; UInt32 v1, v2; - const char* outputPath = "RandomTest1.temp"; + const char *outputPath = "RandomTest1.temp"; { std::ofstream out(outputPath, std::ios::binary); diff --git a/src/test/unit/utils/WatcherTest.cpp b/src/test/unit/utils/WatcherTest.cpp index 3ad8c6ece8..b8408b0118 100644 --- a/src/test/unit/utils/WatcherTest.cpp +++ b/src/test/unit/utils/WatcherTest.cpp @@ -24,9 +24,9 @@ * Implementation of Watcher test */ -#include -#include #include +#include +#include #include #include @@ -43,10 +43,8 @@ using namespace nupic; - -TEST(WatcherTest, SampleNetwork) -{ - //generate sample network +TEST(WatcherTest, SampleNetwork) { + // generate sample network Network n; n.addRegion("level1", "TestNode", ""); n.addRegion("level2", "TestNode", ""); @@ -59,76 +57,76 @@ TEST(WatcherTest, SampleNetwork) n.link("level2", "level3", "TestFanIn2", ""); n.initialize(); - //erase any previous contents of testfile + // erase any previous contents of testfile OFStream o("testfile"); o.close(); - - //test creation + + // test creation Watcher w("testfile"); - //test uint32Params + // test uint32Params unsigned int id1 = w.watchParam("level1", "uint32Param"); ASSERT_EQ(id1, (unsigned int)1); - //test uint64Params + // test uint64Params unsigned int id2 = w.watchParam("level1", "uint64Param"); ASSERT_EQ(id2, (unsigned int)2); - //test int32Params + // test int32Params w.watchParam("level1", "int32Param"); - //test int64Params + // test int64Params w.watchParam("level1", "int64Param"); - //test real32Params + // test real32Params w.watchParam("level1", "real32Param"); - //test real64Params + // test real64Params w.watchParam("level1", "real64Param"); - //test stringParams + // test stringParams w.watchParam("level1", "stringParam"); - //test unclonedParams + // test unclonedParams w.watchParam("level1", "unclonedParam", 0); w.watchParam("level1", "unclonedParam", 1); - - //test attachToNetwork() + + // test attachToNetwork() w.attachToNetwork(n); - //test two simultaneous Watchers on the same network with different files - Watcher* w2 = new Watcher("testfile2"); + // test two simultaneous Watchers on the same network with different files + Watcher *w2 = new Watcher("testfile2"); - //test int64ArrayParam + // test int64ArrayParam w2->watchParam("level1", "int64ArrayParam"); - //test real32ArrayParam + // test real32ArrayParam w2->watchParam("level1", "real32ArrayParam"); - //test output + // test output w2->watchOutput("level1", "bottomUpOut"); - //test int64ArrayParam, sparse = false + // test int64ArrayParam, sparse = false w2->watchParam("level1", "int64ArrayParam", -1, false); w2->attachToNetwork(n); - //set one of the uncloned parameters to 1 instead of 0 - //n.getRegions().getByName("level1")->getNodeAtIndex(1).setParameterUInt32("unclonedParam", (UInt32)1); - //n.run(3); - //see if Watcher notices change in parameter values after 3 iterations - n.getRegions().getByName("level1")->setParameterUInt64("uint64Param", (UInt64)66); + // set one of the uncloned parameters to 1 instead of 0 + // n.getRegions().getByName("level1")->getNodeAtIndex(1).setParameterUInt32("unclonedParam", + // (UInt32)1); n.run(3); see if Watcher notices change in parameter values + // after 3 iterations + n.getRegions().getByName("level1")->setParameterUInt64("uint64Param", + (UInt64)66); n.run(3); - //test flushFile() - this should produce output + // test flushFile() - this should produce output w.flushFile(); - //test closeFile() + // test closeFile() w.closeFile(); - //test to make sure data is flushed when Watcher is deleted + // test to make sure data is flushed when Watcher is deleted delete w2; } - -TEST(WatcherTest, FileTest1) -{ - //test file output + +TEST(WatcherTest, FileTest1) { + // test file output IFStream inStream("testfile"); std::string tempString; - if (inStream.is_open()) - { + if (inStream.is_open()) { getline(inStream, tempString); - ASSERT_EQ("Info: watchID, regionName, nodeType, nodeIndex, varName", tempString); + ASSERT_EQ("Info: watchID, regionName, nodeType, nodeIndex, varName", + tempString); getline(inStream, tempString); ASSERT_EQ("1, level1, TestNode, -1, uint32Param", tempString); getline(inStream, tempString); @@ -149,30 +147,24 @@ TEST(WatcherTest, FileTest1) ASSERT_EQ("9, level1, TestNode, 1, unclonedParam", tempString); getline(inStream, tempString); ASSERT_EQ("Data: watchID, iteration, paramValue", tempString); - + unsigned int i = 1; - while (! inStream.eof() ) - { + while (!inStream.eof()) { std::stringstream stream; std::string value; getline(inStream, tempString); - if (tempString.size() == 0) - { + if (tempString.size() == 0) { break; } - switch (tempString.at(0)) - { + switch (tempString.at(0)) { case '1': stream << "1, " << i << ", 33"; break; case '2': stream << "2, " << i; - if (i < 4) - { + if (i < 4) { stream << ", 66"; - } - else - { + } else { stream << ", 65"; } break; @@ -199,7 +191,7 @@ TEST(WatcherTest, FileTest1) i++; break; } - + value = stream.str(); ASSERT_EQ(value, tempString); } @@ -208,15 +200,14 @@ TEST(WatcherTest, FileTest1) Path::remove("testfile"); } - -TEST(WatcherTest, FileTest2) -{ + +TEST(WatcherTest, FileTest2) { IFStream inStream2("testfile2"); std::string tempString; - if (inStream2.is_open()) - { + if (inStream2.is_open()) { getline(inStream2, tempString); - ASSERT_EQ("Info: watchID, regionName, nodeType, nodeIndex, varName", tempString); + ASSERT_EQ("Info: watchID, regionName, nodeType, nodeIndex, varName", + tempString); getline(inStream2, tempString); ASSERT_EQ("1, level1, TestNode, -1, int64ArrayParam", tempString); getline(inStream2, tempString); @@ -227,19 +218,16 @@ TEST(WatcherTest, FileTest2) ASSERT_EQ("4, level1, TestNode, -1, int64ArrayParam", tempString); getline(inStream2, tempString); ASSERT_EQ("Data: watchID, iteration, paramValue", tempString); - + unsigned int i = 1; - while (! inStream2.eof() ) - { + while (!inStream2.eof()) { std::stringstream stream; std::string value; getline(inStream2, tempString); - if (tempString.size() == 0) - { + if (tempString.size() == 0) { break; } - switch (tempString.at(0)) - { + switch (tempString.at(0)) { case '1': stream << "1, " << i << ", 4 1 2 3"; break; @@ -248,18 +236,13 @@ TEST(WatcherTest, FileTest2) break; case '3': stream << "3, " << i << ", 64"; - if (i == 1) - { - for (unsigned int j = 3; j < 64; j+=2) - { + if (i == 1) { + for (unsigned int j = 3; j < 64; j += 2) { stream << " " << j; } - } - else - { + } else { stream << " 0"; - for (unsigned int j = 2; j < 64; j++) - { + for (unsigned int j = 2; j < 64; j++) { stream << " " << j; } } @@ -269,12 +252,12 @@ TEST(WatcherTest, FileTest2) i++; break; } - + value = stream.str(); ASSERT_EQ(value, tempString); } } inStream2.close(); - + Path::remove("testfile2"); }