From 8a1d5eeb146d1f1af2c27fb3846aa06d40b33e50 Mon Sep 17 00:00:00 2001 From: Luiz Scheinkman Date: Tue, 6 Mar 2018 11:59:24 -0800 Subject: [PATCH] NUP-2475: Make sure Network API links use the same dtype at both ends --- bindings/py/tests/network_test.py | 71 +++++++++++++++++++++++++++++++ src/nupic/engine/Input.cpp | 2 + src/nupic/engine/Input.hpp | 5 +++ src/nupic/engine/Network.cpp | 8 ++++ src/nupic/engine/Output.cpp | 2 + src/nupic/engine/Output.hpp | 5 +++ 6 files changed, 93 insertions(+) diff --git a/bindings/py/tests/network_test.py b/bindings/py/tests/network_test.py index a9608fb7ef..4b49b0f418 100644 --- a/bindings/py/tests/network_test.py +++ b/bindings/py/tests/network_test.py @@ -21,6 +21,7 @@ import json import unittest +import pytest try: # NOTE need to import capnp first to activate the magic necessary for @@ -31,15 +32,67 @@ else: from nupic.proto.NetworkProto_capnp import NetworkProto +from nupic.bindings.regions.PyRegion import PyRegion import nupic.bindings.engine_internal as engine from nupic.bindings.tools.serialization_test_py_region import \ SerializationTestPyRegion +class TestLinks(PyRegion): + """ + Test region used to test link validation + """ + def __init__(self): pass + def initialize(self): pass + def compute(self): pass + def getOutputElementCount(self): pass + @classmethod + def getSpec(cls): + return { + "description": TestLinks.__doc__, + "singleNodeOnly": True, + "inputs": { + "UInt32": { + "description": "UInt32 Data", + "dataType": "UInt32", + "isDefaultInput": True, + "required": False, + "count": 0 + }, + "Real32": { + "description": "Real32 Data", + "dataType": "Real32", + "isDefaultInput": False, + "required": False, + "count": 0 + }, + }, + "outputs": { + "UInt32": { + "description": "UInt32 Data", + "dataType": "UInt32", + "isDefaultOutput": True, + "required": False, + "count": 0 + }, + "Real32": { + "description": "Real32 Data", + "dataType": "Real32", + "isDefaultOutput": False, + "required": False, + "count": 0 + }, + }, + "parameters": { } + } class NetworkTest(unittest.TestCase): + def setUp(self): + """Register test region""" + engine.Network.registerPyRegion(TestLinks.__module__, TestLinks.__name__) + @unittest.skipUnless( capnp, "pycapnp is not installed, skipping serialization test.") @@ -107,3 +160,21 @@ def testSimpleTwoRegionNetworkIntrospection(self): self.fail("Unable to iterate network links.") + def testNetworkLinkTypeValidation(self): + """ + This tests whether the links source and destination dtypes match + """ + network = engine.Network() + network.addRegion("from", "py.TestLinks", "") + network.addRegion("to", "py.TestLinks", "") + + # Check for valid links + network.link("from", "to", "UniformLink", "", "UInt32", "UInt32") + network.link("from", "to", "UniformLink", "", "Real32", "Real32") + + # Check for invalid links + with pytest.raises(RuntimeError): + network.link("from", "to", "UniformLink", "", "Real32", "UInt32") + with pytest.raises(RuntimeError): + network.link("from", "to", "UniformLink", "", "UInt32", "Real32") + diff --git a/src/nupic/engine/Input.cpp b/src/nupic/engine/Input.cpp index 82b17dce68..fcdb4406c3 100644 --- a/src/nupic/engine/Input.cpp +++ b/src/nupic/engine/Input.cpp @@ -124,6 +124,8 @@ const Array &Input::getData() const { return data_; } +NTA_BasicType Input::getDataType() const { return data_.getType(); } + Region &Input::getRegion() { return region_; } const std::vector &Input::getLinks() { return links_; } diff --git a/src/nupic/engine/Input.hpp b/src/nupic/engine/Input.hpp index c661c82bb8..00a9a8569c 100644 --- a/src/nupic/engine/Input.hpp +++ b/src/nupic/engine/Input.hpp @@ -156,6 +156,11 @@ class Input { */ const Array &getData() const; + /** + * Get the data type + */ + NTA_BasicType getDataType() const; + /** * * Get the Region that the input belongs to. diff --git a/src/nupic/engine/Network.cpp b/src/nupic/engine/Network.cpp index 08b79cf1ee..851ea7b5af 100644 --- a/src/nupic/engine/Network.cpp +++ b/src/nupic/engine/Network.cpp @@ -30,6 +30,8 @@ Implementation of the Network class #include #include +#include +#include #include #include #include // for register/unregister @@ -41,6 +43,7 @@ Implementation of the Network class #include #include #include +#include #include #include #include @@ -317,6 +320,11 @@ void Network::link(const std::string &srcRegionName, << " does not exist on region " << destRegionName; } + NTA_CHECK(srcOutput->getDataType() == destInput->getDataType()) + << "Network::link -- Mismatched data types." + << BasicType::getName(srcOutput->getDataType()) + << " != " << BasicType::getName(destInput->getDataType()); + // Create the link itself auto link = new Link(linkType, linkParams, srcOutput, destInput, propagationDelay); diff --git a/src/nupic/engine/Output.cpp b/src/nupic/engine/Output.cpp index e2536dd5ba..0f96f00309 100644 --- a/src/nupic/engine/Output.cpp +++ b/src/nupic/engine/Output.cpp @@ -108,4 +108,6 @@ size_t Output::getNodeOutputElementCount() const { bool Output::hasOutgoingLinks() { return (!links_.empty()); } +NTA_BasicType Output::getDataType() const { return data_->getType(); } + } // namespace nupic diff --git a/src/nupic/engine/Output.hpp b/src/nupic/engine/Output.hpp index 0dd8c8e4c0..ba5bf297d9 100644 --- a/src/nupic/engine/Output.hpp +++ b/src/nupic/engine/Output.hpp @@ -135,6 +135,11 @@ class Output { */ const Array &getData() const; + /** + * Get the data type + */ + NTA_BasicType getDataType() const; + /** * * Tells whether the output is region level.