Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster, better feature count estimates for intPKs #467

Merged
merged 1 commit into from
Aug 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ _When adding new entries to the changelog, please include issue/PR numbers where

* Bugfix: Set GDAL and PROJ environment variables on startup, which fixes an issue where Kart may or may not work properly depending on whether GDAL and PROJ are appropriately configured in the user's environment
* Bugfix: `kart restore` now simply discards all working copy changes, as it is intended to - previously it would complain if there were "structural" schema differences between the working copy and HEAD.
* Feature-count estimates are now more accurate and generally also faster [#467](https://github.com/koordinates/kart/issues/467)

## 0.10.2

Expand Down
253 changes: 192 additions & 61 deletions kart/dataset3_paths.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import functools
import logging
import math
import random
from collections import defaultdict

import pygit2

from .exceptions import NotYetImplemented
from .serialise_util import (
Expand All @@ -10,6 +15,7 @@
)
from .utils import chunk

L = logging.getLogger("kart.dataset3_paths")

_LOWERCASE_HEX_ALPHABET = "0123456789abcdef"

Expand Down Expand Up @@ -171,26 +177,17 @@ def tree_names(self):
for i in range(self.branches):
yield self._single_tree_int_encoder.encode_int(i)

def sample_subtrees(self, num_trees, *, max_tree_id=None):
def _nonrecursive_diff(self, tree_a, tree_b):
"""
Yields a sample set of outermost trees such as might contain features for sampling,
for feature count estimation. Eg: ["A/A/D/E", "A/A/H/J", ...]
Yields num_trees trees. All returned trees will be *before* the max_tree_id supplied -
useful if primary keys are clustered at the low end of the tree structure (see IntPathEncoder).
Returns a dict mapping names to OIDs which differ between the trees.
(either the key is present in both, and the OID is different,
or the key is only present in one of the trees)
"""
if max_tree_id is None:
total_subtrees = self.max_trees
else:
total_subtrees = max_tree_id
if num_trees >= total_subtrees:
yield SAMPLE_ALL_TREES
return
a = {obj.name: obj for obj in tree_a} if tree_a else {}
b = {obj.name: obj for obj in tree_b} if tree_b else {}
all_names = sorted(list(set(a.keys() | b.keys())))

stride = total_subtrees / num_trees
assert stride > 1
for i in range(num_trees):
tree_idx = round(i * stride)
yield self._path_int_encoder.encode_int(tree_idx)
return {k: (a.get(k), b.get(k)) for k in all_names if a.get(k) != b.get(k)}


class MsgpackHashPathEncoder(PathEncoder):
Expand All @@ -217,14 +214,70 @@ def encode_pks_to_path(self, pk_values):
parts.append(self._encode_file_name_from_packed_pk(packed_pk))
return "/".join(parts)

def sample_subtrees(self, num_trees, *, max_tree_id=None):
def _num_expected_distributed_tree_blobs(self, num_samples, branch_factor):
"""
Returns the expected number of children in a tree of the given size.
"""
# https://docs.google.com/document/d/11CeJKbiNQoLmhDcYIM68cJSA_nKBHW7kYVybh2N-Lww/edit#heading=h.7z95y6hc62gn
return math.log(1 - num_samples / branch_factor) / math.log(
1 - 1 / branch_factor
)

def _recursive_diff_estimate(
self, tree1, tree2, branch_count, total_samples_to_take
):
"""
Yields a sample set of outermost trees such as might contain features for sampling,
for feature count estimation. Eg: ["A/A/D/E", "A/A/H/J", ...]
Yields num_trees trees, max_tree_id is ignored since features are distributed uniformly
and randomly all over the structure.
Samples some subtrees of the given two trees, and returns an estimate of the number
of features that are expected to differ.

Takes advantage of the fact that features are evenly distributed amongst trees
- all we need to do is drill down from the root tree until the number of
children of the inspected subtree is much less than the branch factor.
Then we just sample a few trees at that level, take the average number
of subtrees, and multiply by the appropriate exponent of the branch factor
to get a total number of blobs.
"""
yield from super().sample_subtrees(num_trees, max_tree_id=None)
diff = self._nonrecursive_diff(tree1, tree2)

diff_size = len(diff)
if diff_size < branch_count / 2:
estimated_blobs = self._num_expected_distributed_tree_blobs(
diff_size, branch_count
)
L.debug(
f"Found {diff_size} diffs for an estimate of {estimated_blobs} blobs."
)
return estimated_blobs, 1

L.debug(f"Found {diff_size} diffs, checking next level:")

total_subsample_size = 0
total_subsamples_taken = 0
total_samples_taken = 0
for tree1, tree2 in diff.values():
if isinstance(tree1, pygit2.Blob) or isinstance(tree2, pygit2.Blob):
subsample_size = 1
samples_taken = 1
else:
subsample_size, samples_taken = self._recursive_diff_estimate(
tree1, tree2, branch_count, total_samples_to_take
)
total_subsample_size += subsample_size
total_subsamples_taken += 1
total_samples_taken += samples_taken
if total_samples_taken >= total_samples_to_take:
break

return (
1.0 * diff_size * total_subsample_size / total_subsamples_taken,
total_samples_taken,
)

def diff_estimate(self, tree1, tree2, branch_count, total_samples_to_take):
diff_count, samples_taken = self._recursive_diff_estimate(
tree1, tree2, branch_count, total_samples_to_take
)
return int(round(diff_count))


class IntPathEncoder(PathEncoder):
Expand All @@ -245,52 +298,130 @@ def encode_pks_to_path(self, pk_values):
filename = self.encode_filename(pk_values)
return f"{tree_path}/{filename}"

def _nonrecursive_diff(self, tree_a, tree_b):
"""
Returns a dict mapping names to OIDs which differ between the trees.
(either the key is present in both, and the OID is different,
or the key is only present in one of the trees)
def _recursive_depth_first_diff_estimate(
self, tree1, tree2, *, path, paths_fully_explored, diffs_by_path, rand
):
"""
a = {obj.name: obj for obj in tree_a}
b = {obj.name: obj for obj in tree_b}
all_names = a.keys() | b.keys()
return {k: (a.get(k), b.get(k)) for k in all_names if a.get(k) != b.get(k)}
Dives as deep as possible into the diff for the given trees, returning one
feature-count sample.

def max_tree_id(self, repo, base_feature_tree, target_feature_tree):
"""
Looks at a few trees to determine the maximum integer ID of the trees in the given diff.
Used as an upper bound for feature count sampling.
Dives into a random branch at each level, but without replacement;
any bottom-level branch that has already been sampled will be avoided.

e.g if the only tree is 'A/A/A/A', returns 0
Returns 0 if all branches at the current level have already been sampled.
"""
max_tree_path = self._max_feature_tree_path(
repo, base_feature_tree, target_feature_tree
)
return self._path_int_encoder.decode_int(max_tree_path)

def _max_feature_tree_path(
self, repo, base_feature_tree, target_feature_tree, *, depth=0
try:
diff = diffs_by_path[path]
except KeyError:
diff = self._nonrecursive_diff(tree1, tree2)
diffs_by_path[path] = diff

diff_items = list(diff.items())
rand.shuffle(diff_items)

child1, child2 = diff_items[0][1]
if isinstance(child1, pygit2.Blob) or isinstance(child2, pygit2.Blob):
# we're at the bottom level
paths_fully_explored.add(path)
return len(diff)

for (name, (child1, child2)) in diff_items:
child_path = f"{path}/{name}"
if child_path in paths_fully_explored:
continue
num_features = self._recursive_depth_first_diff_estimate(
child1,
child2,
path=child_path,
paths_fully_explored=paths_fully_explored,
diffs_by_path=diffs_by_path,
rand=rand,
)
if not num_features:
# no (new) features found in a subtree. try another subtree
continue
else:
# by recursing into a subtree, we found some new features.
# return these to the root level as a sample.
return num_features
else:
# this path was already fully sampled by a previous call to this function,
# so we haven't sampled any new trees.
paths_fully_explored.add(path)
return 0

def diff_estimate(
self,
tree1,
tree2,
branch_count,
total_samples_to_take,
):
"""
Returns the path of the tree containing the greatest PK,
relative to the given feature tree. Recurses to self.levels.
"""
base_feature_tree = base_feature_tree or repo.empty_tree
target_feature_tree = target_feature_tree or repo.empty_tree

if base_feature_tree == target_feature_tree:
return None
Samples some subtrees of the given two trees, and returns an estimate of the number
of features that are expected to differ.

diff = self._nonrecursive_diff(base_feature_tree, target_feature_tree)
# Diff is always non-empty since the trees must differ.
max_path = max(diff.keys())
if depth == self.levels - 1:
return max_path
else:
a, b = diff[max_path]
return (
f"{max_path}/{self._max_feature_tree_path(repo, a, b, depth=depth + 1)}"
This is a lot harder than the equivalent method on MsgpackHashPathEncoder
because there is no reliable distribution of features across trees.
"""
# start with a deterministic random state.
# Otherwise we'll get unreproducible results.
rand = random.Random(0)

paths_fully_explored = set()
diffs_by_path = {}
samples = []
for i in range(total_samples_to_take):
num_features = self._recursive_depth_first_diff_estimate(
tree1,
tree2,
path="",
paths_fully_explored=paths_fully_explored,
diffs_by_path=diffs_by_path,
rand=rand,
)
if num_features:
samples.append(num_features)
else:
# we sampled the entire diff, so this will be an exact result
return sum(samples)

num_samples = len(samples)
max_level = max(x.count("/") for x in diffs_by_path.keys())

# keyed by level, the total number of trees encountered
# (we didn't sample all of them, but we know they exist)
# doesn't include the deepest level.
# e.g. {0: 1, 1: 1, 2: 1, 3: 34}
trees_seen_at_level = defaultdict(int)

# keyed by level, the total number of trees we sampled
# (not including the deepest level)
# e.g. {0: 1, 1: 1, 2: 1, 3: 16}
trees_sampled_at_level = defaultdict(int)
for k, v in diffs_by_path.items():
level = k.count("/")
if level < max_level:
trees_seen_at_level[level] += len(v)
if level:
trees_sampled_at_level[level - 1] += 1
assert trees_sampled_at_level[max_level - 1] == num_samples

# now we have sampled a bunch of deepest-level trees to figure out how many
# blobs are in them, we can work backwards. how many trees do we expect to exist?
num_features = sum(samples)
for level in trees_sampled_at_level.keys():
# if we only sampled half of the trees at this level, then multiply the
# current total feature count (FC) by 2.
# Even though trees_seen_at_level only has the *known* trees,
# doing this repeatedly until we get to the root level will still result
# in a total FC estimate which approaches the actual FC.
level_multiplier = (
trees_seen_at_level[level] / trees_sampled_at_level[level]
)
num_features *= level_multiplier

return int(round(num_features))


# The encoder that was previously used for all datasets.
Expand Down
Loading