You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have been trying to replicate some of the RL environments to be able to use First Order gradients as done in DFlex for SHAC and AHAC (link).
But one of the main changes in warp is that the internal state is no longer a set of torch tensors which means that I need to create a custom PyTorch function (torch.autograd.Function) as suggested in other issues, to keep track of the gradients.
After doing this however, the VRAM on my GPU seems to blow up in comparison with the same run with DFlex where all the states are Torch tensors instead of warp arrays. I am running the cartpole env with shac, which occupies ~500Mb of VRAM with DFlex but goes beyond 8Gb before even completing with warp.
Am I doing something wrong? Most of the environment code is the same as in the repo linked above but the only difference is this custom function. Is there a way that I can retain the same memory usage as DFlex but shift to Warp?
Please let me know if any other details are required. I wans't able to debug much with the Tape visualization but I am attaching it here as well.
classSimulation(torch.autograd.Function):
@staticmethoddefforward(
ctx,
joint_q,
joint_qd,
action,
model,
state_0,
control,
integrator,
sim_substeps,
sim_dt,
):
""" Forward pass for the simulation step. """# save inputs for backward passctx.model=modelctx.state_0=state_0ctx.control=controlctx.integrator=integratorctx.sim_substeps=sim_substepsctx.sim_dt=sim_dt# gradient tracking for the actionctx.joint_q=wp.from_torch(joint_q)
ctx.joint_qd=wp.from_torch(joint_qd)
ctx.action=wp.from_torch(action.flatten())
# assign inputs to Warp state variables# NOTE: this is necessary to ensure that the integrator uses the correct valuesctx.state_0.joint_q=ctx.joint_q# current joint positionsctx.state_0.joint_qd=ctx.joint_qd# current joint velocities# ctx.control.joint_act = ctx.action # given joint action# prepare state for saving the simulation resultstate_1=model.state()
# prepare a Warp Tape for gradient trackingctx.tape=wp.Tape()
withctx.tape:
# Simulate forward# for _ in range(sim_substeps):state_0.clear_forces()
# wp.sim.collide(model, state_0)# print(control.joint_act)# print(control.joint_act.shape)integrator.simulate(model, state_0, state_1, sim_dt, control)
# state_0 = state_1# save the state for the backward passctx.state_1=state_1# return the joint positions and velocitiesreturnwp.to_torch(state_1.joint_q), wp.to_torch(state_1.joint_qd)
@staticmethoddefbackward(ctx, *grad_joint):
""" Backward pass for gradient computation. """# assign gradients to Warp state variablesgrad_joint_q, grad_joint_qd=grad_jointctx.state_1.joint_q.grad=wp.from_torch(grad_joint_q, dtype=wp.float32)
ctx.state_1.joint_qd.grad=wp.from_torch(grad_joint_qd, dtype=wp.float32)
# backpropagate through the Warp simulationctx.tape.backward()
return_grads= (
wp.to_torch(ctx.tape.gradients[ctx.joint_q]), # joint_qwp.to_torch(ctx.tape.gradients[ctx.joint_qd]), # joint_qdwp.to_torch(ctx.tape.gradients[ctx.action]), # actionNone, # state_0None, # controlNone, # modelNone, # integratorNone, # sim_substepsNone, # sim_dt
)
ctx.tape.zero()
# return adjoint w.r.t. inputs# Return as many inputs as forward pass argumentsreturnreturn_grads
System Information
No response
The text was updated successfully, but these errors were encountered:
The main tricks to get first-order RL working are graph capture and gradient checkpointing--recomputing intermediates in the backward pass. Working on code release for Rewarped, should be out in the next month!
Bug Description
Hello,
I have been trying to replicate some of the RL environments to be able to use First Order gradients as done in DFlex for SHAC and AHAC (link).
But one of the main changes in warp is that the internal state is no longer a set of torch tensors which means that I need to create a custom PyTorch function (
torch.autograd.Function
) as suggested in other issues, to keep track of the gradients.After doing this however, the VRAM on my GPU seems to blow up in comparison with the same run with DFlex where all the states are Torch tensors instead of warp arrays. I am running the cartpole env with shac, which occupies ~500Mb of VRAM with DFlex but goes beyond 8Gb before even completing with warp.
Am I doing something wrong? Most of the environment code is the same as in the repo linked above but the only difference is this custom function. Is there a way that I can retain the same memory usage as DFlex but shift to Warp?
Please let me know if any other details are required. I wans't able to debug much with the Tape visualization but I am attaching it here as well.
System Information
No response
The text was updated successfully, but these errors were encountered: