Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-3097: fix multitask model training #3101

Merged
merged 6 commits into from
Feb 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flair/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def embed(self, data_points: Union[DT, List[DT]]) -> List[DT]:
if not isinstance(data_points, list):
data_points = [data_points]

if not self._everything_embedded(data_points) or not self.static_embeddings:
if not self._everything_embedded(data_points):
self._add_embeddings_internal(data_points)

return data_points
Expand Down
64 changes: 52 additions & 12 deletions flair/models/multitask_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
models: List[flair.nn.Classifier],
task_ids: Optional[List[str]] = None,
loss_factors: Optional[List[float]] = None,
use_all_tasks: bool = False,
):
"""
:param models: Key (Task ID) - Value (flair.nn.Model) Pairs to stack model
Expand All @@ -38,6 +39,7 @@ def __init__(

self.tasks: Dict[str, flair.nn.Classifier] = {}
self.loss_factors: Dict[str, float] = {}
self.use_all_tasks = use_all_tasks

if not loss_factors:
loss_factors = [1.0] * len(models)
Expand All @@ -64,7 +66,7 @@ def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> Tuple[torc
:param sentences: batch of sentences
:return: loss
"""
batch_split = self.split_batch_to_task_ids(sentences)
batch_split = self.split_batch_to_task_ids(sentences, all_tasks=self.use_all_tasks)
loss = torch.tensor(0.0, device=flair.device)
count = 0
for task_id, split in batch_split.items():
Expand All @@ -82,20 +84,25 @@ def predict(
task.predict(sentences, **predictargs)

@staticmethod
def split_batch_to_task_ids(sentences: Union[List[Sentence], Sentence]) -> Dict:
def split_batch_to_task_ids(sentences: Union[List[Sentence], Sentence], all_tasks: bool = False) -> Dict:
"""
Splits a batch of sentences to its respective model. If single sentence is assigned to several tasks
(i.e. same corpus but different tasks), then the model assignment for this batch is randomly choosen.
:param sentences: batch of sentences
:param all_tasks: use all tasks of each sentence. If deactivated, a random task will be sampled
:return: Key-value pairs as (task_id, list of sentences ids in batch)
"""
batch_to_task_mapping: Dict[str, List[int]] = {}
for sentence_id, sentence in enumerate(sentences):
multitask_id = random.choice(sentence.get_labels("multitask_id"))
if multitask_id.value in batch_to_task_mapping:
batch_to_task_mapping[multitask_id.value].append(sentence_id)
elif multitask_id.value not in batch_to_task_mapping:
batch_to_task_mapping[multitask_id.value] = [sentence_id]
if all_tasks:
multitask_ids = sentence.get_labels("multitask_id")
else:
multitask_ids = [random.choice(sentence.get_labels("multitask_id"))]
for multitask_id in multitask_ids:
if multitask_id.value in batch_to_task_mapping:
batch_to_task_mapping[multitask_id.value].append(sentence_id)
elif multitask_id.value not in batch_to_task_mapping:
batch_to_task_mapping[multitask_id.value] = [sentence_id]
return batch_to_task_mapping

def evaluate(
Expand All @@ -110,6 +117,7 @@ def evaluate(
exclude_labels: List[str] = [],
gold_label_dictionary: Optional[Dictionary] = None,
return_loss: bool = True,
evaluate_all: bool = True,
**evalargs,
) -> Result:
"""
Expand All @@ -118,10 +126,35 @@ def evaluate(
'cpu' (embeddings are stored on CPU) or 'gpu' (embeddings are stored on GPU)
:param mini_batch_size: size of batches
:param num_workers: number of workers for DataLoader class
:param evaluate_all: choose if all tasks should be evaluated, or a single one, depending on gold_label_type
:return: Tuple of Result object and loss value (float)
"""

batch_split = self.split_batch_to_task_ids(data_points)
if not evaluate_all:
if gold_label_type not in self.tasks:
raise ValueError(
"evaluating a single task on a multitask model requires 'gold_label_type' to be a valid task."
)
data = [
dp
for dp in data_points
if any(label.value == gold_label_type for label in dp.get_labels("multitask_id"))
]
return self.tasks[gold_label_type].evaluate(
data,
gold_label_type=self.tasks[gold_label_type].label_type,
out_path=out_path,
embedding_storage_mode=embedding_storage_mode,
mini_batch_size=mini_batch_size,
num_workers=num_workers,
main_evaluation_metric=main_evaluation_metric,
exclude_labels=exclude_labels,
gold_label_dictionary=gold_label_dictionary,
return_loss=return_loss,
**evalargs,
)

batch_split = self.split_batch_to_task_ids(data_points, all_tasks=True)

loss = torch.tensor(0.0, device=flair.device)
main_score = 0.0
Expand All @@ -131,8 +164,8 @@ def evaluate(
for task_id, split in batch_split.items():
result = self.tasks[task_id].evaluate(
data_points=[data_points[i] for i in split],
gold_label_type=gold_label_type[task_id],
out_path=f"{out_path}_{task_id}.txt",
gold_label_type=self.tasks[task_id].label_type,
out_path=f"{out_path}_{task_id}.txt" if out_path is not None else None,
embedding_storage_mode=embedding_storage_mode,
mini_batch_size=mini_batch_size,
num_workers=mini_batch_size,
Expand All @@ -157,7 +190,7 @@ def evaluate(
+ task_id
+ " - "
+ "Label type: "
+ self.label_type.get(task_id)
+ self.tasks[task_id].label_type
+ "\n\n"
+ result.detailed_results
)
Expand All @@ -183,6 +216,7 @@ def _get_state_dict(self):
**initial_model_state,
"model_states": {task: model._get_state_dict() for task, model in self.tasks.items()},
"loss_factors": [self.loss_factors[task] for task in self.tasks.keys()],
"use_all_tasks": self.use_all_tasks,
}

return model_state
Expand All @@ -200,7 +234,13 @@ def _init_model_with_state_dict(cls, state, **kwargs):
models.append(Classifier.load(task_state))
tasks.append(task)

model = cls(models=models, task_ids=tasks, loss_factors=loss_factors, **kwargs)
model = cls(
models=models,
task_ids=tasks,
loss_factors=loss_factors,
use_all_tasks=state.get("use_all_tasks", False),
**kwargs,
)
return model

@property
Expand Down
6 changes: 3 additions & 3 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,6 @@ def train(
# forward pass
loss, datapoint_count = self.model.forward_loss(batch_step)
average_over += datapoint_count

# Backward
if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
Expand All @@ -528,8 +527,9 @@ def train(
train_loss += loss.item()

# identify dynamic embeddings (always deleted) on first sentence
if not dynamic_embeddings:
dynamic_embeddings = identify_dynamic_embeddings(batch[0])

if dynamic_embeddings is None:
dynamic_embeddings = identify_dynamic_embeddings(batch)

# depending on memory mode, embeddings are moved to CPU, GPU or deleted
store_embeddings(batch, embeddings_storage_mode, dynamic_embeddings)
Expand Down
32 changes: 20 additions & 12 deletions flair/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.utils.data import Dataset

import flair
from flair.data import DT, DataPoint, Dictionary, Sentence, _iter_dataset
from flair.data import DT, Dictionary, Sentence, _iter_dataset

log = logging.getLogger("flair")

Expand Down Expand Up @@ -376,8 +376,8 @@ def store_embeddings(
dynamic_embeddings = None

# if dynamic embedding keys not passed, identify them automatically
elif not dynamic_embeddings:
dynamic_embeddings = identify_dynamic_embeddings(data_points[0])
elif dynamic_embeddings is None:
dynamic_embeddings = identify_dynamic_embeddings(data_points)

# always delete dynamic embeddings
for data_point in data_points:
Expand All @@ -390,15 +390,23 @@ def store_embeddings(
data_point.to("cpu", pin_memory=pin_memory)


def identify_dynamic_embeddings(data_point: DataPoint):
def identify_dynamic_embeddings(data_points: List[DT]):
dynamic_embeddings = []
if isinstance(data_point, Sentence):
first_token = data_point[0]
for name, vector in first_token._embeddings.items():
all_embeddings = []
for data_point in data_points:
if isinstance(data_point, Sentence):
first_token = data_point[0]
for name, vector in first_token._embeddings.items():
if vector.requires_grad:
dynamic_embeddings.append(name)
all_embeddings.append(name)

for name, vector in data_point._embeddings.items():
if vector.requires_grad:
dynamic_embeddings.append(name)

for name, vector in data_point._embeddings.items():
if vector.requires_grad:
dynamic_embeddings.append(name)
return dynamic_embeddings
all_embeddings.append(name)
if dynamic_embeddings:
return dynamic_embeddings
if not all_embeddings:
return None
return list(set(dynamic_embeddings))