-
Notifications
You must be signed in to change notification settings - Fork 3
/
test_tensorforce.py
34 lines (28 loc) · 966 Bytes
/
test_tensorforce.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from tensorforce import Agent, Environment
# Pre-defined or custom environment
environment = Environment.create(
environment='gym', level='CartPole', max_episode_timesteps=500
)
# Instantiate a Tensorforce agent
agent = Agent.create(
agent='tensorforce',
environment=environment, # alternatively: states, actions, (max_episode_timesteps)
memory=10000,
update=dict(unit='timesteps', batch_size=64),
optimizer=dict(type='adam', learning_rate=3e-4),
policy=dict(network='auto'),
objective='policy_gradient',
reward_estimation=dict(horizon=20)
)
# Train for 300 episodes
for _ in range(300):
# Initialize episode
states = environment.reset()
terminal = False
while not terminal:
# Episode timestep
actions = agent.act(states=states)
states, terminal, reward = environment.execute(actions=actions)
agent.observe(terminal=terminal, reward=reward)
agent.close()
environment.close()