Skip to content

Commit

Permalink
Improve XR-Transformer tokenizer memory efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
jiong-zhang committed Apr 1, 2022
1 parent d14af1a commit f64878e
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 198 deletions.
219 changes: 134 additions & 85 deletions pecos/xmc/xtransformer/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from pecos.xmc import MLModel, MLProblem, PostProcessor
from sklearn.preprocessing import normalize as sk_normalize
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import AdamW, AutoConfig, get_scheduler
from transformers import AdamW, AutoConfig, get_scheduler, BatchEncoding

from .module import XMCDataset
from .module import XMCTensorDataset, XMCTextDataset
from .network import ENCODER_CLASSES, HingeLoss, TransformerLinearXMCHead

logging.getLogger(transformers.__name__).setLevel(logging.WARNING)
Expand Down Expand Up @@ -91,7 +91,12 @@ class TrainParams(pecos.BaseParams): # type: ignore
cost_sensitive_ranker (bool, optional): if True, use clustering count aggregating for ranker's cost-sensitive learnin
Default False
pre_tokenize (bool, optional): if True, will tokenize training instances before training
This could potentially accelerate batch-generation but increases memory cost.
Default False
use_gpu (bool, optional): whether to use GPU even if available. Default True
eval_by_true_shorlist (bool, optional): if True, will compute validation scores by true label
shortlisting at intermediat layer. Default False
checkpoint_dir (str): path to save training checkpoints. Default empty to use a temp dir.
cache_dir (str): dir to store the pre-trained models downloaded from
Expand Down Expand Up @@ -126,7 +131,9 @@ class TrainParams(pecos.BaseParams): # type: ignore
save_steps: int = 100

cost_sensitive_ranker: bool = False
pre_tokenize: bool = False
use_gpu: bool = True
eval_by_true_shorlist: bool = False

checkpoint_dir: str = ""
cache_dir: str = ""
Expand Down Expand Up @@ -453,7 +460,25 @@ def download_model(cls, model_shortcut, num_labels=2, hidden_dropout_prob=0.1, c
text_model = TransformerLinearXMCHead(config.hidden_size, num_labels)
return cls(text_encoder, text_tokenizer, text_model)

def text_to_tensor(self, corpus, max_length=None, **kwargs):
@staticmethod
def _get_tokenizer_config(**kwargs):
"""Obtain tokenizer config.
Additional given kwargs will be added to/overwritting the default value.
Returns:
tokenizer_config (dict)
"""
convert_kwargs = {
"add_special_tokens": True,
"padding": "max_length",
"truncation": True,
"return_tensors": "pt", # return pytorch tensors
"return_token_type_ids": True,
"return_attention_mask": True,
}
return {**convert_kwargs, **kwargs}

def text_to_tensor(self, corpus, max_length=None):
"""Convert input text corpus into padded tensors
Args:
Expand All @@ -468,30 +493,24 @@ def text_to_tensor(self, corpus, max_length=None, **kwargs):
"token_type_ids": tensor of token type ids,
}
"""
convert_kwargs = {
"add_special_tokens": True,
"padding": "max_length",
"truncation": True,
"max_length": max_length,
"return_tensors": "pt", # return pytorch tensors
"return_token_type_ids": True,
"return_attention_mask": True,
}
# this it to disable the warning message for tokenizer
# REF: https://github.com/huggingface/transformers/issues/5486
os.environ["TOKENIZERS_PARALLELISM"] = "true"
LOGGER.info("***** Encoding data len={} truncation={}*****".format(len(corpus), max_length))
t_start = time.time()
feature_tensors = self.text_tokenizer.batch_encode_plus(
batch_text_or_text_pairs=corpus,
**convert_kwargs,
**self._get_tokenizer_config(max_length=max_length),
)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

feature_tensors["instance_number"] = torch.arange(feature_tensors["input_ids"].shape[0])

LOGGER.info("***** Finished with time cost={} *****".format(time.time() - t_start))
return feature_tensors

@staticmethod
def _get_label_tensors(M, Y, idx_padding=-1, val_padding=0, max_labels=None):
def _get_label_tensors(M, Y, idx_padding=-1, max_labels=None):
"""
Given matching matrix M and label matrix Y, construct label tensors for XMC training
The non-zero indices of Y are seen as positive labels and therefore all
Expand All @@ -515,8 +534,6 @@ def _get_label_tensors(M, Y, idx_padding=-1, val_padding=0, max_labels=None):
included.
idx_padding (int, optional): the index used to pad all label_indices
to the same length. Default -1
val_padding (float, optional): the value used to fill in
label_values corresponding to the zero entrees in Y. Default 0
max_labels (int, optional): max number of labels considered for each
instance, will subsample from existing label indices if need to.
Default None to use max row nnz of M.
Expand Down Expand Up @@ -557,7 +574,7 @@ def _get_label_tensors(M, Y, idx_padding=-1, val_padding=0, max_labels=None):
nr_inst = M1.shape[0]
label_indices = np.zeros((nr_inst, max_labels), dtype=np.int64) + idx_padding
if Y is not None:
label_values = np.zeros((nr_inst, max_labels), dtype=np.float32) + val_padding
label_values = np.zeros((nr_inst, max_labels), dtype=np.float32)

for i in range(nr_inst):
offset = 0
Expand Down Expand Up @@ -673,17 +690,22 @@ def predict(
elif not isinstance(pred_params, TransformerMatcher.PredParams):
raise TypeError(f"Unsupported type for pred_params: {type(pred_params)}")

if isinstance(X_text, list):
X_text = self.text_to_tensor(
X_text,
num_workers=kwargs.get("batch_gen_workers", 4),
max_length=pred_params.truncate_length,
)
if isinstance(X_text, (dict, BatchEncoding)):
nr_inst = X_text["input_ids"].shape[0]
elif isinstance(X_text, list):
nr_inst = len(X_text)
else:
raise ValueError(f"Invalid type for X_text ({type(X_text)})")

nr_inst = X_text["input_ids"].shape[0]
max_pred_chunk = kwargs.pop("max_pred_chunk", 10**7)

if max_pred_chunk is None or max_pred_chunk >= nr_inst:
if isinstance(X_text, list):
X_text = self.text_to_tensor(
X_text,
max_length=pred_params.truncate_length,
)

label_pred, embeddings = self._predict(
X_text,
X_feat=X_feat,
Expand All @@ -696,8 +718,16 @@ def predict(
embedding_chunks = []
P_chunks = []
for i in range(0, nr_inst, max_pred_chunk):
if isinstance(X_text, list):
cur_X_text = self.text_to_tensor(
X_text[i : i + max_pred_chunk],
max_length=pred_params.truncate_length,
)
else:
cur_X_text = {k: v[i : i + max_pred_chunk] for k, v in X_text.items()}

cur_P, cur_embedding = self._predict(
{k: v[i : i + max_pred_chunk] for k, v in X_text.items()},
cur_X_text,
X_feat=None if X_feat is None else X_feat[i : i + max_pred_chunk, :],
csr_codes=None if csr_codes is None else csr_codes[i : i + max_pred_chunk, :],
pred_params=pred_params,
Expand All @@ -724,7 +754,7 @@ def _predict(
"""Predict with the transformer matcher
Args:
X_text (dict): prediction inputs, dictionary of tensors
X_text (dict or BatchEncoding): prediction inputs, dictionary of tensors
{
"input_ids": tensor of input token ids,
"attention_mask": tensor of attention masks,
Expand Down Expand Up @@ -775,11 +805,11 @@ def _predict(
label_indices_pt, label_values_pt = TransformerMatcher._get_label_tensors(
csr_codes_next, None, idx_padding=self.text_model.label_pad
)
data = XMCDataset(
data = XMCTensorDataset(
X_text["input_ids"],
X_text["attention_mask"],
X_text["token_type_ids"],
torch.arange(X_text["input_ids"].shape[0]),
X_text["instance_number"],
label_values=label_values_pt,
label_indices=label_indices_pt,
)
Expand Down Expand Up @@ -969,20 +999,37 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None):
# put text_model to GPU
self.text_model.to(self.device)

label_indices_pt, label_values_pt = TransformerMatcher._get_label_tensors(
M_next,
prob.Y,
idx_padding=self.text_model.label_pad,
max_labels=max_act_labels,
)
train_data = XMCDataset(
prob.X_text["input_ids"],
prob.X_text["attention_mask"],
prob.X_text["token_type_ids"],
torch.arange(prob.X_text["input_ids"].shape[0]), # instance number
label_values=label_values_pt,
label_indices=label_indices_pt,
)
if prob.is_tokenized:
LOGGER.info("Using XMCTensorDataset for tokenized inputs!")
label_indices_pt, label_values_pt = TransformerMatcher._get_label_tensors(
M_next,
prob.Y,
idx_padding=self.text_model.label_pad,
max_labels=max_act_labels,
)
train_data = XMCTensorDataset(
prob.X_text["input_ids"],
prob.X_text["attention_mask"],
prob.X_text["token_type_ids"],
prob.X_text["instance_number"],
label_values=label_values_pt,
label_indices=label_indices_pt,
)
else:
LOGGER.info("Using XMCTextDataset for text inputs!")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
train_data = XMCTextDataset(
prob.X_text,
lambda x: self.text_tokenizer(
text=x,
**self._get_tokenizer_config(max_length=pred_params.truncate_length),
),
feature_keys=["input_ids", "attention_mask", "token_type_ids", "instance_number"],
Y=prob.Y,
M=M_next,
idx_padding=self.text_model.label_pad,
max_labels=max_act_labels,
)

# since number of active labels may vary
# using pinned memory will slow down data loading
Expand Down Expand Up @@ -1062,10 +1109,10 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None):

# Start Batch Training
LOGGER.info("***** Running training *****")
LOGGER.info(" Num examples = %d", prob.X_text["input_ids"].shape[0])
LOGGER.info(" Num examples = %d", prob.nr_inst)
LOGGER.info(" Num labels = %d", self.nr_labels)
if prob.M is not None:
LOGGER.info(" Num active labels per instance = %d", label_indices_pt.shape[1])
LOGGER.info(" Num active labels per instance = %d", train_data.num_active_labels)
LOGGER.info(" Num Epochs = %d", train_params.num_train_epochs)
LOGGER.info(" Learning Rate Schedule = %s", train_params.lr_schedule)
LOGGER.info(" Batch size = %d", train_params.batch_size)
Expand All @@ -1083,7 +1130,9 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None):
self.text_encoder.zero_grad()
self.text_model.zero_grad()
for epoch in range(1, int(train_params.num_train_epochs) + 1):
if do_resample and epoch > 1: # redo subsample negative labels
if (
isinstance(train_data, XMCTensorDataset) and do_resample and epoch > 1
): # redo subsample negative labels
label_indices_pt, label_values_pt = TransformerMatcher._get_label_tensors(
M_next,
prob.Y,
Expand Down Expand Up @@ -1171,10 +1220,12 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None):
if val_prob is not None:
if val_prob.M is None:
test_combos = zip(["all"], [None])
else:
elif train_params.eval_by_true_shorlist:
test_combos = zip(
["trn_ns", "pred_ns"], [val_prob.M, val_csr_codes]
)
else:
test_combos = zip(["pred_ns"], [val_csr_codes])
for val_type, valid_M in test_combos:
avr_beam = 1 if valid_M is None else valid_M.nnz / valid_M.shape[0]
# compute loss and prediction on test set
Expand Down Expand Up @@ -1285,8 +1336,8 @@ def train(
return_dict (bool): if True, return a dictionary with model
and its prediction/embeddings on train/validation dataset.
Default False.
return_train_pred (bool): if True and return_dict, return prediction matrix on training data
return_train_embeddings (bool): if True and return_dict, return training instance embeddings
return_pred_on_trn (bool): if True and return_dict, return prediction matrix on training data
return_embed_on_trn (bool): if True and return_dict, return training instance embeddings
Returns:
results (TransformerMatcher or dict):
if return_dict=True, return a dictionary:
Expand Down Expand Up @@ -1336,38 +1387,36 @@ def train(
matcher.train_params = train_params
matcher.pred_params = pred_params

# tokenize X_text if X_text is given as raw text
saved_trn_pt = kwargs.get("saved_trn_pt", "")
if not prob.is_tokenized:
if saved_trn_pt and os.path.isfile(saved_trn_pt):
trn_tensors = torch.load(saved_trn_pt)
LOGGER.info("trn tensors loaded_from {}".format(saved_trn_pt))
else:
trn_tensors = matcher.text_to_tensor(
prob.X_text,
num_workers=train_params.batch_gen_workers,
max_length=pred_params.truncate_length,
)
if saved_trn_pt:
torch.save(trn_tensors, saved_trn_pt)
LOGGER.info("trn tensors saved to {}".format(saved_trn_pt))
prob.X_text = trn_tensors

if val_prob is not None and not val_prob.is_tokenized:
saved_val_pt = kwargs.get("saved_val_pt", "")
if saved_val_pt and os.path.isfile(saved_val_pt):
val_tensors = torch.load(saved_val_pt)
LOGGER.info("val tensors loaded from {}".format(saved_val_pt))
else:
val_tensors = matcher.text_to_tensor(
val_prob.X_text,
num_workers=train_params.batch_gen_workers,
max_length=pred_params.truncate_length,
)
if saved_val_pt:
torch.save(val_tensors, saved_val_pt)
LOGGER.info("val tensors saved to {}".format(saved_val_pt))
val_prob.X_text = val_tensors
if train_params.pre_tokenize:
saved_trn_pt = kwargs.get("saved_trn_pt", "")
if not prob.is_tokenized:
if saved_trn_pt and os.path.isfile(saved_trn_pt):
trn_tensors = torch.load(saved_trn_pt)
LOGGER.info("trn tensors loaded_from {}".format(saved_trn_pt))
else:
trn_tensors = matcher.text_to_tensor(
prob.X_text,
max_length=pred_params.truncate_length,
)
if saved_trn_pt:
torch.save(trn_tensors, saved_trn_pt)
LOGGER.info("trn tensors saved to {}".format(saved_trn_pt))
prob.X_text = trn_tensors

if val_prob is not None and not val_prob.is_tokenized:
saved_val_pt = kwargs.get("saved_val_pt", "")
if saved_val_pt and os.path.isfile(saved_val_pt):
val_tensors = torch.load(saved_val_pt)
LOGGER.info("val tensors loaded from {}".format(saved_val_pt))
else:
val_tensors = matcher.text_to_tensor(
val_prob.X_text,
max_length=pred_params.truncate_length,
)
if saved_val_pt:
torch.save(val_tensors, saved_val_pt)
LOGGER.info("val tensors saved to {}".format(saved_val_pt))
val_prob.X_text = val_tensors

bootstrapping = kwargs.get("bootstrapping", None)
if bootstrapping is not None:
Expand Down Expand Up @@ -1413,12 +1462,12 @@ def train(
matcher.concat_model = None

return_dict = kwargs.get("return_dict", False)
return_train_pred = kwargs.get("return_train_pred", False) and return_dict
return_train_embeddings = kwargs.get("return_train_embeddings", False) and return_dict
return_pred_on_trn = kwargs.get("return_pred_on_trn", False) and return_dict
return_embed_on_trn = kwargs.get("return_embed_on_trn", False) and return_dict

P_trn, inst_embeddings = None, None
train_concat = pred_params.ensemble_method not in ["transformer-only"]
if train_concat or return_train_pred or return_train_embeddings:
if train_concat or return_pred_on_trn or return_embed_on_trn:
# getting the instance embeddings of training data
# since X_feat is not passed, transformer-only result is produced
P_trn, inst_embeddings = matcher.predict(
Expand Down Expand Up @@ -1499,9 +1548,9 @@ def train(
if return_dict:
return {
"matcher": matcher,
"trn_pred": P_trn if return_train_pred else None,
"trn_pred": P_trn if return_pred_on_trn else None,
"val_pred": P_val,
"trn_embeddings": inst_embeddings if return_train_embeddings else None,
"trn_embeddings": inst_embeddings if return_embed_on_trn else None,
"val_embeddings": val_inst_embeddings,
}
else:
Expand Down
Loading

0 comments on commit f64878e

Please sign in to comment.