Skip to content

Commit

Permalink
python code for sampling of distributed clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
yaushian authored and EC2 Default User committed Feb 9, 2023
1 parent 742e097 commit 7c0cd96
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 17 deletions.
1 change: 0 additions & 1 deletion pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ class TfidfVectorizerParam(ctypes.Structure):
]

def __init__(self, base_vect_param_list, norm_p):

self.num_base_vect = len(base_vect_param_list)
self.c_base_params = (TfidfBaseVectorizerParam * self.num_base_vect)()
for i, base_vect_param in enumerate(base_vect_param_list):
Expand Down
58 changes: 55 additions & 3 deletions pecos/distributed/xmc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from scipy.sparse import csr_matrix, csc_matrix
from pecos.distributed.comm.abs_dist_comm import DistComm
from pecos.utils.profile_util import MemInfo
import math
from copy import deepcopy


LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -359,9 +361,29 @@ def _train_meta_cluster(self, X, Y):
)

LOGGER.info("Starts generating meta tree cluster on main node...")
meta_indexer_params = self._indexer_params.to_dict()
meta_indexer_params["max_leaf_size"] = self._get_meta_tree_max_leaf_size(Y.shape[1])
meta_cluster_chain = Indexer.gen(label_feat, **meta_indexer_params)
meta_indexer_params = deepcopy(self._indexer_params)
meta_indexer_params.max_leaf_size = self._get_meta_tree_max_leaf_size(Y.shape[1])

if self._indexer_params.do_sample:
total_binary_indexer_depth = self._indexer_params.infer_binary_tree_depth(
label_feat.shape[0],
)
total_binary_warmup_layers = math.floor(
self._indexer_params.warmup_ratio * total_binary_indexer_depth
)
self._meta_binary_indexer_depth = meta_indexer_params.infer_binary_tree_depth(
label_feat.shape[0],
)
meta_indexer_params.warmup_ratio = min(
1, total_binary_warmup_layers / self._meta_binary_indexer_depth
)

meta_indexer_params.max_sample_rate = self._indexer_params.get_layer_sample_rate(
self._meta_binary_indexer_depth - 1,
total_binary_indexer_depth,
)

meta_cluster_chain = Indexer.gen(label_feat, train_params=meta_indexer_params)
LOGGER.info(f"Done generating meta tree cluster." f" {MemInfo.mem_info()}")

return meta_cluster_chain
Expand Down Expand Up @@ -394,6 +416,32 @@ def _train_sub_clusters(self, self_sub_tree_assign_arr_list, X, Y):
LOGGER.info(
f"Starts generating {idx}th sub-tree cluster on rank {self._dist_comm.get_rank()}..."
)

if self._indexer_params.do_sample:
sub_binary_indexer_depth = self._indexer_params.infer_binary_tree_depth(
label_feat.shape[0],
)

total_binary_indexer_depth = (
sub_binary_indexer_depth + self._meta_binary_indexer_depth
)
total_binary_warmup_layers = math.floor(
self._indexer_params.warmup_ratio * total_binary_indexer_depth
)

if self._meta_binary_indexer_depth >= total_binary_warmup_layers:
self._indexer_params.min_sample_rate = (
self._indexer_params.get_layer_sample_rate(
self._meta_binary_indexer_depth - 1,
total_binary_indexer_depth,
)
)
self._indexer_params.warmup_ratio = 0
else:
self._indexer_params.warmup_ratio = float(
total_binary_warmup_layers - self._meta_binary_indexer_depth
) / float(sub_binary_indexer_depth)

cluster_chain = Indexer.gen(label_feat, train_params=self._indexer_params)
LOGGER.info(
f"Done generating {idx}th sub-tree cluster on rank {self._dist_comm.get_rank()}."
Expand Down Expand Up @@ -439,12 +487,16 @@ def dist_get_cluster_chain(self, X, Y):

# Create meta tree cluster chain on main node
grp_sub_tree_assign_arr_list = None
self._meta_binary_indexer_depth = None
if self._dist_comm.get_rank() == 0:
meta_cluster_chain = self._train_meta_cluster(X, Y)
# Get sub-tree assignment arrays list for leaf cluster layer of meta-tree
sub_tree_assign_arr_list = smat_util.get_csc_col_nonzero(meta_cluster_chain[-1])
# Divide into n_machine groups to scatter
grp_sub_tree_assign_arr_list = self._divide_sub_cluster_jobs(sub_tree_assign_arr_list)
self._meta_binary_indexer_depth = self._dist_comm.bcast(
self._meta_binary_indexer_depth, root=0
)

# Create sub-tree cluster chain on all nodes
self_sub_tree_assign_arr_list = self._dist_comm.scatter(
Expand Down
1 change: 0 additions & 1 deletion pecos/utils/featurization/text/sentencepiece/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def predict(args):


if __name__ == "__main__":

parser = argparse.ArgumentParser(description="SentencePiece: tokenize text")

parser.add_argument(
Expand Down
1 change: 0 additions & 1 deletion pecos/utils/featurization/text/sentencepiece/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def train(args):


if __name__ == "__main__":

parser = argparse.ArgumentParser(description="SentencePiece: train tokenization model")

parser.add_argument(
Expand Down
37 changes: 30 additions & 7 deletions pecos/xmc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class TrainParams(pecos.BaseParams): # type: ignore
We use linear sampling strategy with warmup, which linearly increases sampling rate from `min_sample_rate` to `max_sample_rate`.
The top (total_layer * `warmup_ratio`) layers are warmup_layers which use a fixed sampling rate `min_sample_rate`.
The sampling rate for layer l is `min_sample_rate`+max(l+1-warmup_layer,0)*(`max_sample_rate`-min_sample_rate)/(total_layers-warmup_layers).
Please refer to 'self.get_layer_sample_rate()' function for complete definition.
max_sample_rate (float, optional): the maximum samplng rate at the end of the linear sampling strategy. Default is `1.0`.
min_sample_rate (float, optional): the minimum sampling rate at the begining warmup stage of the linear sampling strategy. Default is `0.1`.
Note that 0 < min_sample_rate <= max_sample_rate <= 1.0.
Expand All @@ -117,6 +118,33 @@ class TrainParams(pecos.BaseParams): # type: ignore
min_sample_rate: float = 0.1
warmup_ratio: float = 0.4

def infer_binary_tree_depth(self, label_num):
"""Given label_num, infer the depth of a binary tree.
label_num: (int): the label number of a binary tree.
"""
depth = max(1, int(math.ceil(math.log2(label_num / self.max_leaf_size))))
if (2**depth) > label_num:
raise ValueError(
f"max_leaf_size > 1 is needed for feat_mat.shape[0] == {label_num} to avoid empty clusters"
)
return depth

def get_layer_sample_rate(self, cur_layer, binary_tree_depth):
"""Given sample parameters, get the sample rate of cur_layer.
The definition of this function corresponds to pecos/core/utils/clustering.hpp "ClusteringSampler" struct
cur_layer: (int): 0-based layer index in a binary tree.
binary_tree_depth: (int): the depth of a binary tree.
"""
warmup_depth = int(self.warmup_ratio * binary_tree_depth)
if cur_layer < warmup_depth:
return self.min_sample_rate
sample_rate = self.min_sample_rate + (
self.max_sample_rate - self.min_sample_rate
) * float(cur_layer + 1 - warmup_depth) / float(binary_tree_depth - warmup_depth)
return sample_rate

@classmethod
def gen(
cls,
Expand Down Expand Up @@ -160,12 +188,7 @@ def gen(
smat.csc_matrix(np.ones((nr_instances, 1), dtype=np.float32))
)

depth = max(1, int(math.ceil(math.log2(nr_instances / train_params.max_leaf_size))))
if (2**depth) > nr_instances:
raise ValueError(
f"max_leaf_size > 1 is needed for feat_mat.shape[0] == {nr_instances} to avoid empty clusters"
)
train_params.depth = depth
train_params.depth = train_params.infer_binary_tree_depth(nr_instances)

partition_algo = cls.SKMEANS if train_params.spherical else cls.KMEANS
train_params.partition_algo = partition_algo
Expand All @@ -186,7 +209,7 @@ def gen(
train_params,
codes,
)
C = cls.convert_codes_to_csc_matrix(codes, depth)
C = cls.convert_codes_to_csc_matrix(codes, train_params.depth)
cluster_chain = ClusterChain.from_partial_chain(
C, min_codes=train_params.min_codes, nr_splits=train_params.nr_splits
)
Expand Down
3 changes: 0 additions & 3 deletions pecos/xmc/xtransformer/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def nr_inst(self):


class XMCTextTensorizer(object):

DEFAULT_FEATURE_KEYS = ["input_ids", "attention_mask", "token_type_ids", "instance_number"]

def __init__(self, text, feature_keys=None, input_transform=None):
Expand Down Expand Up @@ -167,7 +166,6 @@ def __init__(
max_labels=None,
pre_compute=False,
):

self.label_padding_idx = label_padding_idx
self.has_label = Y is not None
self.has_ns = M is not None
Expand Down Expand Up @@ -230,7 +228,6 @@ def num_active_labels(self):
return self.label_width

def get_lbl_mat(self, M, Y, max_labels=None):

if M is None and Y is None:
# 1.inference at top layer
self.label_width = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def test_unigram_model(tmpdir):


def test_cli_bpe_model(tmpdir, capsys):

test_input_file = tmpdir.join("test_input")
test_input_file.write_text("hello world", encoding="utf-8")
test_encoded_file = tmpdir.join("test_encoded")
Expand Down

0 comments on commit 7c0cd96

Please sign in to comment.