Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-2928: Unify loss reduction #2933

Merged
merged 3 commits into from
Sep 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flair/models/lemmatizer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def _calculate_loss(self, scores, labels):

return self.loss(scores_in_correct_format, target), len(labels)

def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> torch.Tensor:
def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]:
scores, labels = self.forward_pass(sentences)

return self._calculate_loss(scores, labels)
Expand Down
2 changes: 1 addition & 1 deletion flair/models/multitask_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def forward(self, *args) -> torch.Tensor:
def _prepare_tensors(self, data_points: List[DT]) -> Tuple[torch.Tensor, ...]:
raise NotImplementedError("`_prepare_tensors` is not used for multitask learning")

def forward_loss(self, sentences: Union[List[Sentence], Sentence]):
def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]:
"""
Abstract forward loss implementation of flair.nn.Model's interface.
Calls the respective forward loss of each model.
Expand Down
2 changes: 1 addition & 1 deletion flair/models/sequence_tagger_utils/viterbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def forward(self, features_tuple: tuple, targets: torch.Tensor) -> torch.Tensor:
:param features_tuple: CRF scores from forward method in shape (batch size, seq len, tagset size, tagset size),
lengths of sentences in batch, transitions from CRF
:param targets: true tags for sentences which will be converted to matrix indices.
:return: average Viterbi Loss over batch size
:return: summed Viterbi Loss over all data points
"""
features, lengths, transitions = features_tuple

Expand Down
8 changes: 3 additions & 5 deletions flair/models/tars_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,16 @@ def __init__(self):

super(FewshotClassifier, self).__init__()

def forward_loss(
self, data_points: Union[List[Sentence], Sentence]
) -> Union[torch.Tensor, Tuple[torch.Tensor, int]]:
def forward_loss(self, data_points: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]:

if not isinstance(data_points, list):
data_points = [data_points]

# Transform input data into TARS format
sentences = self._get_tars_formatted_sentences(data_points)

loss = self.tars_model.forward_loss(sentences)
return loss
loss, count = self.tars_model.forward_loss(sentences)
return loss, count

def _prepare_tensors(self, data_points: List[Sentence]) -> Tuple[torch.Tensor, ...]:
return self.tars_model._prepare_tensors(data_points)
Expand Down
6 changes: 3 additions & 3 deletions flair/models/text_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(

nn.init.xavier_uniform_(self.decoder.weight)

self.loss_function = nn.MSELoss()
self.loss_function = nn.MSELoss(reduction="sum")

# auto-spawn on GPU if available
self.to(flair.device)
Expand All @@ -53,13 +53,13 @@ def forward(self, *args: torch.Tensor) -> torch.Tensor:
label_scores = self.decoder(text_embedding_tensor)
return label_scores

def forward_loss(self, sentences: List[Sentence]) -> torch.Tensor:
def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]:

labels = self._labels_to_tensor(sentences)
text_embedding_tensor = self._prepare_tensors(sentences)
scores = self.forward(*text_embedding_tensor)

return self.loss_function(scores.squeeze(1), labels)
return self.loss_function(scores.squeeze(1), labels), len(sentences)

def _labels_to_tensor(self, sentences: List[Sentence]):
indices = [
Expand Down
4 changes: 2 additions & 2 deletions flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _prepare_tensors(self, data_points: List[DT]) -> Tuple[torch.Tensor, ...]:
raise NotImplementedError

@abstractmethod
def forward_loss(self, data_points: List[DT]) -> Union[torch.Tensor, Tuple[torch.Tensor, int]]:
def forward_loss(self, data_points: List[DT]) -> Tuple[torch.Tensor, int]:
"""Performs a forward pass and returns a loss tensor for backpropagation.
Implement this to enable training."""
raise NotImplementedError
Expand Down Expand Up @@ -575,7 +575,7 @@ def __init__(
self.gradient_reversal = RevGrad()

if self.multi_label:
self.loss_function: _Loss = torch.nn.BCEWithLogitsLoss(weight=self.loss_weights)
self.loss_function: _Loss = torch.nn.BCEWithLogitsLoss(weight=self.loss_weights, reduction="sum")
else:
self.loss_function = torch.nn.CrossEntropyLoss(weight=self.loss_weights, reduction="sum")
self.train_on_gold_pairs_only = train_on_gold_pairs_only
Expand Down
11 changes: 3 additions & 8 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,8 @@ def train(
for batch_step in batch_steps:

# forward pass
loss = self.model.forward_loss(batch_step)

if isinstance(loss, tuple):
average_over += loss[1]
loss = loss[0]
loss, datapoint_count = self.model.forward_loss(batch_step)
average_over += datapoint_count

# Backward
if use_amp:
Expand Down Expand Up @@ -1022,9 +1019,7 @@ def find_learning_rate(
step += 1

# forward pass
loss = self.model.forward_loss(batch)
if isinstance(loss, tuple):
loss = loss[0]
loss, datapoint_count = self.model.forward_loss(batch)

# update optimizer and scheduler
optimizer.zero_grad()
Expand Down