Skip to content

Commit

Permalink
change test to slice_nd
Browse files Browse the repository at this point in the history
  • Loading branch information
jotix16 committed Jun 28, 2021
1 parent fd37e8e commit 4eaa89a
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tests/test_TFNetworkRecLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3361,12 +3361,14 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
with make_scope() as session:
print("Create non-optimized rec layer (with subnet layer moved out)")
rec_layer_dict["optimize_move_layers_out"] = False
rec_layer_dict["unit"]["window"]["class"] = "slice_nd"
net1 = TFNetwork(config=config, train_flag=True, name="<root_not_opt>")
if shared_base_net:
net1.construct_from_dict(shared_base_net)
for key in shared_base_net:
assert key in net1.layers
net1.construct_from_dict({"output_not_opt": rec_layer_dict})
rec_layer_dict["unit"]["window"]["class"] = "slice_nd2"
rec_layer_dict["optimize_move_layers_out"] = True
print("Create optimized rec layer (with subnet layer inside loop)")
net2 = TFNetwork(config=config, extern_data=net1.extern_data, train_flag=True, name="<root_opt>")
Expand Down Expand Up @@ -3611,7 +3613,7 @@ def random_start_positions(source, **kwargs):
from_="position",
other_subnet_layers={
"my_layer": {"class": "gather_nd", "from": "base:data", "position": ":i"},
"window": {"class": "slice_nd2", # no_opt: [B,4,D], opt: [B,T,4,D]
"window": {"class": "slice_nd", # no_opt: [B,4,D], opt: [B,T,4,D]
"from": "base:data", "start": "data:source", "size": 4, "is_output_layer": True},
},
shared_base_net={
Expand Down

0 comments on commit 4eaa89a

Please sign in to comment.