Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes dict_hash discrepancy #3195

Merged
merged 3 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ludwig/data/cache/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def calculate_checksum(original_dataset: CacheableDataset, config: ModelConfigDi
"global_defaults": config.get(DEFAULTS, {}),
# PROC_COLUMN contains both the feature name and the feature hash that is computed
# based on each feature's preprocessing parameters and the feature's type.
"feature_proc_columns": {feature[PROC_COLUMN] for feature in features},
# creating a sorted list out of the dict because hash_dict requires all values
# of the dict to be ordered object to ensure the creation fo the same hash
"feature_proc_columns": sorted({feature[PROC_COLUMN] for feature in features}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if there's a way we can test this. Something that would force the behavior of changing the insertion order of the set.

This looks promising:

By default, the hash() values of str, bytes and datetime objects are “salted” with an unpredictable random value. Although they remain constant within an individual Python process, they are not predictable between repeated invocations of Python.

This is intended to provide protection against a denial-of-service caused by carefully-chosen inputs that exploit the worst case performance of a dict insertion, O(n^2) complexity. See http://www.ocert.org/advisories/ocert-2011-003.html for details.

Changing hash values affects the iteration order of dicts, sets and other mappings. Python has never made guarantees about this ordering (and it typically varies between 32-bit and 64-bit builds).

See also PYTHONHASHSEED.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me see if I can put a test that exploits this together and add it to this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, added a test and verified it repros the issue without this fix, and with this fix succeeds.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for looking into Python docs @tgaddair. It makes sense that the ordering of sets would cause some problems from the docs, so this is a good change!

Also pretty interesting that the salts are scoped to individual processes rather than across processes, but it makes sense if you think about it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the salts were the same between processes, then it wouldn't serve its purpose of making the hash function unpredictable (to an attacker). But if they were different within a process, then hash lookups wouldn't work ;). So it makes sense, just never considered that they designed their hash function with that exploit in mind.

"feature_types": [feature[TYPE] for feature in features],
"feature_preprocessing": [feature.get(PREPROCESSING, {}) for feature in features],
}
Expand Down
6 changes: 6 additions & 0 deletions ludwig/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,12 @@ def save_json(data_fp, data, sort_keys=True, indent=4):

@DeveloperAPI
def hash_dict(d: dict, max_length: Union[int, None] = 6) -> bytes:
"""Function that maps a dictionary into a unique hash.

Known limitation: All values and keys of the dict must have an ordering. If not, there's no guarantee to obtain the
same hash. For instance, values that are sets will potentially lead to different hashed when run on different
machines or in different python sessions. Replacing them with sorted lists is suggested.
"""
s = json.dumps(d, cls=NumpyEncoder, sort_keys=True, ensure_ascii=True)
h = hashlib.md5(s.encode())
d = h.digest()
Expand Down
27 changes: 27 additions & 0 deletions tests/ludwig/data/test_cache_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,30 @@ def test_proc_col_checksum_consistency_same_preprocessing_different_types():
config = ModelConfig.from_dict(config)

assert config.input_features[0].proc_column != config.input_features[1].proc_column


@pytest.mark.distributed
def test_checksum_determinism(ray_cluster_2cpu):
"""Tests that checksums are deterministic across different processes (no unordered hash maps)."""
import ray

# Generate a lot of features so the probability of a reordering of feature sets is very high.
config = {
INPUT_FEATURES: [{"name": f"in{i}", "type": "number"} for i in range(100)],
OUTPUT_FEATURES: [{"name": "out1", "type": "binary"}],
}
config = ModelConfig.from_dict(config)

mock_dataset = mock.Mock()
mock_dataset.checksum = uuid.uuid4().hex

@ray.remote(max_calls=1)
def calculate_checksum_remote(dataset, config):
return calculate_checksum(dataset, config)

# Run each checksum calculation as a remote function so it gets its own Python interpreter, as
# the hash function in Python is deterministic within a process, but not between different processes.
# See: https://docs.python.org/3/reference/datamodel.html#object.__hash__
checksum1 = ray.get(calculate_checksum_remote.remote(mock_dataset, config.to_dict()))
checksum2 = ray.get(calculate_checksum_remote.remote(mock_dataset, config.to_dict()))
assert checksum1 == checksum2
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved