Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/save agents #185

Merged
merged 51 commits into from
Nov 23, 2020
Merged

Feature/save agents #185

merged 51 commits into from
Nov 23, 2020

Conversation

cpnota
Copy link
Owner

@cpnota cpnota commented Nov 20, 2020

It's here! This PR dds the ability to save/load agents, addressing #161 .

There are a few keys to the design. First of all, rather than having agent.act(state) and agent.eval(state), agents were split into a training-mode Agent and a TestAgent. Both types of agents can be instantiated by a Preset:

agent = preset.agent()
test_agent = preset.test_agent()

The Preset is a serializable object containing the hyperparameters and all necessary torch models. The TestAgent inherits a copy of the model trained by the Agent, allowing TestAgents from different points in training to be stored.

The second major key to the design is that the Preset, rather than the Agent itself is saved:

preset.save(filename)
preset = torch.load(filename)

This is important because the underlying Agent objects are often difficult to serialize, and even if they can be serialized they can take up an excessive amount of storage (for example, a standard 1 million frame Atari replay buffer is ~7 GB).

One thing to note is that while this design supports creating a training mode Agent with a previously trained network, it does not support a full "resume" of training, e.g., scheduler states will be reset and replay buffers will be cleared. Full resume functionality introduces many difficulties which interfere with the design of the library, however, we may implement a partial solution in the future.

Example usage can be found below:

# construct the preset
preset = builder().hyperparameters(lr=1e-3).env(some_env).build()

# run agent in train mode
agent = preset.agent()
for i in range(episodes)
    run_episode(agent, some_env)

# run agent in test mode
test_agent = preset.test_agent()
for i in range(test_episodes)
    run_episode(test_agent, some_env)

# save the model for later
preset.save('dqn.pt')

# load model from disk and watch
preset = torch.load('dqn.pt')
test_agent = preset.test_agent()
for i in range(test_episodes)
    run_episode(test_agent, some_env, render=True)

@cpnota cpnota merged commit 8f65a70 into develop Nov 23, 2020
@cpnota cpnota deleted the feature/save-agents branch November 23, 2020 18:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant