Skip to content

Commit

Permalink
allow critic.last_bias_init_value to be configurable; various fixes a…
Browse files Browse the repository at this point in the history
…nd debug message improvements.
  • Loading branch information
Le Horizon committed May 5, 2022
1 parent 9c51823 commit a455316
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 5 deletions.
2 changes: 2 additions & 0 deletions alf/algorithms/merlin_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,8 @@ def __init__(self,
enc_layers.append(res_block)
in_channels = 64

if output_activation is None:
output_activation = alf.math.identity
enc_layers.extend([
nn.Flatten(),
alf.layers.FC(
Expand Down
6 changes: 5 additions & 1 deletion alf/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,11 @@ def pre_config(configs):
try:
config1(name, value, mutable=False)
_HANDLED_PRE_CONFIGS.append((name, value))
except ValueError:
except ValueError as e:
# Most of the times, for command line flags, this warning is a false alarm.
# This can be useful in other failures, e.g. when the Config has already been used,
# before configuring its value.
logging.warning("pre_config potential error: %s", e)
_PRE_CONFIGS.append((name, value))


Expand Down
3 changes: 2 additions & 1 deletion alf/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ def _generate_time_step(batched,
if env_id is None:
env_id = md.arange(batch_size, dtype=md.int32)
if reward is not None:
assert reward.shape[:1] == outer_dims
assert reward.shape[:1] == outer_dims, "%s, %s" % (reward.shape,
outer_dims)
if prev_action is not None:
flat_action = nest.flatten(prev_action)
assert flat_action[0].shape[:1] == outer_dims
Expand Down
4 changes: 3 additions & 1 deletion alf/networks/critic_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self,
joint_fc_layer_params=None,
activation=torch.relu_,
kernel_initializer=None,
last_bias_init_value=0.0,
use_fc_bn=False,
use_naive_parallel_network=False,
name="CriticNetwork"):
Expand Down Expand Up @@ -174,7 +175,8 @@ def __init__(self,
last_activation=math_ops.identity,
use_fc_bn=use_fc_bn,
last_kernel_initializer=last_kernel_initializer,
name=name)
last_bias_init_value=last_bias_init_value,
name=name + ".joint_encoder")
self._use_naive_parallel_network = use_naive_parallel_network

def make_parallel(self, n):
Expand Down
4 changes: 3 additions & 1 deletion alf/networks/encoding_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def __init__(self,
last_layer_size=None,
last_activation=None,
last_kernel_initializer=None,
last_bias_init_value=0.0,
last_use_fc_bn=False,
name="EncodingNetwork"):
"""
Expand Down Expand Up @@ -540,7 +541,8 @@ def __init__(self,
last_layer_size,
activation=last_activation,
use_bn=last_use_fc_bn,
kernel_initializer=last_kernel_initializer))
kernel_initializer=last_kernel_initializer,
bias_init_value=last_bias_init_value))
input_size = last_layer_size

if output_tensor_spec is not None:
Expand Down
1 change: 1 addition & 0 deletions alf/trainers/policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def __init__(self, config: TrainerConfig, ddp_rank: int = -1):
logging.info(
"observation_spec=%s" % pprint.pformat(env.observation_spec()))
logging.info("action_spec=%s" % pprint.pformat(env.action_spec()))
logging.info("reward_spec=%s" % pprint.pformat(env.reward_spec()))

# for offline buffer construction
untransformed_observation_spec = env.observation_spec()
Expand Down
2 changes: 2 additions & 0 deletions alf/utils/external_configurables.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,5 @@

gin.external_configurable(torch.nn.init.xavier_normal_,
'torch.nn.init.xavier_normal_')
gin.external_configurable(torch.nn.Embedding, 'torch.nn.Embedding')
gin.external_configurable(torch.nn.Sequential, 'torch.nn.Sequential')
6 changes: 5 additions & 1 deletion alf/utils/normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ def _summary(name, val):
def _summarize_all(path, t, m2, m):
if path:
path += "."
spec = TensorSpec.from_tensor(m if m2 is None else m2)
if m2 is not None:
spec = TensorSpec.from_tensor(m2)
else:
assert m is not None
spec = TensorSpec.from_tensor(m)
_summary(path + "tensor.batch_min",
_reduce_along_batch_dims(t, spec, torch.min))
_summary(path + "tensor.batch_max",
Expand Down

0 comments on commit a455316

Please sign in to comment.