Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] ppo chess with llm and ConditionalPolicySwitch to sunfish bot #2763

Draft
wants to merge 6 commits into
base: gh/mikaylagawarecki/1/base
Choose a base branch
from

Conversation

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Feb 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2763

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 7 New Failures, 8 Unrelated Failures

As of commit f5d2b15 with merge base dbc8e2e (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 5, 2025
mikaylagawarecki added a commit that referenced this pull request Feb 5, 2025
ghstack-source-id: a0f5b1468f42104c76996e53d037da3a67e80156
Pull Request resolved: #2763
mikaylagawarecki added a commit that referenced this pull request Feb 5, 2025
ghstack-source-id: 8d779f57e748d0ff755b7cce76ca560aee4d8c6f
Pull Request resolved: #2763
mikaylagawarecki added a commit that referenced this pull request Feb 5, 2025
ghstack-source-id: c1a108c14fe9f7a4f3c23a34272e9ae21c2eefa6
Pull Request resolved: #2763
hidden = output.hidden_states[-1][:, input_length - 1, :]
return log_prob, hidden
else:
while True:
Copy link
Author

@mikaylagawarecki mikaylagawarecki Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When collecting data, I do the following

(1) The LLM input will be something like the following tokenized

You are playing a game of chess. The list of moves so far are [<start>, Nf3, Nh6, Nc3] and the legal moves are [Rg8, Nc6, Na6, Ng8, Nf5, Ng4, g6, f6, e6, d6, c6, b6, a6, g5, f5, e5, d5, c5, b5, a5]. Please choose one of the legal moves. Respond only with the following sentence, with no additional explanatory text. Example Answer: I choose Rg8!

(2) I generate a maximum of 7 new tokens in loop, (using argmax over the logits to sample each token) breaking if any one of them is ! (this is supposed to be the end of the sentence)
(3) I use regex to verify the format + the chosen move was legal
(4) Repeat (2) and (3) until a valid move is chosen
(5) pad output_tokens to 7

The reasoning for (5) is that I need to make sure that during the ppo loss computation, the distribution in dist, must have the same dimension action

Does all the above sound reasonable or is there something better I could be doing here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think it's good!
The while loop for generation is a bit worrisome - because then the samples are not really gathered from the LLM.
One way to do this could be to terminate the game if the output is not valid and assign a losing reward (like -1 or smth)

Comment on lines 340 to 350
actor_llm_policy = ProbSeq(
Mod(
LLMWrapper(llm, tokenizer, mode="policy"),
in_keys=["obs_tokens"],
out_keys=["logits", "hidden"],
),
prob_module,
# if use lambda: 'function' object has no attribute 'in_keys'
Mod(AggregateProb(), in_keys=["sample_log_prob"], out_keys=["sample_log_prob"]),
return_composite=True,
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does how data_llm_policy and actor_llm_policy were defined make sense here?

return tensor1_padded, tensor2_padded

for data in tqdm(collector):
# FIXME: reward seems to be getting wrongly propagated (e.g. sunfish's win gets reflected as llm's win)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debugging this one, not sure whether it's an issue with how I applied the transforms

Comment on lines 375 to 380
# layout=torch.jagged errors with Qwen
# File "/home/mg1998/.conda/envs/rl/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 859, in forward
# cache_position = torch.arange(
# past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
# )
# AttributeError: 'ConstantIntNode' object has no attribute 'add'
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The takeaway here for me is that we can't generically expect to use NJT for input_tokens to the LLM unless the LLM forward is written in an "NJT friendly" manner without any queries to sizes/implicit assumptions about dense input, in this case Qwen is making some such assumptions..

mikaylagawarecki added a commit that referenced this pull request Feb 5, 2025
ghstack-source-id: 5b0ae766a774d3cfc78b3c5d1c07a3633b9e1345
Pull Request resolved: #2763
@mikaylagawarecki mikaylagawarecki changed the title ppo chess with llm draft [DRAFT] ppo chess with llm and ConditionalPolicySwitch to sunfish bot Feb 5, 2025
mikaylagawarecki added a commit that referenced this pull request Feb 6, 2025
ghstack-source-id: ac573be6daed7607813f24db6e5706069a308718
Pull Request resolved: #2763
mikaylagawarecki added a commit that referenced this pull request Feb 6, 2025
ghstack-source-id: 43ed58c8609054368e5170b980893b948266a65b
Pull Request resolved: #2763
return tensordict_reset


class Score(Transform):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trying this instead of a randomly initialized critic head

return observation_spec


def run_player(input_queue, output_queue):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wow that's clever! Didn't think it'd take that to use sunfish
Do you think we should upstream that in ChessEnv?


for data in tqdm(rb):

data = gae(data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gae should go above the tqdm(rb)

You need the data to be presented sequentially to apply gae - even slices may not be ideal (better to compute it once and for all than at each iter).

Also I would put a no_grad around it for safety.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was something I wanted to talk about! if gae is before putting into replay buffer, we have no way to nicely append_transform to collector (I think)

rb = ReplayBuffer(
storage=LazyStackStorage(100),
batch_size=48,
sampler=SliceSamplerWithoutReplacement(slice_len=8, end_key=("next", "done")),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in theory we may not need a slice sampler

shifted=True,
)

for data in tqdm(collector):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually there are 3 nested loops

for data in collector:
    for n in range(n_epochs):
        rb.extend(gae(data.copy())) 
        for d in rb:
            loss(d) # etc

hidden = output.hidden_states[-1][:, input_length - 1, :]
return log_prob, hidden
else:
while True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think it's good!
The while loop for generation is a bit worrisome - because then the samples are not really gathered from the LLM.
One way to do this could be to terminate the game if the output is not valid and assign a losing reward (like -1 or smth)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants