Skip to content

Commit

Permalink
Move vectorized sampler from tf to garage.sampler (#840)
Browse files Browse the repository at this point in the history
- refactor tests accordingly
- fix tf_vars initialize bug in examples/np
  • Loading branch information
zequnyu authored Aug 16, 2019
1 parent dc8c106 commit e6943ac
Show file tree
Hide file tree
Showing 17 changed files with 35 additions and 30 deletions.
4 changes: 3 additions & 1 deletion examples/np/cem_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from garage.experiment import run_experiment
from garage.np.algos import CEM
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import OnPolicyVectorizedSampler
from garage.tf.envs import TfEnv
from garage.tf.experiment import LocalTFRunner
from garage.tf.policies import CategoricalMLPPolicy
from garage.tf.samplers import OnPolicyVectorizedSampler


def run_task(snapshot_config, *_):
Expand All @@ -27,6 +27,8 @@ def run_task(snapshot_config, *_):

baseline = LinearFeatureBaseline(env_spec=env.spec)

runner.initialize_tf_vars()

n_samples = 20

algo = CEM(
Expand Down
4 changes: 3 additions & 1 deletion examples/np/cma_es_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from garage.experiment import run_experiment
from garage.np.algos import CMAES
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import OnPolicyVectorizedSampler
from garage.tf.envs import TfEnv
from garage.tf.experiment import LocalTFRunner
from garage.tf.policies import CategoricalMLPPolicy
from garage.tf.samplers import OnPolicyVectorizedSampler


def run_task(snapshot_config, *_):
Expand All @@ -27,6 +27,8 @@ def run_task(snapshot_config, *_):

baseline = LinearFeatureBaseline(env_spec=env.spec)

runner.initialize_tf_vars()

n_samples = 20

algo = CMAES(
Expand Down
3 changes: 2 additions & 1 deletion src/garage/np/algos/batch_polopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from garage.misc import special, tensor_utils
from garage.np.algos.base import RLAlgorithm
from garage.tf.samplers import BatchSampler, OnPolicyVectorizedSampler
from garage.sampler import OnPolicyVectorizedSampler
from garage.tf.samplers import BatchSampler


class BatchPolopt(RLAlgorithm):
Expand Down
2 changes: 1 addition & 1 deletion src/garage/np/algos/off_policy_rl_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from garage.np.algos import RLAlgorithm
from garage.tf.samplers import OffPolicyVectorizedSampler
from garage.sampler import OffPolicyVectorizedSampler


class OffPolicyRLAlgorithm(RLAlgorithm):
Expand Down
7 changes: 6 additions & 1 deletion src/garage/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
from garage.sampler.base import Sampler
from garage.sampler.batch_sampler import BatchSampler
from garage.sampler.is_sampler import ISSampler
from garage.sampler.off_policy_vectorized_sampler import (
OffPolicyVectorizedSampler)
from garage.sampler.on_policy_vectorized_sampler import (
OnPolicyVectorizedSampler)
from garage.sampler.parallel_vec_env_executor import ParallelVecEnvExecutor
from garage.sampler.ray_sampler import RaySampler, SamplerWorker
from garage.sampler.stateful_pool import singleton_pool
from garage.sampler.vec_env_executor import VecEnvExecutor

__all__ = [
'BaseSampler', 'BatchSampler', 'Sampler', 'ISSampler', 'singleton_pool',
'RaySampler', 'SamplerWorker', 'ParallelVecEnvExecutor', 'VecEnvExecutor'
'RaySampler', 'SamplerWorker', 'ParallelVecEnvExecutor', 'VecEnvExecutor',
'OffPolicyVectorizedSampler', 'OnPolicyVectorizedSampler'
]
2 changes: 1 addition & 1 deletion src/garage/sampler/batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class BatchSampler(BaseSampler):
"""

def __init__(self, algo, env):
super(BatchSampler, self).__init__(algo, env)
super().__init__(algo, env)

def start_worker(self):
"""Start worker function."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from garage.misc import tensor_utils
from garage.misc.overrides import overrides
from garage.sampler import VecEnvExecutor
from garage.tf.samplers.batch_sampler import BatchSampler
from garage.sampler.batch_sampler import BatchSampler
from garage.sampler.vec_env_executor import VecEnvExecutor


class OffPolicyVectorizedSampler(BatchSampler):
Expand All @@ -33,7 +33,7 @@ class OffPolicyVectorizedSampler(BatchSampler):
def __init__(self, algo, env, n_envs=None, no_reset=True):
if n_envs is None:
n_envs = int(algo.rollout_batch_size)
super(OffPolicyVectorizedSampler, self).__init__(algo, env, n_envs)
super().__init__(algo, env)
self.n_envs = n_envs
self.no_reset = no_reset

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from garage.misc import tensor_utils
from garage.misc.overrides import overrides
from garage.misc.prog_bar_counter import ProgBarCounter
from garage.sampler import VecEnvExecutor
from garage.sampler.batch_sampler import BatchSampler
from garage.sampler.utils import truncate_paths
from garage.tf.samplers.batch_sampler import BatchSampler
from garage.sampler.vec_env_executor import VecEnvExecutor


class OnPolicyVectorizedSampler(BatchSampler):
def __init__(self, algo, env, n_envs=1):
super(OnPolicyVectorizedSampler, self).__init__(algo, env, n_envs)
super().__init__(algo, env)
self.n_envs = n_envs

@overrides
Expand Down
3 changes: 2 additions & 1 deletion src/garage/tf/algos/batch_polopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

from garage.misc import special
from garage.np.algos import RLAlgorithm
from garage.sampler import OnPolicyVectorizedSampler
from garage.tf.misc import tensor_utils
from garage.tf.samplers import BatchSampler, OnPolicyVectorizedSampler
from garage.tf.samplers import BatchSampler


class BatchPolopt(RLAlgorithm):
Expand Down
9 changes: 1 addition & 8 deletions src/garage/tf/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
from garage.tf.samplers.batch_sampler import BatchSampler
from garage.tf.samplers.off_policy_vectorized_sampler import (
OffPolicyVectorizedSampler)
from garage.tf.samplers.on_policy_vectorized_sampler import (
OnPolicyVectorizedSampler)
from garage.tf.samplers.ray_sampler import (RaySamplerTF, SamplerWorkerTF)

__all__ = [
'BatchSampler', 'OffPolicyVectorizedSampler', 'OnPolicyVectorizedSampler',
'RaySamplerTF', 'SamplerWorkerTF'
]
__all__ = ['BatchSampler', 'RaySamplerTF', 'SamplerWorkerTF']
3 changes: 2 additions & 1 deletion src/garage/tf/samplers/batch_sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import tensorflow as tf

from garage.sampler import parallel_sampler, singleton_pool
from garage.sampler import parallel_sampler
from garage.sampler.base import BaseSampler
from garage.sampler.stateful_pool import singleton_pool
from garage.sampler.utils import truncate_paths


Expand Down
3 changes: 2 additions & 1 deletion tests/garage/np/algos/test_batch_polopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from garage.envs import normalize
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import OnPolicyVectorizedSampler
from garage.tf.envs import TfEnv
from garage.tf.samplers import BatchSampler, OnPolicyVectorizedSampler
from garage.tf.samplers import BatchSampler
from tests.fixtures.algos import DummyAlgo
from tests.fixtures.policies import DummyPolicy, DummyPolicyWithoutVectorized

Expand Down
2 changes: 1 addition & 1 deletion tests/garage/np/algos/test_cem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from garage.np.algos import CEM
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import OnPolicyVectorizedSampler
from garage.tf.envs import TfEnv
from garage.tf.experiment import LocalTFRunner
from garage.tf.policies import CategoricalMLPPolicy
from garage.tf.samplers import OnPolicyVectorizedSampler
from tests.fixtures import TfGraphTestCase


Expand Down
2 changes: 1 addition & 1 deletion tests/garage/np/algos/test_cma_es.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from garage.np.algos import CMAES
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import OnPolicyVectorizedSampler
from garage.tf.envs import TfEnv
from garage.tf.experiment import LocalTFRunner
from garage.tf.policies import CategoricalMLPPolicy
from garage.tf.samplers import OnPolicyVectorizedSampler
from tests.fixtures import TfGraphTestCase


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from garage.envs import normalize
from garage.np.exploration_strategies import OUStrategy
from garage.replay_buffer import SimpleReplayBuffer
from garage.sampler import OffPolicyVectorizedSampler
from garage.tf.algos import DDPG
from garage.tf.envs import TfEnv
from garage.tf.experiment import LocalTFRunner
from garage.tf.policies import ContinuousMLPPolicyWithModel
from garage.tf.q_functions import ContinuousMLPQFunction
from garage.tf.samplers import OffPolicyVectorizedSampler
from tests.fixtures import TfGraphTestCase
from tests.fixtures.envs.dummy import DummyDictEnv
from tests.fixtures.policies import DummyPolicy
Expand Down
4 changes: 1 addition & 3 deletions tests/garage/sampler/test_ray_batched_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@

from garage.envs.grid_world_env import GridWorldEnv
from garage.np.policies import ScriptedPolicy
from garage.sampler import RaySampler, SamplerWorker
from garage.sampler import OnPolicyVectorizedSampler, RaySampler, SamplerWorker
from garage.tf.envs import TfEnv
from garage.tf.samplers.on_policy_vectorized_sampler \
import OnPolicyVectorizedSampler


class TestSampler:
Expand Down
3 changes: 2 additions & 1 deletion tests/garage/tf/algos/test_batch_polopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from garage.envs import normalize
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import OnPolicyVectorizedSampler
from garage.tf.envs import TfEnv
from garage.tf.samplers import BatchSampler, OnPolicyVectorizedSampler
from garage.tf.samplers import BatchSampler
from tests.fixtures.algos import DummyTFAlgo
from tests.fixtures.policies import DummyPolicy, DummyPolicyWithoutVectorized

Expand Down

0 comments on commit e6943ac

Please sign in to comment.