-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Closed
Description
I have a problem when I run the code in PyTorch DQN tutorial,the code and error are :
resize = T.Compose([T.ToPILImage(), T.Scale(40, interpolation=Image.CUBIC), T.ToTensor()])
# This is based on the code from gym.
screen_width = 600
def get_cart_location():
world_width = env.x_threshold * 2
scale = screen_width / world_width
return int(env.state[0] * scale + screen_width / 2.0) # MIDDLE OF CART
def get_screen():
screen = env.render(mode='rgb_array').transpose((2, 0, 1)) # transpose into torch order (CHW)
# Strip off the top and bottom of the screen
screen = screen[:, 160:320]
view_width = 320
cart_location = get_cart_location()
if cart_location < view_width // 2:
slice_range = slice(view_width)
elif cart_location > (screen_width - view_width // 2):
slice_range = slice(-view_width,None)
else:
slice_range = slice(cart_location - view_width // 2, cart_location + view_width // 2)
# Strip off the edges, so that we have a square image centered on a cart
screen = screen[:, :, slice_range]
# Convert to float, rescare, convert to torch tensor (this doesn't require a copy)
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
screen = torch.from_numpy(screen)
# Resize, and add a batch dimension (BCHW)
return resize(screen).unsqueeze(0)
env.reset()
plt.imshow(get_screen().squeeze(0).permute(1, 2, 0).numpy(), interpolation='none')
plt.show()
AttributeError Traceback (most recent call last)
<ipython-input-5-0076b471379d> in <module>()
30
31 env.reset()
---> 32 plt.imshow(get_screen().squeeze(0).permute(1, 2, 0).numpy(), interpolation='none')
33 plt.show()
<ipython-input-5-0076b471379d> in get_screen()
14 screen = screen[:, 160:320]
15 view_width = 320
---> 16 cart_location = get_cart_location()
17 if cart_location < view_width // 2:
18 slice_range = slice(view_width)
<ipython-input-5-0076b471379d> in get_cart_location()
4 screen_width = 600
5 def get_cart_location():
----> 6 world_width = env.x_threshold * 2
7 scale = screen_width / world_width
8 return int(env.state[0] * scale + screen_width / 2.0) # MIDDLE OF CART
AttributeError: 'TimeLimit' object has no attribute 'x_threshold'
I don't know how this happened,since I had installed the gym,and run the demo in gym successfully,any advice of you would be appreciated
Metadata
Metadata
Assignees
Labels
No labels