-
Notifications
You must be signed in to change notification settings - Fork 366
Training from scratch #55
Comments
Did you checked the paper's appendix? I guess that you could find more intuition with respect to things specially linear probing. I was having a similar issue in which loss wasn't decreasing, then I realized that I was initializing the optimizer with the model parameters before adding the linear head to it, therefore the parameters related to classification wasn't getting accounted into the optimizer. That solved I'm still struggling to train it in a supervised fashion. I would suggest that we try to unite into a communication channel such as discord or something to share progress about this stuff. |
Hey @FalsoMoralista thanks for the comment! Sure, let's take this to discord. My handle is johnweak15. Do add me there! |
You all able to solve this? - Brett |
@bdytx5 yes we did. @lazarosgogos was also able to conduct some insightful experiments with it as well. What do you wanted to know specifically? |
Well, I tried training with IJEPA on cifar10 and then using the pretrained model to fine tune on cifar10, using the labels during fine tuning (with just the target encoder). I compared the fine tuning to randomly initialized model, and the results seemed to be the same. I averaged the output embeddings of the last layer. Does this seem strange? Note I just used the tiny_vit. |
Curious! For how many epochs did you pre-trained over cifar10? |
Around 10 or so, as after that train and validation loss began to rise. |
If you've left the config file untouched, there is most likely a warmup period of eg. 40 epochs out of the 300 in total. The loss going up after some epochs (depends on your configuration and total number of epochs) is a normal behavior as mentioned in #41. Try letting your model train for at lest 50-60 epochs (with appropriate changes in the configuration) and then try a downstream task. In the early epochs the model doesn't learn semantic representations of the data, even though the loss seems goes down (I've tested this personally). Once I get to test the ViT-tiny and ViT-small models, I will get back with the differences. |
Ah, I overlooked this. Good catch |
@lazarosgogos Did you get to test ViT-small models? It'd be really helpful if you could share the working configuration for those. Thanks. |
@akshayneema It heavily depends on what type of resources (e.g. GPUs) you have at hand. The more VRAM you have, the bigger the model you can load. The bigger the images you use, the larger the VRAM you'll need For example,to train on ImageNet's images, with a ViT-small model, on 16GB VRAM, I was able to load at most a batch of 60 images per iteration (the rest of the config was untouched) |
Thanks for the reply @lazarosgogos Can you also share how were the results like for you using ViT-small? Were the results competitive with ViT-H or ViT-G? Did you also change the predictor model architecture to suit ViT-small architecture? I am currently using 1 GeForce RTX 3090 GPU training with 32 batch size. I am using UMAP to visualise the embeddings generated by the target-encoder and it does not look that great. |
@akshayneema The results using ViT-small were not competitive with ViT-huge or ViT-Giant, not even close. The difference in some linear probing tasks was immense (>30%). The point of using ViT-small or ViT-base is mostly, in my opinion, to run tests and see how the model performs, in order to then train a ViT-Huge for final results. I did not touch the architecture of the predictor when testing how ViT-small behaves. Batch size plays a role in training as well, keep that in mind. |
Hey everyone!
First off, thanks for the great work. I implemented my own version of I-JEPA (https://github.com/Ugenteraan/I-JEPA) by referencing to this repository.
I used the Doges 77 Breeds (https://www.kaggle.com/datasets/madibokishev/doges-77-breeds) dataset for the training. The loss goes down in a convincing manner during the SSL training. However during the downstream, when I load the pre-trained weights from the encoder and use probing, the accuracy is no better than a randomly initialized encoder weights.
Does anyone have a clue on what might have been the cause of this?
Thanks in advance! Cheers.
The text was updated successfully, but these errors were encountered: