Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[API] Standardize MXNet NumPy creation functions #20572

Merged
merged 54 commits into from
Nov 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
b56f8bd
standardize mxnet numpy creation functions
barry-jin Sep 9, 2021
89f9f6b
update
barry-jin Sep 9, 2021
1d1b2e4
fix linspace
barry-jin Sep 10, 2021
856e683
Merge remote-tracking branch 'upstream/master' into data-api-creation…
barry-jin Sep 10, 2021
5014fad
merge & add tests
barry-jin Sep 10, 2021
85dde73
add NumpyLinspaceParam
barry-jin Sep 10, 2021
9cb5881
fix lint'
barry-jin Sep 10, 2021
6c032ea
fix
barry-jin Sep 10, 2021
ba01eed
add indexing test
barry-jin Sep 10, 2021
3db3b95
fix tests
barry-jin Sep 10, 2021
dbea44a
fix sanity
barry-jin Sep 11, 2021
b2da040
merge
barry-jin Sep 13, 2021
91c9b75
fix lint
barry-jin Sep 13, 2021
47a61bd
fix tests
barry-jin Sep 13, 2021
c6e5596
disable warning
barry-jin Sep 13, 2021
b96e75a
fix
barry-jin Sep 14, 2021
a7a671b
update
barry-jin Sep 14, 2021
cc01d77
skip signature standardization
barry-jin Sep 15, 2021
6c2f48c
fix lint
barry-jin Sep 15, 2021
b78365e
update
barry-jin Sep 16, 2021
4a798da
rm test_contants
barry-jin Sep 18, 2021
302bc73
Merge remote-tracking branch 'upstream/master' into data-api-creation…
barry-jin Sep 20, 2021
5994d01
Merge remote-tracking branch 'upstream/master' into data-api-creation…
barry-jin Sep 24, 2021
b09814c
Add Code Signing Key
barry-jin Sep 27, 2021
c9fe889
Revert "Add Code Signing Key"
barry-jin Sep 27, 2021
0e385dc
Merge remote-tracking branch 'upstream/master' into data-api-creation…
barry-jin Oct 12, 2021
60ee272
Replace context with device & update multiarray.py/_op.py
barry-jin Oct 12, 2021
10810da
Merge remote-tracking branch 'upstream/master' into data-api-creation…
barry-jin Oct 14, 2021
eac4095
ctx => device
barry-jin Oct 14, 2021
74c8fc6
ctx/context => device
barry-jin Oct 14, 2021
a6c9f75
fix conflict
barry-jin Oct 14, 2021
02afedf
fix
barry-jin Oct 14, 2021
c178259
fix multiarray
barry-jin Oct 14, 2021
61cb8e5
update ndarray.py
barry-jin Oct 14, 2021
1111e98
fix
barry-jin Oct 15, 2021
c583c9c
fix
barry-jin Oct 15, 2021
558d7a9
fix tests
barry-jin Oct 16, 2021
ccfbc28
update
barry-jin Oct 16, 2021
2bf9efc
fix tests
barry-jin Oct 17, 2021
b77c6f9
update rand_zipfian
barry-jin Oct 18, 2021
1519650
update
barry-jin Oct 18, 2021
dc0e3b9
device => cuda_device in util.py
barry-jin Oct 18, 2021
d96a824
context.gpu_memory_info => device.gpu_memory_info
barry-jin Oct 18, 2021
d126e89
fix docs
barry-jin Oct 19, 2021
3b1f70c
Merge remote-tracking branch 'upstream/master' into data-api-creation…
barry-jin Oct 19, 2021
17d9898
rm context in doc
barry-jin Oct 19, 2021
a0b4b2b
fix conflict
barry-jin Oct 24, 2021
bb1e4ad
fix lint
barry-jin Oct 25, 2021
e775844
remove npv
barry-jin Oct 25, 2021
f67701a
Revert "remove npv"
barry-jin Oct 25, 2021
c8cd07d
merge
barry-jin Oct 27, 2021
00a13f6
merge
barry-jin Oct 31, 2021
a2a154c
merge
barry-jin Nov 1, 2021
5f472b2
Merge branch 'master' into data-api-creation-func
barry-jin Nov 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/python/sparse/cast_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def measure_cost(repeat, f, *args, **kwargs):

def run_cast_storage_synthetic():
def dense_to_sparse(m, n, density, ctx, repeat, stype):
set_default_context(ctx)
set_default_device(ctx)
data_shape = (m, n)
dns_data = rand_ndarray(data_shape, stype, density).tostype('default')
dns_data.wait_to_read()
Expand Down
4 changes: 2 additions & 2 deletions benchmark/python/sparse/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import mxnet as mx
import numpy as np
import numpy.random as rnd
from mxnet.test_utils import rand_ndarray, set_default_context, assert_almost_equal, get_bz2_data
from mxnet.test_utils import rand_ndarray, set_default_device, assert_almost_equal, get_bz2_data
from mxnet.base import check_call, _LIB
from util import estimate_density

Expand Down Expand Up @@ -267,7 +267,7 @@ def test_dot_synthetic(data_dict):
# Benchmark MXNet and Scipys dot operator
def bench_dot(lhs_shape, rhs_shape, lhs_stype, rhs_stype,
lhs_den, rhs_den, trans_lhs, ctx, num_repeat=10, fw="mxnet", distribution="uniform"):
set_default_context(ctx)
set_default_device(ctx)
assert fw == "mxnet" or fw == "scipy"
# Set funcs
dot_func_sparse = mx.nd.sparse.dot if fw == "mxnet" else sp.spmatrix.dot
Expand Down
4 changes: 2 additions & 2 deletions benchmark/python/sparse/sparse_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def measure_cost_backward_baseline(repeat, dot, transpose, lhs, rhs):
return diff / repeat

def bench_dot_forward(m, k, n, density, ctx, repeat):
set_default_context(ctx)
set_default_device(ctx)
dns = mx.nd.random.uniform(shape=(k, n)).copyto(ctx)
data_shape = (m, k)
csr_data = rand_ndarray(data_shape, 'csr', density)
Expand Down Expand Up @@ -184,7 +184,7 @@ def bench_dot_forward(m, k, n, density, ctx, repeat):
ratio_baseline, costs_baseline[0], costs_baseline[1]))

def bench_dot_backward(m, k, n, density, ctx, repeat):
set_default_context(ctx)
set_default_device(ctx)
dns = mx.nd.random.uniform(shape=(m, n)).copyto(ctx)
data_shape = (m, k)
csr_data = rand_ndarray(data_shape, 'csr', density)
Expand Down
2 changes: 2 additions & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,8 @@ unittest_array_api_standardization() {
export DMLC_LOG_STACK_TRACE_DEPTH=100
python3 -m pytest --durations=50 --cov-report xml:tests_api.xml --verbose \
array_api_tests/test_type_promotion.py::test_elementwise_function_two_arg_bool_type_promotion
python3 -m pytest --durations=50 --cov-report xml:tests_api.xml --verbose array_api_tests/test_creation_functions.py
python3 -m pytest --durations=50 --cov-report xml:tests_api.xml --verbose array_api_tests/test_indexing.py
popd
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
specific language governing permissions and limitations
under the License.

mxnet.context
mxnet.device
=============

.. automodule:: mxnet.context
.. automodule:: mxnet.device
:members:
:autosummary:
12 changes: 6 additions & 6 deletions docs/python_docs/python/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ Gluon related modules
Key value store interface of MXNet for parameter synchronization.

.. card::
:title: mxnet.context
:link: mxnet/context/index.html
:title: mxnet.device
:link: mxnet/device/index.html

CPU and GPU context information.
CPU and GPU device information.

.. card::
:title: mxnet.profiler
Expand All @@ -116,10 +116,10 @@ Advanced modules
API for querying MXNet enabled features.

.. card::
:title: mxnet.context
:link: context/index.html
:title: mxnet.device
:link: device/index.html

MXNet array context for specifying in-memory storage device.
MXNet array device for specifying in-memory storage device.

.. card::
:title: mxnet.profiler
Expand Down
2 changes: 1 addition & 1 deletion docs/python_docs/python/api/npx/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Devices
cpu_pinned
gpu
gpu_memory_info
current_context
current_device
num_gpus

Nerual networks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,18 @@ And we are done. You can test the installation now by importing mxnet from pytho

## Running a pre-trained ResNet-50 model on Jetson

We are now ready to run a pre-trained model and run inference on a Jetson module. In this tutorial we are using ResNet-50 model trained on Imagenet dataset. We run the following classification script with either cpu/gpu context using python3.
We are now ready to run a pre-trained model and run inference on a Jetson module. In this tutorial we are using ResNet-50 model trained on Imagenet dataset. We run the following classification script with either cpu/gpu device using python3.

```{.python .input}
from mxnet import gluon
import mxnet as mx

# set context
# set device
gpus = mx.test_utils.list_gpus()
ctx = mx.gpu() if gpus else mx.cpu()
device = mx.gpu() if gpus else mx.cpu()

# load pre-trained model
net = gluon.model_zoo.vision.resnet50_v1(pretrained=True, ctx=ctx)
net = gluon.model_zoo.vision.resnet50_v1(pretrained=True, device=device)
net.hybridize(static_alloc=True, static_shape=True)

# load labels
Expand All @@ -99,7 +99,7 @@ img = mx.image.color_normalize(img.astype(dtype='float32')/255,
std=mx.np.array([0.229, 0.224, 0.225])) # normalize
img = img.transpose((2, 0, 1)) # channel first
img = mx.np.expand_dims(img, axis=0) # batchify
img = img.as_in_ctx(ctx)
img = img.to_device(device)

prob = mx.npx.softmax(net(img)) # predict and normalize output
idx = mx.npx.topk(prob, k=5)[0] # get top 5 result
Expand Down
8 changes: 4 additions & 4 deletions docs/python_docs/python/tutorials/extend/customop.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class SigmoidProp(mx.operator.CustomOpProp):
# return 3 lists representing inputs shapes, outputs shapes, and aux data shapes.
return (data_shape,), (output_shape,), ()

def create_operator(self, ctx, in_shapes, in_dtypes):
def create_operator(self, device, in_shapes, in_dtypes):
# create and return the CustomOp class.
return Sigmoid()
```
Expand Down Expand Up @@ -183,7 +183,7 @@ class DenseProp(mx.operator.CustomOpProp):
# return 3 lists representing inputs shapes, outputs shapes, and aux data shapes.
return (data_shape, weight_shape), (output_shape,), ()

def create_operator(self, ctx, in_shapes, in_dtypes):
def create_operator(self, device, in_shapes, in_dtypes):
# create and return the CustomOp class.
return Dense(self._bias)
```
Expand All @@ -201,8 +201,8 @@ class DenseBlock(mx.gluon.Block):
self.weight = gluon.Parameter('weight', shape=(channels, in_channels))

def forward(self, x):
ctx = x.context
return mx.nd.Custom(x, self.weight.data(ctx), bias=self._bias, op_type='dense')
device = x.device
return mx.nd.Custom(x, self.weight.data(device), bias=self._bias, op_type='dense')
```

### Example usage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ a = np.array(a)
(type(a), a)
```

Additionally, you can move them to different GPU contexts. You will dive more
Additionally, you can move them to different GPU devices. You will dive more
into this later, but here is an example for now.

```{.python .input}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ with warnings.catch_warnings():
warnings.simplefilter("ignore")
net_loaded = nn.SymbolBlock.imports("MLP_hybrid-symbol.json",
['data'], "MLP_hybrid-0000.params",
ctx=None)
device=None)
```

```{.python .input}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import mxnet as mx
from mxnet.gluon import nn
npx.set_np()

ctx = mx.cpu()
device = mx.cpu()
```

## Initialization
Expand Down Expand Up @@ -103,7 +103,7 @@ To initialize your network using different built-in types, you have to use the
from mxnet import init

# Constant init initializes the weights to be a constant value for all the params
net.initialize(init=init.Constant(3), ctx=ctx)
net.initialize(init=init.Constant(3), device=device)
print(net[0].weight.data()[0])
```

Expand All @@ -113,7 +113,7 @@ already initialized the weight but want to reinitialize the weight, set the
`force_reinit` flag to `True`.

```{.python .input}
net.initialize(init=init.Normal(sigma=0.2), force_reinit=True, ctx=ctx)
net.initialize(init=init.Normal(sigma=0.2), force_reinit=True, device=device)
print(net[0].weight.data()[0])
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ The current data loading pipeline is the major bottleneck for many training task
- `Transform.__call__()/forward()`
- `Batchify`
- (optional communicate through shared_mem)
- `split_and_load(ctxs)`
- `split_and_load(devices)`
- training on GPUs

Performance concerns include slow python dataset/transform functions, multithreading issues due to global interpreter lock, Python multiprocessing issues due to speed, and batchify issues due to poor memory management.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,13 +322,13 @@ hybridize the model.

```{.python .input}
# Create the model based on the blueprint provided and initialize the parameters
ctx = mx.gpu()
device = mx.gpu()

initializer = mx.initializer.Xavier()

model = LeafNetwork()
model.initialize(initializer, ctx=ctx)
model.summary(mx.np.random.uniform(size=(4, 3, 128, 128), ctx=ctx))
model.initialize(initializer, device=device)
model.summary(mx.np.random.uniform(size=(4, 3, 128, 128), device=device))
model.hybridize()
```

Expand Down Expand Up @@ -368,7 +368,7 @@ def test(val_data):
for batch in val_data:
data = batch[0]
labels = batch[1]
outputs = model(data.as_in_ctx(ctx))
outputs = model(data.to_device(device))
acc.update([labels], [outputs])

_, accuracy = acc.get()
Expand Down Expand Up @@ -396,8 +396,8 @@ for epoch in range(epochs):
data = batch[0]
label = batch[1]
with mx.autograd.record():
outputs = model(data.as_in_ctx(ctx))
loss = loss_fn(outputs, label.as_in_ctx(ctx))
outputs = model(data.to_device(device))
loss = loss_fn(outputs, label.to_device(device))
mx.autograd.backward(loss)
trainer.step(batch_size)
accuracy.update([label], [outputs])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ npx.num_gpus() #This command provides the number of GPUs MXNet can access

## Allocate data to a GPU

MXNet's ndarray is very similar to NumPy's. One major difference is that MXNet's ndarray has a `context` attribute specifieing which device an array is on. By default, arrays are stored on `npx.cpu()`. To change it to the first GPU, you can use the following code, `npx.gpu()` or `npx.gpu(0)` to indicate the first GPU.
MXNet's ndarray is very similar to NumPy's. One major difference is that MXNet's ndarray has a `device` attribute specifieing which device an array is on. By default, arrays are stored on `npx.cpu()`. To change it to the first GPU, you can use the following code, `npx.gpu()` or `npx.gpu(0)` to indicate the first GPU.

```{.python .input}
gpu = npx.gpu() if npx.num_gpus() > 0 else npx.cpu()
x = np.ones((3,4), ctx=gpu)
x = np.ones((3,4), device=gpu)
x
```

Expand All @@ -63,7 +63,7 @@ If you have multiple GPUs on your machine, MXNet can access each of them through
To perform an operation on a particular GPU, you only need to guarantee that the input of an operation is already on that GPU. The output is allocated on the same GPU as well. Almost all operators in the `np` and `npx` module support running on a GPU.

```{.python .input}
y = np.random.uniform(size=(3,4), ctx=gpu)
y = np.random.uniform(size=(3,4), device=gpu)
x + y
```

Expand Down Expand Up @@ -115,17 +115,17 @@ class LeafNetwork(nn.HybridBlock):
return batch
```

Load the saved parameters onto GPU 0 directly as shown below; additionally, you could use `net.collect_params().reset_ctx(gpu)` to change the device.
Load the saved parameters onto GPU 0 directly as shown below; additionally, you could use `net.collect_params().reset_device(gpu)` to change the device.

```{.python .input}
net = LeafNetwork()
net.load_parameters('leaf_models.params', ctx=gpu)
net.load_parameters('leaf_models.params', device=gpu)
```

Use the following command to create input data on GPU 0. The forward function will then run on GPU 0.

```{.python .input}
x = np.random.uniform(size=(1, 3, 128, 128), ctx=gpu)
x = np.random.uniform(size=(1, 3, 128, 128), device=gpu)
net(x)
```

Expand Down Expand Up @@ -201,7 +201,7 @@ devices = available_gpus[:num_gpus]
print('Using {} GPUs'.format(len(devices)))

# Diff 2: reinitialize the parameters and place them on multiple GPUs
net.initialize(force_reinit=True, ctx=devices)
net.initialize(force_reinit=True, device=devices)

# Loss and trainer are the same as before
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ lr_factor = 0.75
# learning rate change at following epochs
lr_epochs = [10, 20, 30]

num_gpus = mx.context.num_gpus()
num_gpus = mx.device.num_gpus()
# you can replace num_workers with the number of cores on you device
num_workers = 8
ctx = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
device = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
batch_size = per_device_batch_size * max(num_gpus, 1)
```

Expand Down Expand Up @@ -166,11 +166,11 @@ Before we go to training, one unique Gluon feature you should be aware of is hyb

```{.python .input}
# load pre-trained resnet50_v2 from model zoo
finetune_net = resnet50_v2(pretrained=True, ctx=ctx)
finetune_net = resnet50_v2(pretrained=True, device=device)

# change last softmax layer since number of classes are different
finetune_net.output = nn.Dense(classes)
finetune_net.output.initialize(init.Xavier(), ctx=ctx)
finetune_net.output.initialize(init.Xavier(), device=device)
# hybridize for better performance
finetune_net.hybridize()

Expand All @@ -195,11 +195,11 @@ Now let's define the test metrics and start fine-tuning.


```{.python .input}
def test(net, val_data, ctx):
def test(net, val_data, device):
metric = mx.gluon.metric.Accuracy()
for i, (data, label) in enumerate(val_data):
data = gluon.utils.split_and_load(data, ctx_list=ctx, even_split=False)
label = gluon.utils.split_and_load(label, ctx_list=ctx, even_split=False)
data = gluon.utils.split_and_load(data, device, even_split=False)
label = gluon.utils.split_and_load(label, device, even_split=False)
outputs = [net(x) for x in data]
metric.update(label, outputs)
return metric.get()
Expand All @@ -215,8 +215,8 @@ for epoch in range(1, epochs + 1):

for i, (data, label) in enumerate(train_data):
# get the images and labels
data = gluon.utils.split_and_load(data, ctx_list=ctx, even_split=False)
label = gluon.utils.split_and_load(label, ctx_list=ctx, even_split=False)
data = gluon.utils.split_and_load(data, device, even_split=False)
label = gluon.utils.split_and_load(label, device, even_split=False)
with autograd.record():
outputs = [finetune_net(x) for x in data]
loss = [softmax_cross_entropy(yhat, y) for yhat, y in zip(outputs, label)]
Expand All @@ -229,12 +229,12 @@ for epoch in range(1, epochs + 1):

_, train_acc = metric.get()
train_loss /= num_batch
_, val_acc = test(finetune_net, val_data, ctx)
_, val_acc = test(finetune_net, val_data, device)

print('[Epoch %d] Train-acc: %.3f, loss: %.3f | Val-acc: %.3f | learning-rate: %.3E | time: %.1f' %
(epoch, train_acc, train_loss, val_acc, trainer.learning_rate, time.time() - tic))

_, test_acc = test(finetune_net, test_data, ctx)
_, test_acc = test(finetune_net, test_data, device)
print('[Finished] Test-acc: %.3f' % (test_acc))
```

Expand Down
Loading