Skip to content

Commit

Permalink
[numpy] Refactor np module (example runs through) (apache#15055)
Browse files Browse the repository at this point in the history
* Refactor notebook

* notebook working with hybrid block

* More refactoring

* Remove unnecessary use_np_compat

* Use class decorator to initialize numpy ndarrays in parameter.py

* Clear notebook outputs

* Improve np decorator

* Remove npe op from optimizer

* Fix CI

* Fix functools.wraps issue in Python2

* Fix ci
  • Loading branch information
reminisce authored and Ying committed Jul 2, 2019
1 parent ff6c5ba commit 5344278
Show file tree
Hide file tree
Showing 19 changed files with 578 additions and 333 deletions.
257 changes: 153 additions & 104 deletions example/numpy/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fundamentals of MXNet Numpy Module\n",
"# Fundamentals of MXNet-NumPy Module\n",
"\n",
"## Namespaces for Imperative Programming\n",
"- `mxnet.numpy`: Regular NumPy operators\n",
"- `mxnet.numpy.random`: NumPy random operators\n",
"- `mxnet.numpy.linalg`: NumPy linear algebra operators\n",
"- `mxnet.numpy_extension`: Operators implemented in MXNet that do not exist in the official NumPy\n",
"- `mxnet.numpy_extension`: Operators implemented in MXNet that do not exist in the official NumPy and some utils (e.g. context related functions).\n",
"\n",
"## Operator Namespaces for Gluon\n",
"`F` can be either `mxnet.ndarray` or `mxnet.symbol`. Note that `np` and `npe` are aliases of `numpy` and `numpy_extension`, respectively.\n",
Expand All @@ -20,15 +20,27 @@
"- `F.npe`: Operators implemented in MXNet that do not exist in official NumPy\n",
"\n",
"## New `ndarray` and `symbol`\n",
"`mxnet.numpy.ndarray` (visible to users) and `mxnet.symbol.numpy._Symbol` (not visible to users)\n",
"`mxnet.numpy.ndarray` (visible to users) and `mxnet.symbol.numpy._Symbol` (not directly visible to users)\n",
"- Same name as in the official NumPy package\n",
"- Dispatch convience fluent method calls to MXNet Numpy operators\n",
"- Override many convenience fluent methods that do not exist in the official NumPy ndarray\n",
"- Make the behavior of built-in methods consistent with the official NumPy\n",
" - Indexing: `__getitem__` and `__setitem__`\n",
" - Many binary element-wise with broadcasting, not supported in `mxnet.symbol.Symbol`\n",
" \n",
"## Examples of ndarray and symbol Basics\n",
"## User Experience of Module Importing (In Progress)\n",
"**Legacy**\n",
"```python\n",
"import mxnet as mx\n",
"from mxnet import gluon\n",
"```\n",
"**Numpy**\n",
"```python\n",
"from mxnet import np, npe, gluon\n",
"```\n",
"\n",
" \n",
"## MXNet NumPy in Action\n",
"### Scalar and zero-size tensors"
]
},
Expand All @@ -41,9 +53,6 @@
"import mxnet as mx\n",
"from mxnet import numpy as np\n",
"\n",
"# use numpy-compatible semantics\n",
"mx.set_np_compat(True)\n",
"\n",
"# create a scalar tensor\n",
"x = np.array(3.14)\n",
"print(x) # x is actually an ndarray, but a scalar value will be printed"
Expand Down Expand Up @@ -158,7 +167,63 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Binary element-wise operations with broadcasting in new and old symbols"
"### There is a line between classic operators and numpy operators...\n",
"- Numpy operators can only accept numpy `ndarray`s/`_Symbol`s as inputs\n",
"- Classic operators can only accept classic `NDArray`s/`Symbol`s as inputs\n",
"- Explicit conversions must be performed if users want to leverage operators on both sides\n",
"- The layer inheriting from `HybridBlock` must have the same type of outputs, i.e., either all classic `NDArray`s or all numpy `ndarray`s, before hybridization\n",
"\n",
"#### Imperative"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"a = mx.nd.ones((2, 3)) # create a classic NDArray\n",
"print(a)\n",
"out = np.sum(a) # feeding it to a numpy operator would result in failure"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"b = a.as_np_ndarray() # convert `a` to a numpy ndarray sharing the same data memory\n",
"print(b)\n",
"out = np.sum(b) # feed the numpy ndarray to a numpy operator\n",
"print('np.sum(b) =', out)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"out = mx.nd.sum(b) # feeding `b` to a classic operator would reuslt in failure"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c = b.as_classic_ndarray() # convert `b` to a classic ndarray\n",
"out = mx.nd.sum(c) # feed the classic ndarray to a classic operator\n",
"print('mx.nd.sum(c) =', str(out))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Gluon"
]
},
{
Expand All @@ -168,19 +233,15 @@
"outputs": [],
"source": [
"from mxnet import gluon\n",
"class TestBinaryBroadcast(gluon.HybridBlock):\n",
" def hybrid_forward(self, F, x1, x2):\n",
" print(\"x1 type in hybrid_forward:\", str(type(x1)))\n",
" print(\"x2 type in hybrid_forward:\", str(type(x2)))\n",
" return x1 + x2\n",
"class TestMultipleOutputs(gluon.HybridBlock):\n",
" def hybrid_forward(self, F, x):\n",
" ret1 = F.sum(x) # a classic operator produces a classic NDArray\n",
" ret2 = F.np.sum(x) # a numpy operator produces a numpy NDArray\n",
" return ret1, ret2\n",
"\n",
"net = TestBinaryBroadcast()\n",
"x1 = mx.nd.ones((2, 1))\n",
"x2 = mx.nd.ones((1, 3))\n",
"print('x1 input tensor type: ', str(type(x1)))\n",
"print('x2 input tensor type: ', str(type(x2)))\n",
"out = net(x1, x2) # ok: imperative execution supports broadcasting\n",
"print(out)"
"net = TestMultipleOutputs()\n",
"net.hybridize()\n",
"out = net(a) # `a` is a classic NDArray and will cause an error on `F.np.sum` which is a numpy operator"
]
},
{
Expand All @@ -189,12 +250,9 @@
"metadata": {},
"outputs": [],
"source": [
"net.hybridize() # mark the block for execution using a computational graph\n",
"try:\n",
" out = net(x1, x2) # error: old symbol `+` operation does not support broadcasting\n",
" assert False # should not reach here\n",
"except mx.MXNetError:\n",
" print(\"ERROR: cannot perform broadcast add for two symbols of mxnet.sym.Symbol\")"
"net = TestMultipleOutputs() # redefine a net with no pre-built graph\n",
"net.hybridize()\n",
"out = net(b) # `b` is a numpy ndarray and will cause an error on `F.sum` which is a classic operator"
]
},
{
Expand All @@ -203,19 +261,15 @@
"metadata": {},
"outputs": [],
"source": [
"class TestBinaryBroadcast2(gluon.HybridBlock):\n",
" def hybrid_forward(self, F, x1, x2):\n",
" print(\"x1 type in hybrid_forward:\", str(type(x1)))\n",
" print(\"x2 type in hybrid_forward:\", str(type(x2)))\n",
" return x1.as_np_ndarray() + x2 # convert x1 to new numpy ndarray/symbol\n",
"\n",
"net2 = TestBinaryBroadcast2()\n",
"net2.hybridize()\n",
"class TestMultipleOutputs2(gluon.HybridBlock):\n",
" def hybrid_forward(self, F, x): # x is known to be a numpy ndarray\n",
" ret1 = F.sum(x.as_classic_ndarray()) # a classic operator produces a classic NDArray\n",
" ret2 = F.np.sum() # a numpy operator produces a numpy NDArray\n",
" return ret1, ret2 # two outputs of the layer with different types would result in failure in building the graph\n",
"\n",
"print('x1 input tensor type: ', str(type(x1)))\n",
"print('x2 input tensor type: ', str(type(x2)))\n",
"out =net2(x1, x2)\n",
"print(out)"
"net = TestMultipleOutputs2()\n",
"net.hybridize()\n",
"out = net(b)"
]
},
{
Expand All @@ -224,34 +278,45 @@
"metadata": {},
"outputs": [],
"source": [
"net = TestBinaryBroadcast() # Create a new block object to clear the graph\n",
"net.hybridize() # mark the block for execution using a computational graph\n",
"class TestMultipleOutputs3(gluon.HybridBlock):\n",
" def hybrid_forward(self, F, x): # x is known to be a numpy ndarray\n",
" ret1 = F.sum(x.as_classic_ndarray()) # a classic operator produces a classic NDArray\n",
" ret2 = F.np.sum(x) # a numpy operator produces a numpy NDArray\n",
" return ret1.as_np_ndarray(), ret2 # two outputs of the layer with different types would result in failure in building the graph\n",
"\n",
"x1 = x1.as_np_ndarray() # convert x1 to np.ndarray so that _NumpySymbol will be used in graph construction\n",
"print('x1 input tensor type: ', str(type(x1)))\n",
"x2 = x2.as_np_ndarray() # convert x2 to np.ndarray so that _NumpySymbol will be used in graph construction\n",
"print('x2 input tensor type: ', str(type(x2)))\n",
"out = net(x1, x2) # ok: `+` operation supports broadcasting for _NumpySymbol\n",
"print(out) # mxnet.numpy.ndarray type, because it's from a np operator"
"net = TestMultipleOutputs3()\n",
"net.hybridize()\n",
"out = net(b)\n",
"print('classic operator output: ', out[0])\n",
"print('numpy operator output: ', out[1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## A Simple Linear Regression Model\n",
"Let's consider a simple linear regression model as the following.\n",
"Given dataset `{x, y}`, where `x`s represent input examples and `y`s represent observed data, find the parameters `w1` and `w2` for the following model.\n",
"```\n",
"y_pred = np.dot(np.maximum(np.dot(x, w1), 0), w2)\n",
"```"
"### Binary element-wise operations with broadcasting in new and old symbols"
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### MXNet Numpy Operators in Imperative Programming"
"class TestBinaryBroadcast(gluon.HybridBlock):\n",
" def hybrid_forward(self, F, x1, x2):\n",
" print(\"x1 type in hybrid_forward:\", str(type(x1)))\n",
" print(\"x2 type in hybrid_forward:\", str(type(x2)))\n",
" return x1 + x2\n",
"\n",
"net = TestBinaryBroadcast()\n",
"x1 = mx.nd.ones((2, 1))\n",
"x2 = mx.nd.ones((1, 3))\n",
"print('x1 input tensor type: ', str(type(x1)))\n",
"print('x2 input tensor type: ', str(type(x2)))\n",
"out = net(x1, x2) # ok: imperative execution supports broadcasting\n",
"print(out)"
]
},
{
Expand All @@ -260,56 +325,41 @@
"metadata": {},
"outputs": [],
"source": [
"import mxnet as mx\n",
"from mxnet import numpy as np, numpy_extension as npe\n",
"from mxnet import autograd\n",
"\n",
"\n",
"# Use numpy-compatible semantics to support scalar tensors\n",
"mx.set_np_compat(True)\n",
"\n",
"# N is number of examples; D_in is input dimension;\n",
"# H is hidden dimension; D_out is output dimension.\n",
"N, D_in, H, D_out = 64, 1000, 100, 10\n",
"\n",
"# Create random input and output data\n",
"x = mx.nd.random.normal(shape=(N, D_in)).as_np_ndarray() # x is of type mxnet.numpy.ndarray\n",
"y = mx.nd.random.normal(shape=(N, D_out)).as_np_ndarray() # y is of type mxnet.numpy.ndarray\n",
"\n",
"# Randomly initialize weights\n",
"w1 = mx.nd.random.normal(shape=(D_in, H)).as_np_ndarray() # w1 is of type mxnet.numpy.ndarray\n",
"w1.attach_grad() # w1.grad is of type mxnet.numpy.ndarray\n",
"w2 = mx.nd.random.normal(shape=(H, D_out)).as_np_ndarray() # w2 is of type mxnet.numpy.ndarray\n",
"w2.attach_grad() # w2.grad is of type mxnet.numpy.ndarray\n",
"\n",
"learning_rate = 1e-6\n",
"\n",
"\n",
"for t in range(50):\n",
" with autograd.record():\n",
" # Forward pass: compute predicted y\n",
" h = x.dot(w1) # equivalent to np.dot(x, w1)\n",
" h_relu = npe.relu(h) # equivalent to mx.nd.relu(h)\n",
" y_pred = h_relu.dot(w2) # equivalent to np.dot(h_relu, w2)\n",
"\n",
" # Compute loss\n",
" # (y_pred - y) ** 2 calls np.ndarray.__pow__\n",
" # sum() calls np.sum() which should return a scalar tensor\n",
" loss = ((y_pred - y) ** 2).sum()\n",
" # Note that the print function will invoke loss.asnumpy()\n",
" print(t, loss) # loss is a scalar tensor of type mxnet.numpy.ndarray\n",
" loss.backward()\n",
"net.hybridize() # mark the block for execution using a computational graph\n",
"try:\n",
" out = net(x1, x2) # error: old symbol `+` operation does not support broadcasting\n",
" assert False # should not reach here\n",
"except mx.MXNetError:\n",
" print(\"ERROR: cannot perform broadcast add for two symbols of type mx.sym.Symbol\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"net = TestBinaryBroadcast() # redefine a net to clear the pre-built graph cache\n",
"net.hybridize()\n",
"\n",
" # Update weights\n",
" w1 -= learning_rate * w1.grad\n",
" w2 -= learning_rate * w2.grad"
"x1 = x1.as_np_ndarray() # convert x1 to np.ndarray\n",
"x2 = x2.as_np_ndarray() # convert x2 to np.ndarray\n",
"print('x1 input tensor type: ', str(type(x1)))\n",
"print('x2 input tensor type: ', str(type(x2)))\n",
"out = net(x1, x2) # ok: a graph is built with numpy symbols which supports broadcasting, because inputs are np.ndarray's, \n",
"print(out)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### MXNet Numpy Operators in Gluon `HybridBlock`"
"## A Simple Linear Regression Model\n",
"Let's consider a simple linear regression model as the following.\n",
"Given dataset `{x, y}`, where `x`s represent input examples and `y`s represent observed data, find the parameters `w1` and `w2` for the following model.\n",
"```\n",
"y_pred = np.dot(np.maximum(np.dot(x, w1), 0), w2)\n",
"```"
]
},
{
Expand All @@ -319,13 +369,10 @@
"outputs": [],
"source": [
"import mxnet as mx\n",
"from mxnet import gluon, autograd\n",
"\n",
"\n",
"# Use numpy-compatible semantics to support scalar tensors\n",
"mx.set_np_compat(True)\n",
"from mxnet import gluon, autograd, np\n",
"\n",
"\n",
"@np.use_np_compat\n",
"class LinearRegression(gluon.HybridBlock):\n",
" def __init__(self, num_input_dim=1000, num_hidden_dim=100, num_output_dim=10):\n",
" super(LinearRegression, self).__init__()\n",
Expand All @@ -337,7 +384,7 @@
"\n",
" def hybrid_forward(self, F, x, w1, w2):\n",
" h = x.dot(w1) # equivalent to F.np.dot(x, w1)\n",
" h_relu = F.npe.relu(h) # equivalent to F.relu(h)\n",
" h_relu = F.npe.relu(h) # equivalent to F.relu(h) but generating np.ndarray\n",
" y_pred = h_relu.dot(w2) # equivalent to F.np.dot(h_relu, w2)\n",
" return y_pred\n",
"\n",
Expand All @@ -356,7 +403,9 @@
"y = mx.nd.random.normal(shape=(64, 10)).as_np_ndarray() # y is of type mxnet.numpy.ndarray\n",
"\n",
"total_loss = TotalLoss()\n",
"trainer = gluon.Trainer(regressor.collect_params(), 'sgd', {'learning_rate': 1e-3, 'momentum': 0.9})\n",
"trainer = gluon.Trainer(regressor.collect_params(),\n",
" 'sgd',\n",
" {'learning_rate': 1e-3, 'momentum': 0.9, 'allow_np': True})\n",
"\n",
"for t in range(50):\n",
" with autograd.record():\n",
Expand Down
7 changes: 7 additions & 0 deletions include/mxnet/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,13 @@ inline bool shape_is_known(const TShape& x) {
return true;
}

inline bool shape_is_known(const std::vector<TShape>& shapes) {
for (const TShape& shape : shapes) {
if (!shape_is_known(shape)) return false;
}
return true;
}

/*! \brief helper function to cast type of container elements */
template<typename SrcIter, typename DstIter>
inline DstIter ShapeTypeCast(const SrcIter begin,
Expand Down
Loading

0 comments on commit 5344278

Please sign in to comment.