Skip to content

Conversation

@tyler-griggs
Copy link
Member

Summary

  • Adds optional loss_fn and loss_fn_config parameters to forward_backward() for Tinker API compatibility
  • Maps Tinker algorithm names (e.g., "ppo") to SkyRL equivalents
  • Removes ppo_train() from FSDP workers - uses gradient scaling at optim_step instead
  • Updates all tests to use the new unified API

Changes

  1. WorkerDispatch (worker_dispatch.py):

    • Added loss_fn and loss_fn_config parameters to forward_backward()
    • Passes these parameters through to worker mesh methods
  2. PolicyWorkerBase (worker.py):

    • Added convert_tinker_loss_config() static method to convert Tinker's absolute ratio bounds to SkyRL's offset format
    • Gradient scaling now happens at optim_step time based on accumulated micro batches
    • Removed separate ppo_train() path for FSDP workers
  3. Tests:

    • Updated test helpers to use comprehensive parameter sets
    • Added test_convert_tinker_loss_config for Tinker config conversion
    • Updated all GPU tests to use pass_through routing and positional batch parameters

Test Plan

  • ✅ CPU tests pass: test_normalize_mini_batch_size, test_convert_tinker_loss_config
  • GPU tests can be run with: pytest tests/gpu/gpu_ci/test_training_step.py

Stack

🤖 Generated with Claude Code

@tyler-griggs tyler-griggs changed the base branch from main to arm January 23, 2026 01:56
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces parameterization for loss_fn to forward_backward for Tinker API compatibility and refactors weight synchronization into a new save_weights_for_sampler method. The changes are generally positive, simplifying the API and improving test coverage. However, I've identified a critical issue in trainer.py where the critic training path appears to be broken due to incorrect arguments being passed to _execute_training_step. Additionally, there's an inconsistency in a new test file regarding a test utility function call.

Comment on lines 1088 to 1091
with Timer("critic_train", self.all_timings):
critic_status = self._execute_training_step("critic", data, "critic")
with Timer("policy_train", self.all_timings):
policy_status = self._execute_training_step("policy", data, "policy")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a critical issue with the arguments passed to _execute_training_step for both the critic and policy models. The function signature for _execute_training_step is (self, model: str, data: TrainingInputBatch), but it's being called with three arguments here (e.g., self._execute_training_step("critic", data, "critic")). This will result in a TypeError at runtime.

While the intent seems to be to pass a loss_fn, the implementation appears incomplete. Specifically:

  1. The signature of _execute_training_step hasn't been updated to accept a third argument.
  2. Even if it were updated, the critic training path would likely fail. The loss_fn would be "critic", which is not handled by PolicyWorkerBase._get_loss_fn, and CriticWorkerBase doesn't have a comparable method to handle a parameterized loss function.

To fix this, you'll need to update the signature of _execute_training_step and ensure that both policy and critic workers can correctly handle the new loss_fn parameter. For the critic, you might want to pass None as the loss_fn if it's not meant to be parameterized, and handle that case in _execute_training_step.


# === Step 1: Do a training step ===
dp_size = policy_group.actor_infos[0].rank.dp_size
dummy_batch = make_dummy_training_batch(batch_size=dp_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The call to make_dummy_training_batch here and on line 190 seems inconsistent with changes in other test files. In other files like test_save_load_checkpoint.py and test_training_step.py, the batch_size argument was removed from this call (e.g., make_dummy_training_batch()).

If the signature of make_dummy_training_batch has changed, this could lead to test failures. For consistency across the test suite, please update this call to match the new pattern.

Suggested change
dummy_batch = make_dummy_training_batch(batch_size=dp_size)
dummy_batch = make_dummy_training_batch()

@tyler-griggs tyler-griggs changed the base branch from arm to main January 23, 2026 01:56
tyler-griggs and others added 4 commits January 23, 2026 01:59
- Remove ppo_train() from PolicyWorkerBase and CriticWorkerBase
- Workers now use forward_backward() + optim_step() with gradient scaling
- Trainer branches on strategy: Megatron uses ppo_train, FSDP uses forward_backward + optim_step
- WorkerDispatch forward_backward no longer takes Tinker params (loss_fn, loss_fn_config)
- Update tests to use TrainingInputBatch and remove ppo_train tests

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…anges

- Megatron: Remove redundant batch_to_experience call (iterator already yields Experience)
- test_save_load_model.py: Use TrainingInputBatch, remove extra forward_backward arg
- test_worker_offload.py: Use TrainingInputBatch, remove extra forward_backward arg

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@tyler-griggs tyler-griggs force-pushed the tgriggs/loss_fn_clean branch from 7a0f4c3 to 70fb844 Compare January 23, 2026 01:59
@tyler-griggs tyler-griggs marked this pull request as draft January 23, 2026 02:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant