-
Notifications
You must be signed in to change notification settings - Fork 16
/
train_segmentation.py
166 lines (136 loc) · 5.66 KB
/
train_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
from typing import Optional
from pyannote.audio import Model
from transformers import Trainer, TrainingArguments, HfArgumentParser
from datasets import load_dataset, DatasetDict
from diarizers import Preprocess, SegmentationModel, DataCollator, Metrics
from dataclasses import dataclass, field
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
dataset_name: str = field(
default=None,
metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
dataset_config_name: str = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_split_name: str = field(
default="train", metadata={"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"}
)
eval_split_name: str = field(
default="validation", metadata={"help": "The name of the training data set split to use (via the datasets library). Defaults to 'val'"}
)
split_on_subset: str = field(
default=None,
metadata={"help": "Automatically splits the dataset into train-val-set on a specified subset. Defaults to 'None'"},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
parser = HfArgumentParser((DataTrainingArguments, ModelArguments, TrainingArguments))
data_args, model_args, training_args = parser.parse_args_into_dataclasses()
# Load the Dataset:
if data_args.dataset_config_name:
dataset = load_dataset(
str(data_args.dataset_name),
str(data_args.dataset_config_name),
num_proc=int(data_args.preprocessing_num_workers)
)
else:
dataset = load_dataset(
str(data_args.dataset_name),
num_proc=int(data_args.preprocessing_num_workers)
)
train_split_name = data_args.train_split_name
val_split_name = data_args.eval_split_name
# Split in Train-Val-Test:
if data_args.split_on_subset:
train_testvalid = dataset[str(data_args.split_on_subset)].train_test_split(test_size=0.2, seed=0)
test_valid = train_testvalid['test'].train_test_split(test_size=0.5, seed=0)
dataset = DatasetDict({
'train': train_testvalid['train'],
'validation': test_valid['test'],
'test': test_valid['train']}
)
train_split_name = 'train'
val_split_name = 'validation'
# Load the Pretrained Pyannote Segmentation Model:
pretrained = Model.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_auth_token=True
)
model = SegmentationModel.from_pyannote_model(pretrained)
# Load the preprocessor:
preprocessor = Preprocess(model.config)
# Preprocess:
if training_args.do_train:
train_set = dataset[str(train_split_name)].map(
lambda file: preprocessor(file, random=False, overlap=0.5),
num_proc=data_args.preprocessing_num_workers,
remove_columns=next(iter(dataset.values())).column_names,
batched=True,
batch_size=1
).shuffle().with_format("torch")
if training_args.do_eval:
val_set = dataset[str(val_split_name)].map(
lambda file: preprocessor(file, random=False, overlap=0.0),
num_proc=data_args.preprocessing_num_workers,
remove_columns=next(iter(dataset.values())).column_names,
batched=True,
keep_in_memory=True,
batch_size=1
).with_format('torch')
# Load metrics:
metrics = Metrics(model.specifications)
# Define the Trainer:
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_set,
data_collator=DataCollator(max_speakers_per_chunk=model.config.max_speakers_per_chunk),
eval_dataset=val_set,
compute_metrics=metrics,
)
# Train!
if training_args.do_eval:
first_eval = trainer.evaluate()
print("Initial metric values: ", first_eval)
if training_args.do_train:
trainer.train()
# Write Training Stats
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "speaker diarization"}
if data_args.dataset_name is not None:
kwargs["dataset_tags"] = data_args.dataset_name
if data_args.dataset_config_name is not None:
kwargs["dataset_args"] = data_args.dataset_config_name
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
else:
kwargs["dataset"] = data_args.dataset_name
kwargs['tags'] = ['speaker-diarization', 'speaker-segmentation']
# Push to Hub
if training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
trainer.create_model_card(**kwargs)