Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about the output of the decision transformer #67

Closed
Pulsar110 opened this issue Jan 2, 2024 · 3 comments
Closed

Question about the output of the decision transformer #67

Pulsar110 opened this issue Jan 2, 2024 · 3 comments

Comments

@Pulsar110
Copy link

From the code in here:
https://github.com/kzl/decision-transformer/blob/master/gym/decision_transformer/models/decision_transformer.py#L92-L99

        # reshape x so that the second dimension corresponds to the original
        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
        x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)

        # get predictions
        return_preds = self.predict_return(x[:,2])  # predict next return given state and action
        state_preds = self.predict_state(x[:,2])    # predict next state given state and action
        action_preds = self.predict_action(x[:,1])  # predict next action given state

I'm not sure I understand why self.predict_return(x[:, 2]) or self.predict_state(x[:, 2]) is predicting the return/next state given the state and action. From the comment on the top, x[:, 2] is only the action? Am I missing something?

And if this code is correct, what is the use of x[:, 0]?

I have also asked this question in the huggingface/transformers repo:
huggingface/transformers#27916

@nawta
Copy link

nawta commented Jan 14, 2024

I found the same question, I guess the reason why there's no problem is that Decision Transformer does not use the return_preds (the return in the next timestamp) and return_states.
If it's the case of Trajectory Transformer, there should be some bug appearing.

@yangyichu
Copy link

yangyichu commented Jan 19, 2024

I think the comments in min_decision_transformer is easier to understand: https://github.com/nikhilbarhate99/min-decision-transformer/blob/d6694248b48c57c84fc7487e6e8017dcca861b02/decision_transformer/model.py#L152
new action is conditioned on everything before r_t, s_t,
while new state and new return is conditioned on everything before r_t, s_t, a_t. So the original implementation is right, I think? @nawta @Pulsar110
As for x[:,0] I don't think we can get anything useful out of that.

@kzl
Copy link
Owner

kzl commented Apr 28, 2024

Hi -

You can view next token prediction as like:

ind | input | transformer | output
  0 |   R_t |     -->     | x[:,0]
  1 |   s_t |    \-->     | x[:,1]
  2 |   a_t |    \-->     | x[:,2]

Hence x[:,0] sees R_t; x[:,1] sees R_t and s_t; and x[:,2] sees all three.

Therefore, with the above formulation of the predictions, we have:

  1. return_preds = Q(R_t, s_t, a_t) (for time t+1), matching traditional Q functions
  2. state_preds = f(R_t, s_t, a_t) (for time t+1), matching traditional dynamics models s_{t+1} = f(s_t, a_t)
  3. action_preds = pi(R_t, s_t) (for time t), which is the standard return-conditioned policy formulation

If we wanted to make a prediction using x[:,0], it would be akin to predicting state given return. In most RL settings, you observe the state and reward simultaneously right after you take the action, and hence there is no need to predict the state, which is why there is no corresponding predictor for index 0. In fact, some followups to the original work combine both the return and state into one token in order to remove this redundancy.

Sorry I have been away for a very long time.

Best, Kevin

@kzl kzl closed this as completed Apr 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants