Skip to content

Commit

Permalink
Remove multi_head_multi_task flag in QtOpt.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 402639396
Change-Id: I42eddc752a38be9cef7146cae10daf34c83ab41a
  • Loading branch information
Yao Lu authored and copybara-github committed Oct 12, 2021
1 parent 2dc591b commit 4edb06e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 52 deletions.
29 changes: 0 additions & 29 deletions tf_agents/agents/qtopt/qtopt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def __init__(
debug_summaries=False,
summarize_grads_and_vars=False,
train_step_counter=None,
multi_task_multi_head=False,
info_spec=None,
name=None):
"""Creates a Qtopt Agent.
Expand Down Expand Up @@ -209,10 +208,6 @@ def __init__(
will be written during training.
train_step_counter: An optional counter to increment every time the train
op is run. Defaults to the global_step.
multi_task_multi_head: Multi_task support. Assuming 1) a one_hot vector
'task_id' exists in observation. 2) Q_network is a multi_head network,
with each head representing a separate task. Using 'task_id' to select
Q value for each task.
info_spec: If not None, the policy info spec is set to this spec.
name: The name of this agent. All variables in this module will fall under
that name. Defaults to the class name.
Expand Down Expand Up @@ -241,7 +236,6 @@ def __init__(
}
else:
self._info_spec = ()
self._multi_task_multi_head = multi_task_multi_head

self._q_network = q_network
net_observation_spec = (time_step_spec.observation, action_spec)
Expand Down Expand Up @@ -336,10 +330,6 @@ def policy_q_network(self):
def enable_td3(self):
return self._enable_td3

@property
def multi_task_multi_head(self):
return self._multi_task_multi_head

def _setup_data_converter(self, q_network, gamma, n_step_update):
if q_network.state_spec:
if not self._in_graph_bellman_update:
Expand Down Expand Up @@ -392,7 +382,6 @@ def _setup_policy(self, time_step_spec, action_spec, emit_log_probability):
num_elites=self._num_elites_cem,
num_iterations=self._num_iter_cem,
emit_log_probability=emit_log_probability,
multi_task_multi_head=self._multi_task_multi_head,
training=False)

collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
Expand All @@ -401,8 +390,6 @@ def _setup_policy(self, time_step_spec, action_spec, emit_log_probability):
return policy, collect_policy

def _check_network_output(self, net, label):
if self._multi_task_multi_head:
return
network_utils.check_single_floating_network_output(
net.create_variables(), expected_output_shape=(), label=label)

Expand Down Expand Up @@ -671,13 +658,6 @@ def _compute_q_values(
network_state=network_state,
training=training)

if self._multi_task_multi_head:
if 'task_id' not in time_steps.observation:
raise ValueError('In order to support multi_task_multi_head, a one_hot'
' task_id field is required in observation.')
task_id = tf.argmax(time_steps.observation['task_id'], axis=-1)
q_values = tf.gather(q_values, task_id, batch_dims=1)

return q_values

def _compute_next_q_values(self, next_time_steps, info, network_state=()):
Expand All @@ -700,15 +680,6 @@ def _compute_next_q_values(self, next_time_steps, info, network_state=()):
network_state=network_state,
training=False)

if self._multi_task_multi_head:
if 'task_id' not in next_time_steps.observation:
raise ValueError('In order to support multi_task_multi_head, a '
'one_hot task_id field is required in observation.')
task_id = tf.argmax(next_time_steps.observation['task_id'], axis=-1)
q_values_target_delayed = tf.gather(
q_values_target_delayed, task_id, batch_dims=1)
q_values_target_delayed_2 = tf.gather(
q_values_target_delayed_2, task_id, batch_dims=1)
q_next_state = tf.minimum(q_values_target_delayed_2,
q_values_target_delayed)
else:
Expand Down
26 changes: 3 additions & 23 deletions tf_agents/policies/qtopt_cem_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def __init__(self,
preprocess_state_action: bool = True,
training: bool = False,
weights: types.NestedTensorOrArray = None,
multi_task_multi_head: bool = False,
name: Optional[str] = None):
"""Builds a CEM-Policy given a network and a sampler.
Expand Down Expand Up @@ -144,10 +143,6 @@ def __init__(self,
happen after a few layers inside the network.
training: Whether it is in training mode or inference mode.
weights: A nested structure of weights w/ the same structure as action.
multi_task_multi_head: Multi_task support. Assuming 1) a one_hot vector
'task_id' exists in observation. 2) Q_network is a multi_head network,
with each head representing a separate task. Using 'task_id' to select
Q value for each task.
name: The name of this policy. All variables in this module will fall
under that name. Defaults to the class name.
Expand All @@ -165,10 +160,9 @@ def __init__(self,
(action_spec, network_action_spec))

if q_network:
if not multi_task_multi_head:
network_utils.check_single_floating_network_output(
q_network.create_variables(),
expected_output_shape=(), label=str(q_network))
network_utils.check_single_floating_network_output(
q_network.create_variables(),
expected_output_shape=(), label=str(q_network))
policy_state_spec = q_network.state_spec
else:
policy_state_spec = ()
Expand All @@ -184,7 +178,6 @@ def __init__(self,
self._observation_spec = time_step_spec.observation
self._training = training
self._preprocess_state_action = preprocess_state_action
self._multi_task_multi_head = multi_task_multi_head
self._weights = weights

super(CEMPolicy, self).__init__(
Expand Down Expand Up @@ -374,10 +367,6 @@ def _score(
policy_state: A Tensor, or a nested dict, list or tuple of Tensors
representing the previous policy_state.
Raises:
ValueError: If `task_id` is not in `observation` when the policy is in
multi_task_multi_head mode.
Returns:
a tensor of shape [B, N] representing the scores for the actions.
"""
Expand All @@ -400,15 +389,6 @@ def expand_to_megabatch(feature):
scores, next_policy_state = self.compute_target_q(
observation, sample_actions, step_type, policy_state) # [BxN]

if self._multi_task_multi_head:
if 'task_id' not in observation:
raise ValueError('In order to support multi_task_multi_head, a one_hot'
' task_id field is required in observation.')

task_id = nest_utils.tile_batch(
tf.argmax(observation['task_id'], axis=-1), self._num_samples)
scores = tf.gather(scores, task_id, batch_dims=1)

if self._preprocess_state_action:
next_policy_state = tf.nest.map_structure(
lambda x: tf.reshape(x, [-1, self._num_samples] + x.shape.as_list( # pylint:disable=g-long-lambda
Expand Down

0 comments on commit 4edb06e

Please sign in to comment.