diff --git a/agent.py b/agent.py index e278ac1..3db44b0 100644 --- a/agent.py +++ b/agent.py @@ -78,11 +78,11 @@ def cache(self, state, next_state, action, reward, done): reward (float), done(bool)) """ - state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state) - next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state) - action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action]) - reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward]) - done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done]) + state = torch.FloatTensor(state) + next_state = torch.FloatTensor(next_state) + action = torch.LongTensor([action]) + reward = torch.DoubleTensor([reward]) + done = torch.BoolTensor([done]).cuda() self.memory.append( (state, next_state, action, reward, done,) ) @@ -93,6 +93,8 @@ def recall(self): """ batch = random.sample(self.memory, self.batch_size) state, next_state, action, reward, done = map(torch.stack, zip(*batch)) + if self.use_cuda: + state, next_state, action, reward, done = state.cuda(), next_state.cuda(), action.cuda(), reward.cuda(), done.cuda() return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()