-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ShufflingSampler can lead to significantly different free energies compared to default sampler #40
Comments
As a suggestion for a workaround, we could include an index tensor in the dataset to indicate whether a given datapoint corresponds to a duplicate or not: class TVODataLoader(DataLoader):
def __init__(self, data: to.Tensor, **kwargs):
"""TVO DataLoader class. Derived from torch.utils.data.DataLoader.
:param data: Tensor containing the input dataset. Must have exactly two dimensions (N,D).
:param kwargs: forwarded to pytorch's DataLoader.
TVODataLoader is constructed exactly the same way as pytorch's DataLoader,
but it restricts datasets to TensorDataset constructed from the data passed
as parameter. All other arguments are forwarded to pytorch's DataLoader.
In the case of distributed execution with unevenly sized datasets per worked,
TVODataLoader will sample a few datapoints twice to guarantee that each
worker iterates over the same number of batches.
When iterated over, TVODataLoader yields a tuple containing the indeces of
the datapoints in each batch, the actual datapoints as well as an index
tensor indicating whether a datapoint corresponds to a duplicate or not.
TVODataLoader instances optionally expose the attribute `precision`, which is set to the
dtype of the dataset in data if it is a floating point dtype.
"""
N = data.shape[0]
if data.dtype is not to.uint8:
self.precision = data.dtype
notduplicate = to.ones(N, dtype=to.bool)
if tvo.get_run_policy() == "mpi":
assert dist.is_initialized()
# Ranks ..., (comm_size-2), (comm_size-1) are assigned one data point more than ranks
# 0, 1, ... if the dataset cannot be evenly distributed across MPI processes (the split
# point depends on the total number of data points and number of MPI processes; cf.
# scatter_to_processes, gather_from_processes).
# To ensure that all workers can loop over batches in sync, we assign the processes
# with fewer datapoints, one randomly sampled additional datapoint, and we mark these
# additional datapoints as duplicates (s.t. models can optionally neglect it)
n_samples = to.tensor(N)
comm_size = dist.get_world_size()
broadcast(n_samples, src=comm_size - 1)
n_extra_samples = n_samples.item() - N
if n_extra_samples > 0:
assert n_extra_samples == 1 # by definition (cf. scatter_to_processes), the amount
# of datapoints on different MPI ranks should not differ
# by more than one
replace = True if n_extra_samples > N else False # should always be False
idxs_repeat = np.random.choice(N, size=n_extra_samples, replace=replace)
data = to.cat((data, data[idxs_repeat]), dim=0)
notduplicate = to.cat(
(notduplicate, to.zeros(n_extra_samples, dtype=to.bool)), dim=0
)
dataset = TensorDataset(to.arange(data.shape[0]), data, notduplicate)
super().__init__(dataset, **kwargs) The # in Trainer._train_epoch
for idx, batch, notduplicate in train_data:
# ...
batch_F = model.free_energy(idx[notduplicate], batch[notduplicate], train_states)
# or alternatively, internally handles the index tensor
batch_F = model.free_energy(idx, batch, train_states, notduplicate=notduplicate) One would additionally need to make sure, that the # in Trainer.__init__
# ...
notduplicate = train_data.dataset.tensors[2]
N_train = to.tensor(notduplicate.sum().item()) Similarly, the |
In distributed execution, the ShufflingSampler potentially samples duplicate data points to ensure synchronized batch processing on each worker. The duplicate data points contribute twice to E- and M-step.
In its current version, the
Trainer
includes terms associated to duplicate data points when evaluating free energies (e.g., here, here, here, here). This can lead to significantly different results compared to a sequential execution without the ShufflingSampler and hence without duplicate datapoints (e.g., for a SSSC-House benchmark (\sigma=50, D=144, H=512, |K|=30), I observed a free energy difference on the order of 7).Furthermore, the additional data points lead to additional terms for the Theta updates (in
update_param_batch
methods), s..t. different M-step results compared to the sequential execution setting are obtained.The text was updated successfully, but these errors were encountered: