Skip to content

Commit

Permalink
Merge pull request #46 from salesforce/custom-reset-api
Browse files Browse the repository at this point in the history
add custom resetter API
  • Loading branch information
Emerald01 authored Jul 28, 2022
2 parents 4f380f8 + 5e103b3 commit 7fd2454
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
9 changes: 9 additions & 0 deletions warp_drive/env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ def __init__(
self.env_resetter = CUDAEnvironmentReset(
function_manager=self.cuda_function_manager
)
# custom reset function, if not found, will ignore
reset_function = f"Cuda{self.name}Reset"
self.env_resetter.register_custom_reset_function(
self.cuda_data_manager,
reset_function_name=reset_function)

def reset_all_envs(self):
"""
Expand Down Expand Up @@ -268,6 +273,10 @@ def reset_only_done_envs(self):
self.env_resetter.reset_when_done(self.cuda_data_manager, mode="if_done")
return {}

def custom_reset_all_envs(self, args=None, block=None, grid=None):
self.env_resetter.custom_reset(args=args, block=block, grid=grid)
return {}

def step_all_envs(self, actions=None):
"""
Step through all the environments
Expand Down
26 changes: 26 additions & 0 deletions warp_drive/managers/function_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,32 @@ def __init__(self, function_manager: CUDAFunctionManager):
"undo_done_flag_and_reset_timestep"
)

self._cuda_custom_reset = None
self._cuda_reset_feed = None

def register_custom_reset_function(self, data_manager: CUDADataManager, reset_function_name=None):
if reset_function_name is None or reset_function_name not in self._function_manager._cuda_function_names:
return
self._cuda_custom_reset = self._function_manager.get_function(reset_function_name)
self._cuda_reset_feed = CUDAFunctionFeed(data_manager)

def custom_reset(self,
args: Optional[list] = None,
block=None,
grid=None):

assert self._cuda_custom_reset is not None and self._cuda_reset_feed is not None, \
"Custom Reset function is not defined, call register_custom_reset_function() first"
assert args is None or isinstance(args, list)
if block is None:
block = self._block
if grid is None:
grid = self._grid
if args is None or len(args) == 0:
self._cuda_custom_reset(block=block, grid=grid)
else:
self._cuda_custom_reset(*self._cuda_reset_feed(args), block=block, grid=grid)

def reset_when_done(
self,
data_manager: CUDADataManager,
Expand Down

0 comments on commit 7fd2454

Please sign in to comment.