From 4edb06e5adad85fef775e913aaeb4ac333609d6d Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Tue, 12 Oct 2021 13:19:52 -0700 Subject: [PATCH] Remove multi_head_multi_task flag in QtOpt. PiperOrigin-RevId: 402639396 Change-Id: I42eddc752a38be9cef7146cae10daf34c83ab41a --- tf_agents/agents/qtopt/qtopt_agent.py | 29 -------------------------- tf_agents/policies/qtopt_cem_policy.py | 26 +++-------------------- 2 files changed, 3 insertions(+), 52 deletions(-) diff --git a/tf_agents/agents/qtopt/qtopt_agent.py b/tf_agents/agents/qtopt/qtopt_agent.py index d84dc49d8..18b01c588 100644 --- a/tf_agents/agents/qtopt/qtopt_agent.py +++ b/tf_agents/agents/qtopt/qtopt_agent.py @@ -124,7 +124,6 @@ def __init__( debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=None, - multi_task_multi_head=False, info_spec=None, name=None): """Creates a Qtopt Agent. @@ -209,10 +208,6 @@ def __init__( will be written during training. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. - multi_task_multi_head: Multi_task support. Assuming 1) a one_hot vector - 'task_id' exists in observation. 2) Q_network is a multi_head network, - with each head representing a separate task. Using 'task_id' to select - Q value for each task. info_spec: If not None, the policy info spec is set to this spec. name: The name of this agent. All variables in this module will fall under that name. Defaults to the class name. @@ -241,7 +236,6 @@ def __init__( } else: self._info_spec = () - self._multi_task_multi_head = multi_task_multi_head self._q_network = q_network net_observation_spec = (time_step_spec.observation, action_spec) @@ -336,10 +330,6 @@ def policy_q_network(self): def enable_td3(self): return self._enable_td3 - @property - def multi_task_multi_head(self): - return self._multi_task_multi_head - def _setup_data_converter(self, q_network, gamma, n_step_update): if q_network.state_spec: if not self._in_graph_bellman_update: @@ -392,7 +382,6 @@ def _setup_policy(self, time_step_spec, action_spec, emit_log_probability): num_elites=self._num_elites_cem, num_iterations=self._num_iter_cem, emit_log_probability=emit_log_probability, - multi_task_multi_head=self._multi_task_multi_head, training=False) collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy( @@ -401,8 +390,6 @@ def _setup_policy(self, time_step_spec, action_spec, emit_log_probability): return policy, collect_policy def _check_network_output(self, net, label): - if self._multi_task_multi_head: - return network_utils.check_single_floating_network_output( net.create_variables(), expected_output_shape=(), label=label) @@ -671,13 +658,6 @@ def _compute_q_values( network_state=network_state, training=training) - if self._multi_task_multi_head: - if 'task_id' not in time_steps.observation: - raise ValueError('In order to support multi_task_multi_head, a one_hot' - ' task_id field is required in observation.') - task_id = tf.argmax(time_steps.observation['task_id'], axis=-1) - q_values = tf.gather(q_values, task_id, batch_dims=1) - return q_values def _compute_next_q_values(self, next_time_steps, info, network_state=()): @@ -700,15 +680,6 @@ def _compute_next_q_values(self, next_time_steps, info, network_state=()): network_state=network_state, training=False) - if self._multi_task_multi_head: - if 'task_id' not in next_time_steps.observation: - raise ValueError('In order to support multi_task_multi_head, a ' - 'one_hot task_id field is required in observation.') - task_id = tf.argmax(next_time_steps.observation['task_id'], axis=-1) - q_values_target_delayed = tf.gather( - q_values_target_delayed, task_id, batch_dims=1) - q_values_target_delayed_2 = tf.gather( - q_values_target_delayed_2, task_id, batch_dims=1) q_next_state = tf.minimum(q_values_target_delayed_2, q_values_target_delayed) else: diff --git a/tf_agents/policies/qtopt_cem_policy.py b/tf_agents/policies/qtopt_cem_policy.py index d8078d40a..3d8c05086 100644 --- a/tf_agents/policies/qtopt_cem_policy.py +++ b/tf_agents/policies/qtopt_cem_policy.py @@ -116,7 +116,6 @@ def __init__(self, preprocess_state_action: bool = True, training: bool = False, weights: types.NestedTensorOrArray = None, - multi_task_multi_head: bool = False, name: Optional[str] = None): """Builds a CEM-Policy given a network and a sampler. @@ -144,10 +143,6 @@ def __init__(self, happen after a few layers inside the network. training: Whether it is in training mode or inference mode. weights: A nested structure of weights w/ the same structure as action. - multi_task_multi_head: Multi_task support. Assuming 1) a one_hot vector - 'task_id' exists in observation. 2) Q_network is a multi_head network, - with each head representing a separate task. Using 'task_id' to select - Q value for each task. name: The name of this policy. All variables in this module will fall under that name. Defaults to the class name. @@ -165,10 +160,9 @@ def __init__(self, (action_spec, network_action_spec)) if q_network: - if not multi_task_multi_head: - network_utils.check_single_floating_network_output( - q_network.create_variables(), - expected_output_shape=(), label=str(q_network)) + network_utils.check_single_floating_network_output( + q_network.create_variables(), + expected_output_shape=(), label=str(q_network)) policy_state_spec = q_network.state_spec else: policy_state_spec = () @@ -184,7 +178,6 @@ def __init__(self, self._observation_spec = time_step_spec.observation self._training = training self._preprocess_state_action = preprocess_state_action - self._multi_task_multi_head = multi_task_multi_head self._weights = weights super(CEMPolicy, self).__init__( @@ -374,10 +367,6 @@ def _score( policy_state: A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. - Raises: - ValueError: If `task_id` is not in `observation` when the policy is in - multi_task_multi_head mode. - Returns: a tensor of shape [B, N] representing the scores for the actions. """ @@ -400,15 +389,6 @@ def expand_to_megabatch(feature): scores, next_policy_state = self.compute_target_q( observation, sample_actions, step_type, policy_state) # [BxN] - if self._multi_task_multi_head: - if 'task_id' not in observation: - raise ValueError('In order to support multi_task_multi_head, a one_hot' - ' task_id field is required in observation.') - - task_id = nest_utils.tile_batch( - tf.argmax(observation['task_id'], axis=-1), self._num_samples) - scores = tf.gather(scores, task_id, batch_dims=1) - if self._preprocess_state_action: next_policy_state = tf.nest.map_structure( lambda x: tf.reshape(x, [-1, self._num_samples] + x.shape.as_list( # pylint:disable=g-long-lambda