From 4cfb9d57ed32b98994954792f351fb439b85540d Mon Sep 17 00:00:00 2001 From: Jiong Zhang Date: Fri, 15 Oct 2021 05:17:16 -0700 Subject: [PATCH] Text2text refactor (#78) * text2text refactor and update robin_hood version --- THIRD-PARTY-LICENSES.txt | 2 +- pecos/apps/text2text/README.md | 21 + pecos/apps/text2text/model.py | 377 ++++----- pecos/apps/text2text/train.py | 127 +-- .../robin_hood_hashing/robin_hood.h | 796 ++++++++---------- 5 files changed, 586 insertions(+), 737 deletions(-) diff --git a/THIRD-PARTY-LICENSES.txt b/THIRD-PARTY-LICENSES.txt index b2970c39..800bfb3f 100644 --- a/THIRD-PARTY-LICENSES.txt +++ b/THIRD-PARTY-LICENSES.txt @@ -68,7 +68,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ------ -** Robin Hood Hashing; version 3.9.0 -- +** Robin Hood Hashing; version 3.11.3 -- https://github.com/martinus/robin-hood-hashing Copyright (c) 2018-2019 Martin Ankerl diff --git a/pecos/apps/text2text/README.md b/pecos/apps/text2text/README.md index dd84b7aa..5af78e89 100644 --- a/pecos/apps/text2text/README.md +++ b/pecos/apps/text2text/README.md @@ -105,6 +105,27 @@ python3 -m pecos.apps.text2text.predict \ --predicted-output-item-path ./test-prediction.txt ``` The predictions are saved in the `./test-prediction.txt`. + +### Advanced Usage: Train/Pred params via JSON input +`pecos.apps.text2text` supports accepting training and predicting parameters from an input JSON file. +Moreover, `python3 -m pecos.apps.text2text.train` helpfully provide the option to generate all parameters in JSON format to stdout. + +You can generate train/pred parameter to `.json` files +with all of the parameters that you can edit and fill in. +```bash + > python3 -m pecos.apps.text2text.train --generate-train-params-skeleton &> train_params.json + > python3 -m pecos.apps.text2text.train --generate-pred-params-skeleton &> pred_params.json +``` +After editing the `train_params.json` and `pred_params.json` files, you can do training via: +```bash + > python3 -m pecos.apps.text2text.train \ + --input-text-path ./training-data.txt \ + --vectorizer-config-path ./config.json \ + --output-item-path ./output-labels.txt \ + --model-folder ./pecos-text2text-model + --train-params-path train_params.json \ + --pred-params-path pred_params.json +``` *** Copyright (2021) Amazon.com, Inc. diff --git a/pecos/apps/text2text/model.py b/pecos/apps/text2text/model.py index c5349663..67cb9121 100644 --- a/pecos/apps/text2text/model.py +++ b/pecos/apps/text2text/model.py @@ -8,9 +8,9 @@ # or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES # OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions # and limitations under the License. +import dataclasses as dc import gc import hashlib -import itertools import json import logging import pathlib @@ -18,11 +18,13 @@ from os import makedirs, path import numpy as np +import pecos from pecos.utils import smat_util from pecos.utils.cluster_util import ClusterChain from pecos.utils.featurization.text.preprocess import Preprocessor from pecos.xmc import Indexer, LabelEmbeddingFactory from pecos.xmc.xlinear import XLinearModel +from pecos.xmc.base import HierarchicalKMeans LOGGER = logging.getLogger(__name__) @@ -89,6 +91,48 @@ def __init__(self, preprocessor=None, xlinear_models=None, output_items=None): self.xlinear_models = xlinear_models self.output_items = output_items + @dc.dataclass + class TrainParams(pecos.BaseParams): + """Training Parameters of Text2Text. + + indexer_params (HierarchicalKMeans.TrainParams): params to generate hierarchial label tree. + xlinear_params (XLinearModel.TrainParams): train params for XLinearModel + """ + + indexer_params: HierarchicalKMeans.TrainParams = None # type: ignore + xlinear_params: XLinearModel.TrainParams = None # type: ignore + + @dc.dataclass + class PredParams(pecos.BaseParams): + """Pred Parameters of XTransformer. + + xlinear_params (XLinearModel.PredParams): pred params for linear ranker + """ + + xlinear_params: XLinearModel.PredParams = None # type: ignore + + def override_with_kwargs(self, pred_kwargs): + """override pred_params with kwargs. + + Args: + pred_kwargs: + beam_size (int): the beam search size. + Overrides only_topk for all models except for the bottom one. + only_topk (int): the final topk predictions to generate. + Overrides only_topk for bottom model. + post_processor (str): post processor scheme for prediction. + Overrides post_processor for all models. + """ + self.xlinear_params.override_with_kwargs(pred_kwargs) + return self + + def get_pred_params(self): + """get pred_params saved in the model""" + ret_pred_params = self.PredParams( + xlinear_params=self.xlinear_models.get_pred_params(), + ) + return ret_pred_params + def save(self, model_folder): """Save the Text2Text model @@ -129,6 +173,12 @@ def load(cls, model_folder, is_predict_only=False, **kwargs): for i, model_kwargs in enumerate(ensemble_config["kwargs"]): folder = path.join(xlinear_folder, "{}".format(i)) xlinear_models += [(XLinearModel.load(folder, is_predict_only, **kwargs), model_kwargs)] + + if len(xlinear_models) > 1: + LOGGER.warning( + f"Deprecation warning: loaded {len(xlinear_models)} xlinear models, multi-model ensemble prediction will be deprecated in future releases." + ) + output_items = None folder_path = pathlib.Path(model_folder) json_output_items_filepath = folder_path / "output_items.json" @@ -144,28 +194,12 @@ def train( cls, input_text_path, output_text_path, + label_embed_type="pifa", vectorizer_config=None, - dtype=np.float32, - label_embed_type=["pifa"], - indexer_algo=["hierarchicalkmeans"], - imbalanced_ratio=0.0, - imbalanced_depth=100, - spherical=True, - nr_splits=16, - max_leaf_size=[100], - seed=[0], - max_iter=[20], - solver_type=["L2R_L2LOSS_SVC_DUAL"], - Cp=[1.0], - Cn=[1.0], - bias=1.0, - threshold=[0.1], - negative_sampling_scheme="tfn", - pred_kwargs=None, - rel_mode="disable", - rel_norm="no-norm", - threads=-1, + train_params=None, + pred_params=None, workspace_folder=None, + **kwargs, ): """Train a Text2Text model @@ -179,50 +213,24 @@ def train( output_text_path (str): The file path for output text items. Format: each line corresponds to a representation of the output item. We assume utf-8 encoding for text. - vectorizer_config_json (str): Json_format string for vectorizer config (default None) - dtype (float32 | float64): data type (default float32) label_embed_type (list of str): Label embedding types. (default pifa). We support pifa, pifa_lf_concat::Z=path, and pifa_lf_convex_combine::Z=path::alpha=scalar_value. Multiple values will lead to different individual models for ensembling. - indexer_algo (list of str): Indexer algorithm (default ["hierarchicalkmeans"]). - imbalanced_ratio (float): Value between 0.0 and 0.5 (inclusive). Indicates how relaxed the balancedness - constraint of 2_means can be. Specifically, if an iteration of 2_means is clustering L labels, - the size of the output 2 clusters will be within approx imbalanced_ratio * 2 * L of each other. - (default 0.0) - imbalanced_depth (int): After hierarchical 2_means clustering has reached this depth, - it will continue clustering as if imbalanced_ratio is set to 0.0. (default 100) - spherical (bool): Do l2_normalize cluster centers while clustering (default True). - nr_splits (int): number of splits used to construct hierarchy (a power of 2 is recommended, default 16) - max_leaf_size (list of int): The max size of the leaf nodes of hierarchical 2_means clustering. - Multiple values (separated by comma) are supported and will lead to different - individual models for ensembling. (default [100]) - seed (list of int): Random seeds (default [0]). Multiple values will lead to different individual - models for ensembling. - max_iter (int): The max iteration for indexing (default 20) - solver_type (list of string): solver type for ranking (default ["L2R_L2LOSS_SVC_DUAL"]) - Cp (float): Coefficient for positive class in the loss function (default 1.0) - Cn (float): Coefficient for negative class in the loss function (default 1.0) - bias (float): bias for the ranking model (default=1.0) - threshold (float): Threshold to sparsify the model weights (default 0.1) - negative_sampling (str, choices=[tfn, man, tfn+man]): Negative Sampling Schemes (default tfn) - pred_kwargs (dict): kwargs for prediction used in matching-aware training - only_topk (int): the default number of top labels used in the prediction - beam_size (int): the default size of beam search used in the prediction - post_processor (str): the default post processor used in the prediction - rel_mode (bool, optional): mode to use relevance score for cost sensitive learning - 'disable': do not use cost-sensitive learning (default) - 'induce': induce relevance matrix into relvance chain by label aggregation. - Use all 1.0 if relevance score is not provided. - 'ranker-only': only use cost-sensitive learning for ranker. - rel_norm (str): norm type to row-wise normalzie relevance matrix for cost-sensitive learning + vectorizer_config_json (str): Json_format string for vectorizer config (default None) + train_params (Text2Text.TrainParams): params to train Text2Text model + pred_params (Text2Text.PredParams): params to predict Text2Text model workspace_folder: (str, default=None): A folder name for storing intermediate variables during training + kwargs: + {"beam_size": INT, "only_topk": INT, "post_processor": STR}, + Default None to use HierarchicalMLModel.PredParams defaults Returns: A Text2Text object """ ws = CachedWorkspace(workspace_folder) + dtype = np.float32 # Train Preprocessor and obtain X, Y XY_kwargs = dict( @@ -293,157 +301,98 @@ def train( if R is not None: LOGGER.info(f"Relevance matrix R loaded, cost sensitive learning enabled.") - # Grid Parameters for XLinearModel - ranker_param_names = [ - "bias", - "Cp", - "Cn", - "solver_type", - "threshold", - "negative_sampling_scheme", - "pred_kwargs", - "rel_mode", - "rel_norm", - ] - - ranker_grid_params = {} - for name in ranker_param_names: - tmp = locals()[name] - ranker_grid_params[name] = tmp if isinstance(tmp, (list, tuple)) else [tmp] - - indexer_param_names = [ - "indexer_algo", - "imbalanced_ratio", - "imbalanced_depth", - "spherical", - "seed", - "max_iter", - "max_leaf_size", - "nr_splits", - "label_embed_type", - ] - - indexer_grid_params = {} - for name in indexer_param_names: - tmp = locals()[name] - indexer_grid_params[name] = tmp if isinstance(tmp, (list, tuple)) else [tmp] - - # Generate various label features - label_feat_set = {} - for embed_type in indexer_grid_params["label_embed_type"]: - label_embed_kwargs = dict( - input_text_path=input_text_path, - output_text_path=output_text_path, - dtype=str(dtype), - vectorizer_config=vectorizer_config, - embed_type=embed_type, - ) - label_embed_path = ws.get_path_for_name_and_kwargs("L", label_embed_kwargs) - if path.exists(label_embed_path): - LOGGER.info(f"Loading existing {embed_type} features for {Y.shape[1]} labels...") - label_feat_set[embed_type] = XLinearModel.load_feature_matrix(label_embed_path) - else: - LOGGER.info(f"Generating {embed_type} features for {Y.shape[1]} labels...") - # parse embed_type string, expect either the following three cases: - # (1) pifa - # (2) pifa_lf_concat::Z=path - # (3) pifa_lf_convex_combine::Z=path::alpha=value - lemb_key_val_list = embed_type.split("::") - lemb_type = lemb_key_val_list[0] - lemb_kwargs = {} - for key_val_str in lemb_key_val_list[1:]: - key, val = key_val_str.split("=") - if key == "Z": - Z = smat_util.load_matrix(val) - lemb_kwargs.update({"Z": Z}) - elif key == "alpha": - alpha = float(val) - lemb_kwargs.update({"alpha": alpha}) - else: - raise ValueError(f"key={key}, val={val} is not supported!") - if "lf" in lemb_type and lemb_kwargs.get("Z", None) is None: - raise ValueError( - "pifa_lf_concat/pifa_lf_convex_combine must provide external path for Z." - ) - # Create label features - label_feat_set[embed_type] = LabelEmbeddingFactory.create( - Y, - X, - method=lemb_type, - **lemb_kwargs, - ) - XLinearModel.save_feature_matrix(label_embed_path, label_feat_set[embed_type]) - - for indexer_values in itertools.product( - *[indexer_grid_params[k] for k in indexer_param_names] - ): - # Indexing - indexer_kwargs = dict(zip(indexer_param_names, indexer_values)) - indexer_kwargs_local = indexer_kwargs.copy() - C_path = ws.get_path_for_name_and_kwargs("C", indexer_kwargs_local) - if path.exists(C_path): - LOGGER.info(f"Loading existing clustering code with params {indexer_kwargs_local}") - C = ClusterChain.load(C_path) - else: - label_embed_type = indexer_kwargs.pop( - "label_embed_type", None - ) # as label_embed_type is not a valid argument for XLinearModel.train - LOGGER.info(f"Clustering with params {indexer_kwargs_local}...") - C = Indexer.gen( - label_feat_set[indexer_kwargs_local["label_embed_type"]], - indexer_kwargs.pop("indexer_algo"), - threads=threads, - **indexer_kwargs, - ) - LOGGER.info(f"Created {C[-1].shape[1]} clusters.") - C.save(C_path) - - # Ensemble Models - for ranker_values in itertools.product( - *[ranker_grid_params[k] for k in ranker_param_names] - ): - ranker_kwargs = dict(zip(ranker_param_names, ranker_values)) - ranker_kwargs_local = ranker_kwargs.copy() - # Model Training - ranker_kwargs_local.update(indexer_kwargs_local) - - model_path = ws.get_path_for_name_and_kwargs("model", ranker_kwargs_local) - if path.exists(model_path): - LOGGER.info(f"Model with params {ranker_kwargs_local} exists") + # construct indexing, training and prediction params + if train_params is None: + # fill all BaseParams class with their default value + train_params = cls.TrainParams.from_dict(dict(), recursive=True) + else: + train_params = cls.TrainParams.from_dict(train_params) + + # construct pred_params + if pred_params is None: + # fill all BaseParams with their default value + pred_params = cls.PredParams.from_dict(dict(), recursive=True) + else: + pred_params = cls.PredParams.from_dict(pred_params) + pred_params = pred_params.override_with_kwargs(kwargs) + + # 1. Generate label features + label_embed_kwargs = dict( + input_text_path=input_text_path, + output_text_path=output_text_path, + dtype=str(dtype), + vectorizer_config=vectorizer_config, + embed_type=label_embed_type, + ) + label_embed_path = ws.get_path_for_name_and_kwargs("L", label_embed_kwargs) + if path.exists(label_embed_path): + LOGGER.info(f"Loading existing {label_embed_type} features for {Y.shape[1]} labels...") + label_feat = XLinearModel.load_feature_matrix(label_embed_path) + else: + LOGGER.info(f"Generating {label_embed_type} features for {Y.shape[1]} labels...") + # parse embed_type string, expect either the following three cases: + # (1) pifa + # (2) pifa_lf_concat::Z=path + # (3) pifa_lf_convex_combine::Z=path::alpha=value + lemb_key_val_list = label_embed_type.split("::") + lemb_type = lemb_key_val_list[0] + lemb_kwargs = {} + for key_val_str in lemb_key_val_list[1:]: + key, val = key_val_str.split("=") + if key == "Z": + Z = smat_util.load_matrix(val) + lemb_kwargs.update({"Z": Z}) + elif key == "alpha": + alpha = float(val) + lemb_kwargs.update({"alpha": alpha}) else: - LOGGER.info(f"Training model with params {ranker_kwargs_local}...") - m = XLinearModel.train( - X, - Y, - C=C, - R=R, - threads=threads, - **ranker_kwargs, - ) - m.save(model_path) - del m - gc.collect() - - del C - gc.collect() - - del X, Y, R, label_feat_set + raise ValueError(f"key={key}, val={val} is not supported!") + if "lf" in lemb_type and lemb_kwargs.get("Z", None) is None: + raise ValueError( + "pifa_lf_concat/pifa_lf_convex_combine must provide external path for Z." + ) + # Create label features + label_feat = LabelEmbeddingFactory.create( + Y, + X, + method=lemb_type, + **lemb_kwargs, + ) + XLinearModel.save_feature_matrix(label_embed_path, label_feat) + + # 2. Indexing + indexer_kwargs_dict = train_params.indexer_params.to_dict() + C_path = ws.get_path_for_name_and_kwargs("C", indexer_kwargs_dict) + if path.exists(C_path): + LOGGER.info(f"Loading existing clustering code with params {indexer_kwargs_dict}") + C = ClusterChain.load(C_path) + else: + LOGGER.info(f"Clustering with params: {json.dumps(indexer_kwargs_dict, indent=True)}") + C = Indexer.gen(label_feat, train_params=train_params.indexer_params) + LOGGER.info("Hierarchical label tree: {}".format([cc.shape[0] for cc in C])) + C.save(C_path) + + del label_feat gc.collect() - xlinear_models = [] - for indexer_values in itertools.product( - *[indexer_grid_params[k] for k in indexer_param_names] - ): - indexer_kwargs = dict(zip(indexer_param_names, indexer_values)) - indexer_kwargs_local = indexer_kwargs.copy() - for ranker_values in itertools.product( - *[ranker_grid_params[k] for k in ranker_param_names] - ): - ranker_kwargs = dict(zip(ranker_param_names, ranker_values)) - ranker_kwargs_local = ranker_kwargs.copy() - ranker_kwargs_local.update(indexer_kwargs_local) - model_path = ws.get_path_for_name_and_kwargs("model", ranker_kwargs_local) - xlinear_models += [(XLinearModel.load(model_path), ranker_kwargs_local)] + # Ensemble Models + LOGGER.info( + f"Training model with train params: {json.dumps(train_params.xlinear_params.to_dict(), indent=True)}" + ) + LOGGER.info( + f"Training model with pred params: {json.dumps(pred_params.xlinear_params.to_dict(), indent=True)}" + ) + m = XLinearModel.train( + X, + Y, + C=C, + R=R, + train_params=train_params.xlinear_params, + pred_params=pred_params.xlinear_params, + **kwargs, + ) + + xlinear_models = [[m, train_params.to_dict()]] # Load output items with open(output_text_path, "r", encoding="utf-8") as f: @@ -451,22 +400,18 @@ def train( return cls(preprocessor, xlinear_models, output_items) - def predict( - self, corpus, topk=10, beam_size=None, post_processor=None, threshold=None, **kwargs - ): + def predict(self, corpus, threshold=None, **kwargs): """Predict labels for given inputs Args: corpus (list of strings): input strings. - topk (int, optional): override the only topk specified in the model - Default None to disable overriding - beam_size (int, optional): override the beam size specified in the model - Default None to disable overriding - post_processor (str, optional): override the post_processor specified in the model - Default None to disable overriding threshold (float, optional): Drop output items with scores less than this threshold among top-k items Default None to not threshold kwargs: + only_topk (int, optional): override the only topk specified in the model + Default None to disable overriding + beam_size (int, optional): override the beam size specified in the model + Default None to disable overriding post_processor (str, optional): override the post_processor specified in the model Default None to disable overriding threads (int, optional): the number of threads to use for predicting. @@ -476,13 +421,7 @@ def predict( """ X = self.preprocessor.predict(corpus) - - Y_pred = [ - m.predict( - X, only_topk=topk, beam_size=beam_size, post_processor=post_processor, **kwargs - ) - for m, _ in self.xlinear_models - ] + Y_pred = [m.predict(X, **kwargs) for m, _ in self.xlinear_models] if len(Y_pred) > 1: Y_pred = smat_util.CsrEnsembler.average(*Y_pred) @@ -493,7 +432,7 @@ def predict( Y_pred.data[Y_pred.data <= threshold] = 0 Y_pred.eliminate_zeros() - return smat_util.sorted_csr(Y_pred, topk) + return smat_util.sorted_csr(Y_pred, only_topk=kwargs.get("only_topk", None)) def set_output_constraint(self, output_items_to_keep): """Prune the tree diff --git a/pecos/apps/text2text/train.py b/pecos/apps/text2text/train.py index 162614e3..c3cc8067 100644 --- a/pecos/apps/text2text/train.py +++ b/pecos/apps/text2text/train.py @@ -11,12 +11,12 @@ import argparse import logging import sys +import json -import numpy as np from pecos.core import XLINEAR_SOLVERS from pecos.utils import cli, logging_util from pecos.utils.featurization.text.vectorizers import Vectorizer -from pecos.xmc import Indexer, PostProcessor +from pecos.xmc import PostProcessor from .model import Text2Text @@ -27,12 +27,42 @@ def parse_arguments(args): parser = argparse.ArgumentParser( description="Text2Text: Read input text training files, output item files and train a model" ) + parser.add_argument( + "--generate-train-params-skeleton", + action="store_true", + help="generate template train-params-json to stdout", + ) + parser.add_argument( + "--generate-pred-params-skeleton", + action="store_true", + help="generate template pred-params-json to stdout", + ) + + gen_train_params = "--generate-train-params-skeleton" in sys.argv + gen_pred_params = "--generate-pred-params-skeleton" in sys.argv + skip_training = gen_train_params or gen_pred_params + # ========= parameter jsons ============ + parser.add_argument( + "--train-params-path", + type=str, + default=None, + metavar="TRAIN_PARAMS_PATH", + help="Json file for train_params (default None)", + ) + parser.add_argument( + "--pred-params-path", + type=str, + default=None, + metavar="PRED_PARAMS_PATH", + help="Json file for pred_params (default None)", + ) + # ======= actual arguments ======== parser.add_argument( "-i", "--input-text-path", type=str, - required=True, + required=not skip_training, metavar="INPUT_TEXT_PATH", help="Text input file name. Format: in each line, OUTPUT_ID1,OUTPUT_ID2,OUTPUT_ID3,...\t INPUT_TEXT \ where OUTPUT_IDs are the zero-based output item indices corresponding to the line numbers of OUTPUT_ITEM_PATH. We assume utf-8 encoding for text.", @@ -42,7 +72,7 @@ def parse_arguments(args): "-q", "--output-item-path", type=str, - required=True, + required=not skip_training, metavar="OUTPUT_ITEM_PATH", help="Output item file name. Format: each line corresponds to a representation of the output item. We assume utf-8 encoding for text.", ) @@ -51,7 +81,7 @@ def parse_arguments(args): "-m", "--model-folder", type=str, - required=True, + required=not skip_training, metavar="MODEL_FOLDER", help="Output model folder name", ) @@ -81,19 +111,12 @@ def parse_arguments(args): help='Json-format string for vectorizer config (default {"type":"tfidf", "kwargs":{}})', ) - parser.add_argument( - "--dtype", - type=lambda x: np.float32 if "32" in x else np.float64, - default=np.float32, - help="data type for the csr matrix. float32 | float64. (default float32)", - ) - parser.add_argument( "--max-leaf-size", - type=cli.comma_separated_type(int), - default=[100], - metavar="INT-LIST", - help="The max size of the leaf nodes of hierarchical 2-means clustering. Multiple values (separated by comma) are supported and will lead to different individual models for ensembling. (default [100])", + type=int, + default=100, + metavar="INT", + help="The max size of the leaf nodes of hierarchical 2-means clustering. (default 100)", ) parser.add_argument( @@ -122,21 +145,12 @@ def parse_arguments(args): parser.add_argument( "--label-embed-type", - type=cli.comma_separated_type(str), + type=str, default="pifa", - metavar="STR-LIST", + metavar="STR", help="Label embedding types. (default pifa).\ We support pifa, pifa_lf_concat::Z=path, and pifa_lf_convex_combine::Z=path::alpha=scalar_value,\ - where path is the additional user-porivded label embedding path and alpha is the scalar value for convex combination.\ - Multiple values (separated by comma) are supported and will lead to different individual models for ensembling.", - ) - - parser.add_argument( - "--indexer", - choices=Indexer.indexer_dict.keys(), - default="hierarchicalkmeans", - metavar="STR", - help=f"Indexer algorithm (default hierarchicalkmeans). Available choices are {', '.join(Indexer.indexer_dict.keys())}", + where path is the additional user-porivded label embedding path and alpha is the scalar value for convex combination.", ) parser.add_argument( @@ -149,10 +163,10 @@ def parse_arguments(args): parser.add_argument( "--seed", - type=cli.comma_separated_type(int), - default=[0], - metavar="INT-LIST", - help="Random seeds (default 0). Multiple values (separated by comma) are supported and will lead to different individual models for ensembling.", + type=int, + default=0, + metavar="INT", + help="Random seeds (default 0).", ) parser.add_argument( @@ -286,6 +300,33 @@ def train(args): Args: args (argparse.Namespace): Command line arguments parsed by `parser.parse_args()` """ + if args.generate_train_params_skeleton: + train_params = Text2Text.TrainParams.from_dict({}, recursive=True) + print(f"{json.dumps(train_params.to_dict(), indent=True)}") + return + + if args.generate_pred_params_skeleton: + pred_params = Text2Text.PredParams.from_dict({}, recursive=True) + print(f"{json.dumps(pred_params.to_dict(), indent=True)}") + return + + if args.train_params_path: + with open(args.train_params_path, "r") as fin: + train_params = Text2Text.TrainParams.from_dict(json.load(fin)) + else: + train_params = Text2Text.TrainParams.from_dict( + {k: v for k, v in vars(args).items() if v is not None}, + recursive=True, + ) + + if args.pred_params_path: + with open(args.pred_params_path, "r") as fin: + pred_params = Text2Text.PredParams.from_dict(json.load(fin)) + else: + pred_params = Text2Text.PredParams.from_dict( + {k: v for k, v in vars(args).items() if v is not None}, + recursive=True, + ) pred_kwargs = { "beam_size": args.beam_size, @@ -299,27 +340,11 @@ def train(args): args.input_text_path, args.output_item_path, label_embed_type=args.label_embed_type, - max_leaf_size=args.max_leaf_size, - nr_splits=args.nr_splits, vectorizer_config=vectorizer_config, - dtype=args.dtype, - indexer_algo=[args.indexer], - imbalanced_ratio=args.imbalanced_ratio, - imbalanced_depth=args.imbalanced_depth, - spherical=args.spherical, - seed=args.seed, - max_iter=args.max_iter, - threads=args.threads, - solver_type=args.solver_type, - Cp=args.Cp, - Cn=args.Cn, - bias=args.bias, - threshold=args.threshold, - negative_sampling_scheme=args.negative_sampling, - pred_kwargs=pred_kwargs, - rel_mode=args.rel_mode, - rel_norm=args.rel_norm, + train_params=train_params, + pred_params=pred_params, workspace_folder=args.workspace_folder, + **pred_kwargs, ) t2t_model.save(args.model_folder) diff --git a/pecos/core/third_party/robin_hood_hashing/robin_hood.h b/pecos/core/third_party/robin_hood_hashing/robin_hood.h index 31adf080..511a308d 100644 --- a/pecos/core/third_party/robin_hood_hashing/robin_hood.h +++ b/pecos/core/third_party/robin_hood_hashing/robin_hood.h @@ -10,7 +10,7 @@ // // Licensed under the MIT License . // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2020 Martin Ankerl +// Copyright (c) 2018-2021 Martin Ankerl // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -34,14 +34,15 @@ #define ROBIN_HOOD_H_INCLUDED // see https://semver.org/ -#define ROBIN_HOOD_VERSION_MAJOR 3 // for incompatible API changes -#define ROBIN_HOOD_VERSION_MINOR 9 // for adding functionality in a backwards-compatible manner -#define ROBIN_HOOD_VERSION_PATCH 0 // for backwards-compatible bug fixes +#define ROBIN_HOOD_VERSION_MAJOR 3 // for incompatible API changes +#define ROBIN_HOOD_VERSION_MINOR 11 // for adding functionality in a backwards-compatible manner +#define ROBIN_HOOD_VERSION_PATCH 3 // for backwards-compatible bug fixes #include #include #include #include +#include #include // only to support hash of smart pointers #include #include @@ -50,14 +51,12 @@ #if __cplusplus >= 201703L # include #endif -#if defined(__aarch64__) -# include // for getauxval -#endif // #define ROBIN_HOOD_LOG_ENABLED #ifdef ROBIN_HOOD_LOG_ENABLED # include -# define ROBIN_HOOD_LOG(x) std::cout << __FUNCTION__ << "@" << __LINE__ << ": " << x << std::endl +# define ROBIN_HOOD_LOG(...) \ + std::cout << __FUNCTION__ << "@" << __LINE__ << ": " << __VA_ARGS__ << std::endl; #else # define ROBIN_HOOD_LOG(x) #endif @@ -65,8 +64,8 @@ // #define ROBIN_HOOD_TRACE_ENABLED #ifdef ROBIN_HOOD_TRACE_ENABLED # include -# define ROBIN_HOOD_TRACE(x) \ - std::cout << __FUNCTION__ << "@" << __LINE__ << ": " << x << std::endl +# define ROBIN_HOOD_TRACE(...) \ + std::cout << __FUNCTION__ << "@" << __LINE__ << ": " << __VA_ARGS__ << std::endl; #else # define ROBIN_HOOD_TRACE(x) #endif @@ -194,6 +193,17 @@ static Counts& counts() { # define ROBIN_HOOD_PRIVATE_DEFINITION_HAS_NATIVE_WCHART() 1 #endif +// detect if MSVC supports the pair(std::piecewise_construct_t,...) consructor being constexpr +#ifdef _MSC_VER +# if _MSC_VER <= 1900 +# define ROBIN_HOOD_PRIVATE_DEFINITION_BROKEN_CONSTEXPR() 1 +# else +# define ROBIN_HOOD_PRIVATE_DEFINITION_BROKEN_CONSTEXPR() 0 +# endif +#else +# define ROBIN_HOOD_PRIVATE_DEFINITION_BROKEN_CONSTEXPR() 0 +#endif + // workaround missing "is_trivially_copyable" in g++ < 5.0 // See https://stackoverflow.com/a/31798726/48181 #if defined(__GNUC__) && __GNUC__ < 5 @@ -215,43 +225,6 @@ static Counts& counts() { # define ROBIN_HOOD_PRIVATE_DEFINITION_NODISCARD() #endif -// detect hardware CRC availability. -#if !defined(ROBIN_HOOD_DISABLE_INTRINSICS) -// only use CRC for 64bit targets -# if ROBIN_HOOD(BITNESS) == 64 && \ - (defined(__SSE4_2__) || defined(__ARM_FEATURE_CRC32) || defined(_MSC_VER)) -# define ROBIN_HOOD_PRIVATE_DEFINITION_HAS_CRC32() 1 -# if defined(__ARM_NEON) || defined(__ARM_NEON__) || defined(_M_ARM64) -# ifdef _M_ARM64 -# include -# else -# include -# endif - -# define ROBIN_HOOD_CRC32_64(crc, v) \ - static_cast( \ - __crc32cd(static_cast(crc), static_cast(v))) -# define ROBIN_HOOD_CRC32_32(crc, v) \ - __crc32cw(static_cast(crc), static_cast(v)) -# else -# include -# define ROBIN_HOOD_CRC32_64(crc, v) \ - static_cast( \ - _mm_crc32_u64(static_cast(crc), static_cast(v))) -# define ROBIN_HOOD_CRC32_32(crc, v) \ - _mm_crc32_u32(static_cast(crc), static_cast(v)) -# endif -# else -# define ROBIN_HOOD_PRIVATE_DEFINITION_HAS_CRC32() 0 -# endif - -# if defined(_MSC_VER) -# include -# endif -#else -# define ROBIN_HOOD_PRIVATE_DEFINITION_HAS_CRC32() 0 -#endif - namespace robin_hood { #if ROBIN_HOOD(CXX) >= ROBIN_HOOD(CXX14) @@ -358,14 +331,14 @@ inline T reinterpret_cast_no_cast_align_warning(void const* ptr) noexcept { // make sure this is not inlined as it is slow and dramatically enlarges code, thus making other // inlinings more difficult. Throws are also generally the slow path. template -ROBIN_HOOD(NOINLINE) +[[noreturn]] ROBIN_HOOD(NOINLINE) #if ROBIN_HOOD(HAS_EXCEPTIONS) -void doThrow(Args&&... args) { + void doThrow(Args&&... args) { // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay) throw E(std::forward(args)...); } #else -void doThrow(Args&&... ROBIN_HOOD_UNUSED(args) /*unused*/) { + void doThrow(Args&&... ROBIN_HOOD_UNUSED(args) /*unused*/) { abort(); } #endif @@ -431,6 +404,7 @@ class BulkPoolAllocator { void reset() noexcept { while (mListForFree) { T* tmp = *mListForFree; + ROBIN_HOOD_LOG("std::free") std::free(mListForFree); mListForFree = reinterpret_cast_no_cast_align_warning(tmp); } @@ -466,8 +440,10 @@ class BulkPoolAllocator { // calculate number of available elements in ptr if (numBytes < ALIGNMENT + ALIGNED_SIZE) { // not enough data for at least one element. Free and return. + ROBIN_HOOD_LOG("std::free") std::free(ptr); } else { + ROBIN_HOOD_LOG("add to buffer") add(ptr, numBytes); } } @@ -531,8 +507,9 @@ class BulkPoolAllocator { size_t const numElementsToAlloc = calcNumElementsToAlloc(); // alloc new memory: [prev |T, T, ... T] - // std::cout << (sizeof(T*) + ALIGNED_SIZE * numElementsToAlloc) << " bytes" << std::endl; size_t const bytes = ALIGNMENT + ALIGNED_SIZE * numElementsToAlloc; + ROBIN_HOOD_LOG("std::malloc " << bytes << " = " << ALIGNMENT << " + " << ALIGNED_SIZE + << " * " << numElementsToAlloc) add(assertNotNull(std::malloc(bytes)), bytes); return mHead; } @@ -569,6 +546,7 @@ struct NodeAllocator { // we are not using the data, so just free it. void addOrFree(void* ptr, size_t ROBIN_HOOD_UNUSED(numBytes) /*unused*/) noexcept { + ROBIN_HOOD_LOG("std::free") std::free(ptr); } }; @@ -576,14 +554,6 @@ struct NodeAllocator { template struct NodeAllocator : public BulkPoolAllocator {}; -// dummy hash, unsed as mixer when robin_hood::hash is already used -template -struct identity_hash { - constexpr size_t operator()(T const& obj) const noexcept { - return static_cast(obj); - } -}; - // c++14 doesn't have is_nothrow_swappable, and clang++ 6.0.1 doesn't like it either, so I'm making // my own here. namespace swappable { @@ -644,14 +614,20 @@ struct pair { , second(std::forward(b)) {} template - constexpr pair( - std::piecewise_construct_t /*unused*/, std::tuple a, - std::tuple b) noexcept(noexcept(pair(std::declval&>(), - std::declval&>(), - ROBIN_HOOD_STD::index_sequence_for(), - ROBIN_HOOD_STD::index_sequence_for()))) + // MSVC 2015 produces error "C2476: ‘constexpr’ constructor does not initialize all members" + // if this constructor is constexpr +#if !ROBIN_HOOD(BROKEN_CONSTEXPR) + constexpr +#endif + pair(std::piecewise_construct_t /*unused*/, std::tuple a, + std::tuple + b) noexcept(noexcept(pair(std::declval&>(), + std::declval&>(), + ROBIN_HOOD_STD::index_sequence_for(), + ROBIN_HOOD_STD::index_sequence_for()))) : pair(a, b, ROBIN_HOOD_STD::index_sequence_for(), - ROBIN_HOOD_STD::index_sequence_for()) {} + ROBIN_HOOD_STD::index_sequence_for()) { + } // constructor called from the std::piecewise_construct_t ctor template @@ -713,23 +689,7 @@ inline constexpr bool operator>=(pair const& x, pair const& y) { return !(x < y); } -namespace detail { - -static size_t fallback_hash_int(uint64_t x) noexcept { - // inspired by lemire's strongly universal hashing - // https://lemire.me/blog/2018/08/15/fast-strongly-universal-64-bit-hashing-everywhere/ - // - // Instead of shifts, we use rotations so we don't lose any bits. - // - // Added a final multiplcation with a constant for more mixing. It is most important that - // the lower bits are well mixed. - auto h1 = x * UINT64_C(0xA24BAED4963EE407); - auto h2 = detail::rotr(x, 32U) * UINT64_C(0x9FB21C651E98DF25); - auto h = detail::rotr(h1 + h2, 32U); - return static_cast(h); -} - -static size_t fallback_hash_bytes(void const* ptr, size_t const len) noexcept { +inline size_t hash_bytes(void const* ptr, size_t len) noexcept { static constexpr uint64_t m = UINT64_C(0xc6a4a7935bd1e995); static constexpr uint64_t seed = UINT64_C(0xe17a1465); static constexpr unsigned int r = 47; @@ -778,242 +738,24 @@ static size_t fallback_hash_bytes(void const* ptr, size_t const len) noexcept { } h ^= h >> r; - h *= m; - h ^= h >> r; - return static_cast(h); -} - -#if ROBIN_HOOD(HAS_CRC32) - -# ifndef _M_ARM64 -// see e.g. -// https://github.com/simdjson/simdjson/blob/9863f62321f59d73c7731d4ada2d7c4ed6a0a251/src/isadetection.h -static inline void cpuid(uint32_t* eax, uint32_t* ebx, uint32_t* ecx, uint32_t* edx) { -# if defined(_MSC_VER) - int cpuInfo[4]; - __cpuid(cpuInfo, static_cast(*eax)); - *eax = static_cast(cpuInfo[0]); - *ebx = static_cast(cpuInfo[1]); - *ecx = static_cast(cpuInfo[2]); - *edx = static_cast(cpuInfo[3]); -# else - uint32_t a = *eax; - uint32_t b{}; - uint32_t c = *ecx; - uint32_t d{}; - // NOLINTNEXTLINE(hicpp-no-assembler) - asm volatile("cpuid\n\t" : "+a"(a), "=b"(b), "+c"(c), "=d"(d)); - *eax = a; - *ebx = b; - *ecx = c; - *edx = d; -# endif -} -# endif - -inline bool hasCrc32Support() noexcept { -# if defined(__x86_64__) || defined(__x86_64) || defined(_M_X64) || defined(_M_AMD64) - uint32_t eax{}; - uint32_t ebx{}; - uint32_t ecx{}; - uint32_t edx{}; - - // EBX for EAX=0x1 - eax = 0x1; - cpuid(&eax, &ebx, &ecx, &edx); - - // check SSE4.2 - return 0U != (ecx & (1U << 20U)); -# elif defined(__aarch64__) - auto hwcap = getauxval(AT_HWCAP); - if (hwcap != ENOENT) { - // HWCAP_CRC32 is not necessarily defined, so hardcode it. - // see https://github.com/torvalds/linux/blob/master/arch/arm64/include/uapi/asm/hwcap.h - return (hwcap & (1U << 7U)) != 0; - } -# elif defined(_M_ARM64) - return true; -# endif - return false; -} - -inline size_t hash_bytes_1_to_16(void const* ptr, size_t len, uint64_t seed) noexcept { - // random odd 64bit constants - static constexpr uint64_t c1 = UINT64_C(0x38a3affe8230452c); - static constexpr uint64_t c2 = UINT64_C(0xd55c04dccfde5383); - - auto const* d8 = reinterpret_cast(ptr); - - if (len > 8) { - // 9-16 bytes - auto h1 = ROBIN_HOOD_CRC32_64(seed, detail::unaligned_load(d8)); - auto h2 = ROBIN_HOOD_CRC32_64(seed, detail::unaligned_load(d8 + len - 8)); - return h1 * c1 + h2 * c2; - } - - uint64_t input{}; - if (len <= 4) { - uint64_t a = d8[0]; // 0, 0, 0, 0 - uint64_t b = d8[(len - 1) / 2]; // 0, 0, 1, 1 - uint64_t c = d8[len / 2]; // 0, 1, 1, 2 - uint64_t d = d8[len - 1]; // 0, 1, 2, 3 - input = (a << 24U) | (b << 16U) | (c << 8U) | d; - } else { - // 5-8 bytes - uint64_t a = detail::unaligned_load(d8); - uint64_t b = detail::unaligned_load(d8 + len - 4); - input = (a << 32U) | b; - } - return ROBIN_HOOD_CRC32_64(seed, input) * c1; -} -inline size_t hash_bytes_8_to_xxx(void const* ptr, size_t len, uint64_t seed) { - auto const* d8 = reinterpret_cast(ptr); - - static constexpr auto bs = 128U; - uint64_t h1 = seed; - uint64_t h2 = seed; - uint64_t h3 = seed; - uint64_t h4 = seed; - - auto next = d8; - auto numBlocks = (len - 1) / bs; - auto end = d8 + numBlocks * bs; - while (next != end) { - h1 = ROBIN_HOOD_CRC32_64(h1, detail::unaligned_load(next + 0U)); - h2 = ROBIN_HOOD_CRC32_64(h2, detail::unaligned_load(next + 8U)); - h3 = ROBIN_HOOD_CRC32_64(h3, detail::unaligned_load(next + 16U)); - h4 = ROBIN_HOOD_CRC32_64(h4, detail::unaligned_load(next + 24U)); - - h1 = ROBIN_HOOD_CRC32_64(h1, detail::unaligned_load(next + 32U + 0U)); - h2 = ROBIN_HOOD_CRC32_64(h2, detail::unaligned_load(next + 32U + 8U)); - h3 = ROBIN_HOOD_CRC32_64(h3, detail::unaligned_load(next + 32U + 16U)); - h4 = ROBIN_HOOD_CRC32_64(h4, detail::unaligned_load(next + 32U + 24U)); - - h1 = ROBIN_HOOD_CRC32_64(h1, detail::unaligned_load(next + 64U + 0U)); - h2 = ROBIN_HOOD_CRC32_64(h2, detail::unaligned_load(next + 64U + 8U)); - h3 = ROBIN_HOOD_CRC32_64(h3, detail::unaligned_load(next + 64U + 16U)); - h4 = ROBIN_HOOD_CRC32_64(h4, detail::unaligned_load(next + 64U + 24U)); - - h1 = ROBIN_HOOD_CRC32_64(h1, detail::unaligned_load(next + 96U + 0U)); - h2 = ROBIN_HOOD_CRC32_64(h2, detail::unaligned_load(next + 96U + 8U)); - h3 = ROBIN_HOOD_CRC32_64(h3, detail::unaligned_load(next + 96U + 16U)); - h4 = ROBIN_HOOD_CRC32_64(h4, detail::unaligned_load(next + 96U + 24U)); - - next += bs; - } - - auto remainingBytes = len - (numBlocks * bs); - - auto numBlocks8 = (remainingBytes + 7U) / 8U; - end += numBlocks8 * 8; - switch (numBlocks8) { - case 16: - h1 = ROBIN_HOOD_CRC32_64(h1, detail::unaligned_load(end - 128U)); - ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH - case 15: - h2 = ROBIN_HOOD_CRC32_64(h2, detail::unaligned_load(end - 120U)); - ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH - case 14: - h3 = ROBIN_HOOD_CRC32_64(h3, detail::unaligned_load(end - 112U)); - ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH - case 13: - h4 = ROBIN_HOOD_CRC32_64(h4, detail::unaligned_load(end - 104U)); - ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH - case 12: - h1 = ROBIN_HOOD_CRC32_64(h1, detail::unaligned_load(end - 96U)); - ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH - case 11: - h2 = ROBIN_HOOD_CRC32_64(h2, detail::unaligned_load(end - 88U)); - ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH - case 10: - h3 = ROBIN_HOOD_CRC32_64(h3, detail::unaligned_load(end - 80U)); - ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH - case 9: - h4 = ROBIN_HOOD_CRC32_64(h4, detail::unaligned_load(end - 72U)); - ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH - case 8: - h1 = ROBIN_HOOD_CRC32_64(h1, detail::unaligned_load(end - 64U)); - ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH - case 7: - h2 = ROBIN_HOOD_CRC32_64(h2, detail::unaligned_load(end - 56U)); - ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH - case 6: - h3 = ROBIN_HOOD_CRC32_64(h3, detail::unaligned_load(end - 48U)); - ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH - case 5: - h4 = ROBIN_HOOD_CRC32_64(h4, detail::unaligned_load(end - 40U)); - ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH - case 4: - h1 = ROBIN_HOOD_CRC32_64(h1, detail::unaligned_load(end - 32U)); - ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH - case 3: - h2 = ROBIN_HOOD_CRC32_64(h2, detail::unaligned_load(end - 24U)); - ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH - case 2: - h3 = ROBIN_HOOD_CRC32_64(h3, detail::unaligned_load(end - 16U)); - ROBIN_HOOD(FALLTHROUGH); // FALLTHROUGH - - default: - // make sure that we don't skip past the real length with the last one - h4 = ROBIN_HOOD_CRC32_64(h4, detail::unaligned_load(d8 + len - 8U)); - break; - } - - // how best to combine h1 to h4? Multiplying and summing with a random odd number seems to be - // very fast and leads to few collisions. - return h1 * 0x38a3affe8230452c + h2 * 0xd55c04dccfde5383 + h3 * 0xd348c89fbf80760f + - h4 * 0xdc105301318a46f3; -} - -#endif - -} // namespace detail - -inline size_t hash_bytes(void const* ptr, size_t len) noexcept { - if (len == 0) { - return 0; - } -#if ROBIN_HOOD(HAS_CRC32) && ROBIN_HOOD(BITNESS) == 64 - static bool const hasCrc = detail::hasCrc32Support(); - if (ROBIN_HOOD_LIKELY(hasCrc)) { - auto seed = len * UINT64_C(0xf012a09363e97a8f); - if (len <= 16U) { - return detail::hash_bytes_1_to_16(ptr, len, seed); - } - - return detail::hash_bytes_8_to_xxx(ptr, len, seed); - } -#endif - return detail::fallback_hash_bytes(ptr, len); + // not doing the final step here, because this will be done by keyToIdx anyways + // h *= m; + // h ^= h >> r; + return static_cast(h); } inline size_t hash_int(uint64_t x) noexcept { -#if ROBIN_HOOD(HAS_CRC32) - static bool const hasCrc = detail::hasCrc32Support(); - if (ROBIN_HOOD_LIKELY(hasCrc)) { -# if ROBIN_HOOD(BITNESS) == 64 - // rotr 32 results in bad hash, when hash_int is applied twice. - return ROBIN_HOOD_CRC32_64(0, x ^ UINT64_C(0xA24BAED4963EE407)) ^ - (ROBIN_HOOD_CRC32_64(0, x) << 32U); -# else - return ROBIN_HOOD_CRC32_32(ROBIN_HOOD_CRC32_32(0, static_cast(x)), - static_cast(x >> 32U)); -# endif - } -#endif - return detail::fallback_hash_int(x); -} - -inline size_t hash_int(uint32_t x) noexcept { -#if ROBIN_HOOD(HAS_CRC32) - static bool const hasCrc = detail::hasCrc32Support(); - if (ROBIN_HOOD_LIKELY(hasCrc)) { - // rotr 32 results in bad hash, when hash_int is applied twice. - return ROBIN_HOOD_CRC32_32(0, x); - } -#endif - return detail::fallback_hash_int(x); + // tried lots of different hashes, let's stick with murmurhash3. It's simple, fast, well tested, + // and doesn't need any special 128bit operations. + x ^= x >> 33U; + x *= UINT64_C(0xff51afd7ed558ccd); + x ^= x >> 33U; + + // not doing the final step here, because this will be done by keyToIdx anyways + // x *= UINT64_C(0xc4ceb9fe1a85ec53); + // x ^= x >> 33U; + return static_cast(x); } // A thin wrapper around std::hash, performing an additional simple mixing step of the result. @@ -1073,14 +815,12 @@ struct hash::value>::type> { } }; -#define ROBIN_HOOD_HASH_INT(T) \ - template <> \ - struct hash { \ - size_t operator()(T const& obj) const noexcept { \ - using Type = \ - std::conditional::type; \ - return hash_int(static_cast(obj)); \ - } \ +#define ROBIN_HOOD_HASH_INT(T) \ + template <> \ + struct hash { \ + size_t operator()(T const& obj) const noexcept { \ + return hash_int(static_cast(obj)); \ + } \ } #if defined(__GNUC__) && !defined(__clang__) @@ -1380,7 +1120,7 @@ class Table using Node = DataNode; - // helpers for doInsert: extract first entry (only const required) + // helpers for insertKeyPrepareEmptySpot: extract first entry (only const required) ROBIN_HOOD(NODISCARD) key_type const& getFirstConst(Node const& n) const noexcept { return n.getFirst(); } @@ -1606,18 +1346,17 @@ class Table // The upper 1-5 bits need to be a reasonable good hash, to save comparisons. template void keyToIdx(HashKey&& key, size_t* idx, InfoType* info) const { - // for a user-specified hash that is *not* robin_hood::hash, apply robin_hood::hash as - // an additional mixing step. This serves as a bad hash prevention, if the given data is + // In addition to whatever hash is used, add another mul & shift so we get better hashing. + // This serves as a bad hash prevention, if the given data is // badly mixed. - using Mix = - typename std::conditional, hasher>::value, - ::robin_hood::detail::identity_hash, - ::robin_hood::hash>::type; + auto h = static_cast(WHash::operator()(key)); + + h *= mHashMultiplier; + h ^= h >> 33U; // the lower InitialInfoNumBits are reserved for info. - auto h = Mix{}(WHash::operator()(key)); *info = mInfoInc + static_cast((h & InfoMask) >> mInfoHashShift); - *idx = (h >> InitialInfoNumBits) & mMask; + *idx = (static_cast(h) >> InitialInfoNumBits) & mMask; } // forwards the index by one, wrapping around at the end @@ -1707,12 +1446,12 @@ class Table } // inserts a keyval that is guaranteed to be new, e.g. when the hashmap is resized. - // @return index where the element was created - size_t insert_move(Node&& keyval) { + // @return True on success, false if something went wrong + void insert_move(Node&& keyval) { // we don't retry, fail if overflowing // don't need to check max num elements if (0 == mMaxNumElementsAllowed && !try_increase_info()) { - throwOverflowError(); // impossible to reach LCOV_EXCL_LINE + throwOverflowError(); } size_t idx{}; @@ -1749,20 +1488,25 @@ class Table mInfo[insertion_idx] = insertion_info; ++mNumElements; - return insertion_idx; } public: using iterator = Iter; using const_iterator = Iter; + Table() noexcept(noexcept(Hash()) && noexcept(KeyEqual())) + : WHash() + , WKeyEqual() { + ROBIN_HOOD_TRACE(this) + } + // Creates an empty hash map. Nothing is allocated yet, this happens at the first insert. // This tremendously speeds up ctor & dtor of a map that never receives an element. The // penalty is payed at the first insert, and not before. Lookup of this empty map works // because everybody points to DummyInfoByte::b. parameter bucket_count is dictated by the // standard, but we can ignore it. explicit Table( - size_t ROBIN_HOOD_UNUSED(bucket_count) /*unused*/ = 0, const Hash& h = Hash{}, + size_t ROBIN_HOOD_UNUSED(bucket_count) /*unused*/, const Hash& h = Hash{}, const KeyEqual& equal = KeyEqual{}) noexcept(noexcept(Hash(h)) && noexcept(KeyEqual(equal))) : WHash(h) , WKeyEqual(equal) { @@ -1793,6 +1537,7 @@ class Table , DataPool(std::move(static_cast(o))) { ROBIN_HOOD_TRACE(this) if (o.mMask) { + mHashMultiplier = std::move(o.mHashMultiplier); mKeyVals = std::move(o.mKeyVals); mInfo = std::move(o.mInfo); mNumElements = std::move(o.mNumElements); @@ -1811,6 +1556,7 @@ class Table if (o.mMask) { // only move stuff if the other map actually has some data destroy(); + mHashMultiplier = std::move(o.mHashMultiplier); mKeyVals = std::move(o.mKeyVals); mInfo = std::move(o.mInfo); mNumElements = std::move(o.mNumElements); @@ -1842,8 +1588,13 @@ class Table // elements and insert them, but copying is probably faster. auto const numElementsWithBuffer = calcNumElementsWithBuffer(o.mMask + 1); - mKeyVals = static_cast(detail::assertNotNull( - std::malloc(calcNumBytesTotal(numElementsWithBuffer)))); + auto const numBytesTotal = calcNumBytesTotal(numElementsWithBuffer); + + ROBIN_HOOD_LOG("std::malloc " << numBytesTotal << " = calcNumBytesTotal(" + << numElementsWithBuffer << ")") + mHashMultiplier = o.mHashMultiplier; + mKeyVals = static_cast( + detail::assertNotNull(std::malloc(numBytesTotal))); // no need for calloc because clonData does memcpy mInfo = reinterpret_cast(mKeyVals + numElementsWithBuffer); mNumElements = o.mNumElements; @@ -1891,12 +1642,16 @@ class Table // no luck: we don't have the same array size allocated, so we need to realloc. if (0 != mMask) { // only deallocate if we actually have data! + ROBIN_HOOD_LOG("std::free") std::free(mKeyVals); } auto const numElementsWithBuffer = calcNumElementsWithBuffer(o.mMask + 1); - mKeyVals = static_cast(detail::assertNotNull( - std::malloc(calcNumBytesTotal(numElementsWithBuffer)))); + auto const numBytesTotal = calcNumBytesTotal(numElementsWithBuffer); + ROBIN_HOOD_LOG("std::malloc " << numBytesTotal << " = calcNumBytesTotal(" + << numElementsWithBuffer << ")") + mKeyVals = static_cast( + detail::assertNotNull(std::malloc(numBytesTotal))); // no need for calloc here because cloneData performs a memcpy. mInfo = reinterpret_cast(mKeyVals + numElementsWithBuffer); @@ -1905,6 +1660,7 @@ class Table WHash::operator=(static_cast(o)); WKeyEqual::operator=(static_cast(o)); DataPool::operator=(static_cast(o)); + mHashMultiplier = o.mHashMultiplier; mNumElements = o.mNumElements; mMask = o.mMask; mMaxNumElementsAllowed = o.mMaxNumElementsAllowed; @@ -1972,13 +1728,54 @@ class Table template typename std::enable_if::value, Q&>::type operator[](const key_type& key) { ROBIN_HOOD_TRACE(this) - return doCreateByKey(key); + auto idxAndState = insertKeyPrepareEmptySpot(key); + switch (idxAndState.second) { + case InsertionState::key_found: + break; + + case InsertionState::new_node: + ::new (static_cast(&mKeyVals[idxAndState.first])) + Node(*this, std::piecewise_construct, std::forward_as_tuple(key), + std::forward_as_tuple()); + break; + + case InsertionState::overwrite_node: + mKeyVals[idxAndState.first] = Node(*this, std::piecewise_construct, + std::forward_as_tuple(key), std::forward_as_tuple()); + break; + + case InsertionState::overflow_error: + throwOverflowError(); + } + + return mKeyVals[idxAndState.first].getSecond(); } template typename std::enable_if::value, Q&>::type operator[](key_type&& key) { ROBIN_HOOD_TRACE(this) - return doCreateByKey(std::move(key)); + auto idxAndState = insertKeyPrepareEmptySpot(key); + switch (idxAndState.second) { + case InsertionState::key_found: + break; + + case InsertionState::new_node: + ::new (static_cast(&mKeyVals[idxAndState.first])) + Node(*this, std::piecewise_construct, std::forward_as_tuple(std::move(key)), + std::forward_as_tuple()); + break; + + case InsertionState::overwrite_node: + mKeyVals[idxAndState.first] = + Node(*this, std::piecewise_construct, std::forward_as_tuple(std::move(key)), + std::forward_as_tuple()); + break; + + case InsertionState::overflow_error: + throwOverflowError(); + } + + return mKeyVals[idxAndState.first].getSecond(); } template @@ -1989,17 +1786,38 @@ class Table } } + void insert(std::initializer_list ilist) { + for (auto&& vt : ilist) { + insert(std::move(vt)); + } + } + template std::pair emplace(Args&&... args) { ROBIN_HOOD_TRACE(this) Node n{*this, std::forward(args)...}; - auto r = doInsert(std::move(n)); - if (!r.second) { - // insertion not possible: destroy node - // NOLINTNEXTLINE(bugprone-use-after-move) + auto idxAndState = insertKeyPrepareEmptySpot(getFirstConst(n)); + switch (idxAndState.second) { + case InsertionState::key_found: + n.destroy(*this); + break; + + case InsertionState::new_node: + ::new (static_cast(&mKeyVals[idxAndState.first])) Node(*this, std::move(n)); + break; + + case InsertionState::overwrite_node: + mKeyVals[idxAndState.first] = std::move(n); + break; + + case InsertionState::overflow_error: n.destroy(*this); + throwOverflowError(); + break; } - return r; + + return std::make_pair(iterator(mKeyVals + idxAndState.first, mInfo + idxAndState.first), + InsertionState::key_found != idxAndState.second); } template @@ -2027,34 +1845,34 @@ class Table template std::pair insert_or_assign(const key_type& key, Mapped&& obj) { - return insert_or_assign_impl(key, std::forward(obj)); + return insertOrAssignImpl(key, std::forward(obj)); } template std::pair insert_or_assign(key_type&& key, Mapped&& obj) { - return insert_or_assign_impl(std::move(key), std::forward(obj)); + return insertOrAssignImpl(std::move(key), std::forward(obj)); } template std::pair insert_or_assign(const_iterator hint, const key_type& key, Mapped&& obj) { (void)hint; - return insert_or_assign_impl(key, std::forward(obj)); + return insertOrAssignImpl(key, std::forward(obj)); } template std::pair insert_or_assign(const_iterator hint, key_type&& key, Mapped&& obj) { (void)hint; - return insert_or_assign_impl(std::move(key), std::forward(obj)); + return insertOrAssignImpl(std::move(key), std::forward(obj)); } std::pair insert(const value_type& keyval) { ROBIN_HOOD_TRACE(this) - return doInsert(keyval); + return emplace(keyval); } std::pair insert(value_type&& keyval) { - return doInsert(std::move(keyval)); + return emplace(std::move(keyval)); } // Returns 1 if key is found, 0 otherwise. @@ -2238,23 +2056,36 @@ class Table // reserves space for the specified number of elements. Makes sure the old data fits. // exactly the same as reserve(c). void rehash(size_t c) { - reserve(c); + // forces a reserve + reserve(c, true); } // reserves space for the specified number of elements. Makes sure the old data fits. - // Exactly the same as resize(c). Use resize(0) to shrink to fit. + // Exactly the same as rehash(c). Use rehash(0) to shrink to fit. void reserve(size_t c) { + // reserve, but don't force rehash + reserve(c, false); + } + + // If possible reallocates the map to a smaller one. This frees the underlying table. + // Does not do anything if load_factor is too large for decreasing the table's size. + void compact() { ROBIN_HOOD_TRACE(this) - auto const minElementsAllowed = (std::max)(c, mNumElements); auto newSize = InitialNumElements; - while (calcMaxNumElementsAllowed(newSize) < minElementsAllowed && newSize != 0) { + while (calcMaxNumElementsAllowed(newSize) < mNumElements && newSize != 0) { newSize *= 2; } if (ROBIN_HOOD_UNLIKELY(newSize == 0)) { throwOverflowError(); } - rehashPowerOfTwo(newSize); + ROBIN_HOOD_LOG("newSize > mMask + 1: " << newSize << " > " << mMask << " + 1") + + // only actually do anything when the new size is bigger than the old one. This prevents to + // continuously allocate for each reserve() call. + if (newSize < mMask + 1) { + rehashPowerOfTwo(newSize, true); + } } size_type size() const noexcept { // NOLINT(modernize-use-nodiscard) @@ -2345,9 +2176,30 @@ class Table return find(e) != end(); } + void reserve(size_t c, bool forceRehash) { + ROBIN_HOOD_TRACE(this) + auto const minElementsAllowed = (std::max)(c, mNumElements); + auto newSize = InitialNumElements; + while (calcMaxNumElementsAllowed(newSize) < minElementsAllowed && newSize != 0) { + newSize *= 2; + } + if (ROBIN_HOOD_UNLIKELY(newSize == 0)) { + throwOverflowError(); + } + + ROBIN_HOOD_LOG("newSize > mMask + 1: " << newSize << " > " << mMask << " + 1") + + // only actually do anything when the new size is bigger than the old one. This prevents to + // continuously allocate for each reserve() call. + if (forceRehash || newSize > mMask + 1) { + rehashPowerOfTwo(newSize, false); + } + } + // reserves space for at least the specified number of elements. // only works if numBuckets if power of two - void rehashPowerOfTwo(size_t numBuckets) { + // True on success, false otherwise + void rehashPowerOfTwo(size_t numBuckets, bool forceFree) { ROBIN_HOOD_TRACE(this) Node* const oldKeyVals = mKeyVals; @@ -2356,18 +2208,29 @@ class Table const size_t oldMaxElementsWithBuffer = calcNumElementsWithBuffer(mMask + 1); // resize operation: move stuff - init_data(numBuckets); + initData(numBuckets); if (oldMaxElementsWithBuffer > 1) { for (size_t i = 0; i < oldMaxElementsWithBuffer; ++i) { if (oldInfo[i] != 0) { + // might throw an exception, which is really bad since we are in the middle of + // moving stuff. insert_move(std::move(oldKeyVals[i])); // destroy the node but DON'T destroy the data. oldKeyVals[i].~Node(); } } - // don't destroy old data: put it into the pool instead - DataPool::addOrFree(oldKeyVals, calcNumBytesTotal(oldMaxElementsWithBuffer)); + // this check is not necessary as it's guarded by the previous if, but it helps + // silence g++'s overeager "attempt to free a non-heap object 'map' + // [-Werror=free-nonheap-object]" warning. + if (oldKeyVals != reinterpret_cast_no_cast_align_warning(&mMask)) { + // don't destroy old data: put it into the pool instead + if (forceFree) { + std::free(oldKeyVals); + } else { + DataPool::addOrFree(oldKeyVals, calcNumBytesTotal(oldMaxElementsWithBuffer)); + } + } } } @@ -2382,27 +2245,63 @@ class Table template std::pair try_emplace_impl(OtherKey&& key, Args&&... args) { ROBIN_HOOD_TRACE(this) - auto it = find(key); - if (it == end()) { - return emplace(std::piecewise_construct, - std::forward_as_tuple(std::forward(key)), - std::forward_as_tuple(std::forward(args)...)); + auto idxAndState = insertKeyPrepareEmptySpot(key); + switch (idxAndState.second) { + case InsertionState::key_found: + break; + + case InsertionState::new_node: + ::new (static_cast(&mKeyVals[idxAndState.first])) Node( + *this, std::piecewise_construct, std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(args)...)); + break; + + case InsertionState::overwrite_node: + mKeyVals[idxAndState.first] = Node(*this, std::piecewise_construct, + std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(args)...)); + break; + + case InsertionState::overflow_error: + throwOverflowError(); + break; } - return {it, false}; + + return std::make_pair(iterator(mKeyVals + idxAndState.first, mInfo + idxAndState.first), + InsertionState::key_found != idxAndState.second); } template - std::pair insert_or_assign_impl(OtherKey&& key, Mapped&& obj) { + std::pair insertOrAssignImpl(OtherKey&& key, Mapped&& obj) { ROBIN_HOOD_TRACE(this) - auto it = find(key); - if (it == end()) { - return emplace(std::forward(key), std::forward(obj)); + auto idxAndState = insertKeyPrepareEmptySpot(key); + switch (idxAndState.second) { + case InsertionState::key_found: + mKeyVals[idxAndState.first].getSecond() = std::forward(obj); + break; + + case InsertionState::new_node: + ::new (static_cast(&mKeyVals[idxAndState.first])) Node( + *this, std::piecewise_construct, std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(obj))); + break; + + case InsertionState::overwrite_node: + mKeyVals[idxAndState.first] = Node(*this, std::piecewise_construct, + std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(obj))); + break; + + case InsertionState::overflow_error: + throwOverflowError(); + break; } - it->second = std::forward(obj); - return {it, false}; + + return std::make_pair(iterator(mKeyVals + idxAndState.first, mInfo + idxAndState.first), + InsertionState::key_found != idxAndState.second); } - void init_data(size_t max_elements) { + void initData(size_t max_elements) { mNumElements = 0; mMask = max_elements - 1; mMaxNumElementsAllowed = calcMaxNumElementsAllowed(max_elements); @@ -2410,8 +2309,11 @@ class Table auto const numElementsWithBuffer = calcNumElementsWithBuffer(max_elements); // calloc also zeroes everything - mKeyVals = reinterpret_cast(detail::assertNotNull( - std::calloc(1, calcNumBytesTotal(numElementsWithBuffer)))); + auto const numBytesTotal = calcNumBytesTotal(numElementsWithBuffer); + ROBIN_HOOD_LOG("std::calloc " << numBytesTotal << " = calcNumBytesTotal(" + << numElementsWithBuffer << ")") + mKeyVals = reinterpret_cast( + detail::assertNotNull(std::calloc(1, numBytesTotal))); mInfo = reinterpret_cast(mKeyVals + numElementsWithBuffer); // set sentinel @@ -2421,86 +2323,34 @@ class Table mInfoHashShift = InitialInfoHashShift; } - template - typename std::enable_if::value, Q&>::type doCreateByKey(Arg&& key) { - while (true) { - size_t idx{}; - InfoType info{}; - keyToIdx(key, &idx, &info); - nextWhileLess(&info, &idx); + enum class InsertionState { overflow_error, key_found, new_node, overwrite_node }; - // while we potentially have a match. Can't do a do-while here because when mInfo is - // 0 we don't want to skip forward - while (info == mInfo[idx]) { - if (WKeyEqual::operator()(key, mKeyVals[idx].getFirst())) { - // key already exists, do not insert. - return mKeyVals[idx].getSecond(); - } - next(&info, &idx); - } - - // unlikely that this evaluates to true - if (ROBIN_HOOD_UNLIKELY(mNumElements >= mMaxNumElementsAllowed)) { - increase_size(); - continue; - } - - // key not found, so we are now exactly where we want to insert it. - auto const insertion_idx = idx; - auto const insertion_info = info; - if (ROBIN_HOOD_UNLIKELY(insertion_info + mInfoInc > 0xFF)) { - mMaxNumElementsAllowed = 0; - } - - // find an empty spot - while (0 != mInfo[idx]) { - next(&info, &idx); - } - - auto& l = mKeyVals[insertion_idx]; - if (idx == insertion_idx) { - // put at empty spot. This forwards all arguments into the node where the object - // is constructed exactly where it is needed. - ::new (static_cast(&l)) - Node(*this, std::piecewise_construct, - std::forward_as_tuple(std::forward(key)), std::forward_as_tuple()); - } else { - shiftUp(idx, insertion_idx); - l = Node(*this, std::piecewise_construct, - std::forward_as_tuple(std::forward(key)), std::forward_as_tuple()); - } - - // mKeyVals[idx].getFirst() = std::move(key); - mInfo[insertion_idx] = static_cast(insertion_info); - - ++mNumElements; - return mKeyVals[insertion_idx].getSecond(); - } - } - - // This is exactly the same code as operator[], except for the return values - template - std::pair doInsert(Arg&& keyval) { - while (true) { + // Finds key, and if not already present prepares a spot where to pot the key & value. + // This potentially shifts nodes out of the way, updates mInfo and number of inserted + // elements, so the only operation left to do is create/assign a new node at that spot. + template + std::pair insertKeyPrepareEmptySpot(OtherKey&& key) { + for (int i = 0; i < 256; ++i) { size_t idx{}; InfoType info{}; - keyToIdx(getFirstConst(keyval), &idx, &info); + keyToIdx(key, &idx, &info); nextWhileLess(&info, &idx); // while we potentially have a match while (info == mInfo[idx]) { - if (WKeyEqual::operator()(getFirstConst(keyval), mKeyVals[idx].getFirst())) { + if (WKeyEqual::operator()(key, mKeyVals[idx].getFirst())) { // key already exists, do NOT insert. // see http://en.cppreference.com/w/cpp/container/unordered_map/insert - return std::make_pair(iterator(mKeyVals + idx, mInfo + idx), - false); + return std::make_pair(idx, InsertionState::key_found); } next(&info, &idx); } // unlikely that this evaluates to true if (ROBIN_HOOD_UNLIKELY(mNumElements >= mMaxNumElementsAllowed)) { - increase_size(); + if (!increase_size()) { + return std::make_pair(size_t(0), InsertionState::overflow_error); + } continue; } @@ -2516,20 +2366,19 @@ class Table next(&info, &idx); } - auto& l = mKeyVals[insertion_idx]; - if (idx == insertion_idx) { - ::new (static_cast(&l)) Node(*this, std::forward(keyval)); - } else { + if (idx != insertion_idx) { shiftUp(idx, insertion_idx); - l = Node(*this, std::forward(keyval)); } - // put at empty spot mInfo[insertion_idx] = static_cast(insertion_info); - ++mNumElements; - return std::make_pair(iterator(mKeyVals + insertion_idx, mInfo + insertion_idx), true); + return std::make_pair(insertion_idx, idx == insertion_idx + ? InsertionState::new_node + : InsertionState::overwrite_node); } + + // enough attempts failed, so finally give up. + return std::make_pair(size_t(0), InsertionState::overflow_error); } bool try_increase_info() { @@ -2560,28 +2409,41 @@ class Table return true; } - void increase_size() { + // True if resize was possible, false otherwise + bool increase_size() { // nothing allocated yet? just allocate InitialNumElements if (0 == mMask) { - init_data(InitialNumElements); - return; + initData(InitialNumElements); + return true; } auto const maxNumElementsAllowed = calcMaxNumElementsAllowed(mMask + 1); if (mNumElements < maxNumElementsAllowed && try_increase_info()) { - return; + return true; } ROBIN_HOOD_LOG("mNumElements=" << mNumElements << ", maxNumElementsAllowed=" << maxNumElementsAllowed << ", load=" << (static_cast(mNumElements) * 100.0 / (static_cast(mMask) + 1))) - // it seems we have a really bad hash function! don't try to resize again + if (mNumElements * 2 < calcMaxNumElementsAllowed(mMask + 1)) { - throwOverflowError(); + // we have to resize, even though there would still be plenty of space left! + // Try to rehash instead. Delete freed memory so we don't steadyily increase mem in case + // we have to rehash a few times + nextHashMultiplier(); + rehashPowerOfTwo(mMask + 1, true); + } else { + // we've reached the capacity of the map, so the hash seems to work nice. Keep using it. + rehashPowerOfTwo((mMask + 1) * 2, false); } + return true; + } - rehashPowerOfTwo((mMask + 1) * 2); + void nextHashMultiplier() { + // adding an *even* number, so that the multiplier will always stay odd. This is necessary + // so that the hash stays a mixing function (and thus doesn't have any information loss). + mHashMultiplier += UINT64_C(0xc4ceb9fe1a85ec54); } void destroy() { @@ -2595,9 +2457,10 @@ class Table // This protection against not deleting mMask shouldn't be needed as it's sufficiently // protected with the 0==mMask check, but I have this anyways because g++ 7 otherwise - // reports a compile error: attempt to free a non-heap object ‘fm’ + // reports a compile error: attempt to free a non-heap object 'fm' // [-Werror=free-nonheap-object] if (mKeyVals != reinterpret_cast_no_cast_align_warning(&mMask)) { + ROBIN_HOOD_LOG("std::free") std::free(mKeyVals); } } @@ -2613,13 +2476,14 @@ class Table } // members are sorted so no padding occurs - Node* mKeyVals = reinterpret_cast_no_cast_align_warning(&mMask); // 8 byte 8 - uint8_t* mInfo = reinterpret_cast(&mMask); // 8 byte 16 - size_t mNumElements = 0; // 8 byte 24 - size_t mMask = 0; // 8 byte 32 - size_t mMaxNumElementsAllowed = 0; // 8 byte 40 - InfoType mInfoInc = InitialInfoInc; // 4 byte 44 - InfoType mInfoHashShift = InitialInfoHashShift; // 4 byte 48 + uint64_t mHashMultiplier = UINT64_C(0xc4ceb9fe1a85ec53); // 8 byte 8 + Node* mKeyVals = reinterpret_cast_no_cast_align_warning(&mMask); // 8 byte 16 + uint8_t* mInfo = reinterpret_cast(&mMask); // 8 byte 24 + size_t mNumElements = 0; // 8 byte 32 + size_t mMask = 0; // 8 byte 40 + size_t mMaxNumElementsAllowed = 0; // 8 byte 48 + InfoType mInfoInc = InitialInfoInc; // 4 byte 52 + InfoType mInfoHashShift = InitialInfoHashShift; // 4 byte 56 // 16 byte 56 if NodeAllocator };