Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer classifier #840

Merged
merged 11 commits into from
Apr 30, 2024
Merged

Conversation

SalmanMohammadi
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi commented Apr 22, 2024

Context/Changelog

See #812

This PR adds a TransformerClassifier layer which extends the TransformerDecoder functionality to classification tasks. Exemplar component and model builders have been implemented for the base mistral model.

Test plan

Testing this was tricky as there is currently no reference implementation for base mistral to test against. I performed some numerical testing against HuggingFace MistralModel and AutoModelForSequenceClassification with mistralai/Mistral-7B-v0.1, with the same parameters found in test_llama3.py. However, neither the classification model outputs, nor the base MistralModel and torchtune.models.mistral.mistral produced similar outputs to the HF models.
I've left a sort-of dummy test in which just asserts the output shapes are correct. I can probably test the sequence pooling and classification independently once we agree on how they integrate into the codebase.

Questions/Next steps

This is part of a broader plan to implement RLHF in Torchtune. The TransformerClassifier can hopefully be used against any sequence model we have. @kartikayk - we could implement a recipe for training a mistral reward model using this. I can start implementing a PPO recipe using this reward model, too.

Copy link

pytorch-bot bot commented Apr 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/840

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 13063d1 with merge base fde0dc4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 22, 2024
@RdoubleA
Copy link
Contributor

Thanks @SalmanMohammadi for this PR! This looks mostly reasonable but I have a few high level suggestions.

The new functionality you're adding is mainly using the output layer of TransformerDecoder to project to number of classes, and then extracting the last predicted token. Why not just make a separate classifier head class that consists of this linear projection and post-processing step in the forward? Then you can compose these together directly in the TransformerDecoder constructor in mistral_classifier:

TransformerDecoder(
    tok_embeddings=tok_embeddings,
    layer=layer,
    num_layers=num_layers,
    max_seq_len=max_seq_len,
    num_heads=num_heads,
    head_dim=head_dim,
    norm=norm,
    output=ClassificationHead(embed_dim, num_classes),
)

This still keeps the flexibility of swapping out the classification head with some other head, and you keep all the added logic contained in a new class. Your testing will be simpler too because you only need to test the head and not the entire transformer classifier.

Also, please make sure you've run the linters with pre-commit run --all-files :)

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Apr 22, 2024

Thanks @RdoubleA :) I've fixed the linting.

I tried setting it up this way, but in the forward for TransformerClassifier I also use the input token ids to grab the last non-padding token in each sequence of the batch:

        padding_mask = tokens == 0
        if padding_mask.any():
            sequence_lengths = (
                (padding_mask.logical_not().cumsum(-1) == 1).sum(-1).to(logits.device)
            )
        else:
            sequence_lengths = -1

I thought I'd have to modify the function call in TransformerDecoder's forward:

        h = self.norm(h)

        output = self.output(h).float() # pass input tokens in here
        return output

to achieve this - so I wrapped around TransformerDecoder instead.

I agree your suggestion (and @ebsmothers on Discord) is cleaner. Do you have any thoughts on how I'd still be able to use the input tokens in the output callable?

@ebsmothers
Copy link
Contributor

I tried setting it up this way, but in the forward for TransformerClassifier I also use the input token ids to grab the last non-padding token in each sequence of the batch:

@SalmanMohammadi ah I wasn't aware of this in my suggestion on Discord. In that case it is trickier cause you are actually changing the signature of the output layer. Then at a high level I think it makes sense to add a separate module to handle taking the last token. I'll take a closer look at the exact implementation now.

num_heads=num_heads,
head_dim=head_dim,
norm=RMSNorm(embed_dim, eps=norm_eps),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like pre-commit hooks may not be running for your PR locally, you may need to run pre-commit install. Then you can re-run on all the files you've already modified via pre-commit run --all-files

@@ -111,7 +111,7 @@ class TransformerDecoder(nn.Module):
to setup the :func:`~torchtune.modules.KVCache`
norm (nn.Module): Callable that applies normalization to the output of the decoder,
before final MLP.
output (nn.Linear): Callable that applies a linear transformation to the output of
output (nn.Module): Callable that applies a linear transformation to the output of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically no longer need this change, right? If so I would leave it as-is for now just to be explicit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping this comment, let's switch back to nn.Linear here unless there's a reason not to

output = self.output(h).float()
return output


class TransformerClassifier(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK so if we are writing a standalone module I think this is the way to do it. The one question I would ask here is if we need a standalone module at all. To explain my thought process here a bit:

The only real difference between this class and an appropriately-parametrized builder of TransformerDecoder is the slicing you do to get just the EOS embedding (i.e. what you currently have in L282-296). But this is basically just "find the last non-padded index in the tokens and slice the logits based on that". In that case it may actually be conceptually simpler to do this in the training script. There you already know the padding index and it is not so different than our existing logic for e.g. ignore_index in our cross-entropy loss calculation. This way you have to pass the padding_index down into the model (in this case it is hardcoded, but in general I imagine we'd want it to be parametrized). If we handle this in the recipe, then you don't need to worry about this. You can also see in our PPO recipe where we do something similar in a utility method with the cross entropy ignore index.

So I think the question comes down to what the full training script will look like. If I understand correctly, the outputs of this model will be used for training the reward model on a paired dataset, is that right? Do you think slicing to the EOS token makes more sense in the reward model training loop or in the model implementation? Admittedly it's kind of a philosophical question without a correct answer, but I bias towards doing it in the recipe because pad ID is a tokenizer property and the model should not have to be aware of it. But please let me know your thoughts here!

Copy link
Collaborator Author

@SalmanMohammadi SalmanMohammadi Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for raising this. I think you make a good point. I think my main consideration would be (also philosophical) how explicit we want to be about that step. Generally speaking, any Seq2Label model will follow the step of classifying the last non-padding token, and produce a [b, n_classes] Tensor. Without the additional step of sequence pooling, a mistral_classifier() will return a [b, s, n_classes] Tensor.

You're right about what the training script will look like. Each recipe will need to replicate this step if we didn't use a dedicated module. If I was relatively new to all this, with good documentation I could understand by reading the recipe that you need to do some kind of pooling to convert [b, s, n_classes], to [b, n_classes], though it may not be clear that this is a common step that will occur after any sequence classification recipe.

Looking forward, we'll want to pull the token padding ID directly from the tokenizer. We have a reference to this in the recipe. I think this is the most significant reason for me to agree that we should lose the separate module and just construct the classifier in the component builder using an output layer, if we're okay with a _classifier model not technically doing the complete classification (sans activation function) i.e. the model builder returns a kind of Seq2Seq model.

Let me know your thoughts : )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK actually I am gonna raise an entirely orthogonal point here (sorry 😅). But during your description of dimensions it dawned on me that we are not necessarily doing things in the most efficient way here. We are applying the projection to the full sequence, when in reality we only need to do the projection on the EOS position. So rather than project and then slice we could slice and then project. This would decrease the size of the intermediate by a factor of s. I think the usefulness of this depends on how large the overall tensor is, if e.g. n_classes is really small then the intermediate may be small relative to other tensors in the model and the memory savings would not be as substantial. It also makes fitting into the existing class even thornier, since we would then be adding another operation before TransformerDecoder's output projection.

Assuming that we don't follow the new wrench I threw in things, I think I am OK with either of the two approaches. Ultimately we won't know if we bump up into a wall until we integrate into a recipe, so feel free to make a call; then as we get closer to the actual training loop we can see how it works. How does that sound?

Copy link
Collaborator Author

@SalmanMohammadi SalmanMohammadi Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a good point to raise the extra memory usage due to nn.Linear taking [b, s, d] vs. [b, 1, d]. This scales with sequence length.

I think it gets kind of tricky if we want to add this operation in before the TransformerDecoder output projection, like you said. One way could be to just say output=nn.Identity(), and then slice the sequence before passing it into a nn.Linear defined as part of TransformerClassifier. We'd still need access to the padding token ID to do this, and I think that's the biggest issue.

I think it makes sense to go with the implementation using a utility method, so I'll update this PR. If you have strong feelings towards my above suggestion for saving memory let me know. I'd put the utility function in torchtune/utils since it's not unique to any specific recipe. Like you said, this is something we'll have a clearer idea about when we start adding recipes for downstream use cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely agree with all of this. Let’s not hack in changes for the memory savings now, I feel it makes most sense to start from a place of clean code and then see whether optimizations can/should be made after that. The utility method sounds good too, will take another pass once you update. Thanks for the patience on this discussion!

set_seed(16)


class TestMistralClassifier:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is looking great!

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Apr 23, 2024

Note, the type hints are wrong for the component and model builders (the docs are correct). I think you're right that my pre-commits still aren't working right. I'll fix those and address any comments in a bit.

…utility function for pooling sequence and tests
@SalmanMohammadi
Copy link
Collaborator Author

I've updated my PR @ebsmothers with the changes we discussed :)

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more comments but overall this is looking pretty good! Special thanks for the great work hitting a high bar on testing and documentation here.

torchtune/utils/pooling.py Outdated Show resolved Hide resolved
Comment on lines 48 to 51
# TODO find the actual value
# expected = torch.tensor(3.9763)
assert actual.shape == (BSZ, SEQ_LEN, NUM_CLASSES)
# torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update this prior to merging. Let me know if you need any assistance with the parity checks here

Copy link
Collaborator Author

@SalmanMohammadi SalmanMohammadi Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could use some guidance here :) Could you point to an existing test in the repo you think is a good example? I'm still working on the base mistral tests in another branch, but there's a lot to do there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Note: this is a long comment, but for the actual practical suggestion for this particular PR just skip to the end)

The parity checks can be tricky, it also depends on the model and how readily available a solid reference implementation is. I will use our test_lora.py as a simple example here. We typically:

  1. Find a reference implementation, and either copy-paste it or import it into a runnable Python script. We have a bunch of these under tests/torchtune/models/llama2/scripts (admittedly not the best location, but they are not really used and more meant as references). Here's the copy-pasted LoRA layer from the original Microsoft repo.
  2. In the same script, import our implementation. Then
    a) set the random seed for reproducibility, create some dummy input tensor for the model, and initialize the weights of our implementation in a deterministic way. This is done here using our test utility fixed_init_model
    b) Load the weights we initialized for our model into the reference implementation. This usually involves remapping of state dicts (sometimes just a key mapping, sometimes some tensor reshaping). This part is usually fairly annoying, a bit of guess-and-check, and can actually get quite hairy for full models with nontrivial differences in how parameters are represented. Example for the LoRA comparison
    c) Finally call forward on both models and check that the outputs line up. Then take some summary stat (like mean) to use as the expected value in the test. LoRA example
  3. Then in the unit test we use the value from (2) as our expected value (in the LoRA test). We set the same seed and do the same initialization so that we get the same result from the comparison script.

If you've made it this far, you can see it's actually a fair amount of work to do things this way. We try to hit a really high bar on testing for correctness and have found this is the best way to do it. That being said, going through all these steps is definitely not a prerequisite for landing this change. We still need a proper comparison script for Mistral (though it was tested locally prior to landing) and shouldn't block your PR on that.

Instead for testing this change, I would propose: assume the existing Mistral implementation in the repo is correct. Then the delta for your changes in terms of the Mistral builder is really just the change in output layer. Then the much-abbreviated version of the above is: (a) import an existing Mistral builder and hardcode the output projection to nn.Linear(embed_dim, n_classes). Do the same initialization and run forward on that model to get the expected value. Then use this value in your unit test (no state dict remapping needed).

We can tackle the full testing for Mistral I outlined above in a separate PR, but definitely don't want to block this particular change on that. Let me know if this makes sense to you, and if you are still interested in adding the Mistral test following this approach in a separate PR I am happy to provide (even) more detailed guidance when needed.

Copy link
Collaborator Author

@SalmanMohammadi SalmanMohammadi Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much - this is a super high quality explanation. I think some of it I grokked from the code but 2.b would've been tricky and sounds like it can be a pain in the ass. I've started this process with just the attention layer for mistral so far, I'll put a draft PR up to get your thoughts on this early on, soon.

I think your suggestion for testing the classifier sounds sensible. It's a pretty general test to make sure output_layer does what we want it to do - just a linear projection from the final hidden weights to some arbitrary dimension.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've taken a go at writing a test following test_lora.py. Thoughts appreciated.

@@ -404,3 +404,212 @@ def lora_mistral_mlp(
down_proj=down_proj,
up_proj=up_proj,
)


def mistral_classifier(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another path we could consider here is providing an optional output dim argument in the existing Mistral builders and if it's not passed just set to vocab_size. I hope it's not too unintuitive, but could help get around your previous point regarding "classifier" now being a bit of a misnomer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you're meaning. My only point would be that the implementation I have now minimises changes to the base mistral models. I think keeping it modular for now while we're still experimenting with how this model is going to be used downstream would be easiest, and we can try out potentially cleaner paths once things are more concrete.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep this sounds good to me!

torchtune/models/mistral/_model_builders.py Outdated Show resolved Hide resolved
Comment on lines 33 to 37
# padding token needs to be defined to pool multiple sequences
if padding_token_idx is None and batch_size != 1:
raise ValueError("padding_token_idx must be set if batch_size > 1")
if padding_token_idx is None:
sequence_lengths = -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part I don't fully understand. Is the assumption here that for a batch size of 1 there will be no padding? This may hold in the current state of the world since we pad to the max length in the batch, but for general padding schemes (e.g. pad to fixed length) this may not hold. Personally I would just err on the side of caution and make padding_token_idx a required arg, then you don't have to branch based on batch_size (unless there's an important use case you're solving for here).

Copy link
Collaborator Author

@SalmanMohammadi SalmanMohammadi Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assumption is the other way round - if there's no padding, there must be a batch size of 1. Otherwise, you wouldn't have a way to grab the last token from multiple sequences in a batch with variable lengths.

If you do have a batch size of 1, the function is agnostic to this, it'll just give you the last token if you don't pass in a padding token, but it'll give you the last non-padding token otherwise (the default behaviour). This is how HF implements it, I'll attribute it in the function for clarity. My intuition here is that the use case is for inference, potentially using a tokenizer where the padding token isn't defined, or you just generally know you have a full sequence/remove padding on your own side.

Removing the branch for batch size wouldn't actually change how the sequence(s) are pooled, but prevents the failure case. Allowing the function to be used without a padding token just gives a little flexibility.

Let me know if this makes sense. Open to your thoughts :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK this generally makes sense. My main question is whether we will actually ever hit this function during inference; if its purpose is just to extract the EOS from an existing sequence it seems like we shouldn't, right? And yes attribution to the HF implementation would be great here

Copy link
Collaborator Author

@SalmanMohammadi SalmanMohammadi Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I thought about this for a bit and I think I'm with you on making it required. I think having to handle this error wherever its used is more annoying than just supporting a very niche usecase.

if its purpose is just to extract the EOS from an existing sequence

My usecase for performing inference was: run your trained _classifier model to obtain [b, s, n], then having to call pool_sequence_logits to get [b, n], so following this method for generation you'd hit it. I think that you're probably not doing this process without access to a tokenizer, and unlikely for that tokenizer not using a padding token.

tests/torchtune/utils/test_pooling.py Show resolved Hide resolved
@@ -0,0 +1,188 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for adding this! You really went above and beyond here. Actually one more change request here (this one should be much easier, I promise). Can you add back the previous unit test you deleted in the latest commit? I think my previous wall of text could've been a bit clearer, but the main point I was getting at was that the values from this script can then be used to set the expected values in the unit test.

The idea being that then we will actually run the unit test in our CI, whereas this script is more of a nice-to-have so we can understand the initial testing that went into setting the unit test values. In case it helps, I ran your script locally myself: here's the output I got so that you don't have to run it again. And here is roughly what I think the unit test will look like now (hopefully not too many bugs in there 😅). Again, sorry for the back and forth on this!

Copy link
Collaborator Author

@SalmanMohammadi SalmanMohammadi Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get it now :)
edit: ty for providing the unit test!

Comment on lines 30 to 32
if padding_token_idx is None:
sequence_lengths = -1
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that padding_token_idx is required we can remove this, right?

@SalmanMohammadi SalmanMohammadi mentioned this pull request Apr 26, 2024
4 tasks
@ebsmothers
Copy link
Contributor

OK the changes look good, but the unit test is failing on CI. Does it pass for you locally? I can also help a bit here with debugging if you need

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Apr 29, 2024

Yeah it's passing locally. It looks like it's failing the second test case. Are there more detailed logs? I sometimes fail some of the tests on my mac because the precision tolerance isn't the same on my machine for some reason (i.e. the numbers look right, but the errors are just slightly above the tolerance).

@ebsmothers
Copy link
Contributor

OK I think I cracked the case here. The lack of detailed logs can often be indication that the runner actually just crashed (can be due to out of memory or something like that). It crashed on the 2nd test case, which is the largest one. You can either reduce the batch size or even just remove that test case, since it's not actually testing fundamentally different logic than the other cases. I tested on my fork with a batch size of 4 and confirmed that the CI succeeds (but tbh I'd just scrap the test case for the reason I mentioned).

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Apr 30, 2024

Thanks so much for your help debugging :) I'll keep that in mind for the future!

@@ -129,7 +129,7 @@ def __init__(
num_heads: int,
head_dim: int,
norm: nn.Module,
output: nn.Linear,
output: nn.Module,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change this back to nn.Linear? I don't think we need the general nn.Module in your latest version (lmk if I'm missing something here though)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SalmanMohammadi sorry one more thing on this.. can you also change the corresponding docstring back to nn.Linear too? Then I promise we are good to merge 😄

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my bad! I had a thought that the pre-commits sometimes catch things like this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm also confused why it wasn't caught tbh..

@@ -206,6 +206,5 @@ def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
# shape: [b, s, d]
h = self.norm(h)

# shape: [b, s, v]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would still keep some version of shape comment in just cause it's generally useful for people. could just be # shape: [b, s, out_dim] or something like that

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two more nits, otherwise I think this is good to merge!

@SalmanMohammadi
Copy link
Collaborator Author

Good catch!

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks for your patience on the review!

@ebsmothers ebsmothers merged commit f819b4b into pytorch:main Apr 30, 2024
27 checks passed
@SalmanMohammadi
Copy link
Collaborator Author

Thank you for your patience!!

@SalmanMohammadi SalmanMohammadi deleted the transformer_classifier branch July 20, 2024 22:02
@SalmanMohammadi SalmanMohammadi mentioned this pull request Jul 29, 2024
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants