From f64878e9542ad65ef67400d4a88dd66eacb39b0f Mon Sep 17 00:00:00 2001 From: jiong-zhang Date: Thu, 31 Mar 2022 18:02:36 +0000 Subject: [PATCH] Improve XR-Transformer tokenizer memory efficiency --- pecos/xmc/xtransformer/matcher.py | 219 +++++++++++------- pecos/xmc/xtransformer/model.py | 111 +++------ pecos/xmc/xtransformer/module.py | 150 +++++++++++- pecos/xmc/xtransformer/network.py | 3 +- pecos/xmc/xtransformer/train.py | 32 +-- .../xmc/xtransformer/test_xtransformer.py | 6 +- 6 files changed, 323 insertions(+), 198 deletions(-) diff --git a/pecos/xmc/xtransformer/matcher.py b/pecos/xmc/xtransformer/matcher.py index 20839bd2..b5827c84 100644 --- a/pecos/xmc/xtransformer/matcher.py +++ b/pecos/xmc/xtransformer/matcher.py @@ -26,9 +26,9 @@ from pecos.xmc import MLModel, MLProblem, PostProcessor from sklearn.preprocessing import normalize as sk_normalize from torch.utils.data import DataLoader, RandomSampler, SequentialSampler -from transformers import AdamW, AutoConfig, get_scheduler +from transformers import AdamW, AutoConfig, get_scheduler, BatchEncoding -from .module import XMCDataset +from .module import XMCTensorDataset, XMCTextDataset from .network import ENCODER_CLASSES, HingeLoss, TransformerLinearXMCHead logging.getLogger(transformers.__name__).setLevel(logging.WARNING) @@ -91,7 +91,12 @@ class TrainParams(pecos.BaseParams): # type: ignore cost_sensitive_ranker (bool, optional): if True, use clustering count aggregating for ranker's cost-sensitive learnin Default False + pre_tokenize (bool, optional): if True, will tokenize training instances before training + This could potentially accelerate batch-generation but increases memory cost. + Default False use_gpu (bool, optional): whether to use GPU even if available. Default True + eval_by_true_shorlist (bool, optional): if True, will compute validation scores by true label + shortlisting at intermediat layer. Default False checkpoint_dir (str): path to save training checkpoints. Default empty to use a temp dir. cache_dir (str): dir to store the pre-trained models downloaded from @@ -126,7 +131,9 @@ class TrainParams(pecos.BaseParams): # type: ignore save_steps: int = 100 cost_sensitive_ranker: bool = False + pre_tokenize: bool = False use_gpu: bool = True + eval_by_true_shorlist: bool = False checkpoint_dir: str = "" cache_dir: str = "" @@ -453,7 +460,25 @@ def download_model(cls, model_shortcut, num_labels=2, hidden_dropout_prob=0.1, c text_model = TransformerLinearXMCHead(config.hidden_size, num_labels) return cls(text_encoder, text_tokenizer, text_model) - def text_to_tensor(self, corpus, max_length=None, **kwargs): + @staticmethod + def _get_tokenizer_config(**kwargs): + """Obtain tokenizer config. + Additional given kwargs will be added to/overwritting the default value. + + Returns: + tokenizer_config (dict) + """ + convert_kwargs = { + "add_special_tokens": True, + "padding": "max_length", + "truncation": True, + "return_tensors": "pt", # return pytorch tensors + "return_token_type_ids": True, + "return_attention_mask": True, + } + return {**convert_kwargs, **kwargs} + + def text_to_tensor(self, corpus, max_length=None): """Convert input text corpus into padded tensors Args: @@ -468,15 +493,6 @@ def text_to_tensor(self, corpus, max_length=None, **kwargs): "token_type_ids": tensor of token type ids, } """ - convert_kwargs = { - "add_special_tokens": True, - "padding": "max_length", - "truncation": True, - "max_length": max_length, - "return_tensors": "pt", # return pytorch tensors - "return_token_type_ids": True, - "return_attention_mask": True, - } # this it to disable the warning message for tokenizer # REF: https://github.com/huggingface/transformers/issues/5486 os.environ["TOKENIZERS_PARALLELISM"] = "true" @@ -484,14 +500,17 @@ def text_to_tensor(self, corpus, max_length=None, **kwargs): t_start = time.time() feature_tensors = self.text_tokenizer.batch_encode_plus( batch_text_or_text_pairs=corpus, - **convert_kwargs, + **self._get_tokenizer_config(max_length=max_length), ) + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + feature_tensors["instance_number"] = torch.arange(feature_tensors["input_ids"].shape[0]) LOGGER.info("***** Finished with time cost={} *****".format(time.time() - t_start)) return feature_tensors @staticmethod - def _get_label_tensors(M, Y, idx_padding=-1, val_padding=0, max_labels=None): + def _get_label_tensors(M, Y, idx_padding=-1, max_labels=None): """ Given matching matrix M and label matrix Y, construct label tensors for XMC training The non-zero indices of Y are seen as positive labels and therefore all @@ -515,8 +534,6 @@ def _get_label_tensors(M, Y, idx_padding=-1, val_padding=0, max_labels=None): included. idx_padding (int, optional): the index used to pad all label_indices to the same length. Default -1 - val_padding (float, optional): the value used to fill in - label_values corresponding to the zero entrees in Y. Default 0 max_labels (int, optional): max number of labels considered for each instance, will subsample from existing label indices if need to. Default None to use max row nnz of M. @@ -557,7 +574,7 @@ def _get_label_tensors(M, Y, idx_padding=-1, val_padding=0, max_labels=None): nr_inst = M1.shape[0] label_indices = np.zeros((nr_inst, max_labels), dtype=np.int64) + idx_padding if Y is not None: - label_values = np.zeros((nr_inst, max_labels), dtype=np.float32) + val_padding + label_values = np.zeros((nr_inst, max_labels), dtype=np.float32) for i in range(nr_inst): offset = 0 @@ -673,17 +690,22 @@ def predict( elif not isinstance(pred_params, TransformerMatcher.PredParams): raise TypeError(f"Unsupported type for pred_params: {type(pred_params)}") - if isinstance(X_text, list): - X_text = self.text_to_tensor( - X_text, - num_workers=kwargs.get("batch_gen_workers", 4), - max_length=pred_params.truncate_length, - ) + if isinstance(X_text, (dict, BatchEncoding)): + nr_inst = X_text["input_ids"].shape[0] + elif isinstance(X_text, list): + nr_inst = len(X_text) + else: + raise ValueError(f"Invalid type for X_text ({type(X_text)})") - nr_inst = X_text["input_ids"].shape[0] max_pred_chunk = kwargs.pop("max_pred_chunk", 10**7) if max_pred_chunk is None or max_pred_chunk >= nr_inst: + if isinstance(X_text, list): + X_text = self.text_to_tensor( + X_text, + max_length=pred_params.truncate_length, + ) + label_pred, embeddings = self._predict( X_text, X_feat=X_feat, @@ -696,8 +718,16 @@ def predict( embedding_chunks = [] P_chunks = [] for i in range(0, nr_inst, max_pred_chunk): + if isinstance(X_text, list): + cur_X_text = self.text_to_tensor( + X_text[i : i + max_pred_chunk], + max_length=pred_params.truncate_length, + ) + else: + cur_X_text = {k: v[i : i + max_pred_chunk] for k, v in X_text.items()} + cur_P, cur_embedding = self._predict( - {k: v[i : i + max_pred_chunk] for k, v in X_text.items()}, + cur_X_text, X_feat=None if X_feat is None else X_feat[i : i + max_pred_chunk, :], csr_codes=None if csr_codes is None else csr_codes[i : i + max_pred_chunk, :], pred_params=pred_params, @@ -724,7 +754,7 @@ def _predict( """Predict with the transformer matcher Args: - X_text (dict): prediction inputs, dictionary of tensors + X_text (dict or BatchEncoding): prediction inputs, dictionary of tensors { "input_ids": tensor of input token ids, "attention_mask": tensor of attention masks, @@ -775,11 +805,11 @@ def _predict( label_indices_pt, label_values_pt = TransformerMatcher._get_label_tensors( csr_codes_next, None, idx_padding=self.text_model.label_pad ) - data = XMCDataset( + data = XMCTensorDataset( X_text["input_ids"], X_text["attention_mask"], X_text["token_type_ids"], - torch.arange(X_text["input_ids"].shape[0]), + X_text["instance_number"], label_values=label_values_pt, label_indices=label_indices_pt, ) @@ -969,20 +999,37 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None): # put text_model to GPU self.text_model.to(self.device) - label_indices_pt, label_values_pt = TransformerMatcher._get_label_tensors( - M_next, - prob.Y, - idx_padding=self.text_model.label_pad, - max_labels=max_act_labels, - ) - train_data = XMCDataset( - prob.X_text["input_ids"], - prob.X_text["attention_mask"], - prob.X_text["token_type_ids"], - torch.arange(prob.X_text["input_ids"].shape[0]), # instance number - label_values=label_values_pt, - label_indices=label_indices_pt, - ) + if prob.is_tokenized: + LOGGER.info("Using XMCTensorDataset for tokenized inputs!") + label_indices_pt, label_values_pt = TransformerMatcher._get_label_tensors( + M_next, + prob.Y, + idx_padding=self.text_model.label_pad, + max_labels=max_act_labels, + ) + train_data = XMCTensorDataset( + prob.X_text["input_ids"], + prob.X_text["attention_mask"], + prob.X_text["token_type_ids"], + prob.X_text["instance_number"], + label_values=label_values_pt, + label_indices=label_indices_pt, + ) + else: + LOGGER.info("Using XMCTextDataset for text inputs!") + os.environ["TOKENIZERS_PARALLELISM"] = "false" + train_data = XMCTextDataset( + prob.X_text, + lambda x: self.text_tokenizer( + text=x, + **self._get_tokenizer_config(max_length=pred_params.truncate_length), + ), + feature_keys=["input_ids", "attention_mask", "token_type_ids", "instance_number"], + Y=prob.Y, + M=M_next, + idx_padding=self.text_model.label_pad, + max_labels=max_act_labels, + ) # since number of active labels may vary # using pinned memory will slow down data loading @@ -1062,10 +1109,10 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None): # Start Batch Training LOGGER.info("***** Running training *****") - LOGGER.info(" Num examples = %d", prob.X_text["input_ids"].shape[0]) + LOGGER.info(" Num examples = %d", prob.nr_inst) LOGGER.info(" Num labels = %d", self.nr_labels) if prob.M is not None: - LOGGER.info(" Num active labels per instance = %d", label_indices_pt.shape[1]) + LOGGER.info(" Num active labels per instance = %d", train_data.num_active_labels) LOGGER.info(" Num Epochs = %d", train_params.num_train_epochs) LOGGER.info(" Learning Rate Schedule = %s", train_params.lr_schedule) LOGGER.info(" Batch size = %d", train_params.batch_size) @@ -1083,7 +1130,9 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None): self.text_encoder.zero_grad() self.text_model.zero_grad() for epoch in range(1, int(train_params.num_train_epochs) + 1): - if do_resample and epoch > 1: # redo subsample negative labels + if ( + isinstance(train_data, XMCTensorDataset) and do_resample and epoch > 1 + ): # redo subsample negative labels label_indices_pt, label_values_pt = TransformerMatcher._get_label_tensors( M_next, prob.Y, @@ -1171,10 +1220,12 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None): if val_prob is not None: if val_prob.M is None: test_combos = zip(["all"], [None]) - else: + elif train_params.eval_by_true_shorlist: test_combos = zip( ["trn_ns", "pred_ns"], [val_prob.M, val_csr_codes] ) + else: + test_combos = zip(["pred_ns"], [val_csr_codes]) for val_type, valid_M in test_combos: avr_beam = 1 if valid_M is None else valid_M.nnz / valid_M.shape[0] # compute loss and prediction on test set @@ -1285,8 +1336,8 @@ def train( return_dict (bool): if True, return a dictionary with model and its prediction/embeddings on train/validation dataset. Default False. - return_train_pred (bool): if True and return_dict, return prediction matrix on training data - return_train_embeddings (bool): if True and return_dict, return training instance embeddings + return_pred_on_trn (bool): if True and return_dict, return prediction matrix on training data + return_embed_on_trn (bool): if True and return_dict, return training instance embeddings Returns: results (TransformerMatcher or dict): if return_dict=True, return a dictionary: @@ -1336,38 +1387,36 @@ def train( matcher.train_params = train_params matcher.pred_params = pred_params - # tokenize X_text if X_text is given as raw text - saved_trn_pt = kwargs.get("saved_trn_pt", "") - if not prob.is_tokenized: - if saved_trn_pt and os.path.isfile(saved_trn_pt): - trn_tensors = torch.load(saved_trn_pt) - LOGGER.info("trn tensors loaded_from {}".format(saved_trn_pt)) - else: - trn_tensors = matcher.text_to_tensor( - prob.X_text, - num_workers=train_params.batch_gen_workers, - max_length=pred_params.truncate_length, - ) - if saved_trn_pt: - torch.save(trn_tensors, saved_trn_pt) - LOGGER.info("trn tensors saved to {}".format(saved_trn_pt)) - prob.X_text = trn_tensors - - if val_prob is not None and not val_prob.is_tokenized: - saved_val_pt = kwargs.get("saved_val_pt", "") - if saved_val_pt and os.path.isfile(saved_val_pt): - val_tensors = torch.load(saved_val_pt) - LOGGER.info("val tensors loaded from {}".format(saved_val_pt)) - else: - val_tensors = matcher.text_to_tensor( - val_prob.X_text, - num_workers=train_params.batch_gen_workers, - max_length=pred_params.truncate_length, - ) - if saved_val_pt: - torch.save(val_tensors, saved_val_pt) - LOGGER.info("val tensors saved to {}".format(saved_val_pt)) - val_prob.X_text = val_tensors + if train_params.pre_tokenize: + saved_trn_pt = kwargs.get("saved_trn_pt", "") + if not prob.is_tokenized: + if saved_trn_pt and os.path.isfile(saved_trn_pt): + trn_tensors = torch.load(saved_trn_pt) + LOGGER.info("trn tensors loaded_from {}".format(saved_trn_pt)) + else: + trn_tensors = matcher.text_to_tensor( + prob.X_text, + max_length=pred_params.truncate_length, + ) + if saved_trn_pt: + torch.save(trn_tensors, saved_trn_pt) + LOGGER.info("trn tensors saved to {}".format(saved_trn_pt)) + prob.X_text = trn_tensors + + if val_prob is not None and not val_prob.is_tokenized: + saved_val_pt = kwargs.get("saved_val_pt", "") + if saved_val_pt and os.path.isfile(saved_val_pt): + val_tensors = torch.load(saved_val_pt) + LOGGER.info("val tensors loaded from {}".format(saved_val_pt)) + else: + val_tensors = matcher.text_to_tensor( + val_prob.X_text, + max_length=pred_params.truncate_length, + ) + if saved_val_pt: + torch.save(val_tensors, saved_val_pt) + LOGGER.info("val tensors saved to {}".format(saved_val_pt)) + val_prob.X_text = val_tensors bootstrapping = kwargs.get("bootstrapping", None) if bootstrapping is not None: @@ -1413,12 +1462,12 @@ def train( matcher.concat_model = None return_dict = kwargs.get("return_dict", False) - return_train_pred = kwargs.get("return_train_pred", False) and return_dict - return_train_embeddings = kwargs.get("return_train_embeddings", False) and return_dict + return_pred_on_trn = kwargs.get("return_pred_on_trn", False) and return_dict + return_embed_on_trn = kwargs.get("return_embed_on_trn", False) and return_dict P_trn, inst_embeddings = None, None train_concat = pred_params.ensemble_method not in ["transformer-only"] - if train_concat or return_train_pred or return_train_embeddings: + if train_concat or return_pred_on_trn or return_embed_on_trn: # getting the instance embeddings of training data # since X_feat is not passed, transformer-only result is produced P_trn, inst_embeddings = matcher.predict( @@ -1499,9 +1548,9 @@ def train( if return_dict: return { "matcher": matcher, - "trn_pred": P_trn if return_train_pred else None, + "trn_pred": P_trn if return_pred_on_trn else None, "val_pred": P_val, - "trn_embeddings": inst_embeddings if return_train_embeddings else None, + "trn_embeddings": inst_embeddings if return_embed_on_trn else None, "val_embeddings": val_inst_embeddings, } else: diff --git a/pecos/xmc/xtransformer/model.py b/pecos/xmc/xtransformer/model.py index b1e4ef0a..36102b89 100644 --- a/pecos/xmc/xtransformer/model.py +++ b/pecos/xmc/xtransformer/model.py @@ -16,7 +16,6 @@ import dataclasses as dc import pecos -import torch from pecos.core import clib from pecos.utils import smat_util, torch_util from pecos.utils.cluster_util import ClusterChain @@ -61,7 +60,6 @@ class TrainParams(pecos.BaseParams): only_encoder (bool, optional): if True, skip linear ranker training. Default False fix_clustering (bool, optional): if True, use the same hierarchial label tree for fine-tuning and final prediction. Default false. max_match_clusters (int, optional): max number of clusters on which to fine-tune transformer. Default 32768 - save_emb_dir (str): dir to save instance embeddings. Default None to ignore """ preliminary_indexer_params: HierarchicalKMeans.TrainParams = None # type: ignore @@ -73,7 +71,6 @@ class TrainParams(pecos.BaseParams): only_encoder: bool = False fix_clustering: bool = False max_match_clusters: int = 32768 - save_emb_dir: str = None # type: ignore @dc.dataclass class PredParams(pecos.BaseParams): @@ -231,7 +228,6 @@ def train( train_params (XTransformer.TrainParams): training parameters for XTransformer pred_params (XTransformer.pred_params): pred parameters for XTransformer kwargs: - label_feat (ndarray or csr_matrix, optional): label features on which to generate preliminary HLT saved_trn_pt (str, optional): path to save the tokenized trn text. Use a tempdir if not given saved_val_pt (str, optional): path to save the tokenized val text. Use a tempdir if not given matmul_threads (int, optional): number of threads to use for @@ -289,17 +285,10 @@ def train( "Downloaded encoder from {}.".format(matcher_train_params.model_shortcut) ) - parent_model.to_device(device, n_gpu=n_gpu) - _, inst_embeddings = parent_model.predict( - prob.X_text, - pred_params=matcher_pred_params, - batch_size=matcher_train_params.batch_size * max(1, n_gpu), - batch_gen_workers=matcher_train_params.batch_gen_workers, - only_embeddings=True, - ) - if val_prob: - _, val_inst_embeddings = parent_model.predict( - val_prob.X_text, + if not train_params.only_encoder: + parent_model.to_device(device, n_gpu=n_gpu) + _, inst_embeddings = parent_model.predict( + prob.X_text, pred_params=matcher_pred_params, batch_size=matcher_train_params.batch_size * max(1, n_gpu), batch_gen_workers=matcher_train_params.batch_gen_workers, @@ -308,18 +297,11 @@ def train( else: # 1. Constructing primary Hierarchial Label Tree if clustering is None: - label_feat = kwargs.get("label_feat", None) - if label_feat is None: - if prob.X_feat is None: - raise ValueError( - "Instance features are required to generate label features!" - ) - label_feat = LabelEmbeddingFactory.pifa(prob.Y, prob.X_feat) - clustering = Indexer.gen( - label_feat, + LabelEmbeddingFactory.pifa(prob.Y, prob.X_feat), train_params=train_params.preliminary_indexer_params, ) + else: # assert cluster chain in clustering is valid clustering = ClusterChain(clustering) @@ -337,12 +319,6 @@ def train( ) ) - steps_scale = kwargs.get("steps_scale", None) - if steps_scale is None: - steps_scale = [1.0] * nr_transformers - if len(steps_scale) != nr_transformers: - raise ValueError(f"steps-scale length error: {len(steps_scale)}!={nr_transformers}") - # construct fields with chain now we know the depth train_params = HierarchicalMLModel._duplicate_fields_with_name_ending_with_chain( train_params, cls.TrainParams, nr_transformers @@ -399,10 +375,6 @@ def get_negative_samples(mat_true, mat_pred, scheme): for i in range(nr_transformers): cur_train_params = train_params.matcher_params_chain[i] cur_pred_params = pred_params.matcher_params_chain[i] - cur_train_params.max_steps = steps_scale[i] * cur_train_params.max_steps - cur_train_params.num_train_epochs = ( - steps_scale[i] * cur_train_params.num_train_epochs - ) cur_ns = cur_train_params.negative_sampling @@ -448,13 +420,22 @@ def get_negative_samples(mat_true, mat_pred, scheme): init_text_model = deepcopy(parent_model.text_model) bootstrapping = (init_encoder, inst_embeddings, init_text_model) - # determine whether train prediction and instance embeddings are needed - return_train_pred = ( - i + 1 < nr_transformers - ) and "man" in train_params.matcher_params_chain[i + 1].negative_sampling - return_train_embeddings = ( - i + 1 == nr_transformers - ) or "linear" in cur_train_params.bootstrap_method + # determine whether predictions on training data are needed + if i == nr_transformers - 1: + return_pred_on_trn = False + else: + return_pred_on_trn = any( + "man" in tp.negative_sampling + for tp in train_params.matcher_params_chain[i + 1 :] + ) + + # determine whether train instance embeddings are needed + if i == nr_transformers - 1: + return_embed_on_trn = not train_params.only_encoder + else: + return_embed_on_trn = ( + "linear" in train_params.matcher_params_chain[i + 1].bootstrap_method + ) res_dict = TransformerMatcher.train( cur_prob, @@ -465,8 +446,8 @@ def get_negative_samples(mat_true, mat_pred, scheme): pred_params=cur_pred_params, bootstrapping=bootstrapping, return_dict=True, - return_train_pred=return_train_pred, - return_train_embeddings=return_train_embeddings, + return_pred_on_trn=return_pred_on_trn, + return_embed_on_trn=return_embed_on_trn, saved_trn_pt=saved_trn_pt, saved_val_pt=saved_val_pt, ) @@ -474,22 +455,6 @@ def get_negative_samples(mat_true, mat_pred, scheme): M_pred = res_dict["trn_pred"] val_M_pred = res_dict["val_pred"] inst_embeddings = res_dict["trn_embeddings"] - val_inst_embeddings = res_dict["val_embeddings"] - - if train_params.save_emb_dir: - os.makedirs(train_params.save_emb_dir, exist_ok=True) - if inst_embeddings is not None: - smat_util.save_matrix( - os.path.join(train_params.save_emb_dir, "X.trn.npy"), - inst_embeddings, - ) - LOGGER.info(f"Trn embeddings saved to {train_params.save_emb_dir}/X.trn.npy") - if val_inst_embeddings is not None: - smat_util.save_matrix( - os.path.join(train_params.save_emb_dir, "X.val.npy"), - val_inst_embeddings, - ) - LOGGER.info(f"Val embeddings saved to {train_params.save_emb_dir}/X.val.npy") ranker = None if not train_params.only_encoder: @@ -566,7 +531,6 @@ def predict( Default None to disable overriding post_processor (str, optional): override the post_processor specified in the model Default None to disable overriding - saved_pt (str, optional): if given, will try to load encoded tensors and skip text encoding batch_size (int, optional): per device batch size for transformer evaluation. Default 8 batch_gen_workers (int, optional): number of CPUs to use for batch generation. Default 4 use_gpu (bool, optional): use GPU if available. Default True @@ -580,7 +544,6 @@ def predict( if not isinstance(self.concat_model, XLinearModel): raise TypeError("concat_model is not present in current XTransformer model!") - saved_pt = kwargs.get("saved_pt", None) batch_size = kwargs.get("batch_size", 8) batch_gen_workers = kwargs.get("batch_gen_workers", 4) use_gpu = kwargs.get("use_gpu", True) @@ -603,20 +566,10 @@ def predict( encoder_pred_params = pred_params.matcher_params_chain # generate instance-to-cluster prediction - if saved_pt and os.path.isfile(saved_pt): - text_tensors = torch.load(saved_pt) - LOGGER.info("Text tensors loaded_from {}".format(saved_pt)) - else: - text_tensors = self.text_encoder.text_to_tensor( - X_text, - num_workers=batch_gen_workers, - max_length=encoder_pred_params.truncate_length, - ) - pred_csr = None self.text_encoder.to_device(device, n_gpu=n_gpu) _, embeddings = self.text_encoder.predict( - text_tensors, + X_text, pred_params=encoder_pred_params, batch_size=batch_size * max(1, n_gpu), batch_gen_workers=batch_gen_workers, @@ -654,7 +607,6 @@ def encode( XTransformer.PredParams. Default None to use pred_params stored during model training. kwargs: - saved_pt (str, optional): if given, will try to load encoded tensors and skip text encoding batch_size (int, optional): per device batch size for transformer evaluation. Default 8 batch_gen_workers (int, optional): number of CPUs to use for batch generation. Default 4 use_gpu (bool, optional): use GPU if available. Default True @@ -664,7 +616,6 @@ def encode( Returns: embeddings (ndarray): instance embedding on training data, shape = (nr_inst, hidden_dim). """ - saved_pt = kwargs.get("saved_pt", None) batch_size = kwargs.get("batch_size", 8) batch_gen_workers = kwargs.get("batch_gen_workers", 4) use_gpu = kwargs.get("use_gpu", True) @@ -685,19 +636,9 @@ def encode( encoder_pred_params = pred_params.matcher_params_chain # generate instance-to-cluster prediction - if saved_pt and os.path.isfile(saved_pt): - text_tensors = torch.load(saved_pt) - LOGGER.info("Text tensors loaded_from {}".format(saved_pt)) - else: - text_tensors = self.text_encoder.text_to_tensor( - X_text, - num_workers=batch_gen_workers, - max_length=encoder_pred_params.truncate_length, - ) - self.text_encoder.to_device(device, n_gpu=n_gpu) _, embeddings = self.text_encoder.predict( - text_tensors, + X_text, pred_params=encoder_pred_params, batch_size=batch_size * max(1, n_gpu), batch_gen_workers=batch_gen_workers, diff --git a/pecos/xmc/xtransformer/module.py b/pecos/xmc/xtransformer/module.py index 5e31d413..ab35ff1b 100644 --- a/pecos/xmc/xtransformer/module.py +++ b/pecos/xmc/xtransformer/module.py @@ -8,9 +8,15 @@ # or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES # OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions # and limitations under the License. +import logging import numpy as np +import torch import scipy.sparse as smat +from pecos.utils import smat_util from torch.utils.data import Dataset, TensorDataset +from transformers import BatchEncoding + +LOGGER = logging.getLogger(__name__) class MLProblemWithText(object): @@ -39,7 +45,7 @@ def __init__(self, X_text, Y, X_feat=None, C=None, M=None): @property def is_tokenized(self): - return isinstance(self.X_text, dict) + return isinstance(self.X_text, (dict, BatchEncoding)) def type_check(self): if self.X_feat is not None and not isinstance(self.X_feat, (smat.csr_matrix, np.ndarray)): @@ -63,8 +69,12 @@ def nr_features(self): def nr_codes(self): return None if self.C is None else self.C.shape[1] + @property + def nr_inst(self): + return self.Y.shape[0] + -class XMCDataset(Dataset): +class XMCTensorDataset(Dataset): """Dataset to hold feature and label tensors for XMC training and prediction. Args: @@ -95,6 +105,13 @@ def __init__(self, *features, label_values=None, label_indices=None): self.label_values = label_values self.label_indices = label_indices + @property + def num_active_labels(self): + if self.label_indices is None: + return None + else: + return self.label_indices.shape[1] + def __getitem__(self, index): if self.label_values is not None and self.label_indices is not None: return self.data[index] + (self.label_values[index], self.label_indices[index]) @@ -112,3 +129,132 @@ def refresh_labels(self, label_values=None, label_indices=None): """Refresh label-values and label-indices from given tensors""" self.label_values = label_values self.label_indices = label_indices + + +class XMCTextDataset(Dataset): + """Dataset to hold text and label/matching matrices for XMC training and prediction. + Conduct real-time tokenization of input text and label tensor generation to save memory. + + Args: + text (list of str): input text, length = nr_inst + input_transform (function): the transform function to process/tokenize text + feature_keys (list of str): the feature keys in order for batch generation. + Y (csr_matrix, optional): training labels, shape = (nr_inst, nr_labels) + M (csr_matrix, optional): matching matrix, shape = (nr_inst, nr_codes) + model will be trained only on its non-zero indices + its values will not be used. + idx_padding (int, optional): the index used to pad all label_indices + to the same length. Default -1 + max_labels (int, optional): max number of labels considered for each + instance, will subsample from existing label indices if need to. + Default None to ignore. + + + Return values depend on the Y and M: + 1. Both Y and M are not None (train on middle layer): + data[i] = (feature[0][i], feature[1][i], ..., label_values[i], label_indices[i]) + 2. Both Y and M are None (inference on top layer): + data[i] = (feature[0][i], feature[1][i], ...) + 2. Y is not None, M is None (train on top layer): + data[i] = (feature[0][i], feature[1][i], ..., label_values[i]) + 3. Y is None, M is not None (inference on middle layer): + data[i] = (feature[0][i], feature[1][i], ..., label_indices[i]) + """ + + def __init__( + self, + text, + input_transform, + feature_keys, + Y=None, + M=None, + idx_padding=-1, + max_labels=None, + ): + self.text = text + self.input_transform = input_transform + self.feature_keys = feature_keys + self.idx_padding = idx_padding + + self.lbl_mat = None + self.has_label = Y is not None + self.has_ns = M is not None + + self.offset = 0 + + if M is None and Y is None: + # 1.inference at top layer + self.label_width = None + elif M is not None and Y is None: + # 2.inference at intermediate layer + self.label_width = max(M.indptr[1:] - M.indptr[:-1]) + self.lbl_mat = smat_util.binarized(M) + elif M is None and Y is not None: + # 3.train at top layer + self.label_width = Y.shape[1] + self.lbl_mat = Y.astype(np.float32) + elif M is not None and Y is not None: + # 4.train at intermediate layer + self.lbl_mat = smat_util.binarized(M) + smat_util.binarized(Y) + self.label_width = max(self.lbl_mat.indptr[1:] - self.lbl_mat.indptr[:-1]) + # put values in M, positive labels equal to y + offset, negative to offset + # offset is used to avoid elimination of zero entrees + self.offset = Y.data.max() + 10.0 + self.lbl_mat.data[:] = self.offset + self.lbl_mat += Y + + if self.label_width is not None and max_labels is not None: + if self.label_width > max_labels: + LOGGER.warning(f"will need to sub-sample from {self.label_width} to {max_labels}") + self.label_width = max_labels + + if Y is not None: + label_lower_bound = max(Y.indptr[1:] - Y.indptr[:-1]) + if label_lower_bound > self.label_width: + LOGGER.warning( + f"label-width ({self.label_width}) is not able to cover all positive labels ({label_lower_bound})!" + ) + + def __len__(self): + return len(self.text) + + @property + def num_active_labels(self): + return self.label_width + + def get_input_tensors(self, i): + ret = self.input_transform(self.text[i]) + ret["instance_number"] = torch.IntTensor([i]) + return tuple(ret[kk].squeeze(dim=0) for kk in self.feature_keys) + + def get_output_tensors(self, i): + if not self.has_ns: + if not self.has_label: + return tuple() + else: + return (torch.FloatTensor(self.lbl_mat[i].toarray()).squeeze(dim=0),) + else: + nr_active = self.lbl_mat.indptr[i + 1] - self.lbl_mat.indptr[i] + rng = slice(self.lbl_mat.indptr[i], self.lbl_mat.indptr[i + 1]) + + if nr_active > self.label_width: + # sub-sample to fit in self.label_width + nr_active = self.label_width + rng = np.random.choice( + np.arange(self.lbl_mat.indptr[i], self.lbl_mat.indptr[i + 1]), + nr_active, + replace=False, + ) + + label_indices = torch.zeros((self.label_width,), dtype=torch.int) + self.idx_padding + label_indices[:nr_active] = torch.from_numpy(self.lbl_mat.indices[rng]) + + if not self.has_label: + return (label_indices,) + else: + label_values = torch.zeros((self.label_width,), dtype=torch.float32) + label_values[:nr_active] = torch.from_numpy(self.lbl_mat.data[rng] - self.offset) + return (label_values, label_indices) + + def __getitem__(self, index): + return self.get_input_tensors(index) + self.get_output_tensors(index) diff --git a/pecos/xmc/xtransformer/network.py b/pecos/xmc/xtransformer/network.py index c4fba78b..1989d45b 100644 --- a/pecos/xmc/xtransformer/network.py +++ b/pecos/xmc/xtransformer/network.py @@ -38,6 +38,7 @@ from transformers.models.bert.modeling_bert import BERT_INPUTS_DOCSTRING, BERT_START_DOCSTRING from transformers.models.roberta.modeling_roberta import ( + RobertaPreTrainedModel, ROBERTA_INPUTS_DOCSTRING, ROBERTA_START_DOCSTRING, ) @@ -269,7 +270,7 @@ def forward( """Roberta Model with mutli-label classification head on top for XMC.\n""", ROBERTA_START_DOCSTRING, ) -class RobertaForXMC(BertPreTrainedModel): +class RobertaForXMC(RobertaPreTrainedModel): """ Examples: tokenizer = RobertaTokenizer.from_pretrained('roberta-base') diff --git a/pecos/xmc/xtransformer/train.py b/pecos/xmc/xtransformer/train.py index 6ff9cc14..bc303f50 100644 --- a/pecos/xmc/xtransformer/train.py +++ b/pecos/xmc/xtransformer/train.py @@ -13,12 +13,13 @@ import logging import os import sys +import gc import numpy as np from pecos.utils import cli, logging_util, smat_util, torch_util from pecos.utils.cluster_util import ClusterChain from pecos.utils.featurization.text.preprocess import Preprocessor -from pecos.xmc import PostProcessor +from pecos.xmc import PostProcessor, Indexer, LabelEmbeddingFactory from .matcher import TransformerMatcher from .model import XTransformer @@ -343,14 +344,6 @@ def parse_arguments(): type=int, help="if > 0: set total number of training steps to perform for each sub-task. Overrides num-train-epochs.", ) - parser.add_argument( - "--steps-scale", - nargs="+", - type=float, - default=None, - metavar="FLOAT", - help="scale number of transformer fine-tuning steps for each layer. Default None to ignore", - ) parser.add_argument( "--max-no-improve-cnt", type=int, @@ -408,13 +401,6 @@ def parse_arguments(): type=int, help="Upper limit on labels to put output layer in GPU. Default 65536", ) - parser.add_argument( - "--save-emb-dir", - default=None, - metavar="PATH", - type=str, - help="dir to save the final instance embeddings.", - ) parser.add_argument( "--use-gpu", type=cli.str2bool, @@ -537,8 +523,7 @@ def do_train(args): else: tst_corpus = None - # load cluster chain or label features - cluster_chain, label_feat = None, None + # load cluster chain if os.path.exists(args.code_path): cluster_chain = ClusterChain.from_partial_chain( smat_util.load_matrix(args.code_path), @@ -554,6 +539,15 @@ def do_train(args): label_feat.shape, args.label_feat_path ) ) + else: + label_feat = LabelEmbeddingFactory.pifa(Y_trn, X_trn) + + cluster_chain = Indexer.gen( + label_feat, + train_params=train_params.preliminary_indexer_params, + ) + del label_feat + gc.collect() trn_prob = MLProblemWithText(trn_corpus, Y_trn, X_feat=X_trn) if all(v is not None for v in [tst_corpus, Y_tst]): @@ -568,8 +562,6 @@ def do_train(args): train_params=train_params, pred_params=pred_params, beam_size=args.beam_size, - steps_scale=args.steps_scale, - label_feat=label_feat, ) xtf.save(args.model_dir) diff --git a/test/pecos/xmc/xtransformer/test_xtransformer.py b/test/pecos/xmc/xtransformer/test_xtransformer.py index c8ee2d27..088251ea 100644 --- a/test/pecos/xmc/xtransformer/test_xtransformer.py +++ b/test/pecos/xmc/xtransformer/test_xtransformer.py @@ -148,7 +148,6 @@ def test_encode(tmpdir): model_folder = tmpdir.join("only_encoder") emb_path = model_folder.join("embeddings.npy") - save_emb_path = model_folder.join("X.trn.npy") emb_path_B1 = model_folder.join("embeddings_B1.npy") # Training matcher @@ -164,7 +163,6 @@ def test_encode(tmpdir): cmd += ["--save-steps {}".format(2)] cmd += ["--only-topk {}".format(2)] cmd += ["--batch-gen-workers {}".format(2)] - cmd += ["--save-emb-dir {}".format(str(model_folder))] cmd += ["--only-encoder true"] process = subprocess.run( shlex.split(" ".join(cmd)), stdout=subprocess.PIPE, stderr=subprocess.PIPE @@ -184,9 +182,7 @@ def test_encode(tmpdir): ) assert process.returncode == 0, " ".join(cmd) - X_emb_save = np.load(str(save_emb_path)) X_emb_pred = np.load(str(emb_path)) - assert X_emb_pred == approx(X_emb_save, abs=1e-6) # encode with max_pred_chunk=1 cmd = [] @@ -203,7 +199,7 @@ def test_encode(tmpdir): assert process.returncode == 0, " ".join(cmd) X_emb_pred_B1 = np.load(str(emb_path_B1)) - assert X_emb_pred_B1 == approx(X_emb_save, abs=1e-6) + assert X_emb_pred_B1 == approx(X_emb_pred, abs=1e-6) def test_xtransformer_python_api():