diff options
Diffstat (limited to 'src/include/fst')
40 files changed, 1999 insertions, 718 deletions
diff --git a/src/include/fst/arc-map.h b/src/include/fst/arc-map.h index 914f81c..c33b546 100644 --- a/src/include/fst/arc-map.h +++ b/src/include/fst/arc-map.h @@ -165,8 +165,8 @@ void ArcMap(MutableFst<A> *fst, C* mapper) { } else { fst->SetFinal(s, final_arc.weight); } - break; } + break; } case MAP_REQUIRE_SUPERFINAL: { if (s != superfinal) { @@ -315,8 +315,6 @@ class ArcMapFstImpl : public CacheImpl<B> { using FstImpl<B>::SetInputSymbols; using FstImpl<B>::SetOutputSymbols; - using VectorFstBaseImpl<typename CacheImpl<B>::State>::NumStates; - using CacheImpl<B>::PushArc; using CacheImpl<B>::HasArcs; using CacheImpl<B>::HasFinal; diff --git a/src/include/fst/bi-table.h b/src/include/fst/bi-table.h index bd37781..d220ce4 100644 --- a/src/include/fst/bi-table.h +++ b/src/include/fst/bi-table.h @@ -23,9 +23,15 @@ #define FST_LIB_BI_TABLE_H__ #include <deque> +using std::deque; +#include <functional> #include <vector> using std::vector; +#include <tr1/unordered_set> +using std::tr1::unordered_set; +using std::tr1::unordered_multiset; + namespace fst { // BI TABLES - these determine a bijective mapping between an @@ -49,14 +55,33 @@ namespace fst { // }; // An implementation using a hash map for the entry to ID mapping. -// The entry T must have == defined and the default constructor -// must produce an entry that will never be seen. H is the hash function. -template <class I, class T, class H> +// H is the hash function and E is the equality function. +// If passed to the constructor, ownership is given to this class. + +template <class I, class T, class H, class E = std::equal_to<T> > class HashBiTable { public: + // Reserves space for 'table_size' elements. + explicit HashBiTable(size_t table_size = 0, H *h = 0, E *e = 0) + : hash_func_(h), + hash_equal_(e), + entry2id_(table_size, (h ? *h : H()), (e ? *e : E())) { + if (table_size) + id2entry_.reserve(table_size); + } - HashBiTable() { - T empty_entry; + HashBiTable(const HashBiTable<I, T, H, E> &table) + : hash_func_(table.hash_func_ ? new H(*table.hash_func_) : 0), + hash_equal_(table.hash_equal_ ? new E(*table.hash_equal_) : 0), + entry2id_(table.entry2id_.begin(), table.entry2id_.end(), + table.entry2id_.size(), + (hash_func_ ? *hash_func_ : H()), + (hash_equal_ ? *hash_equal_ : E())), + id2entry_(table.id2entry_) { } + + ~HashBiTable() { + delete hash_func_; + delete hash_equal_; } I FindId(const T &entry, bool insert = true) { @@ -79,39 +104,67 @@ class HashBiTable { I Size() const { return id2entry_.size(); } private: - unordered_map<T, I, H> entry2id_; + H *hash_func_; + E *hash_equal_; + unordered_map<T, I, H, E> entry2id_; vector<T> id2entry_; - DISALLOW_COPY_AND_ASSIGN(HashBiTable); + void operator=(const HashBiTable<I, T, H, E> &table); // disallow +}; + + +// Enables alternative hash set representations below. +// typedef enum { HS_STL = 0, HS_DENSE = 1, HS_SPARSE = 2 } HSType; +typedef enum { HS_STL = 0, HS_DENSE = 1, HS_SPARSE = 2 } HSType; + +// Default hash set is STL hash_set +template<class K, class H, class E, HSType> +struct HashSet : public unordered_set<K, H, E> { + HashSet(size_t n = 0, const H &h = H(), const E &e = E()) + : unordered_set<K, H, E>(n, h, e) { } + + void rehash(size_t n) { } }; -// An implementation using a hash set for the entry to ID -// mapping. The hash set holds 'keys' which are either the ID -// or kCurrentKey. These keys can be mapped to entrys either by -// looking up in the entry vector or, if kCurrentKey, in current_entry_ -// member. The hash and key equality functions map to entries first. -// The entry T must have == defined and the default constructor -// must produce a entry that will never be seen. H is the hash -// function. -template <class I, class T, class H> +// An implementation using a hash set for the entry to ID mapping. +// The hash set holds 'keys' which are either the ID or kCurrentKey. +// These keys can be mapped to entrys either by looking up in the +// entry vector or, if kCurrentKey, in current_entry_ member. The hash +// and key equality functions map to entries first. H +// is the hash function and E is the equality function. If passed to +// the constructor, ownership is given to this class. +template <class I, class T, class H, + class E = std::equal_to<T>, HSType HS = HS_DENSE> class CompactHashBiTable { public: friend class HashFunc; friend class HashEqual; - CompactHashBiTable() - : hash_func_(*this), - hash_equal_(*this), - keys_(0, hash_func_, hash_equal_) { + // Reserves space for 'table_size' elements. + explicit CompactHashBiTable(size_t table_size = 0, H *h = 0, E *e = 0) + : hash_func_(h), + hash_equal_(e), + compact_hash_func_(*this), + compact_hash_equal_(*this), + keys_(table_size, compact_hash_func_, compact_hash_equal_) { + if (table_size) + id2entry_.reserve(table_size); } - // Reserves space for table_size elements. - explicit CompactHashBiTable(size_t table_size) - : hash_func_(*this), - hash_equal_(*this), - keys_(table_size, hash_func_, hash_equal_) { - id2entry_.reserve(table_size); + CompactHashBiTable(const CompactHashBiTable<I, T, H, E, HS> &table) + : hash_func_(table.hash_func_ ? new H(*table.hash_func_) : 0), + hash_equal_(table.hash_equal_ ? new E(*table.hash_equal_) : 0), + compact_hash_func_(*this), + compact_hash_equal_(*this), + keys_(table.keys_.size(), compact_hash_func_, compact_hash_equal_), + id2entry_(table.id2entry_) { + keys_.insert(table.keys_.begin(), table.keys_.end()); + } + + ~CompactHashBiTable() { + delete hash_func_; + delete hash_equal_; } I FindId(const T &entry, bool insert = true) { @@ -132,20 +185,40 @@ class CompactHashBiTable { } const T &FindEntry(I s) const { return id2entry_[s]; } + I Size() const { return id2entry_.size(); } + // Clear content. With argument, erases last n IDs. + void Clear(ssize_t n = -1) { + if (n < 0 || n > id2entry_.size()) + n = id2entry_.size(); + while (n-- > 0) { + I key = id2entry_.size() - 1; + keys_.erase(key); + id2entry_.pop_back(); + } + keys_.rehash(0); + } + private: - static const I kEmptyKey; // -1 - static const I kCurrentKey; // -2 + static const I kCurrentKey; // -1 + static const I kEmptyKey; // -2 + static const I kDeletedKey; // -3 class HashFunc { public: HashFunc(const CompactHashBiTable &ht) : ht_(&ht) {} - size_t operator()(I k) const { return hf(ht_->Key2T(k)); } + size_t operator()(I k) const { + if (k >= kCurrentKey) { + return (*ht_->hash_func_)(ht_->Key2Entry(k)); + } else { + return 0; + } + } + private: const CompactHashBiTable *ht_; - H hf; }; class HashEqual { @@ -153,38 +226,45 @@ class CompactHashBiTable { HashEqual(const CompactHashBiTable &ht) : ht_(&ht) {} bool operator()(I k1, I k2) const { - return ht_->Key2T(k1) == ht_->Key2T(k2); + if (k1 >= kCurrentKey && k2 >= kCurrentKey) { + return (*ht_->hash_equal_)(ht_->Key2Entry(k1), ht_->Key2Entry(k2)); + } else { + return k1 == k2; + } } private: const CompactHashBiTable *ht_; }; - typedef unordered_set<I, HashFunc, HashEqual> KeyHashSet; + typedef HashSet<I, HashFunc, HashEqual, HS> KeyHashSet; - const T &Key2T(I k) const { - if (k == kEmptyKey) - return empty_entry_; - else if (k == kCurrentKey) + const T &Key2Entry(I k) const { + if (k == kCurrentKey) return *current_entry_; else return id2entry_[k]; } - HashFunc hash_func_; - HashEqual hash_equal_; + H *hash_func_; + E *hash_equal_; + HashFunc compact_hash_func_; + HashEqual compact_hash_equal_; KeyHashSet keys_; vector<T> id2entry_; - const T empty_entry_; const T *current_entry_; - DISALLOW_COPY_AND_ASSIGN(CompactHashBiTable); + void operator=(const CompactHashBiTable<I, T, H, E, HS> &table); // disallow }; -template <class I, class T, class H> -const I CompactHashBiTable<I, T, H>::kEmptyKey = -1; -template <class I, class T, class H> -const I CompactHashBiTable<I, T, H>::kCurrentKey = -2; +template <class I, class T, class H, class E, HSType HS> +const I CompactHashBiTable<I, T, H, E, HS>::kCurrentKey = -1; + +template <class I, class T, class H, class E, HSType HS> +const I CompactHashBiTable<I, T, H, E, HS>::kEmptyKey = -2; + +template <class I, class T, class H, class E, HSType HS> +const I CompactHashBiTable<I, T, H, E, HS>::kDeletedKey = -3; // An implementation using a vector for the entry to ID mapping. @@ -196,7 +276,17 @@ const I CompactHashBiTable<I, T, H>::kCurrentKey = -2; template <class I, class T, class FP> class VectorBiTable { public: - explicit VectorBiTable(FP *fp = 0) : fp_(fp ? fp : new FP()) {} + // Reserves space for 'table_size' elements. + explicit VectorBiTable(FP *fp = 0, size_t table_size = 0) + : fp_(fp ? fp : new FP()) { + if (table_size) + id2entry_.reserve(table_size); + } + + VectorBiTable(const VectorBiTable<I, T, FP> &table) + : fp_(table.fp_ ? new FP(*table.fp_) : 0), + fp2id_(table.fp2id_), + id2entry_(table.id2entry_) { } ~VectorBiTable() { delete fp_; } @@ -227,7 +317,7 @@ class VectorBiTable { vector<I> fp2id_; vector<T> id2entry_; - DISALLOW_COPY_AND_ASSIGN(VectorBiTable); + void operator=(const VectorBiTable<I, T, FP> &table); // disallow }; @@ -235,20 +325,21 @@ class VectorBiTable { // selecting functor S returns true for entries to be hashed in the // vector. The fingerprinting functor FP returns a unique fingerprint // for each entry to be hashed in the vector (these need to be -// suitable for indexing in a vector). The hash functor H is used when -// hashing entry into the compact hash table. -template <class I, class T, class S, class FP, class H> +// suitable for indexing in a vector). The hash functor H is used +// when hashing entry into the compact hash table. If passed to the +// constructor, ownership is given to this class. +template <class I, class T, class S, class FP, class H, HSType HS = HS_DENSE> class VectorHashBiTable { public: friend class HashFunc; friend class HashEqual; - VectorHashBiTable(S *s, FP *fp, H *h, - size_t vector_size = 0, - size_t entry_size = 0) + explicit VectorHashBiTable(S *s, FP *fp = 0, H *h = 0, + size_t vector_size = 0, + size_t entry_size = 0) : selector_(s), - fp_(fp), - h_(h), + fp_(fp ? fp : new FP()), + h_(h ? h : new H()), hash_func_(*this), hash_equal_(*this), keys_(0, hash_func_, hash_equal_) { @@ -256,6 +347,18 @@ class VectorHashBiTable { fp2id_.reserve(vector_size); if (entry_size) id2entry_.reserve(entry_size); + } + + VectorHashBiTable(const VectorHashBiTable<I, T, S, FP, H, HS> &table) + : selector_(new S(table.s_)), + fp_(table.fp_ ? new FP(*table.fp_) : 0), + h_(table.h_ ? new H(*table.h_) : 0), + id2entry_(table.id2entry_), + fp2id_(table.fp2id_), + hash_func_(*this), + hash_equal_(*this), + keys_(table.keys_.size(), hash_func_, hash_equal_) { + keys_.insert(table.keys_.begin(), table.keys_.end()); } ~VectorHashBiTable() { @@ -309,14 +412,20 @@ class VectorHashBiTable { const H &Hash() const { return *h_; } private: - static const I kEmptyKey; - static const I kCurrentKey; + static const I kCurrentKey; // -1 + static const I kEmptyKey; // -2 class HashFunc { public: HashFunc(const VectorHashBiTable &ht) : ht_(&ht) {} - size_t operator()(I k) const { return (*(ht_->h_))(ht_->Key2Entry(k)); } + size_t operator()(I k) const { + if (k >= kCurrentKey) { + return (*(ht_->h_))(ht_->Key2Entry(k)); + } else { + return 0; + } + } private: const VectorHashBiTable *ht_; }; @@ -326,53 +435,54 @@ class VectorHashBiTable { HashEqual(const VectorHashBiTable &ht) : ht_(&ht) {} bool operator()(I k1, I k2) const { - return ht_->Key2Entry(k1) == ht_->Key2Entry(k2); + if (k1 >= kCurrentKey && k2 >= kCurrentKey) { + return ht_->Key2Entry(k1) == ht_->Key2Entry(k2); + } else { + return k1 == k2; + } } private: const VectorHashBiTable *ht_; }; - typedef unordered_set<I, HashFunc, HashEqual> KeyHashSet; + typedef HashSet<I, HashFunc, HashEqual, HS> KeyHashSet; const T &Key2Entry(I k) const { - if (k == kEmptyKey) - return empty_entry_; - else if (k == kCurrentKey) + if (k == kCurrentKey) return *current_entry_; else return id2entry_[k]; } - S *selector_; // Returns true if entry hashed into vector FP *fp_; // Fingerprint used when hashing entry into vector H *h_; // Hash function used when hashing entry into hash_set vector<T> id2entry_; // Maps state IDs to entry - vector<I> fp2id_; // Maps entry fingerprints to IDs + vector<I> fp2id_; // Maps entry fingerprints to IDs // Compact implementation of the hash table mapping entrys to // state IDs using the hash function 'h_' HashFunc hash_func_; HashEqual hash_equal_; KeyHashSet keys_; - const T empty_entry_; const T *current_entry_; - DISALLOW_COPY_AND_ASSIGN(VectorHashBiTable); + // disallow + void operator=(const VectorHashBiTable<I, T, S, FP, H, HS> &table); }; -template <class I, class T, class S, class FP, class H> -const I VectorHashBiTable<I, T, S, FP, H>::kEmptyKey = -1; +template <class I, class T, class S, class FP, class H, HSType HS> +const I VectorHashBiTable<I, T, S, FP, H, HS>::kCurrentKey = -1; -template <class I, class T, class S, class FP, class H> -const I VectorHashBiTable<I, T, S, FP, H>::kCurrentKey = -2; +template <class I, class T, class S, class FP, class H, HSType HS> +const I VectorHashBiTable<I, T, S, FP, H, HS>::kEmptyKey = -3; // An implementation using a hash map for the entry to ID -// mapping. This version permits erasing of s. The entry T -// must have == defined and its default constructor must produce a -// entry that will never be seen. F is the hash function. +// mapping. This version permits erasing of arbitrary states. The +// entry T must have == defined and its default constructor must +// produce a entry that will never be seen. F is the hash function. template <class I, class T, class F> class ErasableBiTable { public: @@ -413,7 +523,8 @@ class ErasableBiTable { const T empty_entry_; I first_; // I of first element in the deque; - DISALLOW_COPY_AND_ASSIGN(ErasableBiTable); + // disallow + void operator=(const ErasableBiTable<I, T, F> &table); //disallow }; } // namespace fst diff --git a/src/include/fst/cache.h b/src/include/fst/cache.h index 0177396..7c96fe1 100644 --- a/src/include/fst/cache.h +++ b/src/include/fst/cache.h @@ -89,14 +89,15 @@ struct DefaultCacheStateAllocator { // CacheState below). This class is used to cache FST elements with // the flags used to indicate what has been cached. Use HasStart() // HasFinal(), and HasArcs() to determine if cached and SetStart(), -// SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note you -// must set the final weight even if the state is non-final to mark it as -// cached. If the 'gc' option is 'false', cached items have the extent -// of the FST - minimizing computation. If the 'gc' option is 'true', -// garbage collection of states (not in use in an arc iterator) is -// performed, in a rough approximation of LRU order, when 'gc_limit' -// bytes is reached - controlling memory use. When 'gc_limit' is 0, -// special optimizations apply - minimizing memory use. +// SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note +// you must set the final weight even if the state is non-final to +// mark it as cached. If the 'gc' option is 'false', cached items have +// the extent of the FST - minimizing computation. If the 'gc' option +// is 'true', garbage collection of states (not in use in an arc +// iterator and not 'protected') is performed, in a rough +// approximation of LRU order, when 'gc_limit' bytes is reached - +// controlling memory use. When 'gc_limit' is 0, special optimizations +// apply - minimizing memory use. template <class S, class C = DefaultCacheStateAllocator<S> > class CacheBaseImpl : public VectorFstBaseImpl<S> { @@ -111,8 +112,10 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> { using FstImpl<Arc>::Properties; using FstImpl<Arc>::SetProperties; using VectorFstBaseImpl<State>::NumStates; + using VectorFstBaseImpl<State>::Start; using VectorFstBaseImpl<State>::AddState; using VectorFstBaseImpl<State>::SetState; + using VectorFstBaseImpl<State>::ReserveStates; explicit CacheBaseImpl(C *allocator = 0) : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0), @@ -120,27 +123,57 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> { cache_gc_(FLAGS_fst_default_cache_gc), cache_size_(0), cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit || FLAGS_fst_default_cache_gc_limit == 0 ? - FLAGS_fst_default_cache_gc_limit : kMinCacheLimit) { - allocator_ = allocator ? allocator : new C(); - } + FLAGS_fst_default_cache_gc_limit : kMinCacheLimit), + protect_(false) { + allocator_ = allocator ? allocator : new C(); + } explicit CacheBaseImpl(const CacheOptions &opts, C *allocator = 0) : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId), cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0), cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ? - opts.gc_limit : kMinCacheLimit) { - allocator_ = allocator ? allocator : new C(); - } + opts.gc_limit : kMinCacheLimit), + protect_(false) { + allocator_ = allocator ? allocator : new C(); + } - // Preserve gc parameters, but initially cache nothing. - CacheBaseImpl(const CacheBaseImpl &impl) - : cache_start_(false), nknown_states_(0), - min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId), - cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0), - cache_limit_(impl.cache_limit_) { - allocator_ = new C(); + // Preserve gc parameters. If preserve_cache true, also preserves + // cache data. + CacheBaseImpl(const CacheBaseImpl<S, C> &impl, bool preserve_cache = false) + : VectorFstBaseImpl<S>(), cache_start_(false), nknown_states_(0), + min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId), + cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0), + cache_limit_(impl.cache_limit_), + protect_(impl.protect_) { + allocator_ = new C(); + if (preserve_cache) { + cache_start_ = impl.cache_start_; + nknown_states_ = impl.nknown_states_; + expanded_states_ = impl.expanded_states_; + min_unexpanded_state_id_ = impl.min_unexpanded_state_id_; + if (impl.cache_first_state_id_ != kNoStateId) { + cache_first_state_id_ = impl.cache_first_state_id_; + cache_first_state_ = allocator_->Allocate(cache_first_state_id_); + *cache_first_state_ = *impl.cache_first_state_; } + cache_states_ = impl.cache_states_; + cache_size_ = impl.cache_size_; + ReserveStates(impl.NumStates()); + for (StateId s = 0; s < impl.NumStates(); ++s) { + const S *state = + static_cast<const VectorFstBaseImpl<S> &>(impl).GetState(s); + if (state) { + S *copied_state = allocator_->Allocate(s); + *copied_state = *state; + AddState(copied_state); + } else { + AddState(0); + } + } + VectorFstBaseImpl<S>::SetStart(impl.Start()); + } + } ~CacheBaseImpl() { allocator_->Free(cache_first_state_, cache_first_state_id_); @@ -174,49 +207,7 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> { } // Gets a state from its ID; add it if necessary. - S *ExtendState(StateId s) { - if (s == cache_first_state_id_) { - return cache_first_state_; // Return 1st cached state - } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) { - cache_first_state_id_ = s; // Remember 1st cached state - cache_first_state_ = allocator_->Allocate(s); - return cache_first_state_; - } else if (cache_first_state_id_ != kNoStateId && - cache_first_state_->ref_count == 0) { - // With Default allocator, the Free and Allocate will reuse the same S*. - allocator_->Free(cache_first_state_, cache_first_state_id_); - cache_first_state_id_ = s; - cache_first_state_ = allocator_->Allocate(s); - return cache_first_state_; // Return 1st cached state - } else { - while (NumStates() <= s) // Add state to main cache - AddState(0); - if (!VectorFstBaseImpl<S>::GetState(s)) { - SetState(s, allocator_->Allocate(s)); - if (cache_first_state_id_ != kNoStateId) { // Forget 1st cached state - while (NumStates() <= cache_first_state_id_) - AddState(0); - SetState(cache_first_state_id_, cache_first_state_); - if (cache_gc_) { - cache_states_.push_back(cache_first_state_id_); - cache_size_ += sizeof(S) + - cache_first_state_->arcs.capacity() * sizeof(Arc); - } - cache_limit_ = kMinCacheLimit; - cache_first_state_id_ = kNoStateId; - cache_first_state_ = 0; - } - if (cache_gc_) { - cache_states_.push_back(s); - cache_size_ += sizeof(S); - if (cache_size_ > cache_limit_) - GC(s, false); - } - } - S *state = VectorFstBaseImpl<S>::GetState(s); - return state; - } - } + S *ExtendState(StateId s); void SetStart(StateId s) { VectorFstBaseImpl<S>::SetStart(s); @@ -246,7 +237,8 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> { const Arc *parc = state->arcs.empty() ? 0 : &(state->arcs.back()); SetProperties(AddArcProperties(Properties(), s, arc, parc)); state->flags |= kCacheModified; - if (cache_gc_ && s != cache_first_state_id_) { + if (cache_gc_ && s != cache_first_state_id_ && + !(state->flags & kCacheProtect)) { cache_size_ += sizeof(Arc); if (cache_size_ > cache_limit_) GC(s, false); @@ -278,7 +270,8 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> { } ExpandedState(s); state->flags |= kCacheArcs | kCacheRecent | kCacheModified; - if (cache_gc_ && s != cache_first_state_id_) { + if (cache_gc_ && s != cache_first_state_id_ && + !(state->flags & kCacheProtect)) { cache_size_ += arcs.capacity() * sizeof(Arc); if (cache_size_ > cache_limit_) GC(s, false); @@ -300,18 +293,73 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> { if (arcs[j].olabel == 0) --state->noepsilons; } + state->arcs.resize(arcs.size() - n); SetProperties(DeleteArcsProperties(Properties())); state->flags |= kCacheModified; + if (cache_gc_ && s != cache_first_state_id_ && + !(state->flags & kCacheProtect)) { + cache_size_ -= n * sizeof(Arc); + } } void DeleteArcs(StateId s) { S *state = ExtendState(s); + size_t n = state->arcs.size(); state->niepsilons = 0; state->noepsilons = 0; state->arcs.clear(); SetProperties(DeleteArcsProperties(Properties())); state->flags |= kCacheModified; + if (cache_gc_ && s != cache_first_state_id_ && + !(state->flags & kCacheProtect)) { + cache_size_ -= n * sizeof(Arc); + } + } + + void DeleteStates(const vector<StateId> &dstates) { + size_t old_num_states = NumStates(); + vector<StateId> newid(old_num_states, 0); + for (size_t i = 0; i < dstates.size(); ++i) + newid[dstates[i]] = kNoStateId; + StateId nstates = 0; + for (StateId s = 0; s < old_num_states; ++s) { + if (newid[s] != kNoStateId) { + newid[s] = nstates; + ++nstates; + } + } + // just for states_.resize(), does unnecessary walk. + VectorFstBaseImpl<S>::DeleteStates(dstates); + SetProperties(DeleteStatesProperties(Properties())); + // Update list of cached states. + typename list<StateId>::iterator siter = cache_states_.begin(); + while (siter != cache_states_.end()) { + if (newid[*siter] != kNoStateId) { + *siter = newid[*siter]; + ++siter; + } else { + cache_states_.erase(siter++); + } + } + } + + void DeleteStates() { + cache_states_.clear(); + allocator_->Free(cache_first_state_, cache_first_state_id_); + for (int s = 0; s < NumStates(); ++s) { + allocator_->Free(VectorFstBaseImpl<S>::GetState(s), s); + SetState(s, 0); + } + nknown_states_ = 0; + min_unexpanded_state_id_ = 0; + cache_first_state_id_ = kNoStateId; + cache_first_state_ = 0; + cache_size_ = 0; + cache_start_ = false; + VectorFstBaseImpl<State>::DeleteStates(); + SetProperties(DeleteAllStatesProperties(Properties(), + kExpanded | kMutable)); } // Is the start state cached? @@ -390,48 +438,17 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> { return min_unexpanded_state_id_; } - // Removes from cache_states_ and uncaches (not referenced-counted) - // states that have not been accessed since the last GC until - // cache_limit_/3 bytes are uncached. If that fails to free enough, - // recurs uncaching recently visited states as well. If still - // unable to free enough memory, then widens cache_limit_. - void GC(StateId current, bool free_recent) { - if (!cache_gc_) - return; - VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this - << "), free recently cached = " << free_recent - << ", cache size = " << cache_size_ - << ", cache limit = " << cache_limit_ << "\n"; - typename list<StateId>::iterator siter = cache_states_.begin(); + // Removes from cache_states_ and uncaches (not referenced-counted + // or protected) states that have not been accessed since the last + // GC until at most cache_fraction * cache_limit_ bytes are cached. + // If that fails to free enough, recurs uncaching recently visited + // states as well. If still unable to free enough memory, then + // widens cache_limit_ to fulfill condition. + void GC(StateId current, bool free_recent, float cache_fraction = 0.666); - size_t cache_target = (2 * cache_limit_)/3 + 1; - while (siter != cache_states_.end()) { - StateId s = *siter; - S* state = VectorFstBaseImpl<S>::GetState(s); - if (cache_size_ > cache_target && state->ref_count == 0 && - (free_recent || !(state->flags & kCacheRecent)) && s != current) { - cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc); - allocator_->Free(state, s); - SetState(s, 0); - cache_states_.erase(siter++); - } else { - state->flags &= ~kCacheRecent; - ++siter; - } - } - if (!free_recent && cache_size_ > cache_target) { - GC(current, true); - } else { - while (cache_size_ > cache_target) { - cache_limit_ *= 2; - cache_target *= 2; - } - } - VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this - << "), free recently cached = " << free_recent - << ", cache size = " << cache_size_ - << ", cache limit = " << cache_limit_ << "\n"; - } + // Setc/clears GC protection: if true, new states are protected + // from garbage collection. + void GCProtect(bool on) { protect_ = on; } void ExpandedState(StateId s) { if (s < min_unexpanded_state_id_) @@ -441,26 +458,30 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> { expanded_states_[s] = true; } + C *GetAllocator() const { + return allocator_; + } + // Caching on/off switch, limit and size accessors. bool GetCacheGc() const { return cache_gc_; } size_t GetCacheLimit() const { return cache_limit_; } size_t GetCacheSize() const { return cache_size_; } private: - static const size_t kMinCacheLimit = 8096; // Minimum (non-zero) cache limit - static const uint32 kCacheFinal = 0x0001; // Final weight has been cached - static const uint32 kCacheArcs = 0x0002; // Arcs have been cached - static const uint32 kCacheRecent = 0x0004; // Mark as visited since GC + static const size_t kMinCacheLimit = 8096; // Minimum (non-zero) cache limit + + static const uint32 kCacheFinal = 0x0001; // Final weight has been cached + static const uint32 kCacheArcs = 0x0002; // Arcs have been cached + static const uint32 kCacheRecent = 0x0004; // Mark as visited since GC + static const uint32 kCacheProtect = 0x0008; // Mark state as GC protected public: - static const uint32 kCacheModified = 0x0008; // Mark state as modified + static const uint32 kCacheModified = 0x0010; // Mark state as modified static const uint32 kCacheFlags = kCacheFinal | kCacheArcs | kCacheRecent - | kCacheModified; - - protected: - C *allocator_; // used to allocate new states + | kCacheProtect | kCacheModified; private: + C *allocator_; // used to allocate new states mutable bool cache_start_; // Is the start state cached? StateId nknown_states_; // # of known states vector<bool> expanded_states_; // states that have been expanded @@ -471,10 +492,113 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> { bool cache_gc_; // enable GC size_t cache_size_; // # of bytes cached size_t cache_limit_; // # of bytes allowed before GC + bool protect_; // Protect new states from GC - void operator=(const CacheBaseImpl<S> &impl); // disallow + void operator=(const CacheBaseImpl<S, C> &impl); // disallow }; +// Gets a state from its ID; add it if necessary. +template <class S, class C> +S *CacheBaseImpl<S, C>::ExtendState(typename S::Arc::StateId s) { + // If 'protect_' true and a new state, protects from garbage collection. + if (s == cache_first_state_id_) { + return cache_first_state_; // Return 1st cached state + } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) { + cache_first_state_id_ = s; // Remember 1st cached state + cache_first_state_ = allocator_->Allocate(s); + if (protect_) cache_first_state_->flags |= kCacheProtect; + return cache_first_state_; + } else if (cache_first_state_id_ != kNoStateId && + cache_first_state_->ref_count == 0 && + !(cache_first_state_->flags & kCacheProtect)) { + // With Default allocator, the Free and Allocate will reuse the same S*. + allocator_->Free(cache_first_state_, cache_first_state_id_); + cache_first_state_id_ = s; + cache_first_state_ = allocator_->Allocate(s); + if (protect_) cache_first_state_->flags |= kCacheProtect; + return cache_first_state_; // Return 1st cached state + } else { + while (NumStates() <= s) // Add state to main cache + AddState(0); + S *state = VectorFstBaseImpl<S>::GetState(s); + if (!state) { + state = allocator_->Allocate(s); + if (protect_) state->flags |= kCacheProtect; + SetState(s, state); + if (cache_first_state_id_ != kNoStateId) { // Forget 1st cached state + while (NumStates() <= cache_first_state_id_) + AddState(0); + SetState(cache_first_state_id_, cache_first_state_); + if (cache_gc_ && !(cache_first_state_->flags & kCacheProtect)) { + cache_states_.push_back(cache_first_state_id_); + cache_size_ += sizeof(S) + + cache_first_state_->arcs.capacity() * sizeof(Arc); + } + cache_limit_ = kMinCacheLimit; + cache_first_state_id_ = kNoStateId; + cache_first_state_ = 0; + } + if (cache_gc_ && !protect_) { + cache_states_.push_back(s); + cache_size_ += sizeof(S); + if (cache_size_ > cache_limit_) + GC(s, false); + } + } + return state; + } +} + +// Removes from cache_states_ and uncaches (not referenced-counted or +// protected) states that have not been accessed since the last GC +// until at most cache_fraction * cache_limit_ bytes are cached. If +// that fails to free enough, recurs uncaching recently visited states +// as well. If still unable to free enough memory, then widens cache_limit_ +// to fulfill condition. +template <class S, class C> +void CacheBaseImpl<S, C>::GC(typename S::Arc::StateId current, + bool free_recent, float cache_fraction) { + if (!cache_gc_) + return; + VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this + << "), free recently cached = " << free_recent + << ", cache size = " << cache_size_ + << ", cache frac = " << cache_fraction + << ", cache limit = " << cache_limit_ << "\n"; + typename list<StateId>::iterator siter = cache_states_.begin(); + + size_t cache_target = cache_fraction * cache_limit_; + while (siter != cache_states_.end()) { + StateId s = *siter; + S* state = VectorFstBaseImpl<S>::GetState(s); + if (cache_size_ > cache_target && state->ref_count == 0 && + (free_recent || !(state->flags & kCacheRecent)) && s != current) { + cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc); + allocator_->Free(state, s); + SetState(s, 0); + cache_states_.erase(siter++); + } else { + state->flags &= ~kCacheRecent; + ++siter; + } + } + if (!free_recent && cache_size_ > cache_target) { // recurses on recent + GC(current, true); + } else if (cache_target > 0) { // widens cache limit + while (cache_size_ > cache_target) { + cache_limit_ *= 2; + cache_target *= 2; + } + } else if (cache_size_ > 0) { + FSTERROR() << "CacheImpl:GC: Unable to free all cached states"; + } + VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this + << "), free recently cached = " << free_recent + << ", cache size = " << cache_size_ + << ", cache frac = " << cache_fraction + << ", cache limit = " << cache_limit_ << "\n"; +} + template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheFinal; template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheArcs; template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheRecent; @@ -516,7 +640,8 @@ class CacheImpl : public CacheBaseImpl< CacheState<A> > { explicit CacheImpl(const CacheOptions &opts) : CacheBaseImpl< CacheState<A> >(opts) {} - CacheImpl(const CacheImpl<State> &impl) : CacheBaseImpl<State>(impl) {} + CacheImpl(const CacheImpl<A> &impl, bool preserve_cache = false) + : CacheBaseImpl<State>(impl, preserve_cache) {} private: void operator=(const CacheImpl<State> &impl); // disallow @@ -536,12 +661,13 @@ class CacheStateIterator : public StateIteratorBase<typename F::Arc> { typedef CacheBaseImpl<State> Impl; CacheStateIterator(const F &fst, Impl *impl) - : fst_(fst), impl_(impl), s_(0) {} + : fst_(fst), impl_(impl), s_(0) { + fst_.Start(); // force start state + } bool Done() const { if (s_ < impl_->NumKnownStates()) return false; - fst_.Start(); // force start state if (s_ < impl_->NumKnownStates()) return false; for (StateId u = impl_->MinUnexpandedState(); diff --git a/src/include/fst/compact-fst.h b/src/include/fst/compact-fst.h index 57c927e..6db3317 100644 --- a/src/include/fst/compact-fst.h +++ b/src/include/fst/compact-fst.h @@ -32,6 +32,7 @@ using std::vector; #include <fst/cache.h> #include <fst/expanded-fst.h> #include <fst/fst-decl.h> // For optional argument declarations +#include <fst/mapped-file.h> #include <fst/matcher.h> #include <fst/test-properties.h> #include <fst/util.h> @@ -134,7 +135,9 @@ class CompactFstData { typedef U Unsigned; CompactFstData() - : states_(0), + : states_region_(0), + compacts_region_(0), + states_(0), compacts_(0), nstates_(0), ncompacts_(0), @@ -150,8 +153,14 @@ class CompactFstData { const Compactor &compactor); ~CompactFstData() { - delete[] states_; - delete[] compacts_; + if (states_region_ == NULL) { + delete [] states_; + } + delete states_region_; + if (compacts_region_ == NULL) { + delete [] compacts_; + } + delete compacts_region_; } template <class Compactor> @@ -175,10 +184,9 @@ class CompactFstData { bool Error() const { return error_; } - // Byte alignment for states and arcs in file format (version 1 only) - static const int kFileAlign = 16; - private: + MappedFile *states_region_; + MappedFile *compacts_region_; Unsigned *states_; CompactElement *compacts_; size_t nstates_; @@ -190,13 +198,11 @@ class CompactFstData { }; template <class E, class U> -const int CompactFstData<E, U>::kFileAlign; - - -template <class E, class U> template <class A, class C> CompactFstData<E, U>::CompactFstData(const Fst<A> &fst, const C &compactor) - : states_(0), + : states_region_(0), + compacts_region_(0), + states_(0), compacts_(0), nstates_(0), ncompacts_(0), @@ -265,7 +271,9 @@ template <class Iterator, class C> CompactFstData<E, U>::CompactFstData(const Iterator &begin, const Iterator &end, const C &compactor) - : states_(0), + : states_region_(0), + compacts_region_(0), + states_(0), compacts_(0), nstates_(0), ncompacts_(0), @@ -361,42 +369,40 @@ CompactFstData<E, U> *CompactFstData<E, U>::Read( data->narcs_ = hdr.NumArcs(); if (compactor.Size() == -1) { - data->states_ = new Unsigned[data->nstates_ + 1]; - if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && - !AlignInput(strm, kFileAlign)) { + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { LOG(ERROR) << "CompactFst::Read: Alignment failed: " << opts.source; delete data; return 0; } - // TODO: memory map this size_t b = (data->nstates_ + 1) * sizeof(Unsigned); - strm.read(reinterpret_cast<char *>(data->states_), b); - if (!strm) { + data->states_region_ = MappedFile::Map(&strm, opts, b); + if (!strm || data->states_region_ == NULL) { LOG(ERROR) << "CompactFst::Read: Read failed: " << opts.source; delete data; return 0; } + data->states_ = static_cast<Unsigned *>( + data->states_region_->mutable_data()); } else { data->states_ = 0; } data->ncompacts_ = compactor.Size() == -1 ? data->states_[data->nstates_] : data->nstates_ * compactor.Size(); - data->compacts_ = new CompactElement[data->ncompacts_]; - // TODO: memory map this - size_t b = data->ncompacts_ * sizeof(CompactElement); - if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && - !AlignInput(strm, kFileAlign)) { + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { LOG(ERROR) << "CompactFst::Read: Alignment failed: " << opts.source; delete data; return 0; } - strm.read(reinterpret_cast<char *>(data->compacts_), b); - if (!strm) { + size_t b = data->ncompacts_ * sizeof(CompactElement); + data->compacts_region_ = MappedFile::Map(&strm, opts, b); + if (!strm || data->compacts_region_ == NULL) { LOG(ERROR) << "CompactFst::Read: Read failed: " << opts.source; delete data; return 0; } + data->compacts_ = static_cast<CompactElement *>( + data->compacts_region_->mutable_data()); return data; } @@ -404,14 +410,14 @@ template<class E, class U> bool CompactFstData<E, U>::Write(ostream &strm, const FstWriteOptions &opts) const { if (states_) { - if (opts.align && !AlignOutput(strm, kFileAlign)) { + if (opts.align && !AlignOutput(strm)) { LOG(ERROR) << "CompactFst::Write: Alignment failed: " << opts.source; return false; } strm.write(reinterpret_cast<char *>(states_), (nstates_ + 1) * sizeof(Unsigned)); } - if (opts.align && !AlignOutput(strm, kFileAlign)) { + if (opts.align && !AlignOutput(strm)) { LOG(ERROR) << "CompactFst::Write: Alignment failed: " << opts.source; return false; } @@ -899,6 +905,17 @@ class CompactFst : public ImplToExpandedFst< CompactFstImpl<A, C, U> > { ImplToFst< Impl, ExpandedFst<A> >::SetImpl(impl, own_impl); } + // Use overloading to extract the type of the argument. + static Impl* GetImplIfCompactFst(const CompactFst<A, C, U> &compact_fst) { + return compact_fst.GetImpl(); + } + + // This does not give privileged treatment to subclasses of CompactFst. + template<typename NonCompactFst> + static Impl* GetImplIfCompactFst(const NonCompactFst& fst) { + return NULL; + } + void operator=(const CompactFst<A, C, U> &fst); // disallow }; @@ -914,20 +931,16 @@ bool CompactFst<A, C, U>::WriteFst(const F &fst, typedef U Unsigned; typedef typename C::Element CompactElement; typedef typename A::Weight Weight; - static const int kFileAlign = - CompactFstData<CompactElement, U>::kFileAlign; int file_version = opts.align ? CompactFstImpl<A, C, U>::kAlignedFileVersion : CompactFstImpl<A, C, U>::kFileVersion; size_t num_arcs = -1, num_states = -1, num_compacts = -1; C first_pass_compactor = compactor; - if (fst.Type() == CompactFst<A, C, U>().Type()) { - const CompactFst<A, C, U> *compact_fst = - reinterpret_cast<const CompactFst<A, C, U> *>(&fst); - num_arcs = compact_fst->GetImpl()->Data()->NumArcs(); - num_states = compact_fst->GetImpl()->Data()->NumStates(); - num_compacts = compact_fst->GetImpl()->Data()->NumCompacts(); - first_pass_compactor = *compact_fst->GetImpl()->GetCompactor(); + if (Impl* impl = GetImplIfCompactFst(fst)) { + num_arcs = impl->Data()->NumArcs(); + num_states = impl->Data()->NumStates(); + num_compacts = impl->Data()->NumCompacts(); + first_pass_compactor = *impl->GetCompactor(); } else { // A first pass is needed to compute the state of the compactor, which // is saved ahead of the rest of the data structures. This unfortunately @@ -971,7 +984,7 @@ bool CompactFst<A, C, U>::WriteFst(const F &fst, &hdr); first_pass_compactor.Write(strm); if (first_pass_compactor.Size() == -1) { - if (opts.align && !AlignOutput(strm, kFileAlign)) { + if (opts.align && !AlignOutput(strm)) { LOG(ERROR) << "CompactFst::Write: Alignment failed: " << opts.source; return false; } @@ -986,7 +999,7 @@ bool CompactFst<A, C, U>::WriteFst(const F &fst, } strm.write(reinterpret_cast<const char *>(&compacts), sizeof(compacts)); } - if (opts.align && !AlignOutput(strm, kFileAlign)) { + if (opts.align && !AlignOutput(strm)) { LOG(ERROR) << "Could not align file during write after writing states"; } C second_pass_compactor = compactor; diff --git a/src/include/fst/compose.h b/src/include/fst/compose.h index dfdff0a..db5ea3a 100644 --- a/src/include/fst/compose.h +++ b/src/include/fst/compose.h @@ -122,7 +122,7 @@ class ComposeFstImplBase : public CacheImpl<A> { ComposeFstImplBase(const Fst<A> &fst1, const Fst<A> &fst2, const CacheOptions &opts) - :CacheImpl<A>(opts) { + : CacheImpl<A>(opts) { VLOG(2) << "ComposeFst(" << this << "): Begin"; SetType("compose"); @@ -137,7 +137,7 @@ class ComposeFstImplBase : public CacheImpl<A> { } ComposeFstImplBase(const ComposeFstImplBase<A> &impl) - : CacheImpl<A>(impl) { + : CacheImpl<A>(impl, true) { SetProperties(impl.Properties(), kCopyProperties); SetInputSymbols(impl.InputSymbols()); SetOutputSymbols(impl.OutputSymbols()); @@ -275,6 +275,13 @@ class ComposeFstImpl : public ComposeFstImplBase<typename M1::Arc> { OrderedExpand(s, fst2_, s2, fst1_, s1, matcher2_, true); } + const FST1 &GetFst1() { return fst1_; } + const FST2 &GetFst2() { return fst2_; } + M1 *GetMatcher1() { return matcher1_; } + M2 *GetMatcher2() { return matcher2_; } + F *GetFilter() { return filter_; } + T *GetStateTable() { return state_table_; } + private: // This does that actual matching of labels in the composition. The // arguments are ordered so matching is called on state 'sa' of diff --git a/src/include/fst/const-fst.h b/src/include/fst/const-fst.h index 80efc8d..e6e85af 100644 --- a/src/include/fst/const-fst.h +++ b/src/include/fst/const-fst.h @@ -28,6 +28,7 @@ using std::vector; #include <fst/expanded-fst.h> #include <fst/fst-decl.h> // For optional argument declarations +#include <fst/mapped-file.h> #include <fst/test-properties.h> #include <fst/util.h> @@ -55,7 +56,8 @@ class ConstFstImpl : public FstImpl<A> { typedef U Unsigned; ConstFstImpl() - : states_(0), arcs_(0), nstates_(0), narcs_(0), start_(kNoStateId) { + : states_region_(0), arcs_region_(0), states_(0), arcs_(0), nstates_(0), + narcs_(0), start_(kNoStateId) { string type = "const"; if (sizeof(U) != sizeof(uint32)) { string size; @@ -69,8 +71,8 @@ class ConstFstImpl : public FstImpl<A> { explicit ConstFstImpl(const Fst<A> &fst); ~ConstFstImpl() { - delete[] states_; - delete[] arcs_; + delete arcs_region_; + delete states_region_; } StateId Start() const { return start_; } @@ -125,9 +127,9 @@ class ConstFstImpl : public FstImpl<A> { static const int kAlignedFileVersion = 1; // Minimum file format version supported static const int kMinFileVersion = 1; - // Byte alignment for states and arcs in file format (version 1 only) - static const int kFileAlign = 16; + MappedFile *states_region_; // Mapped file for states + MappedFile *arcs_region_; // Mapped file for arcs State *states_; // States represenation A *arcs_; // Arcs representation StateId nstates_; // Number of states @@ -145,8 +147,6 @@ template <class A, class U> const int ConstFstImpl<A, U>::kAlignedFileVersion; template <class A, class U> const int ConstFstImpl<A, U>::kMinFileVersion; -template <class A, class U> -const int ConstFstImpl<A, U>::kFileAlign; template<class A, class U> @@ -173,8 +173,10 @@ ConstFstImpl<A, U>::ConstFstImpl(const Fst<A> &fst) : nstates_(0), narcs_(0) { aiter.Next()) ++narcs_; } - states_ = new State[nstates_]; - arcs_ = new A[narcs_]; + states_region_ = MappedFile::Allocate(nstates_ * sizeof(*states_)); + arcs_region_ = MappedFile::Allocate(narcs_ * sizeof(*arcs_)); + states_ = reinterpret_cast<State*>(states_region_->mutable_data()); + arcs_ = reinterpret_cast<A*>(arcs_region_->mutable_data()); size_t pos = 0; for (StateId s = 0; s < nstates_; ++s) { states_[s].final = fst.Final(s); @@ -210,39 +212,40 @@ ConstFstImpl<A, U> *ConstFstImpl<A, U>::Read(istream &strm, impl->start_ = hdr.Start(); impl->nstates_ = hdr.NumStates(); impl->narcs_ = hdr.NumArcs(); - impl->states_ = new State[impl->nstates_]; - impl->arcs_ = new A[impl->narcs_]; // Ensures compatibility if (hdr.Version() == kAlignedFileVersion) hdr.SetFlags(hdr.GetFlags() | FstHeader::IS_ALIGNED); - if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && - !AlignInput(strm, kFileAlign)) { + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { LOG(ERROR) << "ConstFst::Read: Alignment failed: " << opts.source; delete impl; return 0; } + size_t b = impl->nstates_ * sizeof(typename ConstFstImpl<A, U>::State); - strm.read(reinterpret_cast<char *>(impl->states_), b); - if (!strm) { + impl->states_region_ = MappedFile::Map(&strm, opts, b); + if (!strm || impl->states_region_ == NULL) { LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source; delete impl; return 0; } - if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && - !AlignInput(strm, kFileAlign)) { + impl->states_ = reinterpret_cast<State*>( + impl->states_region_->mutable_data()); + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { LOG(ERROR) << "ConstFst::Read: Alignment failed: " << opts.source; delete impl; return 0; } + b = impl->narcs_ * sizeof(A); - strm.read(reinterpret_cast<char *>(impl->arcs_), b); - if (!strm) { + impl->arcs_region_ = MappedFile::Map(&strm, opts, b); + if (!strm || impl->arcs_region_ == NULL) { LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source; delete impl; return 0; } + impl->arcs_ = reinterpret_cast<A*>(impl->arcs_region_->mutable_data()); return impl; } @@ -318,6 +321,17 @@ class ConstFst : public ImplToExpandedFst< ConstFstImpl<A, U> > { ImplToFst< Impl, ExpandedFst<A> >::SetImpl(impl, own_impl); } + // Use overloading to extract the type of the argument. + static Impl* GetImplIfConstFst(const ConstFst &const_fst) { + return const_fst.GetImpl(); + } + + // Note that this does not give privileged treatment to subtypes of ConstFst. + template<typename NonConstFst> + static Impl* GetImplIfConstFst(const NonConstFst& fst) { + return NULL; + } + void operator=(const ConstFst<A, U> &fst); // disallow }; @@ -333,11 +347,9 @@ bool ConstFst<A, U>::WriteFst(const F &fst, ostream &strm, size_t num_arcs = -1, num_states = -1; size_t start_offset = 0; bool update_header = true; - if (fst.Type() == ConstFst<A, U>().Type()) { - const ConstFst<A, U> *const_fst = - reinterpret_cast<const ConstFst<A, U> *>(&fst); - num_arcs = const_fst->GetImpl()->narcs_; - num_states = const_fst->GetImpl()->nstates_; + if (Impl* impl = GetImplIfConstFst(fst)) { + num_arcs = impl->narcs_; + num_states = impl->nstates_; update_header = false; } else if ((start_offset = strm.tellp()) == -1) { // precompute values needed for header when we cannot seek to rewrite it. @@ -363,7 +375,7 @@ bool ConstFst<A, U>::WriteFst(const F &fst, ostream &strm, ConstFstImpl<A, U>::kStaticProperties; FstImpl<A>::WriteFstHeader(fst, strm, opts, file_version, type, properties, &hdr); - if (opts.align && !AlignOutput(strm, ConstFstImpl<A, U>::kFileAlign)) { + if (opts.align && !AlignOutput(strm)) { LOG(ERROR) << "Could not align file during write after header"; return false; } @@ -381,7 +393,7 @@ bool ConstFst<A, U>::WriteFst(const F &fst, ostream &strm, } hdr.SetNumStates(states); hdr.SetNumArcs(pos); - if (opts.align && !AlignOutput(strm, ConstFstImpl<A, U>::kFileAlign)) { + if (opts.align && !AlignOutput(strm)) { LOG(ERROR) << "Could not align file during write after writing states"; } for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { diff --git a/src/include/fst/determinize.h b/src/include/fst/determinize.h index a145e4a..9ff8723 100644 --- a/src/include/fst/determinize.h +++ b/src/include/fst/determinize.h @@ -33,9 +33,10 @@ using std::tr1::unordered_multimap; #include <vector> using std::vector; +#include <fst/arc-map.h> #include <fst/cache.h> +#include <fst/bi-table.h> #include <fst/factor-weight.h> -#include <fst/arc-map.h> #include <fst/prune.h> #include <fst/test-properties.h> @@ -108,24 +109,244 @@ class GallicCommonDivisor { D weight_common_divisor_; }; -// Options for finite-state transducer determinization. + +// Represents an element in a subset +template <class A> +struct DeterminizeElement { + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + DeterminizeElement() {} + + DeterminizeElement(StateId s, Weight w) : state_id(s), weight(w) {} + + bool operator==(const DeterminizeElement<A> & element) const { + return state_id == element.state_id && weight == element.weight; + } + + bool operator<(const DeterminizeElement<A> & element) const { + return state_id < element.state_id || + (state_id == element.state_id && weight == element.weight); + } + + StateId state_id; // Input state Id + Weight weight; // Residual weight +}; + + +// +// DETERMINIZE FILTERS - these can be used in determinization to compute +// transformations on the subsets prior to their being added as destination +// states. The filter operates on a map between a label and the +// corresponding destination subsets. The possibly modified map is +// then used to construct the destination states for arcs exiting state 's'. +// It must define the ordered map type LabelMap and have a default +// and copy constructor. + +// A determinize filter that does not modify its input. template <class Arc> +struct IdentityDeterminizeFilter { + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef slist< DeterminizeElement<Arc> > Subset; + typedef map<Label, Subset*> LabelMap; + + static uint64 Properties(uint64 props) { return props; } + + void operator()(StateId s, LabelMap *label_map) {} +}; + + +// +// DETERMINIZATION STATE TABLES +// +// The determiziation state table has the form: +// +// template <class Arc> +// class DeterminizeStateTable { +// public: +// typedef typename Arc::StateId StateId; +// typedef DeterminizeElement<Arc> Element; +// typedef slist<Element> Subset; +// +// // Required constuctor +// DeterminizeStateTable(); +// +// // Required copy constructor that does not copy state +// DeterminizeStateTable(const DeterminizeStateTable<A,P> &table); +// +// // Lookup state ID by subset (not depending of the element order). +// // If it doesn't exist, then add it. FindState takes +// // ownership of the subset argument (so that it doesn't have to +// // copy it if it creates a new state). +// StateId FindState(Subset *subset); +// +// // Lookup subset by ID. +// const Subset *FindSubset(StateId id) const; +// }; +// + +// The default determinization state table based on the +// compact hash bi-table. +template <class Arc> +class DefaultDeterminizeStateTable { + public: + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef DeterminizeElement<Arc> Element; + typedef slist<Element> Subset; + + explicit DefaultDeterminizeStateTable(size_t table_size = 0) + : table_size_(table_size), + subsets_(table_size_, new SubsetKey(), new SubsetEqual(&elements_)) { } + + DefaultDeterminizeStateTable(const DefaultDeterminizeStateTable<Arc> &table) + : table_size_(table.table_size_), + subsets_(table_size_, new SubsetKey(), new SubsetEqual(&elements_)) { } + + ~DefaultDeterminizeStateTable() { + for (StateId s = 0; s < subsets_.Size(); ++s) + delete subsets_.FindEntry(s); + } + + // Finds the state corresponding to a subset. Only creates a new + // state if the subset is not found. FindState takes ownership of + // the subset argument (so that it doesn't have to copy it if it + // creates a new state). + StateId FindState(Subset *subset) { + StateId ns = subsets_.Size(); + StateId s = subsets_.FindId(subset); + if (s != ns) delete subset; // subset found + return s; + } + + const Subset* FindSubset(StateId s) { return subsets_.FindEntry(s); } + + private: + // Comparison object for hashing Subset(s). Subsets are not sorted in this + // implementation, so ordering must not be assumed in the equivalence + // test. + class SubsetEqual { + public: + SubsetEqual() { // needed for compilation but should never be called + FSTERROR() << "SubsetEqual: default constructor not implemented"; + } + + // Constructor takes vector needed to check equality. See immediately + // below for constraints on it. + explicit SubsetEqual(vector<Element *> *elements) + : elements_(elements) {} + + // At each call to operator(), the elements_ vector should contain + // only NULLs. When this operator returns, elements_ will still + // have this property. + bool operator()(Subset* subset1, Subset* subset2) const { + if (!subset1 && !subset2) + return true; + if ((subset1 && !subset2) || (!subset1 && subset2)) + return false; + + if (subset1->size() != subset2->size()) + return false; + + // Loads first subset elements in element vector. + for (typename Subset::iterator iter1 = subset1->begin(); + iter1 != subset1->end(); + ++iter1) { + Element &element1 = *iter1; + while (elements_->size() <= element1.state_id) + elements_->push_back(0); + (*elements_)[element1.state_id] = &element1; + } + + // Checks second subset matches first via element vector. + for (typename Subset::iterator iter2 = subset2->begin(); + iter2 != subset2->end(); + ++iter2) { + Element &element2 = *iter2; + while (elements_->size() <= element2.state_id) + elements_->push_back(0); + Element *element1 = (*elements_)[element2.state_id]; + if (!element1 || element1->weight != element2.weight) { + // Mismatch found. Resets element vector before returning false. + for (typename Subset::iterator iter1 = subset1->begin(); + iter1 != subset1->end(); + ++iter1) + (*elements_)[iter1->state_id] = 0; + return false; + } else { + (*elements_)[element2.state_id] = 0; // Clears entry + } + } + return true; + } + private: + vector<Element *> *elements_; + }; + + // Hash function for Subset to Fst states. Subset elements are not + // sorted in this implementation, so the hash must be invariant + // under subset reordering. + class SubsetKey { + public: + size_t operator()(const Subset* subset) const { + size_t hash = 0; + if (subset) { + for (typename Subset::const_iterator iter = subset->begin(); + iter != subset->end(); + ++iter) { + const Element &element = *iter; + int lshift = element.state_id % (CHAR_BIT * sizeof(size_t) - 1) + 1; + int rshift = CHAR_BIT * sizeof(size_t) - lshift; + size_t n = element.state_id; + hash ^= n << lshift ^ n >> rshift ^ element.weight.Hash(); + } + } + return hash; + } + }; + + size_t table_size_; + + typedef CompactHashBiTable<StateId, Subset *, + SubsetKey, SubsetEqual, HS_STL> SubsetTable; + + SubsetTable subsets_; + vector<Element *> elements_; + + void operator=(const DefaultDeterminizeStateTable<Arc> &); // disallow +}; + +// Options for finite-state transducer determinization templated on +// the arc type, common divisor, the determinization filter and the +// state table. DeterminizeFst takes ownership of the determinization +// filter and state table if provided. +template <class Arc, + class D = DefaultCommonDivisor<typename Arc::Weight>, + class F = IdentityDeterminizeFilter<Arc>, + class T = DefaultDeterminizeStateTable<Arc> > struct DeterminizeFstOptions : CacheOptions { typedef typename Arc::Label Label; float delta; // Quantization delta for subset weights Label subsequential_label; // Label used for residual final output // when producing subsequential transducers. + F *filter; // Determinization filter + T *state_table; // Determinization state table explicit DeterminizeFstOptions(const CacheOptions &opts, - float del = kDelta, - Label lab = 0) - : CacheOptions(opts), delta(del), subsequential_label(lab) {} - - explicit DeterminizeFstOptions(float del = kDelta, Label lab = 0) - : delta(del), subsequential_label(lab) {} + float del = kDelta, Label lab = 0, + F *filt = 0, + T *table = 0) + : CacheOptions(opts), delta(del), subsequential_label(lab), + filter(filt), state_table(table) {} + + explicit DeterminizeFstOptions(float del = kDelta, Label lab = 0, + F *filt = 0, T *table = 0) + : delta(del), subsequential_label(lab), filter(filt), + state_table(table) {} }; - // Implementation of delayed DeterminizeFst. This base class is // common to the variants that implement acceptor and transducer // determinization. @@ -149,15 +370,15 @@ class DeterminizeFstImplBase : public CacheImpl<A> { typedef typename A::StateId StateId; typedef CacheState<A> State; + template <class D, class F, class T> DeterminizeFstImplBase(const Fst<A> &fst, - const DeterminizeFstOptions<A> &opts) + const DeterminizeFstOptions<A, D, F, T> &opts) : CacheImpl<A>(opts), fst_(fst.Copy()) { SetType("determinize"); - uint64 props = fst.Properties(kFstProperties, false); - SetProperties(DeterminizeProperties(props, - opts.subsequential_label != 0), - kCopyProperties); - + uint64 iprops = fst.Properties(kFstProperties, false); + uint64 dprops = DeterminizeProperties(iprops, + opts.subsequential_label != 0); + SetProperties(F::Properties(dprops), kCopyProperties); SetInputSymbols(fst.InputSymbols()); SetOutputSymbols(fst.OutputSymbols()); } @@ -234,7 +455,7 @@ class DeterminizeFstImplBase : public CacheImpl<A> { // Implementation of delayed determinization for weighted acceptors. // It is templated on the arc type A and the common divisor D. -template <class A, class D> +template <class A, class D, class F, class T> class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> { public: using FstImpl<A>::SetProperties; @@ -244,27 +465,19 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> { typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; - - struct Element { - Element() {} - - Element(StateId s, Weight w) : state_id(s), weight(w) {} - - StateId state_id; // Input state Id - Weight weight; // Residual weight - }; + typedef DeterminizeElement<A> Element; typedef slist<Element> Subset; - typedef map<Label, Subset*> LabelMap; + typedef typename F::LabelMap LabelMap; - DeterminizeFsaImpl(const Fst<A> &fst, D common_divisor, + DeterminizeFsaImpl(const Fst<A> &fst, const vector<Weight> *in_dist, vector<Weight> *out_dist, - const DeterminizeFstOptions<A> &opts) + const DeterminizeFstOptions<A, D, F, T> &opts) : DeterminizeFstImplBase<A>(fst, opts), delta_(opts.delta), in_dist_(in_dist), out_dist_(out_dist), - common_divisor_(common_divisor), - subset_hash_(0, SubsetKey(), SubsetEqual(&elements_)) { + filter_(opts.filter ? opts.filter : new F()), + state_table_(opts.state_table ? opts.state_table : new T()) { if (!fst.Properties(kAcceptor, true)) { FSTERROR() << "DeterminizeFst: argument not an acceptor"; SetProperties(kError, kError); @@ -278,13 +491,13 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> { out_dist_->clear(); } - DeterminizeFsaImpl(const DeterminizeFsaImpl<A, D> &impl) + DeterminizeFsaImpl(const DeterminizeFsaImpl<A, D, F, T> &impl) : DeterminizeFstImplBase<A>(impl), delta_(impl.delta_), in_dist_(0), out_dist_(0), - common_divisor_(impl.common_divisor_), - subset_hash_(0, SubsetKey(), SubsetEqual(&elements_)) { + filter_(new F(*impl.filter_)), + state_table_(new T(*impl.state_table_)) { if (impl.out_dist_) { FSTERROR() << "DeterminizeFsaImpl: cannot copy with out_dist vector"; SetProperties(kError, kError); @@ -292,12 +505,12 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> { } virtual ~DeterminizeFsaImpl() { - for (int i = 0; i < subsets_.size(); ++i) - delete subsets_[i]; + delete filter_; + delete state_table_; } - virtual DeterminizeFsaImpl<A, D> *Copy() { - return new DeterminizeFsaImpl<A, D>(*this); + virtual DeterminizeFsaImpl<A, D, F, T> *Copy() { + return new DeterminizeFsaImpl<A, D, F, T>(*this); } uint64 Properties() const { return Properties(kFstProperties); } @@ -320,12 +533,12 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> { } virtual Weight ComputeFinal(StateId s) { - Subset *subset = subsets_[s]; + const Subset *subset = state_table_->FindSubset(s); Weight final = Weight::Zero(); - for (typename Subset::iterator siter = subset->begin(); + for (typename Subset::const_iterator siter = subset->begin(); siter != subset->end(); ++siter) { - Element &element = *siter; + const Element &element = *siter; final = Plus(final, Times(element.weight, GetFst().Final(element.state_id))); if (!final.Member()) @@ -334,33 +547,9 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> { return final; } - // Finds the state corresponding to a subset. Only creates a new state - // if the subset is not found in the subset hash. FindState takes - // ownership of the subset argument (so that it doesn't have to copy it - // if it creates a new state). - // - // The method exploits the following device: all pairs stored in the - // associative container subset_hash_ are of the form (subset, - // id(subset) + 1), i.e. subset_hash_[subset] > 0 if subset has been - // stored previously. For unassigned subsets, the call to - // subset_hash_[subset] creates a new pair (subset, 0). As a result, - // subset_hash_[subset] == 0 iff subset is new. StateId FindState(Subset *subset) { - StateId &assoc_value = subset_hash_[subset]; - if (assoc_value == 0) { // subset wasn't present; create new state - StateId s = CreateState(subset); - assoc_value = s + 1; - return s; - } else { - delete subset; - return assoc_value - 1; // NB: assoc_value = ID + 1 - } - } - - StateId CreateState(Subset *subset) { - StateId s = subsets_.size(); - subsets_.push_back(subset); - if (in_dist_) + StateId s = state_table_->FindState(subset); + if (in_dist_ && out_dist_->size() <= s) out_dist_->push_back(ComputeDistance(subset)); return s; } @@ -398,24 +587,35 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> { // element weights include the input automaton label weights and the // subsets may contain duplicate states. void LabelSubsets(StateId s, LabelMap *label_map) { - Subset *src_subset = subsets_[s]; + const Subset *src_subset = state_table_->FindSubset(s); - for (typename Subset::iterator siter = src_subset->begin(); + for (typename Subset::const_iterator siter = src_subset->begin(); siter != src_subset->end(); ++siter) { - Element &src_element = *siter; + const Element &src_element = *siter; for (ArcIterator< Fst<A> > aiter(GetFst(), src_element.state_id); !aiter.Done(); aiter.Next()) { const A &arc = aiter.Value(); Element dest_element(arc.nextstate, Times(src_element.weight, arc.weight)); - Subset* &dest_subset = (*label_map)[arc.ilabel]; - if (dest_subset == 0) + + // The LabelMap may be a e.g. multimap with more complex + // determinization filters, so we insert efficiently w/o using []. + typename LabelMap::iterator liter = label_map->lower_bound(arc.ilabel); + Subset* dest_subset; + if (liter == label_map->end() || liter->first != arc.ilabel) { dest_subset = new Subset; + label_map->insert(liter, make_pair(arc.ilabel, dest_subset)); + } else { + dest_subset = liter->second; + } + dest_subset->push_front(dest_element); } } + // Applies the determinization filter + (*filter_)(s, label_map); } // Adds an arc from state S to the destination state associated @@ -469,98 +669,17 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> { CacheImpl<A>::PushArc(s, arc); } - // Comparison object for hashing Subset(s). Subsets are not sorted in this - // implementation, so ordering must not be assumed in the equivalence - // test. - class SubsetEqual { - public: - // Constructor takes vector needed to check equality. See immediately - // below for constraints on it. - explicit SubsetEqual(vector<Element *> *elements) - : elements_(elements) {} - - // At each call to operator(), the elements_ vector should contain - // only NULLs. When this operator returns, elements_ will still - // have this property. - bool operator()(Subset* subset1, Subset* subset2) const { - if (subset1->size() != subset2->size()) - return false; - - // Loads first subset elements in element vector. - for (typename Subset::iterator iter1 = subset1->begin(); - iter1 != subset1->end(); - ++iter1) { - Element &element1 = *iter1; - while (elements_->size() <= element1.state_id) - elements_->push_back(0); - (*elements_)[element1.state_id] = &element1; - } - - // Checks second subset matches first via element vector. - for (typename Subset::iterator iter2 = subset2->begin(); - iter2 != subset2->end(); - ++iter2) { - Element &element2 = *iter2; - while (elements_->size() <= element2.state_id) - elements_->push_back(0); - Element *element1 = (*elements_)[element2.state_id]; - if (!element1 || element1->weight != element2.weight) { - // Mismatch found. Resets element vector before returning false. - for (typename Subset::iterator iter1 = subset1->begin(); - iter1 != subset1->end(); - ++iter1) - (*elements_)[iter1->state_id] = 0; - return false; - } else { - (*elements_)[element2.state_id] = 0; // Clears entry - } - } - return true; - } - private: - vector<Element *> *elements_; - }; - - // Hash function for Subset to Fst states. Subset elements are not - // sorted in this implementation, so the hash must be invariant - // under subset reordering. - class SubsetKey { - public: - size_t operator()(const Subset* subset) const { - size_t hash = 0; - for (typename Subset::const_iterator iter = subset->begin(); - iter != subset->end(); - ++iter) { - const Element &element = *iter; - int lshift = element.state_id % (CHAR_BIT * sizeof(size_t) - 1) + 1; - int rshift = CHAR_BIT * sizeof(size_t) - lshift; - size_t n = element.state_id; - hash ^= n << lshift ^ n >> rshift ^ element.weight.Hash(); - } - return hash; - } - }; - float delta_; // Quantization delta for subset weights const vector<Weight> *in_dist_; // Distance to final NFA states vector<Weight> *out_dist_; // Distance to final DFA states D common_divisor_; + F *filter_; + T *state_table_; - // Used to test equivalence of subsets. vector<Element *> elements_; - // Maps from StateId to Subset. - vector<Subset *> subsets_; - - // Hashes from Subset to its StateId in the output automaton. - typedef unordered_map<Subset *, StateId, SubsetKey, SubsetEqual> - SubsetHash; - - // Hashes from Label to Subsets corr. to destination states of current state. - SubsetHash subset_hash_; - - void operator=(const DeterminizeFsaImpl<A, D> &); // disallow + void operator=(const DeterminizeFsaImpl<A, D, F, T> &); // disallow }; @@ -569,7 +688,7 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> { // the Gallic semiring as an acceptor whose weights contain the output // strings and using acceptor determinization above to determinize // that acceptor. -template <class A, StringType S> +template <class A, StringType S, class D, class F, class T> class DeterminizeFstImpl : public DeterminizeFstImplBase<A> { public: using FstImpl<A>::SetProperties; @@ -588,17 +707,18 @@ class DeterminizeFstImpl : public DeterminizeFstImplBase<A> { typedef ArcMapFst<A, ToArc, ToMapper> ToFst; typedef ArcMapFst<ToArc, A, FromMapper> FromFst; - typedef GallicCommonDivisor<Label, Weight, S> CommonDivisor; + typedef GallicCommonDivisor<Label, Weight, S, D> CommonDivisor; typedef GallicFactor<Label, Weight, S> FactorIterator; - DeterminizeFstImpl(const Fst<A> &fst, const DeterminizeFstOptions<A> &opts) + DeterminizeFstImpl(const Fst<A> &fst, + const DeterminizeFstOptions<A, D, F, T> &opts) : DeterminizeFstImplBase<A>(fst, opts), delta_(opts.delta), subsequential_label_(opts.subsequential_label) { Init(GetFst()); } - DeterminizeFstImpl(const DeterminizeFstImpl<A, S> &impl) + DeterminizeFstImpl(const DeterminizeFstImpl<A, S, D, F, T> &impl) : DeterminizeFstImplBase<A>(impl), delta_(impl.delta_), subsequential_label_(impl.subsequential_label_) { @@ -607,8 +727,8 @@ class DeterminizeFstImpl : public DeterminizeFstImplBase<A> { ~DeterminizeFstImpl() { delete from_fst_; } - virtual DeterminizeFstImpl<A, S> *Copy() { - return new DeterminizeFstImpl<A, S>(*this); + virtual DeterminizeFstImpl<A, S, D, F, T> *Copy() { + return new DeterminizeFstImpl<A, S, D, F, T>(*this); } uint64 Properties() const { return Properties(kFstProperties); } @@ -642,7 +762,7 @@ class DeterminizeFstImpl : public DeterminizeFstImplBase<A> { Label subsequential_label_; FromFst *from_fst_; - void operator=(const DeterminizeFstImpl<A, S> &); // disallow + void operator=(const DeterminizeFstImpl<A, S, D, F, T> &); // disallow }; @@ -673,7 +793,8 @@ class DeterminizeFst : public ImplToFst< DeterminizeFstImplBase<A> > { public: friend class ArcIterator< DeterminizeFst<A> >; friend class StateIterator< DeterminizeFst<A> >; - template <class B, StringType S> friend class DeterminizeFstImpl; + template <class B, StringType S, class D, class F, class T> + friend class DeterminizeFstImpl; typedef A Arc; typedef typename A::Weight Weight; @@ -684,33 +805,47 @@ class DeterminizeFst : public ImplToFst< DeterminizeFstImplBase<A> > { using ImplToFst<Impl>::SetImpl; - explicit DeterminizeFst( - const Fst<A> &fst, - const DeterminizeFstOptions<A> &opts = DeterminizeFstOptions<A>()) { + explicit DeterminizeFst(const Fst<A> &fst) { + typedef DefaultCommonDivisor<Weight> D; + typedef IdentityDeterminizeFilter<A> F; + typedef DefaultDeterminizeStateTable<A> T; + DeterminizeFstOptions<A, D, F, T> opts; if (fst.Properties(kAcceptor, true)) { // Calls implementation for acceptors. - typedef DefaultCommonDivisor<Weight> D; - SetImpl(new DeterminizeFsaImpl<A, D>(fst, D(), 0, 0, opts)); + SetImpl(new DeterminizeFsaImpl<A, D, F, T>(fst, 0, 0, opts)); } else { // Calls implementation for transducers. - SetImpl(new DeterminizeFstImpl<A, STRING_LEFT_RESTRICT>(fst, opts)); + SetImpl(new + DeterminizeFstImpl<A, STRING_LEFT_RESTRICT, D, F, T>(fst, opts)); + } + } + + template <class D, class F, class T> + DeterminizeFst(const Fst<A> &fst, + const DeterminizeFstOptions<A, D, F, T> &opts) { + if (fst.Properties(kAcceptor, true)) { + // Calls implementation for acceptors. + SetImpl(new DeterminizeFsaImpl<A, D, F, T>(fst, 0, 0, opts)); + } else { + // Calls implementation for transducers. + SetImpl(new + DeterminizeFstImpl<A, STRING_LEFT_RESTRICT, D, F, T>(fst, opts)); } } // This acceptor-only version additionally computes the distance to // final states in the output if provided with those distances for the // input. Useful for e.g. unique N-shortest paths. - DeterminizeFst( - const Fst<A> &fst, - const vector<Weight> &in_dist, vector<Weight> *out_dist, - const DeterminizeFstOptions<A> &opts = DeterminizeFstOptions<A>()) { + template <class D, class F, class T> + DeterminizeFst(const Fst<A> &fst, + const vector<Weight> *in_dist, vector<Weight> *out_dist, + const DeterminizeFstOptions<A, D, F, T> &opts) { if (!fst.Properties(kAcceptor, true)) { FSTERROR() << "DeterminizeFst:" << " distance to final states computed for acceptors only"; GetImpl()->SetProperties(kError, kError); } - typedef DefaultCommonDivisor<Weight> D; - SetImpl(new DeterminizeFsaImpl<A, D>(fst, D(), &in_dist, out_dist, opts)); + SetImpl(new DeterminizeFsaImpl<A, D, F, T>(fst, in_dist, out_dist, opts)); } // See Fst<>::Copy() for doc. @@ -733,14 +868,6 @@ class DeterminizeFst : public ImplToFst< DeterminizeFstImplBase<A> > { } private: - // This private version is for passing the common divisor to - // FSA determinization. - template <class D> - DeterminizeFst(const Fst<A> &fst, const D &common_div, - const DeterminizeFstOptions<A> &opts) - : ImplToFst<Impl>( - new DeterminizeFsaImpl<A, D>(fst, common_div, 0, 0, opts)) {} - // Makes visible to friends. Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } @@ -750,17 +877,18 @@ class DeterminizeFst : public ImplToFst< DeterminizeFstImplBase<A> > { // Initialization of transducer determinization implementation. which // is defined after DeterminizeFst since it calls it. -template <class A, StringType S> -void DeterminizeFstImpl<A, S>::Init(const Fst<A> &fst) { +template <class A, StringType S, class D, class F, class T> +void DeterminizeFstImpl<A, S, D, F, T>::Init(const Fst<A> &fst) { // Mapper to an acceptor. ToFst to_fst(fst, ToMapper()); - // Determinize acceptor. + // Determinizes acceptor. // This recursive call terminates since it passes the common divisor // to a private constructor. CacheOptions copts(GetCacheGc(), GetCacheLimit()); - DeterminizeFstOptions<ToArc> dopts(copts, delta_); - DeterminizeFst<ToArc> det_fsa(to_fst, CommonDivisor(), dopts); + DeterminizeFstOptions<ToArc, CommonDivisor> dopts(copts, delta_); + // Uses acceptor-only constructor to avoid template recursion + DeterminizeFst<ToArc> det_fsa(to_fst, 0, 0, dopts); // Mapper back to transducer. FactorWeightOptions<ToArc> fopts(CacheOptions(true, 0), delta_, @@ -832,7 +960,7 @@ struct DeterminizeOptions { // Determinizes a weighted transducer. This version writes the // determinized Fst to an output MutableFst. The result will be an -// equivalent FSt that has the property that no state has two +// equivalent FST that has the property that no state has two // transitions with the same input label. For this algorithm, epsilon // transitions are treated as regular symbols (cf. RmEpsilon). // @@ -866,7 +994,7 @@ void Determinize(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, if (ifst.Properties(kAcceptor, false)) { vector<Weight> idistance, odistance; ShortestDistance(ifst, &idistance, true); - DeterminizeFst<Arc> dfst(ifst, idistance, &odistance, nopts); + DeterminizeFst<Arc> dfst(ifst, &idistance, &odistance, nopts); PruneOptions< Arc, AnyArcFilter<Arc> > popts(opts.weight_threshold, opts.state_threshold, AnyArcFilter<Arc>(), diff --git a/src/include/fst/edit-fst.h b/src/include/fst/edit-fst.h index 303cb24..bd33b9d 100644 --- a/src/include/fst/edit-fst.h +++ b/src/include/fst/edit-fst.h @@ -25,6 +25,10 @@ using std::vector; #include <fst/cache.h> +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; + namespace fst { // The EditFst class enables non-destructive edit operations on a wrapped @@ -431,7 +435,8 @@ class EditFstImpl : public FstImpl<A> { // A copy constructor for this implementation class, used to implement // the Copy() method of the Fst interface. EditFstImpl(const EditFstImpl &impl) - : wrapped_(static_cast<WrappedFstT *>(impl.wrapped_->Copy(true))), + : FstImpl<A>(), + wrapped_(static_cast<WrappedFstT *>(impl.wrapped_->Copy(true))), data_(impl.data_) { data_->IncrRefCount(); SetProperties(impl.Properties()); diff --git a/src/include/fst/epsnormalize.h b/src/include/fst/epsnormalize.h index 8187737..9d178b1 100644 --- a/src/include/fst/epsnormalize.h +++ b/src/include/fst/epsnormalize.h @@ -24,7 +24,6 @@ #include <tr1/unordered_map> using std::tr1::unordered_map; using std::tr1::unordered_multimap; -#include <fst/slist.h> #include <fst/factor-weight.h> diff --git a/src/include/fst/equivalent.h b/src/include/fst/equivalent.h index 7f8708a..e28fea1 100644 --- a/src/include/fst/equivalent.h +++ b/src/include/fst/equivalent.h @@ -23,6 +23,7 @@ #include <algorithm> #include <deque> +using std::deque; #include <tr1/unordered_map> using std::tr1::unordered_map; using std::tr1::unordered_multimap; diff --git a/src/include/fst/extensions/far/extract.h b/src/include/fst/extensions/far/extract.h index d6f92ff..95866de 100644 --- a/src/include/fst/extensions/far/extract.h +++ b/src/include/fst/extensions/far/extract.h @@ -32,51 +32,106 @@ using std::vector; namespace fst { template<class Arc> +inline void FarWriteFst(const Fst<Arc>* fst, string key, + string* okey, int* nrep, + const int32 &generate_filenames, int i, + const string &filename_prefix, + const string &filename_suffix) { + if (key == *okey) + ++*nrep; + else + *nrep = 0; + + *okey = key; + + string ofilename; + if (generate_filenames) { + ostringstream tmp; + tmp.width(generate_filenames); + tmp.fill('0'); + tmp << i; + ofilename = tmp.str(); + } else { + if (*nrep > 0) { + ostringstream tmp; + tmp << '.' << nrep; + key.append(tmp.str().data(), tmp.str().size()); + } + ofilename = key; + } + fst->Write(filename_prefix + ofilename + filename_suffix); +} + +template<class Arc> void FarExtract(const vector<string> &ifilenames, const int32 &generate_filenames, - const string &begin_key, - const string &end_key, + const string &keys, + const string &key_separator, + const string &range_delimiter, const string &filename_prefix, const string &filename_suffix) { FarReader<Arc> *far_reader = FarReader<Arc>::Open(ifilenames); if (!far_reader) return; - if (!begin_key.empty()) - far_reader->Find(begin_key); - string okey; int nrep = 0; - for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) { - string key = far_reader->GetKey(); - if (!end_key.empty() && end_key < key) - break; - const Fst<Arc> &fst = far_reader->GetFst(); - - if (key == okey) - ++nrep; - else - nrep = 0; - okey = key; - - string ofilename; - if (generate_filenames) { - ostringstream tmp; - tmp.width(generate_filenames); - tmp.fill('0'); - tmp << i; - ofilename = tmp.str(); - } else { - if (nrep > 0) { - ostringstream tmp; - tmp << '.' << nrep; - key.append(tmp.str().data(), tmp.str().size()); + vector<char *> key_vector; + // User has specified a set of fsts to extract, where some of the "fsts" could + // be ranges. + if (!keys.empty()) { + char *keys_cstr = new char[keys.size()+1]; + strcpy(keys_cstr, keys.c_str()); + SplitToVector(keys_cstr, key_separator.c_str(), &key_vector, true); + int i = 0; + for (int k = 0; k < key_vector.size(); ++k, ++i) { + string key = string(key_vector[k]); + char *key_cstr = new char[key.size()+1]; + strcpy(key_cstr, key.c_str()); + vector<char *> range_vector; + SplitToVector(key_cstr, range_delimiter.c_str(), &range_vector, false); + if (range_vector.size() == 1) { // Not a range + if (!far_reader->Find(key)) { + LOG(ERROR) << "FarExtract: Cannot find key: " << key; + return; + } + const Fst<Arc> &fst = far_reader->GetFst(); + FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i, + filename_prefix, filename_suffix); + } else if (range_vector.size() == 2) { // A legal range + string begin_key = string(range_vector[0]); + string end_key = string(range_vector[1]); + if (begin_key.empty() || end_key.empty()) { + LOG(ERROR) << "FarExtract: Illegal range specification: " << key; + return; + } + if (!far_reader->Find(begin_key)) { + LOG(ERROR) << "FarExtract: Cannot find key: " << begin_key; + return; + } + for ( ; !far_reader->Done(); far_reader->Next(), ++i) { + string ikey = far_reader->GetKey(); + if (end_key < ikey) break; + const Fst<Arc> &fst = far_reader->GetFst(); + FarWriteFst(&fst, ikey, &okey, &nrep, generate_filenames, i, + filename_prefix, filename_suffix); + } + } else { + LOG(ERROR) << "FarExtract: Illegal range specification: " << key; + return; } - ofilename = key; + delete key_cstr; } - fst.Write(filename_prefix + ofilename + filename_suffix); + delete keys_cstr; + return; + } + // Nothing specified: extract everything. + for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) { + string key = far_reader->GetKey(); + const Fst<Arc> &fst = far_reader->GetFst(); + FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i, + filename_prefix, filename_suffix); } - return; } diff --git a/src/include/fst/extensions/far/far.h b/src/include/fst/extensions/far/far.h index acce76e..737f1b8 100644 --- a/src/include/fst/extensions/far/far.h +++ b/src/include/fst/extensions/far/far.h @@ -273,13 +273,10 @@ FarWriter<A> *FarWriter<A>::Create(const string &filename, FarType type) { return STListFarWriter<A>::Create(filename); case FAR_STTABLE: return STTableFarWriter<A>::Create(filename); - break; case FAR_STLIST: return STListFarWriter<A>::Create(filename); - break; case FAR_FST: return FstFarWriter<A>::Create(filename); - break; default: LOG(ERROR) << "FarWriter::Create: unknown far type"; return 0; diff --git a/src/include/fst/extensions/far/farscript.h b/src/include/fst/extensions/far/farscript.h index 3a9c145..cfd9167 100644 --- a/src/include/fst/extensions/far/farscript.h +++ b/src/include/fst/extensions/far/farscript.h @@ -173,18 +173,22 @@ bool FarEqual(const string &filename1, typedef args::Package<const vector<string> &, int32, const string&, const string&, const string&, - const string&> FarExtractArgs; + const string&, const string&> FarExtractArgs; template<class Arc> void FarExtract(FarExtractArgs *args) { fst::FarExtract<Arc>( - args->arg1, args->arg2, args->arg3, args->arg4, args->arg5, args->arg6); + args->arg1, args->arg2, args->arg3, args->arg4, args->arg5, args->arg6, + args->arg7); } void FarExtract(const vector<string> &ifilenames, const string &arc_type, - int32 generate_filenames, const string &begin_key, - const string &end_key, const string &filename_prefix, + int32 generate_filenames, + const string &keys, + const string &key_separator, + const string &range_delimiter, + const string &filename_prefix, const string &filename_suffix); typedef args::Package<const vector<string> &, const string &, diff --git a/src/include/fst/extensions/far/stlist.h b/src/include/fst/extensions/far/stlist.h index 1cdc80c..ff3d98b 100644 --- a/src/include/fst/extensions/far/stlist.h +++ b/src/include/fst/extensions/far/stlist.h @@ -145,13 +145,13 @@ class STListReader { ReadType(*streams_[i], &magic_number); ReadType(*streams_[i], &file_version); if (magic_number != kSTListMagicNumber) { - FSTERROR() << "STListReader::STTableReader: wrong file type: " + FSTERROR() << "STListReader::STListReader: wrong file type: " << filenames[i]; error_ = true; return; } if (file_version != kSTListFileVersion) { - FSTERROR() << "STListReader::STTableReader: wrong file version: " + FSTERROR() << "STListReader::STListReader: wrong file version: " << filenames[i]; error_ = true; return; @@ -161,7 +161,7 @@ class STListReader { if (!key.empty()) heap_.push(make_pair(key, i)); if (!*streams_[i]) { - FSTERROR() << "STTableReader: error reading file: " << sources_[i]; + FSTERROR() << "STListReader: error reading file: " << sources_[i]; error_ = true; return; } @@ -170,7 +170,7 @@ class STListReader { size_t current = heap_.top().second; entry_ = entry_reader_(*streams_[current]); if (!entry_ || !*streams_[current]) { - FSTERROR() << "STTableReader: error reading entry for key: " + FSTERROR() << "STListReader: error reading entry for key: " << heap_.top().first << ", file: " << sources_[current]; error_ = true; } @@ -219,7 +219,7 @@ class STListReader { heap_.pop(); ReadType(*(streams_[current]), &key); if (!*streams_[current]) { - FSTERROR() << "STTableReader: error reading file: " + FSTERROR() << "STListReader: error reading file: " << sources_[current]; error_ = true; return; @@ -233,7 +233,7 @@ class STListReader { delete entry_; entry_ = entry_reader_(*streams_[current]); if (!entry_ || !*streams_[current]) { - FSTERROR() << "STTableReader: error reading entry for key: " + FSTERROR() << "STListReader: error reading entry for key: " << heap_.top().first << ", file: " << sources_[current]; error_ = true; } @@ -267,8 +267,8 @@ class STListReader { // String-type list header reading function template on the entry header // type 'H' having a member function: // Read(istream &strm, const string &filename); -// Checks that 'filename' is an STTable and call the H::Read() on the last -// entry in the STTable. +// Checks that 'filename' is an STList and call the H::Read() on the last +// entry in the STList. // Does not support reading from stdin. template <class H> bool ReadSTListHeader(const string &filename, H *header) { @@ -281,18 +281,18 @@ bool ReadSTListHeader(const string &filename, H *header) { ReadType(strm, &magic_number); ReadType(strm, &file_version); if (magic_number != kSTListMagicNumber) { - LOG(ERROR) << "ReadSTTableHeader: wrong file type: " << filename; + LOG(ERROR) << "ReadSTListHeader: wrong file type: " << filename; return false; } if (file_version != kSTListFileVersion) { - LOG(ERROR) << "ReadSTTableHeader: wrong file version: " << filename; + LOG(ERROR) << "ReadSTListHeader: wrong file version: " << filename; return false; } string key; ReadType(strm, &key); header->Read(strm, filename + ":" + key); if (!strm) { - LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename; + LOG(ERROR) << "ReadSTListHeader: error reading file: " << filename; return false; } return true; diff --git a/src/include/fst/extensions/ngram/ngram-fst.h b/src/include/fst/extensions/ngram/ngram-fst.h index eee664a..873ae6a 100644 --- a/src/include/fst/extensions/ngram/ngram-fst.h +++ b/src/include/fst/extensions/ngram/ngram-fst.h @@ -26,6 +26,7 @@ using std::vector; #include <fst/compat.h> #include <fst/fstlib.h> +#include <fst/mapped-file.h> #include <fst/extensions/ngram/bitmap-index.h> // NgramFst implements a n-gram language model based upon the LOUDS data @@ -76,7 +77,7 @@ class NGramFstImpl : public FstImpl<A> { typedef typename A::StateId StateId; typedef typename A::Weight Weight; - NGramFstImpl() : data_(0), owned_(false) { + NGramFstImpl() : data_region_(0), data_(0), owned_(false) { SetType("ngram"); SetInputSymbols(NULL); SetOutputSymbols(NULL); @@ -89,6 +90,7 @@ class NGramFstImpl : public FstImpl<A> { if (owned_) { delete [] data_; } + delete data_region_; } static NGramFstImpl<A>* Read(istream &strm, // NOLINT @@ -104,7 +106,8 @@ class NGramFstImpl : public FstImpl<A> { strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures)); strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final)); size_t size = Storage(num_states, num_futures, num_final); - char* data = new char[size]; + MappedFile *data_region = MappedFile::Allocate(size); + char *data = reinterpret_cast<char *>(data_region->mutable_data()); // Copy num_states, num_futures and num_final back into data. memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states)); memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures), @@ -116,7 +119,7 @@ class NGramFstImpl : public FstImpl<A> { delete impl; return NULL; } - impl->Init(data, true /* owned */); + impl->Init(data, false, data_region); return impl; } @@ -126,7 +129,7 @@ class NGramFstImpl : public FstImpl<A> { hdr.SetStart(Start()); hdr.SetNumStates(num_states_); WriteHeader(strm, opts, kFileVersion, &hdr); - strm.write(data_, Storage(num_states_, num_futures_, num_final_)); + strm.write(data_, StorageSize()); return strm; } @@ -223,11 +226,23 @@ class NGramFstImpl : public FstImpl<A> { // Access to the underlying representation const char* GetData(size_t* data_size) const { - *data_size = Storage(num_states_, num_futures_, num_final_); + *data_size = StorageSize(); return data_; } - void Init(const char* data, bool owned); + void Init(const char* data, bool owned, MappedFile *file = 0); + + const vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const { + SetInstFuture(s, inst); + SetInstContext(inst); + return inst->context_; + } + + size_t StorageSize() const { + return Storage(num_states_, num_futures_, num_final_); + } + + void GetStates(const vector<Label>& context, vector<StateId> *states) const; private: StateId Transition(const vector<Label> &context, Label future) const; @@ -242,6 +257,7 @@ class NGramFstImpl : public FstImpl<A> { // Minimum file format version supported. static const int kMinFileVersion = 4; + MappedFile *data_region_; const char* data_; bool owned_; // True if we own data_ uint64 num_states_, num_futures_, num_final_; @@ -261,7 +277,7 @@ class NGramFstImpl : public FstImpl<A> { template<typename A> NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) - : data_(0), owned_(false) { + : data_region_(0), data_(0), owned_(false) { typedef A Arc; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; @@ -286,12 +302,16 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) // epsilons. StateId unigram = fst.Start(); while (1) { - ArcIterator<Fst<A> > aiter(fst, unigram); - if (aiter.Done()) { - FSTERROR() << "Start state has no arcs"; + if (unigram == kNoStateId) { + FSTERROR() << "Could not identify unigram state."; SetProperties(kError, kError); return; } + ArcIterator<Fst<A> > aiter(fst, unigram); + if (aiter.Done()) { + LOG(WARNING) << "Unigram state " << unigram << " has no arcs."; + break; + } if (aiter.Value().ilabel != 0) break; unigram = aiter.Value().nextstate; } @@ -385,7 +405,8 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) Weight weight; Label label = kNoLabel; const size_t storage = Storage(num_states, num_futures, num_final); - char* data = new char[storage]; + MappedFile *data_region = MappedFile::Allocate(storage); + char *data = reinterpret_cast<char *>(data_region->mutable_data()); memset(data, 0, storage); size_t offset = 0; memcpy(data + offset, reinterpret_cast<char *>(&num_states), @@ -482,14 +503,17 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) return; } - Init(data, true /* owned */); + Init(data, false, data_region); } template<typename A> -inline void NGramFstImpl<A>::Init(const char* data, bool owned) { +inline void NGramFstImpl<A>::Init(const char* data, bool owned, + MappedFile *data_region) { if (owned_) { delete [] data_; } + delete data_region_; + data_region_ = data_region; owned_ = owned; data_ = data; size_t offset = 0; @@ -507,7 +531,7 @@ inline void NGramFstImpl<A>::Init(const char* data, bool owned) { future_ = reinterpret_cast<const uint64*>(data_ + offset); offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits); final_ = reinterpret_cast<const uint64*>(data_ + offset); - offset += BitmapIndex::StorageSize(num_states_ + 1) * sizeof(bits); + offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits); context_words_ = reinterpret_cast<const Label*>(data_ + offset); offset += (num_states_ + 1) * sizeof(*context_words_); future_words_ = reinterpret_cast<const Label*>(data_ + offset); @@ -538,10 +562,10 @@ inline void NGramFstImpl<A>::Init(const char* data, bool owned) { template<typename A> inline typename A::StateId NGramFstImpl<A>::Transition( const vector<Label> &context, Label future) const { - size_t num_children = root_num_children_; const Label *children = root_children_; - const Label *loc = lower_bound(children, children + num_children, future); - if (loc == children + num_children || *loc != future) { + const Label *loc = lower_bound(children, children + root_num_children_, + future); + if (loc == children + root_num_children_ || *loc != future) { return context_index_.Rank1(0); } size_t node = root_first_child_ + loc - children; @@ -551,7 +575,6 @@ inline typename A::StateId NGramFstImpl<A>::Transition( return context_index_.Rank1(node); } size_t last_child = context_index_.Select0(node_rank + 1) - 1; - num_children = last_child - first_child + 1; for (int word = context.size() - 1; word >= 0; --word) { children = context_words_ + context_index_.Rank1(first_child); loc = lower_bound(children, children + last_child - first_child + 1, @@ -569,6 +592,42 @@ inline typename A::StateId NGramFstImpl<A>::Transition( return context_index_.Rank1(node); } +template<typename A> +inline void NGramFstImpl<A>::GetStates( + const vector<Label> &context, + vector<typename A::StateId>* states) const { + states->clear(); + states->push_back(0); + typename vector<Label>::const_reverse_iterator cit = context.rbegin(); + const Label *children = root_children_; + const Label *loc = lower_bound(children, children + root_num_children_, *cit); + if (loc == children + root_num_children_ || *loc != *cit) return; + size_t node = root_first_child_ + loc - children; + states->push_back(context_index_.Rank1(node)); + if (context.size() == 1) return; + size_t node_rank = context_index_.Rank1(node); + size_t first_child = context_index_.Select0(node_rank) + 1; + ++cit; + if (context_index_.Get(first_child) != false) { + size_t last_child = context_index_.Select0(node_rank + 1) - 1; + while (cit != context.rend()) { + children = context_words_ + context_index_.Rank1(first_child); + loc = lower_bound(children, children + last_child - first_child + 1, + *cit); + if (loc == children + last_child - first_child + 1 || *loc != *cit) { + break; + } + ++cit; + node = first_child + loc - children; + states->push_back(context_index_.Rank1(node)); + node_rank = context_index_.Rank1(node); + first_child = context_index_.Select0(node_rank) + 1; + if (context_index_.Get(first_child) == false) break; + last_child = context_index_.Select0(node_rank + 1) - 1; + } + } +} + /*****************************************************************************/ template<class A> class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { @@ -597,7 +656,7 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { // Non-standard constructor to initialize NGramFst directly from data. NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) { - GetImpl()->Init(data, owned); + GetImpl()->Init(data, owned, NULL); } // Get method that gets the data associated with Init(). @@ -605,6 +664,16 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { return GetImpl()->GetData(data_size); } + const vector<Label> GetContext(StateId s) const { + return GetImpl()->GetContext(s, &inst_); + } + + // Consumes as much as possible of context from right to left, returns the + // the states corresponding to the increasingly conditioned input sequence. + void GetStates(const vector<Label>& context, vector<StateId> *state) const { + return GetImpl()->GetStates(context, state); + } + virtual size_t NumArcs(StateId s) const { return GetImpl()->NumArcs(s, &inst_); } @@ -650,6 +719,10 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { return new NGramFstMatcher<A>(*this, match_type); } + size_t StorageSize() const { + return GetImpl()->StorageSize(); + } + private: explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {} diff --git a/src/include/fst/extensions/pdt/compose.h b/src/include/fst/extensions/pdt/compose.h index 364d76f..c856c6d 100644 --- a/src/include/fst/extensions/pdt/compose.h +++ b/src/include/fst/extensions/pdt/compose.h @@ -21,82 +21,469 @@ #ifndef FST_EXTENSIONS_PDT_COMPOSE_H__ #define FST_EXTENSIONS_PDT_COMPOSE_H__ +#include <list> + +#include <fst/extensions/pdt/pdt.h> #include <fst/compose.h> namespace fst { +// Return paren arcs for Find(kNoLabel). +const uint32 kParenList = 0x00000001; + +// Return a kNolabel loop for Find(paren). +const uint32 kParenLoop = 0x00000002; + +// This class is a matcher that treats parens as multi-epsilon labels. +// It is most efficient if the parens are in a range non-overlapping with +// the non-paren labels. +template <class F> +class ParenMatcher { + public: + typedef SortedMatcher<F> M; + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + ParenMatcher(const FST &fst, MatchType match_type, + uint32 flags = (kParenLoop | kParenList)) + : matcher_(fst, match_type), + match_type_(match_type), + flags_(flags) { + if (match_type == MATCH_INPUT) { + loop_.ilabel = kNoLabel; + loop_.olabel = 0; + } else { + loop_.ilabel = 0; + loop_.olabel = kNoLabel; + } + loop_.weight = Weight::One(); + loop_.nextstate = kNoStateId; + } + + ParenMatcher(const ParenMatcher<F> &matcher, bool safe = false) + : matcher_(matcher.matcher_, safe), + match_type_(matcher.match_type_), + flags_(matcher.flags_), + open_parens_(matcher.open_parens_), + close_parens_(matcher.close_parens_), + loop_(matcher.loop_) { + loop_.nextstate = kNoStateId; + } + + ParenMatcher<F> *Copy(bool safe = false) const { + return new ParenMatcher<F>(*this, safe); + } + + MatchType Type(bool test) const { return matcher_.Type(test); } + + void SetState(StateId s) { + matcher_.SetState(s); + loop_.nextstate = s; + } + + bool Find(Label match_label); + + bool Done() const { + return done_; + } + + const Arc& Value() const { + return paren_loop_ ? loop_ : matcher_.Value(); + } + + void Next(); + + const FST &GetFst() const { return matcher_.GetFst(); } + + uint64 Properties(uint64 props) const { return matcher_.Properties(props); } + + uint32 Flags() const { return matcher_.Flags(); } + + void AddOpenParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad open paren label: 0"; + } else { + open_parens_.Insert(label); + } + } + + void AddCloseParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad close paren label: 0"; + } else { + close_parens_.Insert(label); + } + } + + void RemoveOpenParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad open paren label: 0"; + } else { + open_parens_.Erase(label); + } + } + + void RemoveCloseParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad close paren label: 0"; + } else { + close_parens_.Erase(label); + } + } + + void ClearOpenParens() { + open_parens_.Clear(); + } + + void ClearCloseParens() { + close_parens_.Clear(); + } + + bool IsOpenParen(Label label) const { + return open_parens_.Member(label); + } + + bool IsCloseParen(Label label) const { + return close_parens_.Member(label); + } + + private: + // Advances matcher to next open paren if it exists, returning true. + // O.w. returns false. + bool NextOpenParen(); + + // Advances matcher to next open paren if it exists, returning true. + // O.w. returns false. + bool NextCloseParen(); + + M matcher_; + MatchType match_type_; // Type of match to perform + uint32 flags_; + + // open paren label set + CompactSet<Label, kNoLabel> open_parens_; + + // close paren label set + CompactSet<Label, kNoLabel> close_parens_; + + + bool open_paren_list_; // Matching open paren list + bool close_paren_list_; // Matching close paren list + bool paren_loop_; // Current arc is the implicit paren loop + mutable Arc loop_; // For non-consuming symbols + bool done_; // Matching done + + void operator=(const ParenMatcher<F> &); // Disallow +}; + +template <class M> inline +bool ParenMatcher<M>::Find(Label match_label) { + open_paren_list_ = false; + close_paren_list_ = false; + paren_loop_ = false; + done_ = false; + + // Returns all parenthesis arcs + if (match_label == kNoLabel && (flags_ & kParenList)) { + if (open_parens_.LowerBound() != kNoLabel) { + matcher_.LowerBound(open_parens_.LowerBound()); + open_paren_list_ = NextOpenParen(); + if (open_paren_list_) return true; + } + if (close_parens_.LowerBound() != kNoLabel) { + matcher_.LowerBound(close_parens_.LowerBound()); + close_paren_list_ = NextCloseParen(); + if (close_paren_list_) return true; + } + } + + // Returns 'implicit' paren loop + if (match_label > 0 && (flags_ & kParenLoop) && + (IsOpenParen(match_label) || IsCloseParen(match_label))) { + paren_loop_ = true; + return true; + } + + // Returns all other labels + if (matcher_.Find(match_label)) + return true; + + done_ = true; + return false; +} + +template <class F> inline +void ParenMatcher<F>::Next() { + if (paren_loop_) { + paren_loop_ = false; + done_ = true; + } else if (open_paren_list_) { + matcher_.Next(); + open_paren_list_ = NextOpenParen(); + if (open_paren_list_) return; + + if (close_parens_.LowerBound() != kNoLabel) { + matcher_.LowerBound(close_parens_.LowerBound()); + close_paren_list_ = NextCloseParen(); + if (close_paren_list_) return; + } + done_ = !matcher_.Find(kNoLabel); + } else if (close_paren_list_) { + matcher_.Next(); + close_paren_list_ = NextCloseParen(); + if (close_paren_list_) return; + done_ = !matcher_.Find(kNoLabel); + } else { + matcher_.Next(); + done_ = matcher_.Done(); + } +} + +// Advances matcher to next open paren if it exists, returning true. +// O.w. returns false. +template <class F> inline +bool ParenMatcher<F>::NextOpenParen() { + for (; !matcher_.Done(); matcher_.Next()) { + Label label = match_type_ == MATCH_INPUT ? + matcher_.Value().ilabel : matcher_.Value().olabel; + if (label > open_parens_.UpperBound()) + return false; + if (IsOpenParen(label)) + return true; + } + return false; +} + +// Advances matcher to next close paren if it exists, returning true. +// O.w. returns false. +template <class F> inline +bool ParenMatcher<F>::NextCloseParen() { + for (; !matcher_.Done(); matcher_.Next()) { + Label label = match_type_ == MATCH_INPUT ? + matcher_.Value().ilabel : matcher_.Value().olabel; + if (label > close_parens_.UpperBound()) + return false; + if (IsCloseParen(label)) + return true; + } + return false; +} + + +template <class F> +class ParenFilter { + public: + typedef typename F::FST1 FST1; + typedef typename F::FST2 FST2; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef typename F::Matcher1 Matcher1; + typedef typename F::Matcher2 Matcher2; + typedef typename F::FilterState FilterState1; + typedef StateId StackId; + typedef PdtStack<StackId, Label> ParenStack; + typedef IntegerFilterState<StackId> FilterState2; + typedef PairFilterState<FilterState1, FilterState2> FilterState; + typedef ParenFilter<F> Filter; + + ParenFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = 0, Matcher2 *matcher2 = 0, + const vector<pair<Label, Label> > *parens = 0, + bool expand = false, bool keep_parens = true) + : filter_(fst1, fst2, matcher1, matcher2), + parens_(parens ? *parens : vector<pair<Label, Label> >()), + expand_(expand), + keep_parens_(keep_parens), + f_(FilterState::NoState()), + stack_(parens_), + paren_id_(-1) { + if (parens) { + for (size_t i = 0; i < parens->size(); ++i) { + const pair<Label, Label> &p = (*parens)[i]; + parens_.push_back(p); + GetMatcher1()->AddOpenParen(p.first); + GetMatcher2()->AddOpenParen(p.first); + if (!expand_) { + GetMatcher1()->AddCloseParen(p.second); + GetMatcher2()->AddCloseParen(p.second); + } + } + } + } + + ParenFilter(const Filter &filter, bool safe = false) + : filter_(filter.filter_, safe), + parens_(filter.parens_), + expand_(filter.expand_), + keep_parens_(filter.keep_parens_), + f_(FilterState::NoState()), + stack_(filter.parens_), + paren_id_(-1) { } + + FilterState Start() const { + return FilterState(filter_.Start(), FilterState2(0)); + } + + void SetState(StateId s1, StateId s2, const FilterState &f) { + f_ = f; + filter_.SetState(s1, s2, f_.GetState1()); + if (!expand_) + return; + + ssize_t paren_id = stack_.Top(f.GetState2().GetState()); + if (paren_id != paren_id_) { + if (paren_id_ != -1) { + GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second); + GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second); + } + paren_id_ = paren_id; + if (paren_id_ != -1) { + GetMatcher1()->AddCloseParen(parens_[paren_id_].second); + GetMatcher2()->AddCloseParen(parens_[paren_id_].second); + } + } + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + FilterState1 f1 = filter_.FilterArc(arc1, arc2); + const FilterState2 &f2 = f_.GetState2(); + if (f1 == FilterState1::NoState()) + return FilterState::NoState(); + + if (arc1->olabel == kNoLabel && arc2->ilabel) { // arc2 parentheses + if (keep_parens_) { + arc1->ilabel = arc2->ilabel; + } else if (arc2->ilabel) { + arc2->olabel = arc1->ilabel; + } + return FilterParen(arc2->ilabel, f1, f2); + } else if (arc2->ilabel == kNoLabel && arc1->olabel) { // arc1 parentheses + if (keep_parens_) { + arc2->olabel = arc1->olabel; + } else { + arc1->ilabel = arc2->olabel; + } + return FilterParen(arc1->olabel, f1, f2); + } else { + return FilterState(f1, f2); + } + } + + void FilterFinal(Weight *w1, Weight *w2) const { + if (f_.GetState2().GetState() != 0) + *w1 = Weight::Zero(); + filter_.FilterFinal(w1, w2); + } + + // Return resp matchers. Ownership stays with filter. + Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); } + Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); } + + uint64 Properties(uint64 iprops) const { + uint64 oprops = filter_.Properties(iprops); + return oprops & kILabelInvariantProperties & kOLabelInvariantProperties; + } + + private: + const FilterState FilterParen(Label label, const FilterState1 &f1, + const FilterState2 &f2) const { + if (!expand_) + return FilterState(f1, f2); + + StackId stack_id = stack_.Find(f2.GetState(), label); + if (stack_id < 0) { + return FilterState::NoState(); + } else { + return FilterState(f1, FilterState2(stack_id)); + } + } + + F filter_; + vector<pair<Label, Label> > parens_; + bool expand_; // Expands to FST + bool keep_parens_; // Retains parentheses in output + FilterState f_; // Current filter state + mutable ParenStack stack_; + ssize_t paren_id_; +}; + // Class to setup composition options for PDT composition. // Default is for the PDT as the first composition argument. template <class Arc, bool left_pdt = true> -class PdtComposeOptions : public +class PdtComposeFstOptions : public ComposeFstOptions<Arc, - MultiEpsMatcher< Matcher<Fst<Arc> > >, - MultiEpsFilter<AltSequenceComposeFilter< - MultiEpsMatcher< - Matcher<Fst<Arc> > > > > > { + ParenMatcher< Fst<Arc> >, + ParenFilter<AltSequenceComposeFilter< + ParenMatcher< Fst<Arc> > > > > { public: typedef typename Arc::Label Label; - typedef MultiEpsMatcher< Matcher<Fst<Arc> > > PdtMatcher; - typedef MultiEpsFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter; + typedef ParenMatcher< Fst<Arc> > PdtMatcher; + typedef ParenFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter; typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions; using COptions::matcher1; using COptions::matcher2; using COptions::filter; - PdtComposeOptions(const Fst<Arc> &ifst1, + PdtComposeFstOptions(const Fst<Arc> &ifst1, const vector<pair<Label, Label> > &parens, - const Fst<Arc> &ifst2) { - matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsList); - matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsLoop); - - // Treat parens as multi-epsilons when composing. - for (size_t i = 0; i < parens.size(); ++i) { - matcher1->AddMultiEpsLabel(parens[i].first); - matcher1->AddMultiEpsLabel(parens[i].second); - matcher2->AddMultiEpsLabel(parens[i].first); - matcher2->AddMultiEpsLabel(parens[i].second); - } + const Fst<Arc> &ifst2, bool expand = false, + bool keep_parens = true) { + matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenList); + matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenLoop); - filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true); + filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, + expand, keep_parens); } }; // Class to setup composition options for PDT with FST composition. // Specialization is for the FST as the first composition argument. template <class Arc> -class PdtComposeOptions<Arc, false> : public +class PdtComposeFstOptions<Arc, false> : public ComposeFstOptions<Arc, - MultiEpsMatcher< Matcher<Fst<Arc> > >, - MultiEpsFilter<SequenceComposeFilter< - MultiEpsMatcher< - Matcher<Fst<Arc> > > > > > { + ParenMatcher< Fst<Arc> >, + ParenFilter<SequenceComposeFilter< + ParenMatcher< Fst<Arc> > > > > { public: typedef typename Arc::Label Label; - typedef MultiEpsMatcher< Matcher<Fst<Arc> > > PdtMatcher; - typedef MultiEpsFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter; + typedef ParenMatcher< Fst<Arc> > PdtMatcher; + typedef ParenFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter; typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions; using COptions::matcher1; using COptions::matcher2; using COptions::filter; - PdtComposeOptions(const Fst<Arc> &ifst1, - const Fst<Arc> &ifst2, - const vector<pair<Label, Label> > &parens) { - matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsLoop); - matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsList); - - // Treat parens as multi-epsilons when composing. - for (size_t i = 0; i < parens.size(); ++i) { - matcher1->AddMultiEpsLabel(parens[i].first); - matcher1->AddMultiEpsLabel(parens[i].second); - matcher2->AddMultiEpsLabel(parens[i].first); - matcher2->AddMultiEpsLabel(parens[i].second); - } + PdtComposeFstOptions(const Fst<Arc> &ifst1, + const Fst<Arc> &ifst2, + const vector<pair<Label, Label> > &parens, + bool expand = false, bool keep_parens = true) { + matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenLoop); + matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenList); - filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true); + filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, + expand, keep_parens); } }; +enum PdtComposeFilter { + PAREN_FILTER, // Bar-Hillel construction; keeps parentheses + EXPAND_FILTER, // Bar-Hillel + expansion; removes parentheses + EXPAND_PAREN_FILTER, // Bar-Hillel + expansion; keeps parentheses +}; + +struct PdtComposeOptions { + bool connect; // Connect output + PdtComposeFilter filter_type; // Which pre-defined filter to use + + explicit PdtComposeOptions(bool c, PdtComposeFilter ft = PAREN_FILTER) + : connect(c), filter_type(ft) {} + PdtComposeOptions() : connect(true), filter_type(PAREN_FILTER) {} +}; // Composes pushdown transducer (PDT) encoded as an FST (1st arg) and // an FST (2nd arg) with the result also a PDT encoded as an Fst. (3rd arg). @@ -110,16 +497,17 @@ void Compose(const Fst<Arc> &ifst1, typename Arc::Label> > &parens, const Fst<Arc> &ifst2, MutableFst<Arc> *ofst, - const ComposeOptions &opts = ComposeOptions()) { - - PdtComposeOptions<Arc, true> copts(ifst1, parens, ifst2); + const PdtComposeOptions &opts = PdtComposeOptions()) { + bool expand = opts.filter_type != PAREN_FILTER; + bool keep_parens = opts.filter_type != EXPAND_FILTER; + PdtComposeFstOptions<Arc, true> copts(ifst1, parens, ifst2, + expand, keep_parens); copts.gc_limit = 0; *ofst = ComposeFst<Arc>(ifst1, ifst2, copts); if (opts.connect) Connect(ofst); } - // Composes an FST (1st arg) and pushdown transducer (PDT) encoded as // an FST (2nd arg) with the result also a PDT encoded as an Fst (3rd arg). // In the PDTs, some transitions are labeled with open or close @@ -132,9 +520,11 @@ void Compose(const Fst<Arc> &ifst1, const vector<pair<typename Arc::Label, typename Arc::Label> > &parens, MutableFst<Arc> *ofst, - const ComposeOptions &opts = ComposeOptions()) { - - PdtComposeOptions<Arc, false> copts(ifst1, ifst2, parens); + const PdtComposeOptions &opts = PdtComposeOptions()) { + bool expand = opts.filter_type != PAREN_FILTER; + bool keep_parens = opts.filter_type != EXPAND_FILTER; + PdtComposeFstOptions<Arc, false> copts(ifst1, ifst2, parens, + expand, keep_parens); copts.gc_limit = 0; *ofst = ComposeFst<Arc>(ifst1, ifst2, copts); if (opts.connect) diff --git a/src/include/fst/extensions/pdt/pdt.h b/src/include/fst/extensions/pdt/pdt.h index 6649f55..c56afbd 100644 --- a/src/include/fst/extensions/pdt/pdt.h +++ b/src/include/fst/extensions/pdt/pdt.h @@ -27,6 +27,7 @@ using std::tr1::unordered_multimap; #include <map> #include <set> +#include <fst/compat.h> #include <fst/state-table.h> #include <fst/fst.h> diff --git a/src/include/fst/extensions/pdt/pdtscript.h b/src/include/fst/extensions/pdt/pdtscript.h index c2a1cf4..84bb27e 100644 --- a/src/include/fst/extensions/pdt/pdtscript.h +++ b/src/include/fst/extensions/pdt/pdtscript.h @@ -48,7 +48,7 @@ typedef args::Package<const FstClass &, const FstClass &, const vector<pair<int64, int64> >&, MutableFstClass *, - const ComposeOptions &, + const PdtComposeOptions &, bool> PdtComposeArgs; template<class Arc> @@ -76,7 +76,7 @@ void PdtCompose(const FstClass & ifst1, const FstClass & ifst2, const vector<pair<int64, int64> > &parens, MutableFstClass *ofst, - const ComposeOptions &copts, + const PdtComposeOptions &copts, bool left_pdt); // PDT EXPAND diff --git a/src/include/fst/extensions/pdt/replace.h b/src/include/fst/extensions/pdt/replace.h index a85d0fe..9081400 100644 --- a/src/include/fst/extensions/pdt/replace.h +++ b/src/include/fst/extensions/pdt/replace.h @@ -21,6 +21,10 @@ #ifndef FST_EXTENSIONS_PDT_REPLACE_H__ #define FST_EXTENSIONS_PDT_REPLACE_H__ +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; + #include <fst/replace.h> namespace fst { @@ -62,11 +66,14 @@ void Replace(const vector<pair<typename Arc::Label, label2id[ifst_array[i].first] = i; Label max_label = kNoLabel; + size_t max_non_term_count = 0; - deque<size_t> non_term_queue; // Queue of non-terminals to replace - unordered_set<Label> non_term_set; // Set of non-terminals to replace + // Queue of non-terminals to replace + deque<size_t> non_term_queue; + // Map of non-terminals to replace to count + unordered_map<Label, size_t> non_term_map; non_term_queue.push_back(root); - non_term_set.insert(root); + non_term_map[root] = 1;; // PDT state corr. to ith replace FST start state. vector<StateId> fst_start(ifst_array.size(), kNoLabel); @@ -107,10 +114,11 @@ void Replace(const vector<pair<typename Arc::Label, size_t nfst_id = it->second; if (ifst_array[nfst_id].second->Start() == -1) continue; - if (non_term_set.count(arc.olabel) == 0) { + size_t count = non_term_map[arc.olabel]++; + if (count == 0) non_term_queue.push_back(arc.olabel); - non_term_set.insert(arc.olabel); - } + if (count > max_non_term_count) + max_non_term_count = count; } arc.nextstate += soff; ofst->AddArc(os, arc); @@ -134,7 +142,8 @@ void Replace(const vector<pair<typename Arc::Label, // # of parenthesis pairs per fst. vector<size_t> nparens(ifst_array.size(), 0); // Initial open parenthesis label - Label first_paren = max_label + 1; + Label first_open_paren = max_label + 1; + Label first_close_paren = max_label + max_non_term_count + 1; for (StateIterator< Fst<Arc> > siter(*ofst); !siter.Done(); siter.Next()) { @@ -158,8 +167,8 @@ void Replace(const vector<pair<typename Arc::Label, close_paren = (*parens)[paren_id].second; } else { size_t paren_id = nparens[nfst_id]++; - open_paren = first_paren + 2 * paren_id; - close_paren = open_paren + 1; + open_paren = first_open_paren + paren_id; + close_paren = first_close_paren + paren_id; paren_map[paren_key] = paren_id; if (paren_id >= parens->size()) parens->push_back(make_pair(open_paren, close_paren)); diff --git a/src/include/fst/factor-weight.h b/src/include/fst/factor-weight.h index 97440e1..685155c 100644 --- a/src/include/fst/factor-weight.h +++ b/src/include/fst/factor-weight.h @@ -25,7 +25,6 @@ #include <tr1/unordered_map> using std::tr1::unordered_map; using std::tr1::unordered_multimap; -#include <fst/slist.h> #include <string> #include <utility> using std::pair; using std::make_pair; diff --git a/src/include/fst/fst-decl.h b/src/include/fst/fst-decl.h index 0e2cdf1..f27ded8 100644 --- a/src/include/fst/fst-decl.h +++ b/src/include/fst/fst-decl.h @@ -56,7 +56,6 @@ template <class A> class ClosureFst; template <class A> class ComposeFst; template <class A> class ConcatFst; template <class A> class DeterminizeFst; -template <class A> class DeterminizeFst; template <class A> class DifferenceFst; template <class A> class IntersectFst; template <class A> class InvertFst; diff --git a/src/include/fst/fst.h b/src/include/fst/fst.h index dd11e4f..150fc4e 100644 --- a/src/include/fst/fst.h +++ b/src/include/fst/fst.h @@ -53,6 +53,11 @@ template <class A> class ArcIteratorData; template <class A> class MatcherBase; struct FstReadOptions { + // FileReadMode(s) are advisory, there are many conditions than prevent a + // file from being mapped, READ mode will be selected in these cases with + // a warning indicating why it was chosen. + enum FileReadMode { READ, MAP }; + string source; // Where you're reading from const FstHeader *header; // Pointer to Fst header. If non-zero, use // this info (don't read a stream header) @@ -60,19 +65,20 @@ struct FstReadOptions { // this info (read and skip stream isymbols) const SymbolTable* osymbols; // Pointer to output symbols. If non-zero, use // this info (read and skip stream osymbols) + FileReadMode mode; // Read or map files (advisory, if possible) - explicit FstReadOptions(const string& src = "<unspecfied>", + explicit FstReadOptions(const string& src = "<unspecified>", const FstHeader *hdr = 0, const SymbolTable* isym = 0, - const SymbolTable* osym = 0) - : source(src), header(hdr), isymbols(isym), osymbols(osym) {} + const SymbolTable* osym = 0); explicit FstReadOptions(const string& src, const SymbolTable* isym, - const SymbolTable* osym = 0) - : source(src), header(0), isymbols(isym), osymbols(osym) {} -}; + const SymbolTable* osym = 0); + // Helper function to convert strings FileReadModes into their enum value. + static FileReadMode ReadMode(const string &mode); +}; struct FstWriteOptions { string source; // Where you're writing to diff --git a/src/include/fst/interval-set.h b/src/include/fst/interval-set.h index cf6ac54..c4362f2 100644 --- a/src/include/fst/interval-set.h +++ b/src/include/fst/interval-set.h @@ -81,12 +81,12 @@ class IntervalSet { const vector<Interval> *Intervals() const { return &intervals_; } - const bool Empty() const { return intervals_.empty(); } + bool Empty() const { return intervals_.empty(); } - const T Size() const { return intervals_.size(); } + T Size() const { return intervals_.size(); } // Number of points in the intervals (undefined if not normalized). - const T Count() const { return count_; } + T Count() const { return count_; } void Clear() { intervals_.clear(); diff --git a/src/include/fst/mapped-file.h b/src/include/fst/mapped-file.h new file mode 100644 index 0000000..d61bc14 --- /dev/null +++ b/src/include/fst/mapped-file.h @@ -0,0 +1,83 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: sorenj@google.com (Jeffrey Sorensen) + +#ifndef FST_LIB_MAPPED_FILE_H_ +#define FST_LIB_MAPPED_FILE_H_ + +#include <unistd.h> +#include <sys/mman.h> + +#include <fst/fst.h> +#include <iostream> +#include <fstream> +#include <sstream> + +DECLARE_int32(fst_arch_alignment); // defined in mapped-file.h + +namespace fst { + +// A memory region is a simple abstraction for allocated memory or data from +// mmap'ed files. If mmap equals NULL, then data represents an owned region of +// size bytes. Otherwise, mmap and size refer to the mapping and data is a +// casted pointer to a region contained within [mmap, mmap + size). +// If size is 0, then mmap refers and data refer to a block of memory managed +// externally by some other allocator. +struct MemoryRegion { + void *data; + void *mmap; + size_t size; +}; + +class MappedFile { + public: + virtual ~MappedFile(); + + void* mutable_data() const { + return reinterpret_cast<void*>(region_.data); + } + + const void* data() const { + return reinterpret_cast<void*>(region_.data); + } + + // Returns a MappedFile object that contains the contents of the input + // stream s starting from the current file position with size bytes. + // The file name must also be provided in the FstReadOptions as opts.source + // or else mapping will fail. If mapping is not possible, then a MappedFile + // object with a new[]'ed block of memory will be created. + static MappedFile* Map(istream* s, const FstReadOptions& opts, size_t size); + + // Creates a MappedFile object with a new[]'ed block of memory of size. + // RECOMMENDED FOR INTERNAL USE ONLY, may change in future releases. + static MappedFile* Allocate(size_t size); + + // Creates a MappedFile object pointing to a borrowed reference to data. + // This block of memory is not owned by the MappedFile object and will not + // be freed. + // RECOMMENDED FOR INTERNAL USE ONLY, may change in future releases. + static MappedFile* Borrow(void *data); + + static const int kArchAlignment; + + private: + explicit MappedFile(const MemoryRegion ®ion); + + MemoryRegion region_; + DISALLOW_COPY_AND_ASSIGN(MappedFile); +}; +} // namespace fst + +#endif // FST_LIB_MAPPED_FILE_H_ diff --git a/src/include/fst/matcher.h b/src/include/fst/matcher.h index 5ab3d26..89ed9be 100644 --- a/src/include/fst/matcher.h +++ b/src/include/fst/matcher.h @@ -75,6 +75,8 @@ namespace fst { // bool Done() const; // No more matches. // const A& Value() const; // Current arc (when !Done) // void Next(); // Advance to next arc (when !Done) +// // Initially and after SetState() the iterator methods +// // have undefined behavior until Find() is called. // // // Return matcher FST. // const F& GetFst() const; @@ -223,13 +225,44 @@ class SortedMatcher : public MatcherBase<typename F::Arc> { loop_.nextstate = s; } - bool Find(Label match_label); + bool Find(Label match_label) { + exact_match_ = true; + if (error_) { + current_loop_ = false; + match_label_ = kNoLabel; + return false; + } + current_loop_ = match_label == 0; + match_label_ = match_label == kNoLabel ? 0 : match_label; + if (Search()) { + return true; + } else { + return current_loop_; + } + } + + // Positions matcher to the first position where inserting + // match_label would maintain the sort order. + void LowerBound(Label match_label) { + exact_match_ = false; + current_loop_ = false; + if (error_) { + match_label_ = kNoLabel; + return; + } + match_label_ = match_label; + Search(); + } + // After Find(), returns false if no more exact matches. + // After LowerBound(), returns false if no more arcs. bool Done() const { if (current_loop_) return false; if (aiter_->Done()) return true; + if (!exact_match_) + return false; aiter_->SetFlags( match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue, kArcValueFlags); @@ -261,6 +294,8 @@ class SortedMatcher : public MatcherBase<typename F::Arc> { return outprops; } + size_t Position() const { return aiter_ ? aiter_->Position() : 0; } + private: virtual void SetState_(StateId s) { SetState(s); } virtual bool Find_(Label label) { return Find(label); } @@ -268,6 +303,8 @@ class SortedMatcher : public MatcherBase<typename F::Arc> { virtual const Arc& Value_() const { return Value(); } virtual void Next_() { Next(); } + bool Search(); + const F *fst_; StateId s_; // Current state ArcIterator<F> *aiter_; // Iterator for current state @@ -277,20 +314,16 @@ class SortedMatcher : public MatcherBase<typename F::Arc> { size_t narcs_; // Current state arc count Arc loop_; // For non-consuming symbols bool current_loop_; // Current arc is the implicit loop + bool exact_match_; // Exact match or lower bound? bool error_; // Error encountered void operator=(const SortedMatcher<F> &); // Disallow }; +// Returns true iff match to match_label_. Positions arc iterator at +// lower bound regardless. template <class F> inline -bool SortedMatcher<F>::Find(Label match_label) { - if (error_) { - current_loop_ = false; - match_label_ = kNoLabel; - return false; - } - current_loop_ = match_label == 0; - match_label_ = match_label == kNoLabel ? 0 : match_label; +bool SortedMatcher<F>::Search() { aiter_->SetFlags( match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue, kArcValueFlags); @@ -321,7 +354,8 @@ bool SortedMatcher<F>::Find(Label match_label) { return true; } } - return current_loop_; + aiter_->Seek(low); + return false; } else { // Linear search for match. for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) { @@ -333,7 +367,7 @@ bool SortedMatcher<F>::Find(Label match_label) { if (label > match_label_) break; } - return current_loop_; + return false; } } @@ -1047,16 +1081,19 @@ class MultiEpsMatcher { } } + void RemoveMultiEpsLabel(Label label) { + if (label == 0) { + FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0"; + } else { + multi_eps_labels_.Erase(label); + } + } + void ClearMultiEpsLabels() { multi_eps_labels_.Clear(); } private: - // Specialized for 'set' - log lookup - bool IsMultiEps(const set<Label> &multi_eps_labels, Label label) const { - return multi_eps_labels.Find(label) != multi_eps_labels.end(); - } - M *matcher_; uint32 flags_; bool own_matcher_; // Does this class delete the matcher? diff --git a/src/include/fst/queue.h b/src/include/fst/queue.h index e31f087..95a082d 100644 --- a/src/include/fst/queue.h +++ b/src/include/fst/queue.h @@ -23,6 +23,7 @@ #define FST_LIB_QUEUE_H__ #include <deque> +using std::deque; #include <vector> using std::vector; @@ -791,13 +792,13 @@ struct TrivialStateEquivClass { }; -// Pruning queue discipline: Enqueues a state 's' only when its -// shortest distance (so far), as specified by 'distance', is less -// than (as specified by 'comp') the shortest distance Times() the -// 'threshold' to any state in the same equivalence class, as -// specified by the function object 'class_func'. The underlying -// queue discipline is specified by 'queue'. The ownership of 'queue' -// is given to this class. +// Distance-based pruning queue discipline: Enqueues a state 's' +// only when its shortest distance (so far), as specified by +// 'distance', is less than (as specified by 'comp') the shortest +// distance Times() the 'threshold' to any state in the same +// equivalence class, as specified by the function object +// 'class_func'. The underlying queue discipline is specified by +// 'queue'. The ownership of 'queue' is given to this class. template <typename Q, typename L, typename C> class PruneQueue : public QueueBase<typename Q::StateId> { public: @@ -884,6 +885,54 @@ class NaturalPruneQueue : }; +// Filter-based pruning queue discipline: Enqueues a state 's' only +// if allowed by the filter, specified by the function object 'state_filter'. +// The underlying queue discipline is specified by 'queue'. The ownership +// of 'queue' is given to this class. +template <typename Q, typename F> +class FilterQueue : public QueueBase<typename Q::StateId> { + public: + typedef typename Q::StateId StateId; + + FilterQueue(Q *queue, const F &state_filter) + : QueueBase<StateId>(OTHER_QUEUE), + queue_(queue), + state_filter_(state_filter) {} + + ~FilterQueue() { delete queue_; } + + StateId Head() const { return queue_->Head(); } + + // Enqueues only if allowed by state filter. + void Enqueue(StateId s) { + if (state_filter_(s)) { + queue_->Enqueue(s); + } + } + + void Dequeue() { queue_->Dequeue(); } + + void Update(StateId s) {} + bool Empty() const { return queue_->Empty(); } + void Clear() { queue_->Clear(); } + + private: + // This allows base-class virtual access to non-virtual derived- + // class members of the same name. It makes the derived class more + // efficient to use but unsafe to further derive. + virtual StateId Head_() const { return Head(); } + virtual void Enqueue_(StateId s) { Enqueue(s); } + virtual void Dequeue_() { Dequeue(); } + virtual void Update_(StateId s) { Update(s); } + virtual bool Empty_() const { return Empty(); } + virtual void Clear_() { return Clear(); } + + Q *queue_; + const F &state_filter_; // Filter to prune states + + DISALLOW_COPY_AND_ASSIGN(FilterQueue); +}; + } // namespace fst #endif diff --git a/src/include/fst/relabel.h b/src/include/fst/relabel.h index 685d42a..dc675b6 100644 --- a/src/include/fst/relabel.h +++ b/src/include/fst/relabel.h @@ -34,6 +34,10 @@ using std::vector; #include <fst/test-properties.h> +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; + namespace fst { // diff --git a/src/include/fst/script/convert.h b/src/include/fst/script/convert.h index 2c70a70..4a3ce6b 100644 --- a/src/include/fst/script/convert.h +++ b/src/include/fst/script/convert.h @@ -34,7 +34,7 @@ void Convert(ConvertArgs *args) { const string &new_type = args->args.arg2; Fst<Arc> *result = Convert(fst, new_type); - args->retval = new FstClass(result); + args->retval = new FstClass(*result); delete result; } diff --git a/src/include/fst/script/disambiguate.h b/src/include/fst/script/disambiguate.h new file mode 100644 index 0000000..e42a9c2 --- /dev/null +++ b/src/include/fst/script/disambiguate.h @@ -0,0 +1,68 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_DISAMBIGUATE_H_ +#define FST_SCRIPT_DISAMBIGUATE_H_ + +#include <fst/disambiguate.h> +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/weight-class.h> + +namespace fst { +namespace script { + +struct DisambiguateOptions { + float delta; + WeightClass weight_threshold; + int64 state_threshold; + int64 subsequential_label; + + explicit DisambiguateOptions(float d = fst::kDelta, + WeightClass w = + fst::script::WeightClass::Zero(), + int64 n = fst::kNoStateId, int64 l = 0) + : delta(d), weight_threshold(w), state_threshold(n), + subsequential_label(l) {} +}; + +typedef args::Package<const FstClass&, MutableFstClass*, + const DisambiguateOptions &> DisambiguateArgs; + +template<class Arc> +void Disambiguate(DisambiguateArgs *args) { + const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + const DisambiguateOptions &opts = args->arg3; + + fst::DisambiguateOptions<Arc> detargs; + detargs.delta = opts.delta; + detargs.weight_threshold = + *(opts.weight_threshold.GetWeight<typename Arc::Weight>()); + detargs.state_threshold = opts.state_threshold; + detargs.subsequential_label = opts.subsequential_label; + + Disambiguate(ifst, ofst, detargs); +} + +void Disambiguate(const FstClass &ifst, MutableFstClass *ofst, + const DisambiguateOptions &opts = + fst::script::DisambiguateOptions()); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_DISAMBIGUATE_H_ diff --git a/src/include/fst/script/fst-class.h b/src/include/fst/script/fst-class.h index a820c1c..fe2cf53 100644 --- a/src/include/fst/script/fst-class.h +++ b/src/include/fst/script/fst-class.h @@ -52,8 +52,8 @@ class FstClassBase { virtual const string &WeightType() const = 0; virtual const SymbolTable *InputSymbols() const = 0; virtual const SymbolTable *OutputSymbols() const = 0; - virtual void Write(const string& fname) const = 0; - virtual void Write(ostream &ostr, const FstWriteOptions &opts) const = 0; + virtual bool Write(const string& fname) const = 0; + virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const = 0; virtual uint64 Properties(uint64 mask, bool test) const = 0; virtual ~FstClassBase() { } }; @@ -82,6 +82,8 @@ class FstClassImpl : public FstClassImplBase { bool should_own = false) : impl_(should_own ? impl : impl->Copy()) { } + explicit FstClassImpl(const Fst<Arc> &impl) : impl_(impl.Copy()) { } + virtual const string &ArcType() const { return Arc::Type(); } @@ -112,12 +114,12 @@ class FstClassImpl : public FstClassImplBase { static_cast<MutableFst<Arc> *>(impl_)->SetOutputSymbols(os); } - virtual void Write(const string &fname) const { - impl_->Write(fname); + virtual bool Write(const string &fname) const { + return impl_->Write(fname); } - virtual void Write(ostream &ostr, const FstWriteOptions &opts) const { - impl_->Write(ostr, opts); + virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const { + return impl_->Write(ostr, opts); } virtual uint64 Properties(uint64 mask, bool test) const { @@ -166,10 +168,10 @@ class FstClass : public FstClassBase { } template<class Arc> - explicit FstClass(Fst<Arc> *fst) : impl_(new FstClassImpl<Arc>(fst)) { + explicit FstClass(const Fst<Arc> &fst) : impl_(new FstClassImpl<Arc>(fst)) { } - explicit FstClass(const FstClass &other) : impl_(other.impl_->Copy()) { } + FstClass(const FstClass &other) : impl_(other.impl_->Copy()) { } FstClass &operator=(const FstClass &other) { delete impl_; @@ -201,12 +203,12 @@ class FstClass : public FstClassBase { return impl_->WeightType(); } - virtual void Write(const string &fname) const { - impl_->Write(fname); + virtual bool Write(const string &fname) const { + return impl_->Write(fname); } - virtual void Write(ostream &ostr, const FstWriteOptions &opts) const { - impl_->Write(ostr, opts); + virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const { + return impl_->Write(ostr, opts); } virtual uint64 Properties(uint64 mask, bool test) const { @@ -253,7 +255,7 @@ class FstClass : public FstClassBase { if (!u) { return 0; } else { - FstClassT *r = new FstClassT(u); + FstClassT *r = new FstClassT(*u); delete u; return r; } @@ -276,7 +278,7 @@ class FstClass : public FstClassBase { class MutableFstClass : public FstClass { public: template<class Arc> - explicit MutableFstClass(MutableFst<Arc> *fst) : + explicit MutableFstClass(const MutableFst<Arc> &fst) : FstClass(fst) { } template<class Arc> @@ -294,18 +296,18 @@ class MutableFstClass : public FstClass { if (!mfst) { return 0; } else { - MutableFstClass *retval = new MutableFstClass(mfst); + MutableFstClass *retval = new MutableFstClass(*mfst); delete mfst; return retval; } } - virtual void Write(const string &fname) const { - GetImpl()->Write(fname); + virtual bool Write(const string &fname) const { + return GetImpl()->Write(fname); } - virtual void Write(ostream &ostr, const FstWriteOptions &opts) const { - GetImpl()->Write(ostr, opts); + virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const { + return GetImpl()->Write(ostr, opts); } static MutableFstClass *Read(const string &fname, bool convert = false); @@ -344,7 +346,7 @@ class VectorFstClass : public MutableFstClass { explicit VectorFstClass(const string &arc_type); template<class Arc> - explicit VectorFstClass(VectorFst<Arc> *fst) : + explicit VectorFstClass(const VectorFst<Arc> &fst) : MutableFstClass(fst) { } template<class Arc> @@ -354,7 +356,7 @@ class VectorFstClass : public MutableFstClass { if (!vfst) { return 0; } else { - VectorFstClass *retval = new VectorFstClass(vfst); + VectorFstClass *retval = new VectorFstClass(*vfst); delete vfst; return retval; } diff --git a/src/include/fst/script/map.h b/src/include/fst/script/map.h index 2332074..3caaa9f 100644 --- a/src/include/fst/script/map.h +++ b/src/include/fst/script/map.h @@ -59,46 +59,54 @@ void Map(MapArgs *args) { float delta = args->args.arg3; typename Arc::Weight w = *(args->args.arg4.GetWeight<typename Arc::Weight>()); + Fst<Arc> *fst = NULL; + Fst<LogArc> *lfst = NULL; + Fst<Log64Arc> *l64fst = NULL; + Fst<StdArc> *sfst = NULL; if (map_type == ARC_SUM_MAPPER) { - args->retval = new FstClass( - script::StateMap(ifst, ArcSumMapper<Arc>(ifst))); + args->retval = new FstClass(*(fst = + script::StateMap(ifst, ArcSumMapper<Arc>(ifst)))); } else if (map_type == IDENTITY_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, IdentityArcMapper<Arc>())); + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, IdentityArcMapper<Arc>()))); } else if (map_type == INVERT_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, InvertWeightMapper<Arc>())); + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, InvertWeightMapper<Arc>()))); } else if (map_type == PLUS_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, PlusMapper<Arc>(w))); + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, PlusMapper<Arc>(w)))); } else if (map_type == QUANTIZE_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, QuantizeMapper<Arc>(delta))); + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, QuantizeMapper<Arc>(delta)))); } else if (map_type == RMWEIGHT_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, RmWeightMapper<Arc>())); + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, RmWeightMapper<Arc>()))); } else if (map_type == SUPERFINAL_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, SuperFinalMapper<Arc>())); + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, SuperFinalMapper<Arc>()))); } else if (map_type == TIMES_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, TimesMapper<Arc>(w))); + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, TimesMapper<Arc>(w)))); } else if (map_type == TO_LOG_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, WeightConvertMapper<Arc, LogArc>())); + args->retval = new FstClass(*(lfst = + script::ArcMap(ifst, WeightConvertMapper<Arc, LogArc>()))); } else if (map_type == TO_LOG64_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, WeightConvertMapper<Arc, Log64Arc>())); + args->retval = new FstClass(*(l64fst = + script::ArcMap(ifst, WeightConvertMapper<Arc, Log64Arc>()))); } else if (map_type == TO_STD_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, WeightConvertMapper<Arc, StdArc>())); + args->retval = new FstClass(*(sfst = + script::ArcMap(ifst, WeightConvertMapper<Arc, StdArc>()))); } else { FSTERROR() << "Error: unknown/unsupported mapper type: " << map_type; VectorFst<Arc> *ofst = new VectorFst<Arc>; ofst->SetProperties(kError, kError); - args->retval = new FstClass(ofst); + args->retval = new FstClass(*(fst =ofst)); } + delete sfst; + delete l64fst; + delete lfst; + delete fst; } diff --git a/src/include/fst/script/shortest-distance.h b/src/include/fst/script/shortest-distance.h index 5fc2976..39c5045 100644 --- a/src/include/fst/script/shortest-distance.h +++ b/src/include/fst/script/shortest-distance.h @@ -175,11 +175,11 @@ void ShortestDistance(ShortestDistanceArgs1 *args) { return; case FIFO_QUEUE: - ShortestDistanceHelper<Arc, FifoQueue<StateId> >(args); + ShortestDistanceHelper<Arc, FifoQueue<StateId> >(args); return; case LIFO_QUEUE: - ShortestDistanceHelper<Arc, LifoQueue<StateId> >(args); + ShortestDistanceHelper<Arc, LifoQueue<StateId> >(args); return; case SHORTEST_FIRST_QUEUE: @@ -188,11 +188,11 @@ void ShortestDistance(ShortestDistanceArgs1 *args) { return; case STATE_ORDER_QUEUE: - ShortestDistanceHelper<Arc, StateOrderQueue<StateId> >(args); + ShortestDistanceHelper<Arc, StateOrderQueue<StateId> >(args); return; case TOP_ORDER_QUEUE: - ShortestDistanceHelper<Arc, TopOrderQueue<StateId> >(args); + ShortestDistanceHelper<Arc, TopOrderQueue<StateId> >(args); return; } } diff --git a/src/include/fst/script/weight-class.h b/src/include/fst/script/weight-class.h index 228216d..b9f7ddf 100644 --- a/src/include/fst/script/weight-class.h +++ b/src/include/fst/script/weight-class.h @@ -128,6 +128,13 @@ class WeightClass { return w; } + const string &Type() const { + if (impl_) return impl_->Type(); + static const string no_type = "none"; + return no_type; + } + + ~WeightClass() { if (impl_) delete impl_; } private: enum ElementType { ZERO, ONE, OTHER }; diff --git a/src/include/fst/shortest-distance.h b/src/include/fst/shortest-distance.h index 9320c4c..ec47a14 100644 --- a/src/include/fst/shortest-distance.h +++ b/src/include/fst/shortest-distance.h @@ -22,6 +22,7 @@ #define FST_LIB_SHORTEST_DISTANCE_H__ #include <deque> +using std::deque; #include <vector> using std::vector; diff --git a/src/include/fst/shortest-path.h b/src/include/fst/shortest-path.h index f12970c..9cd13d9 100644 --- a/src/include/fst/shortest-path.h +++ b/src/include/fst/shortest-path.h @@ -458,7 +458,7 @@ void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, } else { vector<Weight> ddistance; DeterminizeFstOptions<ReverseArc> dopts(opts.delta); - DeterminizeFst<ReverseArc> dfst(rfst, *distance, &ddistance, dopts); + DeterminizeFst<ReverseArc> dfst(rfst, distance, &ddistance, dopts); NShortestPath(dfst, ofst, ddistance, n, opts.delta, opts.weight_threshold, opts.state_threshold); } diff --git a/src/include/fst/state-map.h b/src/include/fst/state-map.h index 2c59e1d..0e65d74 100644 --- a/src/include/fst/state-map.h +++ b/src/include/fst/state-map.h @@ -191,8 +191,6 @@ class StateMapFstImpl : public CacheImpl<B> { using FstImpl<B>::SetInputSymbols; using FstImpl<B>::SetOutputSymbols; - using VectorFstBaseImpl<typename CacheImpl<B>::State>::NumStates; - using CacheImpl<B>::PushArc; using CacheImpl<B>::HasArcs; using CacheImpl<B>::HasFinal; @@ -535,7 +533,7 @@ class ArcUniqueMapper { explicit ArcUniqueMapper(const Fst<A> &fst) : fst_(fst), i_(0) {} // Allows updating Fst argument; pass only if changed. - ArcUniqueMapper(const ArcSumMapper<A> &mapper, + ArcUniqueMapper(const ArcUniqueMapper<A> &mapper, const Fst<A> *fst = 0) : fst_(fst ? *fst : mapper.fst_), i_(0) {} diff --git a/src/include/fst/state-table.h b/src/include/fst/state-table.h index 7d863a0..d8107a1 100644 --- a/src/include/fst/state-table.h +++ b/src/include/fst/state-table.h @@ -22,6 +22,7 @@ #define FST_LIB_STATE_TABLE_H__ #include <deque> +using std::deque; #include <vector> using std::vector; @@ -58,14 +59,14 @@ namespace fst { // struct StateTuple { // typedef typename S StateId; // -// // Required constructor. +// // Required constructors. // StateTuple(); +// StateTuple(const StateTuple &); // }; // An implementation using a hash map for the tuple to state ID mapping. -// The state tuple T must have == defined and the default constructor -// must produce a tuple that will never be seen. H is the hash function. +// The state tuple T must have == defined. H is the hash function. template <class T, class H> class HashStateTable : public HashBiTable<typename T::StateId, T, H> { public: @@ -76,15 +77,18 @@ class HashStateTable : public HashBiTable<typename T::StateId, T, H> { using HashBiTable<StateId, T, H>::Size; HashStateTable() : HashBiTable<StateId, T, H>() {} + + // Reserves space for table_size elements. + explicit HashStateTable(size_t table_size) + : HashBiTable<StateId, T, H>(table_size) {} + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } const StateTuple &Tuple(StateId s) const { return FindEntry(s); } }; -// An implementation using a hash set for the tuple to state ID -// mapping. The state tuple T must have == defined and the default -// constructor must produce a tuple that will never be seen. H is the -// hash function. +// An implementation using a hash map for the tuple to state ID mapping. +// The state tuple T must have == defined. H is the hash function. template <class T, class H> class CompactHashStateTable : public CompactHashBiTable<typename T::StateId, T, H> { @@ -97,7 +101,7 @@ class CompactHashStateTable CompactHashStateTable() : CompactHashBiTable<StateId, T, H>() {} - // Reserves space for table_size elements. + // Reserves space for 'table_size' elements. explicit CompactHashStateTable(size_t table_size) : CompactHashBiTable<StateId, T, H>(table_size) {} @@ -122,7 +126,10 @@ class VectorStateTable using VectorBiTable<StateId, T, FP>::Size; using VectorBiTable<StateId, T, FP>::Fingerprint; - explicit VectorStateTable(FP *fp = 0) : VectorBiTable<StateId, T, FP>(fp) {} + // Reserves space for 'table_size' elements. + explicit VectorStateTable(FP *fp = 0, size_t table_size = 0) + : VectorBiTable<StateId, T, FP>(fp, table_size) {} + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } const StateTuple &Tuple(StateId s) const { return FindEntry(s); } }; @@ -268,7 +275,9 @@ class GenericComposeStateTable : public H { GenericComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2) {} - GenericComposeStateTable(const GenericComposeStateTable<A, F> &table) {} + // Reserves space for 'table_size' elements. + GenericComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2, + size_t table_size) : H(table_size) {} bool Error() const { return false; } @@ -342,17 +351,18 @@ VectorStateTable<ComposeStateTuple<typename A::StateId, F>, typedef typename A::StateId StateId; typedef F FilterState; typedef ComposeStateTuple<StateId, F> StateTuple; + typedef VectorStateTable<StateTuple, + ComposeFingerprint<StateId, F> > StateTable; - ProductComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2) - : VectorStateTable<ComposeStateTuple<StateId, F>, - ComposeFingerprint<StateId, F> > - (new ComposeFingerprint<StateId, F>(CountStates(fst1), - CountStates(fst2))) { } + // Reserves space for 'table_size' elements. + ProductComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2, + size_t table_size = 0) + : StateTable(new ComposeFingerprint<StateId, F>(CountStates(fst1), + CountStates(fst2)), + table_size) {} ProductComposeStateTable(const ProductComposeStateTable<A, F> &table) - : VectorStateTable<ComposeStateTuple<StateId, F>, - ComposeFingerprint<StateId, F> > - (new ComposeFingerprint<StateId, F>(table.Fingerprint())) {} + : StateTable(new ComposeFingerprint<StateId, F>(table.Fingerprint())) {} bool Error() const { return false; } @@ -375,6 +385,8 @@ VectorStateTable<ComposeStateTuple<typename A::StateId, F>, typedef typename A::StateId StateId; typedef F FilterState; typedef ComposeStateTuple<StateId, F> StateTuple; + typedef VectorStateTable<StateTuple, + ComposeState1Fingerprint<StateId, F> > StateTable; StringDetComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2) : error_(false) { @@ -389,7 +401,7 @@ VectorStateTable<ComposeStateTuple<typename A::StateId, F>, } StringDetComposeStateTable(const StringDetComposeStateTable<A, F> &table) - : error_(table.error_) {} + : StateTable(table), error_(table.error_) {} bool Error() const { return error_; } @@ -409,12 +421,14 @@ VectorStateTable<ComposeStateTuple<typename A::StateId, F>, template <typename A, typename F> class DetStringComposeStateTable : public VectorStateTable<ComposeStateTuple<typename A::StateId, F>, - ComposeState1Fingerprint<typename A::StateId, F> > { + ComposeState2Fingerprint<typename A::StateId, F> > { public: typedef A Arc; typedef typename A::StateId StateId; typedef F FilterState; typedef ComposeStateTuple<StateId, F> StateTuple; + typedef VectorStateTable<StateTuple, + ComposeState2Fingerprint<StateId, F> > StateTable; DetStringComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2) :error_(false) { @@ -429,7 +443,7 @@ VectorStateTable<ComposeStateTuple<typename A::StateId, F>, } DetStringComposeStateTable(const DetStringComposeStateTable<A, F> &table) - : error_(table.error_) {} + : StateTable(table), error_(table.error_) {} bool Error() const { return error_; } @@ -456,8 +470,6 @@ ErasableStateTable<ComposeStateTuple<typename A::StateId, F>, ErasableComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2) {} - ErasableComposeStateTable(const ErasableComposeStateTable<A, F> &table) {} - bool Error() const { return false; } private: diff --git a/src/include/fst/string.h b/src/include/fst/string.h index d51182e..9eaf7a3 100644 --- a/src/include/fst/string.h +++ b/src/include/fst/string.h @@ -57,6 +57,15 @@ class StringCompiler { return true; } + template <class F> + bool operator()(const string &s, F *fst, Weight w) const { + vector<Label> labels; + if (!ConvertStringToLabels(s, &labels)) + return false; + Compile(labels, fst, w); + return true; + } + private: bool ConvertStringToLabels(const string &str, vector<Label> *labels) const { labels->clear(); @@ -83,22 +92,35 @@ class StringCompiler { return true; } - void Compile(const vector<Label> &labels, MutableFst<A> *fst) const { + void Compile(const vector<Label> &labels, MutableFst<A> *fst, + const Weight &weight = Weight::One()) const { fst->DeleteStates(); while (fst->NumStates() <= labels.size()) fst->AddState(); for (size_t i = 0; i < labels.size(); ++i) fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1)); fst->SetStart(0); - fst->SetFinal(labels.size(), Weight::One()); + fst->SetFinal(labels.size(), weight); } template <class Unsigned> - void Compile(const vector<Label> &labels, CompactFst<A, StringCompactor<A>, - Unsigned> *fst) const { + void Compile(const vector<Label> &labels, + CompactFst<A, StringCompactor<A>, Unsigned> *fst) const { fst->SetCompactElements(labels.begin(), labels.end()); } + template <class Unsigned> + void Compile(const vector<Label> &labels, + CompactFst<A, WeightedStringCompactor<A>, Unsigned> *fst, + const Weight &weight = Weight::One()) const { + vector<pair<Label, Weight> > compacts; + compacts.reserve(labels.size()); + for (size_t i = 0; i < labels.size(); ++i) + compacts.push_back(make_pair(labels[i], Weight::One())); + compacts.back().second = weight; + fst->SetCompactElements(compacts.begin(), compacts.end()); + } + bool ConvertSymbolToLabel(const char *s, Label* output) const { int64 n; if (syms_) { @@ -167,6 +189,7 @@ class StringPrinter { } *output = sstrm.str(); } else if (token_type_ == BYTE) { + output->reserve(labels_.size()); for (size_t i = 0; i < labels_.size(); ++i) { output->push_back(labels_[i]); } diff --git a/src/include/fst/util.h b/src/include/fst/util.h index 1f6046b..4eb8fba 100644 --- a/src/include/fst/util.h +++ b/src/include/fst/util.h @@ -268,17 +268,17 @@ void WeightToStr(Weight w, string *s) { s->append(strm.str().data(), strm.str().size()); } -// Utilities for reading/writing label pairs +// Utilities for reading/writing integer pairs (typically labels) // Returns true on success -template <typename Label> -bool ReadLabelPairs(const string& filename, - vector<pair<Label, Label> >* pairs, +template <typename I> +bool ReadIntPairs(const string& filename, + vector<pair<I, I> >* pairs, bool allow_negative = false) { ifstream strm(filename.c_str()); if (!strm) { - LOG(ERROR) << "ReadLabelPairs: Can't open file: " << filename; + LOG(ERROR) << "ReadIntPairs: Can't open file: " << filename; return false; } @@ -291,33 +291,34 @@ bool ReadLabelPairs(const string& filename, ++nline; vector<char *> col; SplitToVector(line, "\n\t ", &col, true); - if (col.size() == 0 || col[0][0] == '\0') // empty line + // empty line or comment? + if (col.size() == 0 || col[0][0] == '\0' || col[0][0] == '#') continue; if (col.size() != 2) { - LOG(ERROR) << "ReadLabelPairs: Bad number of columns, " + LOG(ERROR) << "ReadIntPairs: Bad number of columns, " << "file = " << filename << ", line = " << nline; return false; } bool err; - Label frmlabel = StrToInt64(col[0], filename, nline, allow_negative, &err); + I i1 = StrToInt64(col[0], filename, nline, allow_negative, &err); if (err) return false; - Label tolabel = StrToInt64(col[1], filename, nline, allow_negative, &err); + I i2 = StrToInt64(col[1], filename, nline, allow_negative, &err); if (err) return false; - pairs->push_back(make_pair(frmlabel, tolabel)); + pairs->push_back(make_pair(i1, i2)); } return true; } // Returns true on success -template <typename Label> -bool WriteLabelPairs(const string& filename, - const vector<pair<Label, Label> >& pairs) { +template <typename I> +bool WriteIntPairs(const string& filename, + const vector<pair<I, I> >& pairs) { ostream *strm = &cout; if (!filename.empty()) { strm = new ofstream(filename.c_str()); if (!*strm) { - LOG(ERROR) << "WriteLabelPairs: Can't open file: " << filename; + LOG(ERROR) << "WriteIntPairs: Can't open file: " << filename; return false; } } @@ -326,7 +327,7 @@ bool WriteLabelPairs(const string& filename, *strm << pairs[n].first << "\t" << pairs[n].second << "\n"; if (!*strm) { - LOG(ERROR) << "WriteLabelPairs: Write failed: " + LOG(ERROR) << "WriteIntPairs: Write failed: " << (filename.empty() ? "standard output" : filename); return false; } @@ -335,6 +336,21 @@ bool WriteLabelPairs(const string& filename, return true; } +// Utilities for reading/writing label pairs + +template <typename Label> +bool ReadLabelPairs(const string& filename, + vector<pair<Label, Label> >* pairs, + bool allow_negative = false) { + return ReadIntPairs(filename, pairs, allow_negative); +} + +template <typename Label> +bool WriteLabelPairs(const string& filename, + vector<pair<Label, Label> >& pairs) { + return WriteIntPairs(filename, pairs); +} + // Utilities for converting a type name to a legal C symbol. void ConvertToLegalCSymbol(string *s); @@ -344,8 +360,8 @@ void ConvertToLegalCSymbol(string *s); // UTILITIES FOR STREAM I/O // -bool AlignInput(istream &strm, int align); -bool AlignOutput(ostream &strm, int align); +bool AlignInput(istream &strm); +bool AlignOutput(ostream &strm); // // UTILITIES FOR PROTOCOL BUFFER I/O @@ -380,6 +396,17 @@ public: max_key_ = key; } + void Erase(Key key) { + set_.erase(key); + if (set_.empty()) { + min_key_ = max_key_ = NoKey; + } else if (key == min_key_) { + ++min_key_; + } else if (key == max_key_) { + --max_key_; + } + } + void Clear() { set_.clear(); min_key_ = max_key_ = NoKey; @@ -393,10 +420,26 @@ public: return set_.find(key); } + bool Member(Key key) const { + if (min_key_ == NoKey || key < min_key_ || max_key_ < key) { + return false; // out of range + } else if (min_key_ != NoKey && max_key_ + 1 == min_key_ + set_.size()) { + return true; // dense range + } else { + return set_.find(key) != set_.end(); + } + } + const_iterator Begin() const { return set_.begin(); } const_iterator End() const { return set_.end(); } + // All stored keys are greater than or equal to this value. + Key LowerBound() const { return min_key_; } + + // All stored keys are less than or equal to this value. + Key UpperBound() const { return max_key_; } + private: set<Key> set_; Key min_key_; diff --git a/src/include/fst/visit.h b/src/include/fst/visit.h index a02d86a..5f5059a 100644 --- a/src/include/fst/visit.h +++ b/src/include/fst/visit.h @@ -238,18 +238,28 @@ class CopyVisitor { }; -// Visits input FST up to a state limit following queue order. +// Visits input FST up to a state limit following queue order. If +// 'access_only' is true, aborts on visiting first state not +// accessible from the initial state. template <class A> class PartialVisitor { public: typedef A Arc; typedef typename A::StateId StateId; - explicit PartialVisitor(StateId maxvisit) : maxvisit_(maxvisit) {} + explicit PartialVisitor(StateId maxvisit, bool access_only = false) + : maxvisit_(maxvisit), + access_only_(access_only), + start_(kNoStateId) {} - void InitVisit(const Fst<A> &ifst) { nvisit_ = 0; } + void InitVisit(const Fst<A> &ifst) { + nvisit_ = 0; + start_ = ifst.Start(); + } - bool InitState(StateId s, StateId) { + bool InitState(StateId s, StateId root) { + if (access_only_ && root != start_) + return false; ++nvisit_; return nvisit_ <= maxvisit_; } @@ -262,7 +272,10 @@ class PartialVisitor { private: StateId maxvisit_; + bool access_only_; StateId nvisit_; + StateId start_; + }; |