Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Recreate initializer per replica to make sure seed is properly set
Browse files Browse the repository at this point in the history
APJansen committed Jan 9, 2024
1 parent 2f3dec6 commit 1375daa
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions n3fit/src/n3fit/backends/keras_backend/multi_dense.py
Original file line number Diff line number Diff line change
@@ -161,7 +161,8 @@ class MultiInitializer(Initializer):
"""

def __init__(self, single_initializer: Initializer, replica_seeds: List[int]):
self.single_initializer = single_initializer
self.initializer_class = type(single_initializer)
self.initializer_config = single_initializer.get_config()
self.base_seed = single_initializer.seed if hasattr(single_initializer, "seed") else None
self.replica_seeds = replica_seeds

@@ -170,8 +171,9 @@ def __call__(self, shape, dtype=None, **kwargs):
per_replica_weights = []
for replica_seed in self.replica_seeds:
if self.base_seed is not None:
self.single_initializer.seed = self.base_seed + replica_seed
self.initializer_config["seed"] = self.base_seed + replica_seed
single_initializer = self.initializer_class.from_config(self.initializer_config)

per_replica_weights.append(self.single_initializer(shape, dtype, **kwargs))
per_replica_weights.append(single_initializer(shape, dtype, **kwargs))

return tf.stack(per_replica_weights, axis=0)

0 comments on commit 1375daa

Please sign in to comment.