-
Notifications
You must be signed in to change notification settings - Fork 334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT] ppo chess with llm and ConditionalPolicySwitch to sunfish bot #2763
base: gh/mikaylagawarecki/1/base
Are you sure you want to change the base?
[DRAFT] ppo chess with llm and ConditionalPolicySwitch to sunfish bot #2763
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2763
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 7 New Failures, 8 Unrelated FailuresAs of commit f5d2b15 with merge base dbc8e2e ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: a0f5b1468f42104c76996e53d037da3a67e80156 Pull Request resolved: #2763
[ghstack-poisoned]
ghstack-source-id: 8d779f57e748d0ff755b7cce76ca560aee4d8c6f Pull Request resolved: #2763
[ghstack-poisoned]
ghstack-source-id: c1a108c14fe9f7a4f3c23a34272e9ae21c2eefa6 Pull Request resolved: #2763
hidden = output.hidden_states[-1][:, input_length - 1, :] | ||
return log_prob, hidden | ||
else: | ||
while True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When collecting data, I do the following
(1) The LLM input will be something like the following tokenized
You are playing a game of chess. The list of moves so far are [<start>, Nf3, Nh6, Nc3] and the legal moves are [Rg8, Nc6, Na6, Ng8, Nf5, Ng4, g6, f6, e6, d6, c6, b6, a6, g5, f5, e5, d5, c5, b5, a5]. Please choose one of the legal moves. Respond only with the following sentence, with no additional explanatory text. Example Answer: I choose Rg8!
(2) I generate a maximum of 7 new tokens in loop, (using argmax over the logits to sample each token) breaking if any one of them is !
(this is supposed to be the end of the sentence)
(3) I use regex to verify the format + the chosen move was legal
(4) Repeat (2) and (3) until a valid move is chosen
(5) pad output_tokens to 7
The reasoning for (5) is that I need to make sure that during the ppo loss computation, the distribution in dist, must have the same dimension action
Does all the above sound reasonable or is there something better I could be doing here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think it's good!
The while loop for generation is a bit worrisome - because then the samples are not really gathered from the LLM.
One way to do this could be to terminate the game if the output is not valid and assign a losing reward (like -1 or smth)
examples/agents/ppo-chess-llm.py
Outdated
actor_llm_policy = ProbSeq( | ||
Mod( | ||
LLMWrapper(llm, tokenizer, mode="policy"), | ||
in_keys=["obs_tokens"], | ||
out_keys=["logits", "hidden"], | ||
), | ||
prob_module, | ||
# if use lambda: 'function' object has no attribute 'in_keys' | ||
Mod(AggregateProb(), in_keys=["sample_log_prob"], out_keys=["sample_log_prob"]), | ||
return_composite=True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does how data_llm_policy
and actor_llm_policy
were defined make sense here?
return tensor1_padded, tensor2_padded | ||
|
||
for data in tqdm(collector): | ||
# FIXME: reward seems to be getting wrongly propagated (e.g. sunfish's win gets reflected as llm's win) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
debugging this one, not sure whether it's an issue with how I applied the transforms
examples/agents/ppo-chess-llm.py
Outdated
# layout=torch.jagged errors with Qwen | ||
# File "/home/mg1998/.conda/envs/rl/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 859, in forward | ||
# cache_position = torch.arange( | ||
# past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device | ||
# ) | ||
# AttributeError: 'ConstantIntNode' object has no attribute 'add' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The takeaway here for me is that we can't generically expect to use NJT for input_tokens to the LLM unless the LLM forward is written in an "NJT friendly" manner without any queries to sizes/implicit assumptions about dense input, in this case Qwen is making some such assumptions..
[ghstack-poisoned]
ghstack-source-id: 5b0ae766a774d3cfc78b3c5d1c07a3633b9e1345 Pull Request resolved: #2763
…sunfish bot" [ghstack-poisoned]
ghstack-source-id: ac573be6daed7607813f24db6e5706069a308718 Pull Request resolved: #2763
…sunfish bot" [ghstack-poisoned]
ghstack-source-id: 43ed58c8609054368e5170b980893b948266a65b Pull Request resolved: #2763
return tensordict_reset | ||
|
||
|
||
class Score(Transform): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trying this instead of a randomly initialized critic head
return observation_spec | ||
|
||
|
||
def run_player(input_queue, output_queue): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wow that's clever! Didn't think it'd take that to use sunfish
Do you think we should upstream that in ChessEnv?
|
||
for data in tqdm(rb): | ||
|
||
data = gae(data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gae should go above the tqdm(rb)
You need the data to be presented sequentially to apply gae - even slices may not be ideal (better to compute it once and for all than at each iter).
Also I would put a no_grad around it for safety.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was something I wanted to talk about! if gae is before putting into replay buffer, we have no way to nicely append_transform to collector (I think)
rb = ReplayBuffer( | ||
storage=LazyStackStorage(100), | ||
batch_size=48, | ||
sampler=SliceSamplerWithoutReplacement(slice_len=8, end_key=("next", "done")), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in theory we may not need a slice sampler
shifted=True, | ||
) | ||
|
||
for data in tqdm(collector): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually there are 3 nested loops
for data in collector:
for n in range(n_epochs):
rb.extend(gae(data.copy()))
for d in rb:
loss(d) # etc
hidden = output.hidden_states[-1][:, input_length - 1, :] | ||
return log_prob, hidden | ||
else: | ||
while True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think it's good!
The while loop for generation is a bit worrisome - because then the samples are not really gathered from the LLM.
One way to do this could be to terminate the game if the output is not valid and assign a losing reward (like -1 or smth)
Stack from ghstack (oldest at bottom):