Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG maybe] Memory Blow up when using Warp with FO RL #524

Open
dyumanaditya opened this issue Feb 12, 2025 · 1 comment
Open

[BUG maybe] Memory Blow up when using Warp with FO RL #524

dyumanaditya opened this issue Feb 12, 2025 · 1 comment
Labels
bug Something isn't working

Comments

@dyumanaditya
Copy link

dyumanaditya commented Feb 12, 2025

Bug Description

Hello,

I have been trying to replicate some of the RL environments to be able to use First Order gradients as done in DFlex for SHAC and AHAC (link).

But one of the main changes in warp is that the internal state is no longer a set of torch tensors which means that I need to create a custom PyTorch function (torch.autograd.Function) as suggested in other issues, to keep track of the gradients.

After doing this however, the VRAM on my GPU seems to blow up in comparison with the same run with DFlex where all the states are Torch tensors instead of warp arrays. I am running the cartpole env with shac, which occupies ~500Mb of VRAM with DFlex but goes beyond 8Gb before even completing with warp.

Am I doing something wrong? Most of the environment code is the same as in the repo linked above but the only difference is this custom function. Is there a way that I can retain the same memory usage as DFlex but shift to Warp?

Please let me know if any other details are required. I wans't able to debug much with the Tape visualization but I am attaching it here as well.

Image

class Simulation(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        joint_q,
        joint_qd,
        action,
        model,
        state_0,
        control,
        integrator,
        sim_substeps,
        sim_dt,
    ):
        """
        Forward pass for the simulation step.
        """
        # save inputs for backward pass
        ctx.model = model
        ctx.state_0 = state_0
        ctx.control = control
        ctx.integrator = integrator
        ctx.sim_substeps = sim_substeps
        ctx.sim_dt = sim_dt

        # gradient tracking for the action
        ctx.joint_q = wp.from_torch(joint_q)
        ctx.joint_qd = wp.from_torch(joint_qd)
        ctx.action = wp.from_torch(action.flatten())

        # assign inputs to Warp state variables
        # NOTE: this is necessary to ensure that the integrator uses the correct values
        ctx.state_0.joint_q = ctx.joint_q  # current joint positions
        ctx.state_0.joint_qd = ctx.joint_qd  # current joint velocities
        # ctx.control.joint_act = ctx.action  # given joint action

        # prepare state for saving the simulation result
        state_1 = model.state()

        # prepare a Warp Tape for gradient tracking
        ctx.tape = wp.Tape()

        with ctx.tape:
            # Simulate forward
            # for _ in range(sim_substeps):
            state_0.clear_forces()
            # wp.sim.collide(model, state_0)
            # print(control.joint_act)
            # print(control.joint_act.shape)
            integrator.simulate(model, state_0, state_1, sim_dt, control)
            # state_0 = state_1

        # save the state for the backward pass
        ctx.state_1 = state_1

        # return the joint positions and velocities
        return wp.to_torch(state_1.joint_q), wp.to_torch(state_1.joint_qd)

    @staticmethod
    def backward(ctx, *grad_joint):
        """
        Backward pass for gradient computation.
        """
        # assign gradients to Warp state variables
        grad_joint_q, grad_joint_qd = grad_joint
        ctx.state_1.joint_q.grad = wp.from_torch(grad_joint_q, dtype=wp.float32)
        ctx.state_1.joint_qd.grad = wp.from_torch(grad_joint_qd, dtype=wp.float32)

        # backpropagate through the Warp simulation
        ctx.tape.backward()

        return_grads = (
            wp.to_torch(ctx.tape.gradients[ctx.joint_q]),   # joint_q
            wp.to_torch(ctx.tape.gradients[ctx.joint_qd]),  # joint_qd
            wp.to_torch(ctx.tape.gradients[ctx.action]),    # action
            None,  # state_0
            None,  # control
            None,  # model
            None,  # integrator
            None,  # sim_substeps
            None,  # sim_dt
        )

        ctx.tape.zero()

        # return adjoint w.r.t. inputs
        # Return as many inputs as forward pass arguments
        return return_grads

System Information

No response

@dyumanaditya dyumanaditya added the bug Something isn't working label Feb 12, 2025
@dyumanaditya dyumanaditya changed the title [BUG maybe] Memory Blow up when using Warp with RL [BUG maybe] Memory Blow up when using Warp with FO RL Feb 12, 2025
@etaoxing
Copy link
Contributor

The main tricks to get first-order RL working are graph capture and gradient checkpointing--recomputing intermediates in the backward pass. Working on code release for Rewarped, should be out in the next month!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants