Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

fix for number of inputs/outputs for backward custom ops #17069

Merged
merged 2 commits into from
Dec 14, 2019

Conversation

samskalicky
Copy link
Contributor

Description

The initial custom operator support provides all possible inputs to backward operators: inputs, outputs, and gradients. The user is responsible for compute the correct number of inputs and outputs, and the abstraction logic tells MXNet that the backward operator will have the correct number of inputs.

This fixes the problem where we dont account for the number of gradients in computing the number of inputs for a backward operator. Previously we computed: "num_inputs + num_outputs", but the correct computation should be: "num_inputs + 2*num_outputs" since there is an input gradient for each output argument of the forward operator.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the best of my knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Copy link
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wkcn
Copy link
Member

wkcn commented Dec 14, 2019

https://github.com/apache/incubator-mxnet/blob/5f9686af87ecedef12bce6f1025f1deed44aef0a/src/c_api/c_api.cc#L671-L672

Hi. What is inputs and outpus? Do they have the same meaning in forward and backward?

@samskalicky
Copy link
Contributor Author

samskalicky commented Dec 14, 2019

Hi. What is inputs and outpus? Do they have the same meaning in forward and backward?

In the forward pass, an operator takes its inputs and produces its outputs. In the backward pass, an operator receives an input gradient for each of its outputs and produces an output gradient for each of its inputs. The picture looks something like this:

MXNetOpArgs

In the forward pass the inputs are: A,B and the outputs are: C
In the backward pass the inputs are: dC, A, B, C and the outputs are: dA, dB

In MXNet you can specify which inputs/outputs from the forward function you want to keep for the backward pass. To simplify the first commit of custom op we just made it so that every possible input/output was available during the backward pass. Heres the code where we tell MXNet that we want to have all inputs/outputs available for the backward function:

https://github.com/apache/incubator-mxnet/blob/bbdc1c3627ad5254c049c2bb871ecb4527d7dc14/src/c_api/c_api.cc#L498-L512

grad_reg is an input when registering the FGradient attribute on the forward function:

https://github.com/apache/incubator-mxnet/blob/bbdc1c3627ad5254c049c2bb871ecb4527d7dc14/src/c_api/c_api.cc#L666

So the fix here is to correctly tell MXNet how many inputs and outputs we have for the backward pass. We were setting that all inputs/outputs and input gradients were inputs in the "grad_reg" lambda, but still reporting the number of inputs for backward as (num_inputs + num_outputs). This didnt account for the input gradients, so the fix to change the number to (num_inputs + 2*num_outputs) just matches what we did in "grad_reg". Previously MXNet wasnt checking or erroring out when the numbers didnt match. But @ptrendx is implementing a new feature to check #17049 so this fix will help him continue.

@ptrendx ptrendx merged commit 831b548 into apache:master Dec 14, 2019
@wkcn
Copy link
Member

wkcn commented Dec 15, 2019

I see. Thanks @samskalicky for the detailed explanation!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants