Skip to content

Commit

Permalink
NUP-2475: Make sure Network API links use the same dtype at both ends
Browse files Browse the repository at this point in the history
  • Loading branch information
lscheinkman committed Mar 6, 2018
1 parent 1ece537 commit 8a1d5ee
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 0 deletions.
71 changes: 71 additions & 0 deletions bindings/py/tests/network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import json
import unittest
import pytest

try:
# NOTE need to import capnp first to activate the magic necessary for
Expand All @@ -31,15 +32,67 @@
else:
from nupic.proto.NetworkProto_capnp import NetworkProto

from nupic.bindings.regions.PyRegion import PyRegion

import nupic.bindings.engine_internal as engine
from nupic.bindings.tools.serialization_test_py_region import \
SerializationTestPyRegion


class TestLinks(PyRegion):
"""
Test region used to test link validation
"""
def __init__(self): pass
def initialize(self): pass
def compute(self): pass
def getOutputElementCount(self): pass
@classmethod
def getSpec(cls):
return {
"description": TestLinks.__doc__,
"singleNodeOnly": True,
"inputs": {
"UInt32": {
"description": "UInt32 Data",
"dataType": "UInt32",
"isDefaultInput": True,
"required": False,
"count": 0
},
"Real32": {
"description": "Real32 Data",
"dataType": "Real32",
"isDefaultInput": False,
"required": False,
"count": 0
},
},
"outputs": {
"UInt32": {
"description": "UInt32 Data",
"dataType": "UInt32",
"isDefaultOutput": True,
"required": False,
"count": 0
},
"Real32": {
"description": "Real32 Data",
"dataType": "Real32",
"isDefaultOutput": False,
"required": False,
"count": 0
},
},
"parameters": { }
}

class NetworkTest(unittest.TestCase):

def setUp(self):
"""Register test region"""
engine.Network.registerPyRegion(TestLinks.__module__, TestLinks.__name__)


@unittest.skipUnless(
capnp, "pycapnp is not installed, skipping serialization test.")
Expand Down Expand Up @@ -107,3 +160,21 @@ def testSimpleTwoRegionNetworkIntrospection(self):
self.fail("Unable to iterate network links.")


def testNetworkLinkTypeValidation(self):
"""
This tests whether the links source and destination dtypes match
"""
network = engine.Network()
network.addRegion("from", "py.TestLinks", "")
network.addRegion("to", "py.TestLinks", "")

# Check for valid links
network.link("from", "to", "UniformLink", "", "UInt32", "UInt32")
network.link("from", "to", "UniformLink", "", "Real32", "Real32")

# Check for invalid links
with pytest.raises(RuntimeError):
network.link("from", "to", "UniformLink", "", "Real32", "UInt32")
with pytest.raises(RuntimeError):
network.link("from", "to", "UniformLink", "", "UInt32", "Real32")

2 changes: 2 additions & 0 deletions src/nupic/engine/Input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ const Array &Input::getData() const {
return data_;
}

NTA_BasicType Input::getDataType() const { return data_.getType(); }

Region &Input::getRegion() { return region_; }

const std::vector<Link *> &Input::getLinks() { return links_; }
Expand Down
5 changes: 5 additions & 0 deletions src/nupic/engine/Input.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ class Input {
*/
const Array &getData() const;

/**
* Get the data type
*/
NTA_BasicType getDataType() const;

/**
*
* Get the Region that the input belongs to.
Expand Down
8 changes: 8 additions & 0 deletions src/nupic/engine/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Implementation of the Network class
#include <stdexcept>

#include <nupic/engine/Input.hpp>
#include <nupic/engine/Output.hpp>
#include <nupic/types/BasicType.hpp>
#include <nupic/engine/Link.hpp>
#include <nupic/engine/Network.hpp>
#include <nupic/engine/NuPIC.hpp> // for register/unregister
Expand All @@ -41,6 +43,7 @@ Implementation of the Network class
#include <nupic/os/Path.hpp>
#include <nupic/proto/NetworkProto.capnp.h>
#include <nupic/proto/RegionProto.capnp.h>
#include <nupic/types/BasicType.hpp>
#include <nupic/utils/Log.hpp>
#include <nupic/utils/StringUtils.hpp>
#include <yaml-cpp/yaml.h>
Expand Down Expand Up @@ -317,6 +320,11 @@ void Network::link(const std::string &srcRegionName,
<< " does not exist on region " << destRegionName;
}

NTA_CHECK(srcOutput->getDataType() == destInput->getDataType())
<< "Network::link -- Mismatched data types."
<< BasicType::getName(srcOutput->getDataType())
<< " != " << BasicType::getName(destInput->getDataType());

// Create the link itself
auto link =
new Link(linkType, linkParams, srcOutput, destInput, propagationDelay);
Expand Down
2 changes: 2 additions & 0 deletions src/nupic/engine/Output.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,6 @@ size_t Output::getNodeOutputElementCount() const {

bool Output::hasOutgoingLinks() { return (!links_.empty()); }

NTA_BasicType Output::getDataType() const { return data_->getType(); }

} // namespace nupic
5 changes: 5 additions & 0 deletions src/nupic/engine/Output.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ class Output {
*/
const Array &getData() const;

/**
* Get the data type
*/
NTA_BasicType getDataType() const;

/**
*
* Tells whether the output is region level.
Expand Down

0 comments on commit 8a1d5ee

Please sign in to comment.