-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch refactor #168
base: master
Are you sure you want to change the base?
Pytorch refactor #168
Conversation
archived old code, new files, basic model rough-in
/getting ready to refactor model init and data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mortonjt, some updates:
We've poured over the model a few times, but cannot get it to converge in the same place that the multimodal unit test expects. We think starting conditions or the Adam optimizer may have an outsized effect here, but we're also not sure we didn't miss something.
The general structure of our tensors are: [batch, sample, whatever], and tracing each operation seems to do what we expect.
So our X is drawn from a multinomial to do a bunch of categorical draws at once, giving us OTU indices in the form: [batch, sample]
That goes into the embedding giving us: [batch, sample, latent]
Then we slice the bias and add it to latent after reshaping it to match (maybe something went wrong here, but we've inspected that line a few different ways and it seems to do what we want).
Then we use the decoder, we run a linear model on 1 less dimension, giving us: [batch, sample, ALR_metabolites], then we add zeros to the front of that last dimension and run softmax over it to hopefully have: [batch, sample, P_metabolites].
We then parameterize the multinomial and calculate likelihoods.
As far as we can tell, this is what should be happening, so we don't really know why we get such unstable correlations from the unit-tests, ranging from -0.29
to accidentally passing for U
, and always failing by the point we check V
.
def mmvec_training_loop(model, learning_rate, batch_size, epochs): | ||
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, | ||
betas=(0.8, 0.9), maximize=True) | ||
for epoch in range(epochs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add better logic here, so that a single epoch represents the correct number of batch draws for the data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mortonjt, the paper seems to imply that an epoch represents a random draw (in batches of course) for each read in the feature-table, but the original code seems to use nnz
which I interpret to mean "n-non-zero". So this would be the number of different types of sample:microbe pairs, rather than the number of observations. What was the goal there, and should we replicate that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line was largely to make the concept of epoch more interpretable. And yes, nnz is the number of non-zeros.
One epoch is completed if you read through the entire dataset, which means that you should be able to process all of the reads. Since the batch size is computed over the number of reads, this is used to compute the number of "iterations" within each loop.
So it should read like this : 1 epoch = num iterations / epoch = (total number of reads [aka nnz] ) / (num reads per batch)
We're basically calculating how many batches are within an epoch, in order to read through the entire dataset.
That being said -- I don't think you really need this. I think the current implementation is fine -- we just need a way to make the term epochs
interpretable to the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think nnz
in the older implementation was actually the number of non-zero cells, not the sum of those cells.
It sounds like the goal was to make it the number of reads outright though (sum of the entire table). I think it's probably worth making sure that epoch fits that, if only for the sake of explanation. (It hasn't seemed to matter too much in practice while we've been testing.)
v_r, v_p = spearmanr(pdist(model.V.T), pdist(self.V.T)) | ||
|
||
self.assertGreater(u_r, 0.5) | ||
self.assertGreater(v_r, 0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We always fail by this point, but often fail the u_r
test above as well. @mortonjt, we're kind of at a loss here.
mmvec/ALR.py
Outdated
|
||
forward_dist = forward_dist.log_prob(self.metabolites) | ||
|
||
l_y = forward_dist.sum(0).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing the norm
that is multiplied against the data likelihood. @mortonjt we aren't 100% sure what its purpose is, but it kind of looks like a weird mean if you squint.
What is the interpretation of this line: https://github.com/biocore/mmvec/blob/master/mmvec/multimodal.py#L137?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, so there are two ways you can deal with the data
- You try to use the mini-batches to approximate the loss on the entire dataset
- You just compute the per-sample loss for each mini-batch
For all intents and purposes, I think it is ok to just compute the per-sample loss -- this appears to be an emerging standard in deep learning.
I think taking a mean is very ok. It'll basically be just l_y = forward_dist.sum(0).mean()
. I'm able get the tests to pass once I run this model locally.
How about this, let me try to reproduce the findings. Sometimes it may require tweaking learning rates and batch sizes. U and V does have an identifiability issue, so that is something to consider. The one metric that should always pass is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the implementation in this pull request is actually correct. We don't expect U
and V
tests to always pass (this is why we are running SVD after fitting the model). Its the U @ V
test that needs to pass.
I'm able to get the tests passing on my side (r>0.5, p<0.05). The only thing that you may want to drop is the total_count
argument in the multinomial.
mmvec/multimodal.py
Outdated
self.encoder = nn.Embedding(num_microbes, latent_dim) | ||
self.decoder = nn.Sequential( | ||
nn.Linear(latent_dim, num_metabolites), | ||
nn.Softmax(dim=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.input_bias = nn.Parameter(torch.randn(num_microbes))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you might have looked at an older commit. We should have that in the current model.
mmvec/multimodal.py
Outdated
# Three likelihoods, the likelihood of each weight and the likelihood | ||
# of the data fitting in the way that we thought | ||
# LY | ||
z = self.encoder(X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bias = self.input_bias[X]
z = z + bias.view(-1, 1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, although the .view(-1, 1) looks nicer
mmvec/ALR.py
Outdated
|
||
forward_dist = forward_dist.log_prob(self.metabolites) | ||
|
||
l_y = forward_dist.sum(0).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, so there are two ways you can deal with the data
- You try to use the mini-batches to approximate the loss on the entire dataset
- You just compute the per-sample loss for each mini-batch
For all intents and purposes, I think it is ok to just compute the per-sample loss -- this appears to be an emerging standard in deep learning.
I think taking a mean is very ok. It'll basically be just l_y = forward_dist.sum(0).mean()
. I'm able get the tests to pass once I run this model locally.
def mmvec_training_loop(model, learning_rate, batch_size, epochs): | ||
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, | ||
betas=(0.8, 0.9), maximize=True) | ||
for epoch in range(epochs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line was largely to make the concept of epoch more interpretable. And yes, nnz is the number of non-zeros.
One epoch is completed if you read through the entire dataset, which means that you should be able to process all of the reads. Since the batch size is computed over the number of reads, this is used to compute the number of "iterations" within each loop.
So it should read like this : 1 epoch = num iterations / epoch = (total number of reads [aka nnz] ) / (num reads per batch)
We're basically calculating how many batches are within an epoch, in order to read through the entire dataset.
That being said -- I don't think you really need this. I think the current implementation is fine -- we just need a way to make the term epochs
interpretable to the user.
mmvec/ALR.py
Outdated
z = z + self.encoder_bias[X].reshape((*X.shape, 1)) | ||
y_pred = self.decoder(z) | ||
|
||
forward_dist = Multinomial(total_count=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest getting rid of the total_count=0
parameter -- we don't actually need it for log_prob
.
And it may introduce a bug downstream (since the total_count isn't actually zero).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was actually a result of doing things in batch. We would run into an issue where the log_prob would indicate our calculation was out of the support of the distribution, because it had different counts sample to sample, so we solved it via this suggestion:
pytorch/pytorch#42407 (comment)
That said, looking at the documentation again, I wonder if we should be using logits
instead of probs
?
Thanks for the review @mortonjt! |
No description provided.