Skip to content

Commit

Permalink
[XPU] change cuda_rng_state to rng_state in fleet random (#54077)
Browse files Browse the repository at this point in the history
  • Loading branch information
houj04 authored May 24, 2023
1 parent 44044d8 commit 23baa8c
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions python/paddle/distributed/fleet/layers/mpu/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def add(self, name, seed):
self.seeds_.add(seed)
if name in self.states_:
raise ValueError(f'state {name} already exists')
orig_rng_state = paddle.get_cuda_rng_state()
orig_rng_state = paddle.get_rng_state()
paddle.seed(seed)
self.states_[name] = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(orig_rng_state)
self.states_[name] = paddle.get_rng_state()
paddle.set_rng_state(orig_rng_state)

def get_states_tracker(self):
states = {}
Expand All @@ -69,13 +69,13 @@ def set_states_tracker(self, states):
def rng_state(self, name=MODEL_PARALLEL_RNG):
if name not in self.states_:
raise ValueError(f'state {name} does not exist')
orig_cuda_rng_state = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(self.states_[name])
orig_rng_state = paddle.get_rng_state()
paddle.set_rng_state(self.states_[name])
try:
yield
finally:
self.states_[name] = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(orig_cuda_rng_state)
self.states_[name] = paddle.get_rng_state()
paddle.set_rng_state(orig_rng_state)


RNG_STATE_TRACKER = RNGStatesTracker()
Expand Down

0 comments on commit 23baa8c

Please sign in to comment.