diff --git a/src/nupic/engine/Network.cpp b/src/nupic/engine/Network.cpp index 1f90ca6741..e3df60eab4 100644 --- a/src/nupic/engine/Network.cpp +++ b/src/nupic/engine/Network.cpp @@ -319,10 +319,23 @@ void Network::link(const std::string &srcRegionName, << " does not exist on region " << destRegionName; } - NTA_CHECK(srcOutput->getDataType() == destInput->getDataType()) - << "Network::link -- Mismatched data types." - << BasicType::getName(srcOutput->getDataType()) - << " != " << BasicType::getName(destInput->getDataType()); + // Validate link data types + if (srcOutput->isSparse() == destInput->isSparse()) { + NTA_CHECK(srcOutput->getDataType() == destInput->getDataType()) + << "Network::link -- Mismatched data types." + << BasicType::getName(srcOutput->getDataType()) + << " != " << BasicType::getName(destInput->getDataType()); + } else if (srcOutput->isSparse()) { + // Sparse to dense: unit32 -> bool + NTA_CHECK(srcOutput->getDataType() == NTA_BasicType_UInt32 && + destInput->getDataType() == NTA_BasicType_Bool) + << "Network::link -- Sparse to Dense link: source must be uint32 and " + "destination must be boolean"; + } else if (destInput->isSparse()) { + // Dense to sparse: NTA_BasicType -> uint32 + NTA_CHECK(destInput->getDataType() == NTA_BasicType_UInt32) + << "Network::link -- Dense to Sparse link: destination must be uint32"; + } // Create the link itself auto link =