Skip to content

Commit

Permalink
Cloudpickle for experiment API method_call (#821)
Browse files Browse the repository at this point in the history
- use cloudpickle for method_call dump/loads
- add integration test for examples/
  • Loading branch information
zequnyu authored and krzentner committed Jul 30, 2019
1 parent 72e3cc3 commit f4a6271
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 2 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
'akro==0.0.6',
'cached_property',
'click',
'cloudpickle',
'cma==1.1.06',
# dm_control throws an error during install about not being able to
# find a build dependency (absl-py). Later pip executes the `install`
Expand Down
3 changes: 2 additions & 1 deletion src/garage/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import subprocess
import sys

import cloudpickle
import dateutil.tz
import numpy as np

Expand Down Expand Up @@ -276,7 +277,7 @@ def run_experiment(method_call=None,

for task in batch_tasks:
call = task.pop('method_call')
data = base64.b64encode(pickle.dumps(call)).decode('utf-8')
data = base64.b64encode(cloudpickle.dumps(call)).decode('utf-8')
task['args_data'] = data
exp_count += 1

Expand Down
3 changes: 2 additions & 1 deletion src/garage/experiment/experiment_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sys
import uuid

import cloudpickle
import dateutil.tz
import dowel
from dowel import logger
Expand Down Expand Up @@ -182,7 +183,7 @@ def run_experiment(argv):
snapshot_mode=args.snapshot_mode,
snapshot_gap=args.snapshot_gap)

method_call = pickle.loads(base64.b64decode(args.args_data))
method_call = cloudpickle.loads(base64.b64decode(args.args_data))
try:
method_call(snapshot_config, variant_data, args.resume_from_dir,
args.resume_from_epoch)
Expand Down
38 changes: 38 additions & 0 deletions tests/integration_tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
This is an integration test to make sure scripts from examples/
work when running `python examples/xx/xxx.py`.
"""
from garage.experiment import LocalRunner, run_experiment
from garage.np.baselines import LinearFeatureBaseline
from garage.tf.algos import VPG
from garage.tf.envs import TfEnv
from garage.tf.policies import CategoricalMLPPolicy


def _run_task(snapshot_config, *_):
with LocalRunner(snapshot_config=snapshot_config) as runner:
env = TfEnv(env_name='CartPole-v1')

policy = CategoricalMLPPolicy(
name='policy', env_spec=env.spec, hidden_sizes=(32, 32))

baseline = LinearFeatureBaseline(env_spec=env.spec)

algo = VPG(
env_spec=env.spec,
policy=policy,
baseline=baseline,
max_path_length=100,
discount=0.99,
optimizer_args=dict(tf_optimizer_args=dict(learning_rate=0.01, )))

runner.setup(algo, env)
runner.train(n_epochs=3, batch_size=100)


if __name__ == '__main__':
run_experiment(
_run_task,
snapshot_mode='last',
seed=1,
)

0 comments on commit f4a6271

Please sign in to comment.