-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use minimal perfect hashing for lookups #37
Changes from all commits
aa84f63
3a4a8f6
db57ffc
2d7bfd1
f64d47c
2e432d2
08996fa
40f9ba6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
# Since this should not require frequent updates, we just store this | ||
# out-of-line and check the unicode.rs file into git. | ||
import collections | ||
import requests | ||
import urllib.request | ||
|
||
UNICODE_VERSION = "9.0.0" | ||
UCD_URL = "https://www.unicode.org/Public/%s/ucd/" % UNICODE_VERSION | ||
|
@@ -68,9 +68,9 @@ def __init__(self): | |
|
||
def stats(name, table): | ||
count = sum(len(v) for v in table.values()) | ||
print "%s: %d chars => %d decomposed chars" % (name, len(table), count) | ||
print("%s: %d chars => %d decomposed chars" % (name, len(table), count)) | ||
|
||
print "Decomposition table stats:" | ||
print("Decomposition table stats:") | ||
stats("Canonical decomp", self.canon_decomp) | ||
stats("Compatible decomp", self.compat_decomp) | ||
stats("Canonical fully decomp", self.canon_fully_decomp) | ||
|
@@ -79,8 +79,8 @@ def stats(name, table): | |
self.ss_leading, self.ss_trailing = self._compute_stream_safe_tables() | ||
|
||
def _fetch(self, filename): | ||
resp = requests.get(UCD_URL + filename) | ||
return resp.text | ||
resp = urllib.request.urlopen(UCD_URL + filename) | ||
return resp.read().decode('utf-8') | ||
|
||
def _load_unicode_data(self): | ||
self.combining_classes = {} | ||
|
@@ -234,7 +234,7 @@ def _decompose(char_int, compatible): | |
# need to store their overlap when they agree. When they don't agree, | ||
# store the decomposition in the compatibility table since we'll check | ||
# that first when normalizing to NFKD. | ||
assert canon_fully_decomp <= compat_fully_decomp | ||
assert set(canon_fully_decomp) <= set(compat_fully_decomp) | ||
|
||
for ch in set(canon_fully_decomp) & set(compat_fully_decomp): | ||
if canon_fully_decomp[ch] == compat_fully_decomp[ch]: | ||
|
@@ -284,27 +284,37 @@ def _compute_stream_safe_tables(self): | |
|
||
return leading_nonstarters, trailing_nonstarters | ||
|
||
hexify = lambda c: hex(c)[2:].upper().rjust(4, '0') | ||
hexify = lambda c: '{:04X}'.format(c) | ||
|
||
def gen_combining_class(combining_classes, out): | ||
out.write("#[inline]\n") | ||
out.write("pub fn canonical_combining_class(c: char) -> u8 {\n") | ||
out.write(" match c {\n") | ||
|
||
for char, combining_class in sorted(combining_classes.items()): | ||
out.write(" '\u{%s}' => %s,\n" % (hexify(char), combining_class)) | ||
def gen_mph_data(name, d, kv_type, kv_callback): | ||
(salt, keys) = minimal_perfect_hash(d) | ||
out.write("pub(crate) const %s_SALT: &[u16] = &[\n" % name.upper()) | ||
for s in salt: | ||
out.write(" 0x{:x},\n".format(s)) | ||
out.write("];\n") | ||
out.write("pub(crate) const {}_KV: &[{}] = &[\n".format(name.upper(), kv_type)) | ||
for k in keys: | ||
out.write(" {},\n".format(kv_callback(k))) | ||
out.write("];\n\n") | ||
|
||
out.write(" _ => 0,\n") | ||
out.write(" }\n") | ||
out.write("}\n") | ||
def gen_combining_class(combining_classes, out): | ||
gen_mph_data('canonical_combining_class', combining_classes, 'u32', | ||
lambda k: "0x{:X}".format(int(combining_classes[k]) | (k << 8))) | ||
|
||
def gen_composition_table(canon_comp, out): | ||
out.write("#[inline]\n") | ||
out.write("pub fn composition_table(c1: char, c2: char) -> Option<char> {\n") | ||
table = {} | ||
for (c1, c2), c3 in canon_comp.items(): | ||
if c1 < 0x10000 and c2 < 0x10000: | ||
table[(c1 << 16) | c2] = c3 | ||
(salt, keys) = minimal_perfect_hash(table) | ||
gen_mph_data('COMPOSITION_TABLE', table, '(u32, char)', | ||
lambda k: "(0x%s, '\\u{%s}')" % (hexify(k), hexify(table[k]))) | ||
|
||
out.write("pub(crate) fn composition_table_astral(c1: char, c2: char) -> Option<char> {\n") | ||
out.write(" match (c1, c2) {\n") | ||
|
||
for (c1, c2), c3 in sorted(canon_comp.items()): | ||
out.write(" ('\u{%s}', '\u{%s}') => Some('\u{%s}'),\n" % (hexify(c1), hexify(c2), hexify(c3))) | ||
if c1 >= 0x10000 and c2 >= 0x10000: | ||
out.write(" ('\\u{%s}', '\\u{%s}') => Some('\\u{%s}'),\n" % (hexify(c1), hexify(c2), hexify(c3))) | ||
|
||
out.write(" _ => None,\n") | ||
out.write(" }\n") | ||
|
@@ -313,23 +323,9 @@ def gen_composition_table(canon_comp, out): | |
def gen_decomposition_tables(canon_decomp, compat_decomp, out): | ||
tables = [(canon_decomp, 'canonical'), (compat_decomp, 'compatibility')] | ||
for table, name in tables: | ||
out.write("#[inline]\n") | ||
out.write("pub fn %s_fully_decomposed(c: char) -> Option<&'static [char]> {\n" % name) | ||
# The "Some" constructor is around the match statement here, because | ||
# putting it into the individual arms would make the item_bodies | ||
# checking of rustc takes almost twice as long, and it's already pretty | ||
# slow because of the huge number of match arms and the fact that there | ||
# is a borrow inside each arm | ||
out.write(" Some(match c {\n") | ||
|
||
for char, chars in sorted(table.items()): | ||
d = ", ".join("'\u{%s}'" % hexify(c) for c in chars) | ||
out.write(" '\u{%s}' => &[%s],\n" % (hexify(char), d)) | ||
|
||
out.write(" _ => return None,\n") | ||
out.write(" })\n") | ||
out.write("}\n") | ||
out.write("\n") | ||
gen_mph_data(name + '_decomposed', table, "(u32, &'static [char])", | ||
lambda k: "(0x{:x}, &[{}])".format(k, | ||
", ".join("'\\u{%s}'" % hexify(c) for c in table[k]))) | ||
|
||
def gen_qc_match(prop_table, out): | ||
out.write(" match c {\n") | ||
|
@@ -371,40 +367,25 @@ def gen_nfkd_qc(prop_tables, out): | |
out.write("}\n") | ||
|
||
def gen_combining_mark(general_category_mark, out): | ||
out.write("#[inline]\n") | ||
out.write("pub fn is_combining_mark(c: char) -> bool {\n") | ||
out.write(" match c {\n") | ||
|
||
for char in general_category_mark: | ||
out.write(" '\u{%s}' => true,\n" % hexify(char)) | ||
|
||
out.write(" _ => false,\n") | ||
out.write(" }\n") | ||
out.write("}\n") | ||
gen_mph_data('combining_mark', general_category_mark, 'u32', | ||
lambda k: '0x{:04x}'.format(k)) | ||
|
||
def gen_stream_safe(leading, trailing, out): | ||
# This could be done as a hash but the table is very small. | ||
out.write("#[inline]\n") | ||
out.write("pub fn stream_safe_leading_nonstarters(c: char) -> usize {\n") | ||
out.write(" match c {\n") | ||
|
||
for char, num_leading in leading.items(): | ||
out.write(" '\u{%s}' => %d,\n" % (hexify(char), num_leading)) | ||
for char, num_leading in sorted(leading.items()): | ||
out.write(" '\\u{%s}' => %d,\n" % (hexify(char), num_leading)) | ||
|
||
out.write(" _ => 0,\n") | ||
out.write(" }\n") | ||
out.write("}\n") | ||
out.write("\n") | ||
|
||
out.write("#[inline]\n") | ||
out.write("pub fn stream_safe_trailing_nonstarters(c: char) -> usize {\n") | ||
out.write(" match c {\n") | ||
|
||
for char, num_trailing in trailing.items(): | ||
out.write(" '\u{%s}' => %d,\n" % (hexify(char), num_trailing)) | ||
|
||
out.write(" _ => 0,\n") | ||
out.write(" }\n") | ||
out.write("}\n") | ||
gen_mph_data('trailing_nonstarters', trailing, 'u32', | ||
lambda k: "0x{:X}".format(int(trailing[k]) | (k << 8))) | ||
|
||
def gen_tests(tests, out): | ||
out.write("""#[derive(Debug)] | ||
|
@@ -419,7 +400,7 @@ def gen_tests(tests, out): | |
""") | ||
|
||
out.write("pub const NORMALIZATION_TESTS: &[NormalizationTest] = &[\n") | ||
str_literal = lambda s: '"%s"' % "".join("\u{%s}" % c for c in s) | ||
str_literal = lambda s: '"%s"' % "".join("\\u{%s}" % c for c in s) | ||
|
||
for test in tests: | ||
out.write(" NormalizationTest {\n") | ||
|
@@ -432,9 +413,65 @@ def gen_tests(tests, out): | |
|
||
out.write("];\n") | ||
|
||
# Guaranteed to be less than n. | ||
def my_hash(x, salt, n): | ||
# This is hash based on the theory that multiplication is efficient | ||
mask_32 = 0xffffffff | ||
y = ((x + salt) * 2654435769) & mask_32 | ||
y ^= (x * 0x31415926) & mask_32 | ||
return (y * n) >> 32 | ||
|
||
# Compute minimal perfect hash function, d can be either a dict or list of keys. | ||
def minimal_perfect_hash(d): | ||
n = len(d) | ||
buckets = dict((h, []) for h in range(n)) | ||
for key in d: | ||
h = my_hash(key, 0, n) | ||
buckets[h].append(key) | ||
bsorted = [(len(buckets[h]), h) for h in range(n)] | ||
bsorted.sort(reverse = True) | ||
claimed = [False] * n | ||
salts = [0] * n | ||
keys = [0] * n | ||
for (bucket_size, h) in bsorted: | ||
# Note: the traditional perfect hashing approach would also special-case | ||
# bucket_size == 1 here and assign any empty slot, rather than iterating | ||
# until rehash finds an empty slot. But we're not doing that so we can | ||
# avoid the branch. | ||
if bucket_size == 0: | ||
break | ||
else: | ||
for salt in range(1, 32768): | ||
rehashes = [my_hash(key, salt, n) for key in buckets[h]] | ||
# Make sure there are no rehash collisions within this bucket. | ||
if all(not claimed[hash] for hash in rehashes): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a guarantee that we won't have a collision amongst the rehashes? Is it just really unlikely? (I suspect it's the latter but want to confirm) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, if it finds a suitable salt that comes with a guarantee the rehash won't have a collision (this is what the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, wait, the (worth leaving a comment saying that) |
||
if len(set(rehashes)) < bucket_size: | ||
continue | ||
salts[h] = salt | ||
for key in buckets[h]: | ||
rehash = my_hash(key, salt, n) | ||
claimed[rehash] = True | ||
keys[rehash] = key | ||
break | ||
if salts[h] == 0: | ||
print("minimal perfect hashing failed") | ||
# Note: if this happens (because of unfortunate data), then there are | ||
# a few things that could be done. First, the hash function could be | ||
# tweaked. Second, the bucket order could be scrambled (especially the | ||
# singletons). Right now, the buckets are sorted, which has the advantage | ||
# of being deterministic. | ||
# | ||
# As a more extreme approach, the singleton bucket optimization could be | ||
# applied (give the direct address for singleton buckets, rather than | ||
# relying on a rehash). That is definitely the more standard approach in | ||
# the minimal perfect hashing literature, but in testing the branch was a | ||
# significant slowdown. | ||
exit(1) | ||
return (salts, keys) | ||
|
||
if __name__ == '__main__': | ||
data = UnicodeData() | ||
with open("tables.rs", "w") as out: | ||
with open("tables.rs", "w", newline = "\n") as out: | ||
out.write(PREAMBLE) | ||
out.write("use quick_check::IsNormalized;\n") | ||
out.write("use quick_check::IsNormalized::*;\n") | ||
|
@@ -470,6 +507,6 @@ def gen_tests(tests, out): | |
gen_stream_safe(data.ss_leading, data.ss_trailing, out) | ||
out.write("\n") | ||
|
||
with open("normalization_tests.rs", "w") as out: | ||
with open("normalization_tests.rs", "w", newline = "\n") as out: | ||
out.write(PREAMBLE) | ||
gen_tests(data.norm_tests, out) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// Copyright 2019 The Rust Project Developers. See the COPYRIGHT | ||
// file at the top-level directory of this distribution and at | ||
// http://rust-lang.org/COPYRIGHT. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
|
||
//! Lookups of unicode properties using minimal perfect hashing. | ||
|
||
use perfect_hash::mph_lookup; | ||
use tables::*; | ||
|
||
/// Look up the canonical combining class for a codepoint. | ||
/// | ||
/// The value returned is as defined in the Unicode Character Database. | ||
pub fn canonical_combining_class(c: char) -> u8 { | ||
mph_lookup(c.into(), CANONICAL_COMBINING_CLASS_SALT, CANONICAL_COMBINING_CLASS_KV, | ||
u8_lookup_fk, u8_lookup_fv, 0) | ||
} | ||
|
||
pub(crate) fn composition_table(c1: char, c2: char) -> Option<char> { | ||
if c1 < '\u{10000}' && c2 < '\u{10000}' { | ||
mph_lookup((c1 as u32) << 16 | (c2 as u32), | ||
COMPOSITION_TABLE_SALT, COMPOSITION_TABLE_KV, | ||
pair_lookup_fk, pair_lookup_fv_opt, None) | ||
} else { | ||
composition_table_astral(c1, c2) | ||
} | ||
} | ||
|
||
pub(crate) fn canonical_fully_decomposed(c: char) -> Option<&'static [char]> { | ||
mph_lookup(c.into(), CANONICAL_DECOMPOSED_SALT, CANONICAL_DECOMPOSED_KV, | ||
pair_lookup_fk, pair_lookup_fv_opt, None) | ||
} | ||
|
||
pub(crate) fn compatibility_fully_decomposed(c: char) -> Option<&'static [char]> { | ||
mph_lookup(c.into(), COMPATIBILITY_DECOMPOSED_SALT, COMPATIBILITY_DECOMPOSED_KV, | ||
pair_lookup_fk, pair_lookup_fv_opt, None) | ||
} | ||
|
||
/// Return whether the given character is a combining mark (`General_Category=Mark`) | ||
pub fn is_combining_mark(c: char) -> bool { | ||
mph_lookup(c.into(), COMBINING_MARK_SALT, COMBINING_MARK_KV, | ||
bool_lookup_fk, bool_lookup_fv, false) | ||
} | ||
|
||
pub fn stream_safe_trailing_nonstarters(c: char) -> usize { | ||
mph_lookup(c.into(), TRAILING_NONSTARTERS_SALT, TRAILING_NONSTARTERS_KV, | ||
u8_lookup_fk, u8_lookup_fv, 0) as usize | ||
} | ||
|
||
/// Extract the key in a 24 bit key and 8 bit value packed in a u32. | ||
#[inline] | ||
fn u8_lookup_fk(kv: u32) -> u32 { | ||
kv >> 8 | ||
} | ||
|
||
/// Extract the value in a 24 bit key and 8 bit value packed in a u32. | ||
#[inline] | ||
fn u8_lookup_fv(kv: u32) -> u8 { | ||
(kv & 0xff) as u8 | ||
} | ||
|
||
/// Extract the key for a boolean lookup. | ||
#[inline] | ||
fn bool_lookup_fk(kv: u32) -> u32 { | ||
kv | ||
} | ||
|
||
/// Extract the value for a boolean lookup. | ||
#[inline] | ||
fn bool_lookup_fv(_kv: u32) -> bool { | ||
true | ||
} | ||
|
||
/// Extract the key in a pair. | ||
#[inline] | ||
fn pair_lookup_fk<T>(kv: (u32, T)) -> u32 { | ||
kv.0 | ||
} | ||
|
||
/// Extract the value in a pair, returning an option. | ||
#[inline] | ||
fn pair_lookup_fv_opt<T>(kv: (u32, T)) -> Option<T> { | ||
Some(kv.1) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably should have a comment saying "guaranteed to be less than
n
"