Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Dispatch for DDPG loss module #1215

Merged
merged 9 commits into from
Jun 4, 2023
Merged

Conversation

Blonck
Copy link
Contributor

@Blonck Blonck commented Jun 1, 2023

Description

Enable dispatching arguments for the .forward() method of the DDPG loss module.

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 1, 2023
@Blonck
Copy link
Contributor Author

Blonck commented Jun 1, 2023

Open Questions

This pull request raises some open questions regarding the patch. Some of these questions may also apply to other loss modules.

  1. Currently, the PR breaks backward compatibility by renaming the argument input_tensordict to tensordict. This renaming is necessary for @dispatch to function properly.
    Potential solutions:
  • Accept the change and standardize the argument name. It is unlikely that anyone is using keyword arguments to call .forward() (my preferred option).
  • Enhance the functionality of @dispatch to allow alternative input names. This could involve adding a parameter that, when set, skips the name validation.
  1. The underlying value estimator relies on optional input keys, such as steps_to_next_obs, which may be used by .value_estimate() if not None. The current DQNLoss implementation ignores these keys:
class DQNLoss(LossModule):
     ...

    @dispatch(
        source=[
            "observation",
            ("next", "observation"),
            "action",
            ("next", "reward"),
            ("next", "done"),
        ],
        dest=["loss"],
    )
    def forward(self, tensordict: TensorDictBase) -> TensorDict:
        ...

Possible solutions include:

  • Simply ignore optional arguments.
  • Add optional tensor dict keys to the dispatch argument list, requiring users to provide the argument (which can be set to None).
  • Extend @dispatch to support optional arguments.
  1. Some tensor dict keys used by the value estimator, like ("next", "reward") and ("next", "done"), have fixed values despite the advantage module allowing configuration. These keys are also input keys for .forward().
    Should they be made configurable by adding them to _AcceptedKeys of the loss module in this PR?

  2. Sometimes input/output tensordict keys are dynamic and may depend on the configuration of the loss module. For DDPG loss this is not the case but for example the output keys of the A2CLoss depends on configuration.
    One solution would be to use a @Property for that. In case of the A2CLoss it would look like:

    @property
    def out_keys(self):
        outs = ["loss_objective"]
        if self.critic_coef:
            outs.append("loss_critic")
        if self.entropy_bonus:
            outs.append("entropy")
            outs.append("loss_entropy")

        return outs

Do you see any problems with that approach?

@Blonck Blonck requested a review from vmoens June 1, 2023 08:21
@Blonck
Copy link
Contributor Author

Blonck commented Jun 1, 2023

The underlying value estimator relies on optional input keys, such as steps_to_next_obs, which may be used by .value_estimate() if not None. The current DQNLoss implementation ignores these keys:

You can ignore this one. @dispatch does not allow any optional arguments, for now. So we need to go with the first solution and ignore optional arguments.

@vmoens
Copy link
Contributor

vmoens commented Jun 1, 2023

Open Questions

This pull request raises some open questions regarding the patch. Some of these questions may also apply to other loss modules.

  1. Currently, the PR breaks backward compatibility by renaming the argument input_tensordict to tensordict. This renaming is necessary for @dispatch to function properly.
    Potential solutions:
  • Accept the change and standardize the argument name. It is unlikely that anyone is using keyword arguments to call .forward() (my preferred option).
  • Enhance the functionality of @dispatch to allow alternative input names. This could involve adding a parameter that, when set, skips the name validation.

I'm open to bc-breaking change in this case

@vmoens
Copy link
Contributor

vmoens commented Jun 1, 2023

Simply ignore optional arguments.
Add optional tensor dict keys to the dispatch argument list, requiring users to provide the argument (which can be set to None).
Extend @dispatch to support optional arguments.

Two things here:
We can't assume that dispatch will have all the functionalities. I would keep things small scale if possible.
If there is a key whose presence or absence controls the flow, I think it's best to consider that it is absent for dispatch.
That being said, in some places in the code we do

value = data.get(key, None)
if value is None:
   foo()
else:
   bar()

In this case, we could have the user call

module(..., steps_to_next_obs=tensor)

which will populate the tensordict with that key, and if not the resulting value will be None.
That could require to change the function that reads "steps_to_next_obs".

In summary:
I would suggest to keep the behaviour simple and not include "steps_to_next_obs" in the in_keys, but keep track of this in an issue in TorchRL.

@vmoens
Copy link
Contributor

vmoens commented Jun 1, 2023

Some tensor dict keys used by the value estimator, like ("next", "reward") and ("next", "done"), have fixed values despite the advantage module allowing configuration. These keys are also input keys for .forward().
Should they be made configurable by adding them to _AcceptedKeys of the loss module in this PR?

Yes I think they should be part of the _AcceptedKeys. Do you have an example of a place where this?

@vmoens
Copy link
Contributor

vmoens commented Jun 1, 2023

Sometimes input/output tensordict keys are dynamic and may depend on the configuration of the loss module. For DDPG loss this is not the case but for example the output keys of the A2CLoss depends on configuration.
One solution would be to use a https://github.com/Property for that. In case of the A2CLoss it would look like:

Yes we should be using a property!

@vmoens vmoens added the enhancement New feature or request label Jun 1, 2023
@Blonck
Copy link
Contributor Author

Blonck commented Jun 1, 2023

Yes I think they should be part of the _AcceptedKeys. Do you have an example of a place where this?

For example ("next", "reward") or better ("next", self.tensor_keys.reward) is used in all advantages.

@vmoens
Copy link
Contributor

vmoens commented Jun 1, 2023

That should be customisable

@vmoens vmoens added the bc breaking backward compatibility breaking change label Jun 1, 2023
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic, really fancy!
On a high level: do you think that the in_keys will be recyclable across modules or will we need to re-code it every time?

torchrl/objectives/ddpg.py Show resolved Hide resolved
test/test_cost.py Outdated Show resolved Hide resolved
test/test_cost.py Outdated Show resolved Hide resolved
test/test_cost.py Outdated Show resolved Hide resolved
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic, really fancy!
On a high level: do you think that the in_keys will be recyclable across modules or will we need to re-code it every time?

@Blonck
Copy link
Contributor Author

Blonck commented Jun 3, 2023

On a high level: do you think that the in_keys will be recyclable across modules or will we need to re-code it every time?

No sure if there is really something common to all loss modules here. For me it seems to dependent on what actually happens inside the loss to recycle something. But I will think about it, maybe I've got some idea.

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Final review

@vmoens vmoens merged commit 331f677 into pytorch:main Jun 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bc breaking backward compatibility breaking change CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants