Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added metadata node to all NIRNodes #79

Merged
merged 10 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

runs-on: ${{ matrix.os }}

Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
]

# MyST settings
nb_execution_mode = "off" # this can be turned to 'auto' once the package is stable
nb_execution_mode = "off" # this can be turned to 'auto' once the package is stable
nb_execution_timeout = 300
nb_execution_show_tb = True

Expand Down
100 changes: 90 additions & 10 deletions docs/source/primitives.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# Primitives

NIR defines 16 fundamental primitives listed in the table below, which backends are free to implement as they want, leading to varying outputs across platforms. While discrepancies could be minimized by constraining implementations or making backends aware of each other's discretization choices, NIR does not do this since it is declarative, specifying only the necessary inputs and outputs. Constraining implementations would cause hardware incompatibilities and making backends aware of each other could create large O(N^2) overhead for N backends. The primitives are already computationally expressive and able to solve complex PDEs.
At its core, NIR is simply a [directed graph](https://en.wikipedia.org/wiki/Directed_graph) (using the [`NIRGraph` primitive](https://github.com/neuromorphs/NIR/blob/main/nir/ir/graph.py)).
The nodes of the graph are computational units, and the edges are the (directed) connections between them.
There are no restrictions on the graph structure, so it can be a simple feedforward network, a recurrent network, a graph with cycles, and even with duplicated connections, if needed.

But, if you plan to execute the graph on restricted neuromorphic hardware, please **verify that the graph is compatible with the hardware**.

## NIR computational primitives

NIR defines 16 fundamental primitives listed in the table below, which backends are free to implement as they want, leading to varying outputs across platforms. While discrepancies could be minimized by constraining implementations or making backends aware of each other's discretization choices, NIR does not do this since it is declarative, specifying only the necessary inputs and outputs. Constraining implementations would cause hardware incompatibilities and making backends aware of each other could create large O(N^2) overhead for N backends. The primitives are already computationally expressive and able to solve complex PDEs.

| Primitive | Parameters | Computation | Reset |
|-|-|-|-|
Expand All @@ -21,17 +29,89 @@ NIR defines 16 fundamental primitives listed in the table below, which backends
| **AvgPooling** | $p$ | **SumPooling**; **Scale** | - |
| **Threshold** | $\theta_\text{thr}$ | $H(I - \theta_\text{thr})$ | - |

Each primitive is defined by their own dynamical equation, specified in the [API docs](https://nnir.readthedocs.io/en/latest/modindex.html).
Each primitive is defined by their own dynamical equation, specified in the [API docs](https://nnir.readthedocs.io/en/latest/).

## Connectivity

Each computational unit is a node in a static graph.
Given 3 nodes $A$ which is a LIF node, $B$ which is a Linear node and $C$ which is another LIF node, we can define edges in the graph such as:
In the graph, each node has a name like "Neuron 1" or, in some cases, simply just an index "1".
Connections between nodes are simply a tuple of the strings desribing the source and target.
As an example, `("A", "B")`, tells us that the output of node `A` is sent to node `B`.

Describing the full connectivity in a graph is as simple as listing all the connections in the graph:
```
[
("A", "B"),
("B", "C"),
("C", "D"),
...
]
```

## Input and output nodes
Given a graph, how do we know which nodes should receive inputs? And which nodes should provide outputs?
For that, we define two special nodes: `Input` and `Output`.
Both nodes are "dummies" in the sense that they do not provide any function, apart from marking the beginning and end of the graph.
Note that a single node can be both an input and an output node.

To clarify the dimensionality/input types of the input and output nodes, we require the user to specify the shape *and* name of the input, like so:
```python
import numpy as np
nir.Input(
input_type = {"input": np.array([28, 28])}
)
nir.Output(
output_type = {"output": np.array([2])}
)
```

## A Graph Example in Python
To illustrate how a computational graph can be defined using the NIR Python primitives, here is an example of a graph with a single `LIF` neuron with input and output nodes:

```python
import nir

nir.NIRGraph(
nodes = {
"input" : nir.Input({"input": np.array([1])}),
"lif" : nir.LIF(...),
"output": nir.Output{"output": np.array([1])}
},
edges = [
("Input", "LIF"),
("LIF" , "Output"),
],
)
```

## Metadata

Each node in the graph can have metadata attached to it.
The metadata is a dictionary that can contain any information that may be helpful for the user or backend.
Any dictionary entries can be added, although we recommend restricting the entries to strings, numbers, and arrays.
Here is an example of a metadata dictionary attached to a graph:

```python
import nir

nir.NIRGraph(
...,
metadata = {"some": "metadata", "info": 1}
)
```


```{admonition} Do not rely on the metadata
:class: warning
It's vital to ensure that **no backend should rely on this metadata**.
Metadata entries should contain non-essential meta-information about nodes or graphs, such as the discretization scheme with which the graph was trained, timestamps, etc.
Tidbits that can improve the model or execution, but are not necessary for the execution itself.

If the backend would strictly rely this metadata, it would require everyone else to adhere to this non-enforced standard.
NIR graphs should be self-contained and unambiguous, such that the graph itself (without the metadata) contains all the necessary information to execute the graph.
```

$$
A \rightarrow B \\
B \rightarrow C
$$
## Importing and exporting
While the NIR librray is written in Python, the graph can be defined and used in any language.
We provide import and export functions to and from the [Hierarchical Data Format](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) which allows for easy storage and retrieval of the graph.

## Format
The intermediate represenation can be stored as hdf5 file, which benefits from compression.
See [the usage page](usage) for more information.
7 changes: 5 additions & 2 deletions nir/ir/conv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -41,6 +41,9 @@ class Conv1d(NIRNode):
dilation: int # Dilation
groups: int # Groups
bias: np.ndarray # Bias C_out
input_type: Optional[Dict[str, np.ndarray]] = None
output_type: Optional[Dict[str, np.ndarray]] = None
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
if isinstance(self.padding, str) and self.padding not in ["same", "valid"]:
Expand Down
6 changes: 5 additions & 1 deletion nir/ir/delay.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

import numpy as np

Expand All @@ -16,6 +17,9 @@ class Delay(NIRNode):
"""

delay: np.ndarray # Delay
input_type: Optional[Dict[str, np.ndarray]] = None
output_type: Optional[Dict[str, np.ndarray]] = None
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
# set input and output shape, if not set by user
Expand Down
8 changes: 5 additions & 3 deletions nir/ir/flatten.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Dict
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

import numpy as np

Expand All @@ -21,6 +21,9 @@ class Flatten(NIRNode):
input_type: Types
start_dim: int = 1 # First dimension to flatten
end_dim: int = -1 # Last dimension to flatten
input_type: Optional[Dict[str, np.ndarray]] = None
output_type: Optional[Dict[str, np.ndarray]] = None
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.input_type = parse_shape_argument(self.input_type, "input")
Expand All @@ -41,7 +44,6 @@ def __post_init__(self):

def to_dict(self) -> Dict[str, Any]:
ret = super().to_dict()
del ret["input_type"]
ret["input_type"] = self.input_type["input"]
return ret

Expand Down
9 changes: 5 additions & 4 deletions nir/ir/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import Counter
from dataclasses import dataclass
from typing import Any, Dict
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

import numpy as np

Expand All @@ -27,6 +27,9 @@ class NIRGraph(NIRNode):

nodes: Nodes # List of computational nodes
edges: Edges # List of edges between nodes
input_type: Optional[Dict[str, np.ndarray]] = None
output_type: Optional[Dict[str, np.ndarray]] = None
metadata: Dict[str, Any] = field(default_factory=dict)

@property
def inputs(self):
Expand Down Expand Up @@ -456,7 +459,6 @@ def __post_init__(self):

def to_dict(self) -> Dict[str, Any]:
ret = super().to_dict()
del ret["input_type"]
ret["shape"] = self.input_type["input"]
return ret

Expand Down Expand Up @@ -484,7 +486,6 @@ def __post_init__(self):

def to_dict(self) -> Dict[str, Any]:
ret = super().to_dict()
del ret["output_type"]
ret["shape"] = self.output_type["output"]
return ret

Expand Down
6 changes: 5 additions & 1 deletion nir/ir/linear.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

import numpy as np

Expand All @@ -20,6 +21,9 @@ class Affine(NIRNode):

weight: np.ndarray # Weight term
bias: np.ndarray # Bias term
input_type: Optional[Dict[str, np.ndarray]] = None
output_type: Optional[Dict[str, np.ndarray]] = None
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
assert len(self.weight.shape) >= 2, "Weight must be at least 2D"
Expand Down
18 changes: 17 additions & 1 deletion nir/ir/neuron.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

import numpy as np

Expand Down Expand Up @@ -45,6 +46,9 @@ class CubaLIF(NIRNode):
v_leak: np.ndarray # Leak voltage
v_threshold: np.ndarray # Firing threshold
w_in: np.ndarray = 1.0 # Input current weight
input_type: Optional[Dict[str, np.ndarray]] = None
output_type: Optional[Dict[str, np.ndarray]] = None
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
assert (
Expand All @@ -71,6 +75,9 @@ class I(NIRNode): # noqa: E742
"""

r: np.ndarray
input_type: Optional[Dict[str, np.ndarray]] = None
output_type: Optional[Dict[str, np.ndarray]] = None
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.input_type = {"input": np.array(self.r.shape)}
Expand Down Expand Up @@ -101,6 +108,9 @@ class IF(NIRNode):

r: np.ndarray # Resistance
v_threshold: np.ndarray # Firing threshold
input_type: Optional[Dict[str, np.ndarray]] = None
output_type: Optional[Dict[str, np.ndarray]] = None
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
assert (
Expand All @@ -127,6 +137,9 @@ class LI(NIRNode):
tau: np.ndarray # Time constant
r: np.ndarray # Resistance
v_leak: np.ndarray # Leak voltage
input_type: Optional[Dict[str, np.ndarray]] = None
output_type: Optional[Dict[str, np.ndarray]] = None
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
assert (
Expand Down Expand Up @@ -166,6 +179,9 @@ class LIF(NIRNode):
r: np.ndarray # Resistance
v_leak: np.ndarray # Leak voltage
v_threshold: np.ndarray # Firing threshold
input_type: Optional[Dict[str, np.ndarray]] = None
output_type: Optional[Dict[str, np.ndarray]] = None
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
assert (
Expand Down
12 changes: 9 additions & 3 deletions nir/ir/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,22 @@ class NIRNode:
instantiated.
"""

# Note: Adding input/output types as follows is ideal, but requires Python 3.10
# input_type: Types = field(init=False, kw_only=True)
# output_type: Types = field(init=False, kw_only=True)
# Note: Adding input/output types and metadata as follows is ideal, but requires Python 3.10
# TODO: implement this in 2025 when 3.9 is EOL
# input_type: Dict[str, np.ndarray] = field(init=False, kw_only=True)
# output_type: Dict[str, np.ndarray] = field(init=False, kw_only=True)
# metadata: Dict[str, Any] = field(init=True, default_factory=dict)

def __eq__(self, other):
return self is other

def to_dict(self) -> Dict[str, Any]:
"""Serialize into a dictionary."""
ret = asdict(self)
if "input_type" in ret.keys():
del ret["input_type"]
if "output_type" in ret.keys():
del ret["output_type"]
# Note: The customization below won't be automatically done recursively for nested NIRNode.
# Therefore, classes with nested NIRNode e.g. NIRGraph must implement its own to_dict
ret["type"] = type(self).__name__
Expand Down
6 changes: 5 additions & 1 deletion nir/ir/pooling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

import numpy as np

Expand All @@ -12,6 +13,9 @@ class SumPool2d(NIRNode):
kernel_size: np.ndarray # (Height, Width)
stride: np.ndarray # (Height, width)
padding: np.ndarray # (Height, width)
input_type: Optional[Dict[str, np.ndarray]] = None
output_type: Optional[Dict[str, np.ndarray]] = None
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.input_type = {"input": None}
Expand Down
6 changes: 5 additions & 1 deletion nir/ir/surrogate_gradient.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

import numpy as np

Expand All @@ -19,6 +20,9 @@ class Threshold(NIRNode):
"""

threshold: np.ndarray # Firing threshold
input_type: Optional[Dict[str, np.ndarray]] = None
output_type: Optional[Dict[str, np.ndarray]] = None
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
self.input_type = {"input": np.array(self.threshold.shape)}
Expand Down
2 changes: 1 addition & 1 deletion nir/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def calculate_conv_output(
/ _index_tuple(stride, i)
+ 1
)
shapes.append(int(shape))
shapes.append(int(shape.item()))
return np.array(shapes)


Expand Down
Loading
Loading