Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Losses.md docs #658

Merged
merged 7 commits into from
Jan 29, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
lint
shaydeci committed Jan 29, 2023
commit e6cc16d897d78a095f5fce559fa70957a8ddd783
179 changes: 125 additions & 54 deletions documentation/assets/Losses.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Losses in SG

SuperGradients provides multiple Loss function implementations for various tasks:
SuperGradients can support any PyTorch-based loss function. Additionally, multiple Loss function implementations for various tasks are also supported:

cross_entropy
mse
@@ -15,30 +15,13 @@ SuperGradients provides multiple Loss function implementations for various tasks
kd_loss
dice_ce_edge_loss

All of the above, are just string aliases for the underlying torch.nn.Module classes, implementing the specified loss functions.
All the above, are just string aliases for the underlying torch.nn.Module classes, implementing the specified loss functions.

##Basic Usage of Implemented Loss Functions in SG:
## Basic Usage of Implemented Loss Functions in SG:

When using configuration files, for example training using train_from_recipe (or similar, when the underlying train method that is being called is Trainer.train_from_config(...)):


In your `my_training_hyperparams.yaml` file:
```yaml
...
...

loss: yolox_loss

criterion_params:
strides: [8, 16, 32] # output strides of all yolo outputs
num_classes: 80
```
`criterion_params` dictionary will be unpacked to the underlying `yolox_loss` class constructor.
The most basic use case is when using a direct Trainer.train(...) call:


Another usage case, is when using a direct Trainer.train(...) call:
- In your `my_training_script.py`
In your `my_training_script.py`:
```python
...
trainer = Trainer("external_criterion_test")
@@ -64,7 +47,57 @@ Another usage case, is when using a direct Trainer.train(...) call:
Losses.CROSS_ENTROPY
```


Another use case is when using configuration files. For example, when training using train_from_recipe (or similar, when the underlying train method that is being called is Trainer.train_from_config(...)).

When doing so, in your `my_training_hyperparams.yaml` file:
```yaml
...
...

loss: yolox_loss

criterion_params:
strides: [8, 16, 32] # output strides of all yolo outputs
num_classes: 80
```
`criterion_params` dictionary will be unpacked to the underlying `yolox_loss` class constructor.

## Passing Instantiated nn.Module Objects as Loss Functions

SuperGradients also supports passing instantiated nn.Module Objects as demonstrated below:
When using a direct Trainer.train(...) call, in your `my_training_script.py` simply pass the instantiated nn.Module under the "loss" key inside training_params:
```python
...
trainer = Trainer("external_criterion_test")
train_dataloader = ...
valid_dataloader = ...
model = ...
train_params = {
...
"loss": torch.nn.CrossEntropy()
...
}
trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
```
Though not as convenient as using `register_loss` (discussed further into detail in the next sub-section), one can also equivalently instantiate objects when using train_from_recipe (or similar, when the underlying train method is Trainer.train_from_config(...) as demonstrated below:


In your `my_training_hyperparams.yaml` file:
```yaml
...
...
loss:
_target_: torch.nn.CrossEntropy
```
Note that when passing an instantiated loss object, `criterion_params` will be ignored.


## Using Your Own Loss

SuperGradients also supports user-defined loss functions assuming they are torch.nn.Module inheritors, and that their `forward` signature is in the form:

```python
@@ -76,9 +109,77 @@ forward(preds, target):
...
```
And as the argument names suggest- the first argument is the model's output, and target is the label/ground truth (argument naming is arbitrary and does not need to be specifically 'preds' or 'target').
And as the argument names suggest, the first argument is the model's output, and target is the label/ground truth (argument naming is arbitrary and does not need to be specifically 'preds' or 'target').
Loss functions accepting additional arguments in their `forward` method will be supported in the future.
When using configuration files, for example training using train_from_recipe (or similar, when the underlying train method that is being called is Trainer.train_from_config(...)), In your ``my_loss.py``, register your loss class by decorating the class with `register_loss`:

### Using Your Own Loss- Logging Loss Outputs

In the most common case, where the loss function returns a single item for backprop the loss output will appear in
the logs, training logs (i.e Tensorboards and any other supported SGLogger, for more information on SGLoggers click [here](https://github.com/Deci-AI/super-gradients)), over epochs under <LOSS_CLASS.__name__>.

forward(...) should return a (loss, loss_items) tuple where loss is the tensor used
for backprop (i.e what your original loss function returns), and loss_items should be a tensor of
shape (n_items) consisting of values computed during the forward pass which we desire to log over the
entire epoch. For example- the loss itself should always be logged. Another example is a scenario
where the computed loss is the sum of a few components we would like to log.

For example:
```python
class MyLoss(_Loss):
...
def forward(self, inputs, targets):
...
total_loss = comp1 + comp2
loss_items = torch.cat((total_loss.unsqueeze(0),comp1.unsqueeze(0), comp2.unsqueeze(0)).detach()
return total_loss, loss_items
...
Trainer.train(...
train_params={"loss":MyLoss(),
...
"metric_to_watch": "MyLoss2/loss_0"}
```


The above snippet will log `MyLoss2/loss_0`, `MyLoss2/loss_1` and `MyLoss2/loss_2` as they have been named by their positional index in loss_items.
Note we also defined "MyLoss2/loss_0" to be our watched metric which means we save our checkpoint every epoch we reach the best loss score.

For more visibility, you can also set a "component_names" property in the loss class,
to be a list of strings, of length n_items whose ith element is the name of the ith entry in loss_items.
Then each item will be logged, rendered on the tensorboard, and "watched" (i.e saving model checkpoints
according to it) under `<LOSS_CLASS.__name__>/<COMPONENT_NAME>`.

For example:
```python
class MyLoss(_Loss):
...
def forward(self, inputs, targets):
...
total_loss = comp1 + comp2
loss_items = torch.cat((total_loss.unsqueeze(0),comp1.unsqueeze(0), comp2.unsqueeze(0)).detach()
return total_loss, loss_items
...
@property
def component_names(self):
return ["total_loss", "my_1st_component", "my_2nd_component"]
Trainer.train(...
train_params={"loss":MyLoss(),
...
"metric_to_watch": "MyLoss/my_1st_component"}
```


The above code will log and monitor `MyLoss/total_loss`, `MyLoss/my_1st_component` and `MyLoss/my_2nd_component`.


Since running logs will save the loss_items in some internal state, it is recommended to
detach loss_items from their computational graph for memory efficiency.

### Using Your Own Loss- Training with Configuration Files

When using configuration files, for example, training using train_from_recipe (or similar, when the underlying train method that is being called is Trainer.train_from_config(...)), In your ``my_loss.py``, register your loss class by decorating the class with `register_loss`:
```python
import torch.nn
from super_gradients.common.registry import register_loss
@@ -123,33 +224,3 @@ Last, in your ``my_train_from_recipe_script.py`` file, just import the newly reg
```
## Passing Instantiated nn.Module Objects as Loss Functions

SuperGradients also supports passing instantiated nn.Module Objects as demonstrated below:
When using a direct Trainer.train(...) call, in your `my_training_script.py` simply pass the instantiated nn.Module under the "loss" key inside training_params:
```python
...
trainer = Trainer("external_criterion_test")
train_dataloader = ...
valid_dataloader = ...
model = ...
train_params = {
...
"loss": torch.nn.CrossEntropy()
...
}
trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
```
Though not as convenient as using `register_loss`, one can also equivalently instantiate objects when using train_from_recipe (or similar, when the underlying train method is Trainer.train_from_config(...) as demonstrated below:


In your `my_training_hyperparams.yaml` file:
```yaml
...
...
loss:
_target_: torch.nn.CrossEntropy
```
Note that when passing an instantiated loss object, `criterion_params` will be ignored.