Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use minimal perfect hashing for lookups #37

Merged
merged 8 commits into from
Apr 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 100 additions & 63 deletions scripts/unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Since this should not require frequent updates, we just store this
# out-of-line and check the unicode.rs file into git.
import collections
import requests
import urllib.request

UNICODE_VERSION = "9.0.0"
UCD_URL = "https://www.unicode.org/Public/%s/ucd/" % UNICODE_VERSION
Expand Down Expand Up @@ -68,9 +68,9 @@ def __init__(self):

def stats(name, table):
count = sum(len(v) for v in table.values())
print "%s: %d chars => %d decomposed chars" % (name, len(table), count)
print("%s: %d chars => %d decomposed chars" % (name, len(table), count))

print "Decomposition table stats:"
print("Decomposition table stats:")
stats("Canonical decomp", self.canon_decomp)
stats("Compatible decomp", self.compat_decomp)
stats("Canonical fully decomp", self.canon_fully_decomp)
Expand All @@ -79,8 +79,8 @@ def stats(name, table):
self.ss_leading, self.ss_trailing = self._compute_stream_safe_tables()

def _fetch(self, filename):
resp = requests.get(UCD_URL + filename)
return resp.text
resp = urllib.request.urlopen(UCD_URL + filename)
return resp.read().decode('utf-8')

def _load_unicode_data(self):
self.combining_classes = {}
Expand Down Expand Up @@ -234,7 +234,7 @@ def _decompose(char_int, compatible):
# need to store their overlap when they agree. When they don't agree,
# store the decomposition in the compatibility table since we'll check
# that first when normalizing to NFKD.
assert canon_fully_decomp <= compat_fully_decomp
assert set(canon_fully_decomp) <= set(compat_fully_decomp)

for ch in set(canon_fully_decomp) & set(compat_fully_decomp):
if canon_fully_decomp[ch] == compat_fully_decomp[ch]:
Expand Down Expand Up @@ -284,27 +284,37 @@ def _compute_stream_safe_tables(self):

return leading_nonstarters, trailing_nonstarters

hexify = lambda c: hex(c)[2:].upper().rjust(4, '0')
hexify = lambda c: '{:04X}'.format(c)

def gen_combining_class(combining_classes, out):
out.write("#[inline]\n")
out.write("pub fn canonical_combining_class(c: char) -> u8 {\n")
out.write(" match c {\n")

for char, combining_class in sorted(combining_classes.items()):
out.write(" '\u{%s}' => %s,\n" % (hexify(char), combining_class))
def gen_mph_data(name, d, kv_type, kv_callback):
(salt, keys) = minimal_perfect_hash(d)
out.write("pub(crate) const %s_SALT: &[u16] = &[\n" % name.upper())
for s in salt:
out.write(" 0x{:x},\n".format(s))
out.write("];\n")
out.write("pub(crate) const {}_KV: &[{}] = &[\n".format(name.upper(), kv_type))
for k in keys:
out.write(" {},\n".format(kv_callback(k)))
out.write("];\n\n")

out.write(" _ => 0,\n")
out.write(" }\n")
out.write("}\n")
def gen_combining_class(combining_classes, out):
gen_mph_data('canonical_combining_class', combining_classes, 'u32',
lambda k: "0x{:X}".format(int(combining_classes[k]) | (k << 8)))

def gen_composition_table(canon_comp, out):
out.write("#[inline]\n")
out.write("pub fn composition_table(c1: char, c2: char) -> Option<char> {\n")
table = {}
for (c1, c2), c3 in canon_comp.items():
if c1 < 0x10000 and c2 < 0x10000:
table[(c1 << 16) | c2] = c3
(salt, keys) = minimal_perfect_hash(table)
gen_mph_data('COMPOSITION_TABLE', table, '(u32, char)',
lambda k: "(0x%s, '\\u{%s}')" % (hexify(k), hexify(table[k])))

out.write("pub(crate) fn composition_table_astral(c1: char, c2: char) -> Option<char> {\n")
out.write(" match (c1, c2) {\n")

for (c1, c2), c3 in sorted(canon_comp.items()):
out.write(" ('\u{%s}', '\u{%s}') => Some('\u{%s}'),\n" % (hexify(c1), hexify(c2), hexify(c3)))
if c1 >= 0x10000 and c2 >= 0x10000:
out.write(" ('\\u{%s}', '\\u{%s}') => Some('\\u{%s}'),\n" % (hexify(c1), hexify(c2), hexify(c3)))

out.write(" _ => None,\n")
out.write(" }\n")
Expand All @@ -313,23 +323,9 @@ def gen_composition_table(canon_comp, out):
def gen_decomposition_tables(canon_decomp, compat_decomp, out):
tables = [(canon_decomp, 'canonical'), (compat_decomp, 'compatibility')]
for table, name in tables:
out.write("#[inline]\n")
out.write("pub fn %s_fully_decomposed(c: char) -> Option<&'static [char]> {\n" % name)
# The "Some" constructor is around the match statement here, because
# putting it into the individual arms would make the item_bodies
# checking of rustc takes almost twice as long, and it's already pretty
# slow because of the huge number of match arms and the fact that there
# is a borrow inside each arm
out.write(" Some(match c {\n")

for char, chars in sorted(table.items()):
d = ", ".join("'\u{%s}'" % hexify(c) for c in chars)
out.write(" '\u{%s}' => &[%s],\n" % (hexify(char), d))

out.write(" _ => return None,\n")
out.write(" })\n")
out.write("}\n")
out.write("\n")
gen_mph_data(name + '_decomposed', table, "(u32, &'static [char])",
lambda k: "(0x{:x}, &[{}])".format(k,
", ".join("'\\u{%s}'" % hexify(c) for c in table[k])))

def gen_qc_match(prop_table, out):
out.write(" match c {\n")
Expand Down Expand Up @@ -371,40 +367,25 @@ def gen_nfkd_qc(prop_tables, out):
out.write("}\n")

def gen_combining_mark(general_category_mark, out):
out.write("#[inline]\n")
out.write("pub fn is_combining_mark(c: char) -> bool {\n")
out.write(" match c {\n")

for char in general_category_mark:
out.write(" '\u{%s}' => true,\n" % hexify(char))

out.write(" _ => false,\n")
out.write(" }\n")
out.write("}\n")
gen_mph_data('combining_mark', general_category_mark, 'u32',
lambda k: '0x{:04x}'.format(k))

def gen_stream_safe(leading, trailing, out):
# This could be done as a hash but the table is very small.
out.write("#[inline]\n")
out.write("pub fn stream_safe_leading_nonstarters(c: char) -> usize {\n")
out.write(" match c {\n")

for char, num_leading in leading.items():
out.write(" '\u{%s}' => %d,\n" % (hexify(char), num_leading))
for char, num_leading in sorted(leading.items()):
out.write(" '\\u{%s}' => %d,\n" % (hexify(char), num_leading))

out.write(" _ => 0,\n")
out.write(" }\n")
out.write("}\n")
out.write("\n")

out.write("#[inline]\n")
out.write("pub fn stream_safe_trailing_nonstarters(c: char) -> usize {\n")
out.write(" match c {\n")

for char, num_trailing in trailing.items():
out.write(" '\u{%s}' => %d,\n" % (hexify(char), num_trailing))

out.write(" _ => 0,\n")
out.write(" }\n")
out.write("}\n")
gen_mph_data('trailing_nonstarters', trailing, 'u32',
lambda k: "0x{:X}".format(int(trailing[k]) | (k << 8)))

def gen_tests(tests, out):
out.write("""#[derive(Debug)]
Expand All @@ -419,7 +400,7 @@ def gen_tests(tests, out):
""")

out.write("pub const NORMALIZATION_TESTS: &[NormalizationTest] = &[\n")
str_literal = lambda s: '"%s"' % "".join("\u{%s}" % c for c in s)
str_literal = lambda s: '"%s"' % "".join("\\u{%s}" % c for c in s)

for test in tests:
out.write(" NormalizationTest {\n")
Expand All @@ -432,9 +413,65 @@ def gen_tests(tests, out):

out.write("];\n")

# Guaranteed to be less than n.
def my_hash(x, salt, n):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably should have a comment saying "guaranteed to be less than n"

# This is hash based on the theory that multiplication is efficient
mask_32 = 0xffffffff
y = ((x + salt) * 2654435769) & mask_32
y ^= (x * 0x31415926) & mask_32
return (y * n) >> 32

# Compute minimal perfect hash function, d can be either a dict or list of keys.
def minimal_perfect_hash(d):
n = len(d)
buckets = dict((h, []) for h in range(n))
for key in d:
h = my_hash(key, 0, n)
buckets[h].append(key)
bsorted = [(len(buckets[h]), h) for h in range(n)]
bsorted.sort(reverse = True)
claimed = [False] * n
salts = [0] * n
keys = [0] * n
for (bucket_size, h) in bsorted:
# Note: the traditional perfect hashing approach would also special-case
# bucket_size == 1 here and assign any empty slot, rather than iterating
# until rehash finds an empty slot. But we're not doing that so we can
# avoid the branch.
if bucket_size == 0:
break
else:
for salt in range(1, 32768):
rehashes = [my_hash(key, salt, n) for key in buckets[h]]
# Make sure there are no rehash collisions within this bucket.
if all(not claimed[hash] for hash in rehashes):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a guarantee that we won't have a collision amongst the rehashes? Is it just really unlikely? (I suspect it's the latter but want to confirm)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if it finds a suitable salt that comes with a guarantee the rehash won't have a collision (this is what the claimed bool-array keeps track of). On the other hand, it's possible that no salt can be found that satisfies that, but I believe it to be quite a low probability. There's things that can be done to make it more robust. I'll try to add a comment outlining that in case somebody does run into it with a data update.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, wait, the set check deals with this, I'd forgotten it was there 😄 . To be clear, I was specifically worried about cases where a single run of rehashes has collisions, which claimed won't catch since we update it later.

(worth leaving a comment saying that)

if len(set(rehashes)) < bucket_size:
continue
salts[h] = salt
for key in buckets[h]:
rehash = my_hash(key, salt, n)
claimed[rehash] = True
keys[rehash] = key
break
if salts[h] == 0:
print("minimal perfect hashing failed")
# Note: if this happens (because of unfortunate data), then there are
# a few things that could be done. First, the hash function could be
# tweaked. Second, the bucket order could be scrambled (especially the
# singletons). Right now, the buckets are sorted, which has the advantage
# of being deterministic.
#
# As a more extreme approach, the singleton bucket optimization could be
# applied (give the direct address for singleton buckets, rather than
# relying on a rehash). That is definitely the more standard approach in
# the minimal perfect hashing literature, but in testing the branch was a
# significant slowdown.
exit(1)
return (salts, keys)

if __name__ == '__main__':
data = UnicodeData()
with open("tables.rs", "w") as out:
with open("tables.rs", "w", newline = "\n") as out:
out.write(PREAMBLE)
out.write("use quick_check::IsNormalized;\n")
out.write("use quick_check::IsNormalized::*;\n")
Expand Down Expand Up @@ -470,6 +507,6 @@ def gen_tests(tests, out):
gen_stream_safe(data.ss_leading, data.ss_trailing, out)
out.write("\n")

with open("normalization_tests.rs", "w") as out:
with open("normalization_tests.rs", "w", newline = "\n") as out:
out.write(PREAMBLE)
gen_tests(data.norm_tests, out)
8 changes: 3 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ pub use stream_safe::StreamSafe;
use std::str::Chars;

mod decompose;
mod lookups;
mod normalize;
mod perfect_hash;
mod recompose;
mod quick_check;
mod stream_safe;
Expand All @@ -80,11 +82,7 @@ mod normalization_tests;
pub mod char {
pub use normalize::{decompose_canonical, decompose_compatible, compose};

/// Look up the canonical combining class of a character.
pub use tables::canonical_combining_class;

/// Return whether the given character is a combining mark (`General_Category=Mark`)
pub use tables::is_combining_mark;
pub use lookups::{canonical_combining_class, is_combining_mark};
}


Expand Down
89 changes: 89 additions & 0 deletions src/lookups.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright 2019 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! Lookups of unicode properties using minimal perfect hashing.

use perfect_hash::mph_lookup;
use tables::*;

/// Look up the canonical combining class for a codepoint.
///
/// The value returned is as defined in the Unicode Character Database.
pub fn canonical_combining_class(c: char) -> u8 {
mph_lookup(c.into(), CANONICAL_COMBINING_CLASS_SALT, CANONICAL_COMBINING_CLASS_KV,
u8_lookup_fk, u8_lookup_fv, 0)
}

pub(crate) fn composition_table(c1: char, c2: char) -> Option<char> {
if c1 < '\u{10000}' && c2 < '\u{10000}' {
mph_lookup((c1 as u32) << 16 | (c2 as u32),
COMPOSITION_TABLE_SALT, COMPOSITION_TABLE_KV,
pair_lookup_fk, pair_lookup_fv_opt, None)
} else {
composition_table_astral(c1, c2)
}
}

pub(crate) fn canonical_fully_decomposed(c: char) -> Option<&'static [char]> {
mph_lookup(c.into(), CANONICAL_DECOMPOSED_SALT, CANONICAL_DECOMPOSED_KV,
pair_lookup_fk, pair_lookup_fv_opt, None)
}

pub(crate) fn compatibility_fully_decomposed(c: char) -> Option<&'static [char]> {
mph_lookup(c.into(), COMPATIBILITY_DECOMPOSED_SALT, COMPATIBILITY_DECOMPOSED_KV,
pair_lookup_fk, pair_lookup_fv_opt, None)
}

/// Return whether the given character is a combining mark (`General_Category=Mark`)
pub fn is_combining_mark(c: char) -> bool {
mph_lookup(c.into(), COMBINING_MARK_SALT, COMBINING_MARK_KV,
bool_lookup_fk, bool_lookup_fv, false)
}

pub fn stream_safe_trailing_nonstarters(c: char) -> usize {
mph_lookup(c.into(), TRAILING_NONSTARTERS_SALT, TRAILING_NONSTARTERS_KV,
u8_lookup_fk, u8_lookup_fv, 0) as usize
}

/// Extract the key in a 24 bit key and 8 bit value packed in a u32.
#[inline]
fn u8_lookup_fk(kv: u32) -> u32 {
kv >> 8
}

/// Extract the value in a 24 bit key and 8 bit value packed in a u32.
#[inline]
fn u8_lookup_fv(kv: u32) -> u8 {
(kv & 0xff) as u8
}

/// Extract the key for a boolean lookup.
#[inline]
fn bool_lookup_fk(kv: u32) -> u32 {
kv
}

/// Extract the value for a boolean lookup.
#[inline]
fn bool_lookup_fv(_kv: u32) -> bool {
true
}

/// Extract the key in a pair.
#[inline]
fn pair_lookup_fk<T>(kv: (u32, T)) -> u32 {
kv.0
}

/// Extract the value in a pair, returning an option.
#[inline]
fn pair_lookup_fv_opt<T>(kv: (u32, T)) -> Option<T> {
Some(kv.1)
}
Loading