aboutsummaryrefslogtreecommitdiff
path: root/src/owned.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/owned.rs')
-rw-r--r--src/owned.rs371
1 files changed, 371 insertions, 0 deletions
diff --git a/src/owned.rs b/src/owned.rs
new file mode 100644
index 0000000..a621378
--- /dev/null
+++ b/src/owned.rs
@@ -0,0 +1,371 @@
+use std::borrow::Borrow;
+use std::collections::HashMap;
+use std::error;
+use std::fmt;
+use std::io;
+use std::result;
+
+use super::{TrieSetSlice, CHUNK_SIZE};
+
+// This implementation was pretty much cribbed from raphlinus' contribution
+// to the standard library: https://github.com/rust-lang/rust/pull/33098/files
+//
+// The fundamental principle guiding this implementation is to take advantage
+// of the fact that similar Unicode codepoints are often grouped together, and
+// that most boolean Unicode properties are quite sparse over the entire space
+// of Unicode codepoints.
+//
+// To do this, we represent sets using something like a trie (which gives us
+// prefix compression). The "final" states of the trie are embedded in leaves
+// or "chunks," where each chunk is a 64 bit integer. Each bit position of the
+// integer corresponds to whether a particular codepoint is in the set or not.
+// These chunks are not just a compact representation of the final states of
+// the trie, but are also a form of suffix compression. In particular, if
+// multiple ranges of 64 contiguous codepoints map have the same set membership
+// ordering, then they all map to the exact same chunk in the trie.
+//
+// We organize this structure by partitioning the space of Unicode codepoints
+// into three disjoint sets. The first set corresponds to codepoints
+// [0, 0x800), the second [0x800, 0x1000) and the third [0x10000, 0x110000).
+// These partitions conveniently correspond to the space of 1 or 2 byte UTF-8
+// encoded codepoints, 3 byte UTF-8 encoded codepoints and 4 byte UTF-8 encoded
+// codepoints, respectively.
+//
+// Each partition has its own tree with its own root. The first partition is
+// the simplest, since the tree is completely flat. In particular, to determine
+// the set membership of a Unicode codepoint (that is less than `0x800`), we
+// do the following (where `cp` is the codepoint we're testing):
+//
+// let chunk_address = cp >> 6;
+// let chunk_bit = cp & 0b111111;
+// let chunk = tree1[cp >> 6];
+// let is_member = 1 == ((chunk >> chunk_bit) & 1);
+//
+// We do something similar for the second partition:
+//
+// // we subtract 0x20 since (0x800 >> 6) == 0x20.
+// let child_address = (cp >> 6) - 0x20;
+// let chunk_address = tree2_level1[child_address];
+// let chunk_bit = cp & 0b111111;
+// let chunk = tree2_level2[chunk_address];
+// let is_member = 1 == ((chunk >> chunk_bit) & 1);
+//
+// And so on for the third partition.
+//
+// Note that as a special case, if the second or third partitions are empty,
+// then the trie will store empty slices for those levels. The `contains`
+// check knows to return `false` in those cases.
+
+const CHUNKS: usize = 0x110000 / CHUNK_SIZE;
+
+/// A type alias that maps to `std::result::Result<T, ucd_trie::Error>`.
+pub type Result<T> = result::Result<T, Error>;
+
+/// An error that can occur during construction of a trie.
+#[derive(Clone, Debug)]
+pub enum Error {
+ /// This error is returned when an invalid codepoint is given to
+ /// `TrieSetOwned::from_codepoints`. An invalid codepoint is a `u32` that
+ /// is greater than `0x10FFFF`.
+ InvalidCodepoint(u32),
+ /// This error is returned when a set of Unicode codepoints could not be
+ /// sufficiently compressed into the trie provided by this crate. There is
+ /// no work-around for this error at this time.
+ GaveUp,
+}
+
+impl error::Error for Error {}
+
+impl fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match *self {
+ Error::InvalidCodepoint(cp) => write!(
+ f,
+ "could not construct trie set containing an \
+ invalid Unicode codepoint: 0x{:X}",
+ cp
+ ),
+ Error::GaveUp => {
+ write!(f, "could not compress codepoint set into a trie")
+ }
+ }
+ }
+}
+
+impl From<Error> for io::Error {
+ fn from(err: Error) -> io::Error {
+ io::Error::new(io::ErrorKind::Other, err)
+ }
+}
+
+/// An owned trie set.
+#[derive(Clone)]
+pub struct TrieSetOwned {
+ tree1_level1: Vec<u64>,
+ tree2_level1: Vec<u8>,
+ tree2_level2: Vec<u64>,
+ tree3_level1: Vec<u8>,
+ tree3_level2: Vec<u8>,
+ tree3_level3: Vec<u64>,
+}
+
+impl fmt::Debug for TrieSetOwned {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "TrieSetOwned(...)")
+ }
+}
+
+impl TrieSetOwned {
+ fn new(all: &[bool]) -> Result<TrieSetOwned> {
+ let mut bitvectors = Vec::with_capacity(CHUNKS);
+ for i in 0..CHUNKS {
+ let mut bitvector = 0u64;
+ for j in 0..CHUNK_SIZE {
+ if all[i * CHUNK_SIZE + j] {
+ bitvector |= 1 << j;
+ }
+ }
+ bitvectors.push(bitvector);
+ }
+
+ let tree1_level1 =
+ bitvectors.iter().cloned().take(0x800 / CHUNK_SIZE).collect();
+
+ let (mut tree2_level1, mut tree2_level2) = compress_postfix_leaves(
+ &bitvectors[0x800 / CHUNK_SIZE..0x10000 / CHUNK_SIZE],
+ )?;
+ if tree2_level2.len() == 1 && tree2_level2[0] == 0 {
+ tree2_level1.clear();
+ tree2_level2.clear();
+ }
+
+ let (mid, mut tree3_level3) = compress_postfix_leaves(
+ &bitvectors[0x10000 / CHUNK_SIZE..0x110000 / CHUNK_SIZE],
+ )?;
+ let (mut tree3_level1, mut tree3_level2) =
+ compress_postfix_mid(&mid, 64)?;
+ if tree3_level3.len() == 1 && tree3_level3[0] == 0 {
+ tree3_level1.clear();
+ tree3_level2.clear();
+ tree3_level3.clear();
+ }
+
+ Ok(TrieSetOwned {
+ tree1_level1,
+ tree2_level1,
+ tree2_level2,
+ tree3_level1,
+ tree3_level2,
+ tree3_level3,
+ })
+ }
+
+ /// Create a new trie set from a set of Unicode scalar values.
+ ///
+ /// This returns an error if a set could not be sufficiently compressed to
+ /// fit into a trie.
+ pub fn from_scalars<I, C>(scalars: I) -> Result<TrieSetOwned>
+ where
+ I: IntoIterator<Item = C>,
+ C: Borrow<char>,
+ {
+ let mut all = vec![false; 0x110000];
+ for s in scalars {
+ all[*s.borrow() as usize] = true;
+ }
+ TrieSetOwned::new(&all)
+ }
+
+ /// Create a new trie set from a set of Unicode scalar values.
+ ///
+ /// This returns an error if a set could not be sufficiently compressed to
+ /// fit into a trie. This also returns an error if any of the given
+ /// codepoints are greater than `0x10FFFF`.
+ pub fn from_codepoints<I, C>(codepoints: I) -> Result<TrieSetOwned>
+ where
+ I: IntoIterator<Item = C>,
+ C: Borrow<u32>,
+ {
+ let mut all = vec![false; 0x110000];
+ for cp in codepoints {
+ let cp = *cp.borrow();
+ if cp > 0x10FFFF {
+ return Err(Error::InvalidCodepoint(cp));
+ }
+ all[cp as usize] = true;
+ }
+ TrieSetOwned::new(&all)
+ }
+
+ /// Return this set as a slice.
+ #[inline(always)]
+ pub fn as_slice(&self) -> TrieSetSlice<'_> {
+ TrieSetSlice {
+ tree1_level1: &self.tree1_level1,
+ tree2_level1: &self.tree2_level1,
+ tree2_level2: &self.tree2_level2,
+ tree3_level1: &self.tree3_level1,
+ tree3_level2: &self.tree3_level2,
+ tree3_level3: &self.tree3_level3,
+ }
+ }
+
+ /// Returns true if and only if the given Unicode scalar value is in this
+ /// set.
+ pub fn contains_char(&self, c: char) -> bool {
+ self.as_slice().contains_char(c)
+ }
+
+ /// Returns true if and only if the given codepoint is in this set.
+ ///
+ /// If the given value exceeds the codepoint range (i.e., it's greater
+ /// than `0x10FFFF`), then this returns false.
+ pub fn contains_u32(&self, cp: u32) -> bool {
+ self.as_slice().contains_u32(cp)
+ }
+}
+
+fn compress_postfix_leaves(chunks: &[u64]) -> Result<(Vec<u8>, Vec<u64>)> {
+ let mut root = vec![];
+ let mut children = vec![];
+ let mut bychild = HashMap::new();
+ for &chunk in chunks {
+ if !bychild.contains_key(&chunk) {
+ let start = bychild.len();
+ if start > ::std::u8::MAX as usize {
+ return Err(Error::GaveUp);
+ }
+ bychild.insert(chunk, start as u8);
+ children.push(chunk);
+ }
+ root.push(bychild[&chunk]);
+ }
+ Ok((root, children))
+}
+
+fn compress_postfix_mid(
+ chunks: &[u8],
+ chunk_size: usize,
+) -> Result<(Vec<u8>, Vec<u8>)> {
+ let mut root = vec![];
+ let mut children = vec![];
+ let mut bychild = HashMap::new();
+ for i in 0..(chunks.len() / chunk_size) {
+ let chunk = &chunks[i * chunk_size..(i + 1) * chunk_size];
+ if !bychild.contains_key(chunk) {
+ let start = bychild.len();
+ if start > ::std::u8::MAX as usize {
+ return Err(Error::GaveUp);
+ }
+ bychild.insert(chunk, start as u8);
+ children.extend(chunk);
+ }
+ root.push(bychild[chunk]);
+ }
+ Ok((root, children))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::TrieSetOwned;
+ use crate::general_category;
+ use std::collections::HashSet;
+
+ fn mk(scalars: &[char]) -> TrieSetOwned {
+ TrieSetOwned::from_scalars(scalars).unwrap()
+ }
+
+ fn ranges_to_set(ranges: &[(u32, u32)]) -> Vec<u32> {
+ let mut set = vec![];
+ for &(start, end) in ranges {
+ for cp in start..end + 1 {
+ set.push(cp);
+ }
+ }
+ set
+ }
+
+ #[test]
+ fn set1() {
+ let set = mk(&['a']);
+ assert!(set.contains_char('a'));
+ assert!(!set.contains_char('b'));
+ assert!(!set.contains_char('β'));
+ assert!(!set.contains_char('☃'));
+ assert!(!set.contains_char('😼'));
+ }
+
+ #[test]
+ fn set_combined() {
+ let set = mk(&['a', 'b', 'β', '☃', '😼']);
+ assert!(set.contains_char('a'));
+ assert!(set.contains_char('b'));
+ assert!(set.contains_char('β'));
+ assert!(set.contains_char('☃'));
+ assert!(set.contains_char('😼'));
+
+ assert!(!set.contains_char('c'));
+ assert!(!set.contains_char('θ'));
+ assert!(!set.contains_char('⛇'));
+ assert!(!set.contains_char('🐲'));
+ }
+
+ // Basic tests on all of the general category sets. We check that
+ // membership is correct on every Unicode codepoint... because we can.
+
+ macro_rules! category_test {
+ ($name:ident, $ranges:ident) => {
+ #[test]
+ fn $name() {
+ let set = ranges_to_set(general_category::$ranges);
+ let hashset: HashSet<u32> = set.iter().cloned().collect();
+ let trie = TrieSetOwned::from_codepoints(&set).unwrap();
+ for cp in 0..0x110000 {
+ assert!(trie.contains_u32(cp) == hashset.contains(&cp));
+ }
+ // Test that an invalid codepoint is treated correctly.
+ assert!(!trie.contains_u32(0x110000));
+ assert!(!hashset.contains(&0x110000));
+ }
+ };
+ }
+
+ category_test!(gencat_cased_letter, CASED_LETTER);
+ category_test!(gencat_close_punctuation, CLOSE_PUNCTUATION);
+ category_test!(gencat_connector_punctuation, CONNECTOR_PUNCTUATION);
+ category_test!(gencat_control, CONTROL);
+ category_test!(gencat_currency_symbol, CURRENCY_SYMBOL);
+ category_test!(gencat_dash_punctuation, DASH_PUNCTUATION);
+ category_test!(gencat_decimal_number, DECIMAL_NUMBER);
+ category_test!(gencat_enclosing_mark, ENCLOSING_MARK);
+ category_test!(gencat_final_punctuation, FINAL_PUNCTUATION);
+ category_test!(gencat_format, FORMAT);
+ category_test!(gencat_initial_punctuation, INITIAL_PUNCTUATION);
+ category_test!(gencat_letter, LETTER);
+ category_test!(gencat_letter_number, LETTER_NUMBER);
+ category_test!(gencat_line_separator, LINE_SEPARATOR);
+ category_test!(gencat_lowercase_letter, LOWERCASE_LETTER);
+ category_test!(gencat_math_symbol, MATH_SYMBOL);
+ category_test!(gencat_mark, MARK);
+ category_test!(gencat_modifier_letter, MODIFIER_LETTER);
+ category_test!(gencat_modifier_symbol, MODIFIER_SYMBOL);
+ category_test!(gencat_nonspacing_mark, NONSPACING_MARK);
+ category_test!(gencat_number, NUMBER);
+ category_test!(gencat_open_punctuation, OPEN_PUNCTUATION);
+ category_test!(gencat_other, OTHER);
+ category_test!(gencat_other_letter, OTHER_LETTER);
+ category_test!(gencat_other_number, OTHER_NUMBER);
+ category_test!(gencat_other_punctuation, OTHER_PUNCTUATION);
+ category_test!(gencat_other_symbol, OTHER_SYMBOL);
+ category_test!(gencat_paragraph_separator, PARAGRAPH_SEPARATOR);
+ category_test!(gencat_private_use, PRIVATE_USE);
+ category_test!(gencat_punctuation, PUNCTUATION);
+ category_test!(gencat_separator, SEPARATOR);
+ category_test!(gencat_space_separator, SPACE_SEPARATOR);
+ category_test!(gencat_spacing_mark, SPACING_MARK);
+ category_test!(gencat_surrogate, SURROGATE);
+ category_test!(gencat_symbol, SYMBOL);
+ category_test!(gencat_titlecase_letter, TITLECASE_LETTER);
+ category_test!(gencat_unassigned, UNASSIGNED);
+ category_test!(gencat_uppercase_letter, UPPERCASE_LETTER);
+}