Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Implement MXNET_BACKWARD_DO_MIRROR in Gluon. #19133

Open
kohillyang opened this issue Sep 13, 2020 · 1 comment
Open

Implement MXNET_BACKWARD_DO_MIRROR in Gluon. #19133

kohillyang opened this issue Sep 13, 2020 · 1 comment

Comments

@kohillyang
Copy link

kohillyang commented Sep 13, 2020

Description

MXNET_BACKWARD_DO_MIRROR is a technology to save GPU memory cost. It is important because some tasks like Object Detection and semantic segmentation can benefit from a larger batch size. Currently, Gluon(or CachedOp) has not implemented it yet.

MXNET_USE_FUSION is also an option to save memory cost. And it seems that MXNET_USE_FUSION is conflict with MXNET_BACKWARD_DO_MIRROR. If MXNET_BACKWARD_DO_MIRROR is set 1, MXNET_USE_FUSION must be turned off, and it leads to higher memory usage, The following codes can show that:

import mxnet as mx
import mxnet.autograd as ag


class NaiveDataset(object):
    def __len__(self):
        return 10000

    def __getitem__(self, idx):
        if idx % 2 ==0:
            label = mx.nd.zeros(shape=(1000, ))
            label[0] = 1
            return mx.nd.array(mx.nd.zeros(shape=(3, 224, 224))), label
        else:
            label = mx.nd.zeros(shape=(1000, ))
            label[1] = 1
            return mx.nd.array(mx.nd.ones(shape=(3, 224, 224))), label


def train_gluon_model_with_module():
    import os
    # os.environ["MXNET_BACKWARD_DO_MIRROR"]="1"
    # os.environ["MXNET_USE_FUSION"]="0"
    ctx_list = [mx.gpu(0)]
    net = mx.gluon.model_zoo.vision.resnet50_v1(pretrained=False)
    net.initialize()
    _ = net(mx.nd.zeros(shape=(1, 3, 224, 224)))
    arg_params = {}
    aux_params = {}
    arg_params_collected = net.collect_params()
    for k in arg_params_collected:
        arg_params[k] = arg_params_collected[k].data(mx.cpu())
    for k in arg_params_collected:
        aux_params[k] = arg_params_collected[k].data(mx.cpu())

    data = mx.sym.var(name="data")
    sym = net(data)
    module = mx.mod.Module(sym, data_names=['data'], label_names=[], context=ctx_list)
    module.bind(data_shapes=[("data", (len(ctx_list) * 2, 3, 224, 224))])
    module.init_params(arg_params=arg_params, aux_params=aux_params, allow_missing=False, allow_extra=True)
    module.init_optimizer(force_init=True)
    train_loader = mx.gluon.data.DataLoader(dataset=NaiveDataset(), batch_size=100,
                                            num_workers=8, last_batch="discard", shuffle=True,
                                            thread_pool=False)
    for data_batch in train_loader:
        module_data_batch = mx.io.DataBatch(data=[data_batch[0], ], label=None)
        module.forward(module_data_batch, is_train=True)
        y_hat = module.get_outputs(merge_multi_context=True)
        label_list = mx.gluon.utils.split_and_load(data_batch[1], ctx_list=ctx_list, batch_axis=0)
        preds_list = mx.gluon.utils.split_and_load(y_hat[0], ctx_list=ctx_list, batch_axis=0)
        pred_grad_list = []
        for pred, label in zip(preds_list, label_list):  # type: mx.nd.NDArray, mx.nd.NDArray
            pred.attach_grad()
            label.attach_grad()
            with ag.record():
                pred_log_softmax = mx.nd.log_softmax(pred,  axis=1)
                loss = pred_log_softmax * label * -1
            loss.backward()
            pred_grad_list.append(pred.grad)
        pred_gradients = mx.nd.concatenate(pred_grad_list, axis=0)
        module.backward([pred_gradients])
        module.update()
        print(loss.sum().asnumpy())
        mx.nd.waitall()


def train_gluon_model_with_gluon():
    ctx_list = [mx.gpu(0)]
    net = mx.gluon.model_zoo.vision.resnet50_v1(pretrained=False)
    net.initialize()
    net.collect_params().reset_ctx(ctx_list)
    net.hybridize(static_alloc=True)
    trainer = mx.gluon.Trainer(
        net.collect_params(),  # fix batchnorm, fix first stage, etc...
        'sgd',
        {
            'learning_rate':1e-2
         },
    )

    train_loader = mx.gluon.data.DataLoader(dataset=NaiveDataset(), batch_size=100,
                                            num_workers=8, last_batch="discard", shuffle=True,
                                            thread_pool=False)
    for data_batch in train_loader:
        data_list = mx.gluon.utils.split_and_load(data_batch[0], ctx_list=ctx_list, batch_axis=0)
        label_list = mx.gluon.utils.split_and_load(data_batch[1], ctx_list=ctx_list, batch_axis=0)
        losses = []
        for data, label in zip(data_list, label_list):  # type: mx.nd.NDArray, mx.nd.NDArray
            with ag.record():
                y_hat = net(data)
                pred_log_softmax = mx.nd.log_softmax(y_hat,  axis=1)
                loss = pred_log_softmax * label * -1
            losses.append(loss)
        ag.backward(losses)
        trainer.step(1)
        print(loss.sum().asnumpy())
        mx.nd.waitall()


if __name__ == '__main__':
    # train_gluon_model_with_module()
    train_gluon_model_with_gluon()

By default train_gluon_model_with_module and train_gluon_model_with_gluon need almost same GPU memory, but if set MXNET_BACKWARD_DO_MIRROR to 1 and set MXNET_USE_FUSION to 0, train_gluon_model_with_module will fail and raise a OOM exception.

There is a pull request#11472 which tried to implement that. It added an option in HybridBlock's hybridize function. Maybe it is better to use MXNET_BACKWARD_DO_MIRROR to control the option instead of another option, which keeps the same behavior as mx.mod.Module.

@szha
Copy link
Member

szha commented Sep 14, 2020

See #18543 (comment)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

2 participants