diff --git a/README.md b/README.md
index d010012..59fbe33 100644
--- a/README.md
+++ b/README.md
@@ -23,11 +23,11 @@ It can be used to:
- layer warning labels: over-trained; under-trained
-## Quick Links
+## Quick Links
- Please see [our latest talk from the Sillicon Valley ACM meetup](https://www.youtube.com/watch?v=Tnafo6JVoJs)
-- Join the [Discord Server](https://discord.gg/uVVsEAcfyF)
+- Join the [Discord Server](https://discord.gg/uVVsEAcfyF)
- For a deeper dive into the theory, see [our latest talk at ENS](https://youtu.be/xEuBwBj_Ov4)
@@ -84,7 +84,7 @@ and `summary` dictionary of generalization metrics
'mp_softrank': 0.52}
```
-## Advanced Usage
+## Advanced Usage
The `watcher` object has several functions and analysis features described below
@@ -109,13 +109,13 @@ watcher.distances(model_1, model_2)
To analyze an PEFT / LORA fine-tuned model, specify the peft option.
- peft = True: Forms the BA low rank matric and analyzes the delta layers, with 'lora_BA" tag in name
-
+
```details = watcher.analyze(peft='peft_only')```
- - peft = 'with_base': Analyes the base_model, the delta, and the combined layer weight matrices.
-
+ - peft = 'with_base': Analyes the base_model, the delta, and the combined layer weight matrices.
+
```details = watcher.analyze(peft=True)```
-
+
The base_model and fine-tuned model must have the same layer names. And weightwatcher will ignore layers that do not share the same name.
Also,at this point, biases are not considered. Finally, both models should be stored in the same format (i.e safetensors)
@@ -146,9 +146,9 @@ Visually, the ESD looks like a straight line on a log-log plot (above left).
The goal of the WeightWatcher project is find generalization metrics that most accurately reflect observed test accuracies, across many different models and architectures, for pre-trained models and models undergoing training.
-
+
-
+
[Our HTSR theory](https://jmlr.org/papers/volume22/20-410/20-410.pdf) says that well trained, well correlated layers should be signficantly different from the MP (Marchenko-Pastur) random bulk, and specifically to be heavy tailed. There are different layer metrics in WeightWatcher for this, including:
@@ -159,20 +159,20 @@ The goal of the WeightWatcher project is find generalization metrics that most a
- `num_spikes` : the number of spikes outside the MP bulk region
- `max_rand_eval` : scale of the random noise etc
-All of these attempt to measure how on-random and/or non-heavy-tailed the layer ESDs are.
+All of these attempt to measure how on-random and/or non-heavy-tailed the layer ESDs are.
-#### Scale Metrics
+#### Scale Metrics
- log Frobenius norm :
- `log_spectral_norm` :
- `stable_rank` :
- `mp_softrank` :
-
+
#### Shape Metrics
- - `alpha` :
Power Law (PL) exponent
+ - `alpha` :
Power Law (PL) exponent
- (Truncated) PL quality of fit `D` :
(the Kolmogorov Smirnov Distance metric)
@@ -183,13 +183,13 @@ All of these attempt to measure how on-random and/or non-heavy-tailed the layer
- E_TPL : (alpha and Lambda) Extended Truncated Power Law Fit
-
+
#### Scale-adjusted Shape Metrics
- `alpha_weighted` :
- `log_alpha_norm` : (Shatten norm):
-#### Direct Correlation Metrics
+#### Direct Correlation Metrics
The random distance metric is a new, non-parameteric approach that appears to work well in early testing.
[See this recent blog post](https://calculatedcontent.com/2021/10/17/fantastic-measures-of-generalization-that-actually-work-part-1/)
@@ -209,7 +209,7 @@ There re also related metrics, including the new
- `max_rand_eval` : scale of the random noise in the layer
-#### Summary Statistics:
+#### Summary Statistics:
The layer metrics are averaged in the **summary** statistics:
Get the average metrics, as a `summary` (dict), from the given (or current) `details` dataframe
@@ -233,8 +233,8 @@ The summary statistics can be used to gauge the test error of a series of pre/tr
#### Predicting the Generalization Error
-WeightWatcher (WW) can be used to compare the test error for a series of models, trained on the similar dataset, but with different hyperparameters **θ**, or even different but related architectures.
-
+WeightWatcher (WW) can be used to compare the test error for a series of models, trained on the similar dataset, but with different hyperparameters **θ**, or even different but related architectures.
+
Our Theory of HT-SR predicts that models with smaller PL exponents `alpha`, on average, correspond to models that generalize better.
Here is an example of the `alpha_weighted` capacity metric for all the current pretrained VGG models.
@@ -242,7 +242,7 @@ Here is an example of the `alpha_weighted` capacity metric for all the current p
Notice: we *did not peek* at the ImageNet test data to build this plot.
-
+
This can be reproduced with the Examples Notebooks for [VGG](https://github.com/CalculatedContent/WeightWatcher/blob/master/examples/WW-VGG.ipynb) and also for [ResNet](https://github.com/CalculatedContent/WeightWatcher/blob/master/examples/WW-ResNet.ipynb)
randomize option lets you compare the ESD of the layer weight m
This is good way to visualize the correlations in the true ESD, and detect signatures of over- and under-fitting
-
+
```python
details = watcher.analyze(randomize=True, plot=True)
```
Fig (a) is well trained; Fig (b) may be over-fit.
-
-That orange spike on the far right is the tell-tale clue; it's caled a **Correlation Trap**.
+
+That orange spike on the far right is the tell-tale clue; it's caled a **Correlation Trap**.
A **Correlation Trap** is characterized by Fig (b); here the actual (green) and random (red) ESDs look almost identical, except for a small shelf of correlation (just right of 0). And random (red) ESD, the largest eigenvalue (orange) is far to the right of and seperated from the bulk of the ESD.
-
+

-
+
When layers look like Figure (b) above, then they have not been trained properly because they look almost random, with only a little bit of information present. And the information the layer learned may even be spurious.
-
+
Moreover, the metric `num_rand_spikes` (in the `details` dataframe) contains the number of spikes (or traps) that appear in the layer.
-The `SVDSharpness` transform can be used to remove Correlation Traps during training (after each epoch) or after training using
-
+The `SVDSharpness` transform can be used to remove Correlation Traps during training (after each epoch) or after training using
+
```python
sharpemed_model = watcher.SVDSharpness(model=...)
```
-
+
Sharpening a model is similar to clipping the layer weight matrices, but uses Random Matrix Theory to do this in a more principle way than simple clipping.
-
+
### Early Stopping
@@ -293,7 +293,7 @@ Sharpening a model is similar to clipping the layer weight matrices, but uses Ra