aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst')
-rw-r--r--src/include/fst/arc-map.h4
-rw-r--r--src/include/fst/bi-table.h259
-rw-r--r--src/include/fst/cache.h366
-rw-r--r--src/include/fst/compact-fst.h89
-rw-r--r--src/include/fst/compose.h11
-rw-r--r--src/include/fst/const-fst.h64
-rw-r--r--src/include/fst/determinize.h508
-rw-r--r--src/include/fst/edit-fst.h7
-rw-r--r--src/include/fst/epsnormalize.h1
-rw-r--r--src/include/fst/equivalent.h1
-rw-r--r--src/include/fst/extensions/far/extract.h119
-rw-r--r--src/include/fst/extensions/far/far.h3
-rw-r--r--src/include/fst/extensions/far/farscript.h12
-rw-r--r--src/include/fst/extensions/far/stlist.h22
-rw-r--r--src/include/fst/extensions/ngram/ngram-fst.h111
-rw-r--r--src/include/fst/extensions/pdt/compose.h486
-rw-r--r--src/include/fst/extensions/pdt/pdt.h1
-rw-r--r--src/include/fst/extensions/pdt/pdtscript.h4
-rw-r--r--src/include/fst/extensions/pdt/replace.h27
-rw-r--r--src/include/fst/factor-weight.h1
-rw-r--r--src/include/fst/fst-decl.h1
-rw-r--r--src/include/fst/fst.h18
-rw-r--r--src/include/fst/interval-set.h6
-rw-r--r--src/include/fst/mapped-file.h83
-rw-r--r--src/include/fst/matcher.h69
-rw-r--r--src/include/fst/queue.h63
-rw-r--r--src/include/fst/relabel.h4
-rw-r--r--src/include/fst/script/convert.h2
-rw-r--r--src/include/fst/script/disambiguate.h68
-rw-r--r--src/include/fst/script/fst-class.h44
-rw-r--r--src/include/fst/script/map.h54
-rw-r--r--src/include/fst/script/shortest-distance.h8
-rw-r--r--src/include/fst/script/weight-class.h7
-rw-r--r--src/include/fst/shortest-distance.h1
-rw-r--r--src/include/fst/shortest-path.h2
-rw-r--r--src/include/fst/state-map.h4
-rw-r--r--src/include/fst/state-table.h58
-rw-r--r--src/include/fst/string.h31
-rw-r--r--src/include/fst/util.h77
-rw-r--r--src/include/fst/visit.h21
40 files changed, 1999 insertions, 718 deletions
diff --git a/src/include/fst/arc-map.h b/src/include/fst/arc-map.h
index 914f81c..c33b546 100644
--- a/src/include/fst/arc-map.h
+++ b/src/include/fst/arc-map.h
@@ -165,8 +165,8 @@ void ArcMap(MutableFst<A> *fst, C* mapper) {
} else {
fst->SetFinal(s, final_arc.weight);
}
- break;
}
+ break;
}
case MAP_REQUIRE_SUPERFINAL: {
if (s != superfinal) {
@@ -315,8 +315,6 @@ class ArcMapFstImpl : public CacheImpl<B> {
using FstImpl<B>::SetInputSymbols;
using FstImpl<B>::SetOutputSymbols;
- using VectorFstBaseImpl<typename CacheImpl<B>::State>::NumStates;
-
using CacheImpl<B>::PushArc;
using CacheImpl<B>::HasArcs;
using CacheImpl<B>::HasFinal;
diff --git a/src/include/fst/bi-table.h b/src/include/fst/bi-table.h
index bd37781..d220ce4 100644
--- a/src/include/fst/bi-table.h
+++ b/src/include/fst/bi-table.h
@@ -23,9 +23,15 @@
#define FST_LIB_BI_TABLE_H__
#include <deque>
+using std::deque;
+#include <functional>
#include <vector>
using std::vector;
+#include <tr1/unordered_set>
+using std::tr1::unordered_set;
+using std::tr1::unordered_multiset;
+
namespace fst {
// BI TABLES - these determine a bijective mapping between an
@@ -49,14 +55,33 @@ namespace fst {
// };
// An implementation using a hash map for the entry to ID mapping.
-// The entry T must have == defined and the default constructor
-// must produce an entry that will never be seen. H is the hash function.
-template <class I, class T, class H>
+// H is the hash function and E is the equality function.
+// If passed to the constructor, ownership is given to this class.
+
+template <class I, class T, class H, class E = std::equal_to<T> >
class HashBiTable {
public:
+ // Reserves space for 'table_size' elements.
+ explicit HashBiTable(size_t table_size = 0, H *h = 0, E *e = 0)
+ : hash_func_(h),
+ hash_equal_(e),
+ entry2id_(table_size, (h ? *h : H()), (e ? *e : E())) {
+ if (table_size)
+ id2entry_.reserve(table_size);
+ }
- HashBiTable() {
- T empty_entry;
+ HashBiTable(const HashBiTable<I, T, H, E> &table)
+ : hash_func_(table.hash_func_ ? new H(*table.hash_func_) : 0),
+ hash_equal_(table.hash_equal_ ? new E(*table.hash_equal_) : 0),
+ entry2id_(table.entry2id_.begin(), table.entry2id_.end(),
+ table.entry2id_.size(),
+ (hash_func_ ? *hash_func_ : H()),
+ (hash_equal_ ? *hash_equal_ : E())),
+ id2entry_(table.id2entry_) { }
+
+ ~HashBiTable() {
+ delete hash_func_;
+ delete hash_equal_;
}
I FindId(const T &entry, bool insert = true) {
@@ -79,39 +104,67 @@ class HashBiTable {
I Size() const { return id2entry_.size(); }
private:
- unordered_map<T, I, H> entry2id_;
+ H *hash_func_;
+ E *hash_equal_;
+ unordered_map<T, I, H, E> entry2id_;
vector<T> id2entry_;
- DISALLOW_COPY_AND_ASSIGN(HashBiTable);
+ void operator=(const HashBiTable<I, T, H, E> &table); // disallow
+};
+
+
+// Enables alternative hash set representations below.
+// typedef enum { HS_STL = 0, HS_DENSE = 1, HS_SPARSE = 2 } HSType;
+typedef enum { HS_STL = 0, HS_DENSE = 1, HS_SPARSE = 2 } HSType;
+
+// Default hash set is STL hash_set
+template<class K, class H, class E, HSType>
+struct HashSet : public unordered_set<K, H, E> {
+ HashSet(size_t n = 0, const H &h = H(), const E &e = E())
+ : unordered_set<K, H, E>(n, h, e) { }
+
+ void rehash(size_t n) { }
};
-// An implementation using a hash set for the entry to ID
-// mapping. The hash set holds 'keys' which are either the ID
-// or kCurrentKey. These keys can be mapped to entrys either by
-// looking up in the entry vector or, if kCurrentKey, in current_entry_
-// member. The hash and key equality functions map to entries first.
-// The entry T must have == defined and the default constructor
-// must produce a entry that will never be seen. H is the hash
-// function.
-template <class I, class T, class H>
+// An implementation using a hash set for the entry to ID mapping.
+// The hash set holds 'keys' which are either the ID or kCurrentKey.
+// These keys can be mapped to entrys either by looking up in the
+// entry vector or, if kCurrentKey, in current_entry_ member. The hash
+// and key equality functions map to entries first. H
+// is the hash function and E is the equality function. If passed to
+// the constructor, ownership is given to this class.
+template <class I, class T, class H,
+ class E = std::equal_to<T>, HSType HS = HS_DENSE>
class CompactHashBiTable {
public:
friend class HashFunc;
friend class HashEqual;
- CompactHashBiTable()
- : hash_func_(*this),
- hash_equal_(*this),
- keys_(0, hash_func_, hash_equal_) {
+ // Reserves space for 'table_size' elements.
+ explicit CompactHashBiTable(size_t table_size = 0, H *h = 0, E *e = 0)
+ : hash_func_(h),
+ hash_equal_(e),
+ compact_hash_func_(*this),
+ compact_hash_equal_(*this),
+ keys_(table_size, compact_hash_func_, compact_hash_equal_) {
+ if (table_size)
+ id2entry_.reserve(table_size);
}
- // Reserves space for table_size elements.
- explicit CompactHashBiTable(size_t table_size)
- : hash_func_(*this),
- hash_equal_(*this),
- keys_(table_size, hash_func_, hash_equal_) {
- id2entry_.reserve(table_size);
+ CompactHashBiTable(const CompactHashBiTable<I, T, H, E, HS> &table)
+ : hash_func_(table.hash_func_ ? new H(*table.hash_func_) : 0),
+ hash_equal_(table.hash_equal_ ? new E(*table.hash_equal_) : 0),
+ compact_hash_func_(*this),
+ compact_hash_equal_(*this),
+ keys_(table.keys_.size(), compact_hash_func_, compact_hash_equal_),
+ id2entry_(table.id2entry_) {
+ keys_.insert(table.keys_.begin(), table.keys_.end());
+ }
+
+ ~CompactHashBiTable() {
+ delete hash_func_;
+ delete hash_equal_;
}
I FindId(const T &entry, bool insert = true) {
@@ -132,20 +185,40 @@ class CompactHashBiTable {
}
const T &FindEntry(I s) const { return id2entry_[s]; }
+
I Size() const { return id2entry_.size(); }
+ // Clear content. With argument, erases last n IDs.
+ void Clear(ssize_t n = -1) {
+ if (n < 0 || n > id2entry_.size())
+ n = id2entry_.size();
+ while (n-- > 0) {
+ I key = id2entry_.size() - 1;
+ keys_.erase(key);
+ id2entry_.pop_back();
+ }
+ keys_.rehash(0);
+ }
+
private:
- static const I kEmptyKey; // -1
- static const I kCurrentKey; // -2
+ static const I kCurrentKey; // -1
+ static const I kEmptyKey; // -2
+ static const I kDeletedKey; // -3
class HashFunc {
public:
HashFunc(const CompactHashBiTable &ht) : ht_(&ht) {}
- size_t operator()(I k) const { return hf(ht_->Key2T(k)); }
+ size_t operator()(I k) const {
+ if (k >= kCurrentKey) {
+ return (*ht_->hash_func_)(ht_->Key2Entry(k));
+ } else {
+ return 0;
+ }
+ }
+
private:
const CompactHashBiTable *ht_;
- H hf;
};
class HashEqual {
@@ -153,38 +226,45 @@ class CompactHashBiTable {
HashEqual(const CompactHashBiTable &ht) : ht_(&ht) {}
bool operator()(I k1, I k2) const {
- return ht_->Key2T(k1) == ht_->Key2T(k2);
+ if (k1 >= kCurrentKey && k2 >= kCurrentKey) {
+ return (*ht_->hash_equal_)(ht_->Key2Entry(k1), ht_->Key2Entry(k2));
+ } else {
+ return k1 == k2;
+ }
}
private:
const CompactHashBiTable *ht_;
};
- typedef unordered_set<I, HashFunc, HashEqual> KeyHashSet;
+ typedef HashSet<I, HashFunc, HashEqual, HS> KeyHashSet;
- const T &Key2T(I k) const {
- if (k == kEmptyKey)
- return empty_entry_;
- else if (k == kCurrentKey)
+ const T &Key2Entry(I k) const {
+ if (k == kCurrentKey)
return *current_entry_;
else
return id2entry_[k];
}
- HashFunc hash_func_;
- HashEqual hash_equal_;
+ H *hash_func_;
+ E *hash_equal_;
+ HashFunc compact_hash_func_;
+ HashEqual compact_hash_equal_;
KeyHashSet keys_;
vector<T> id2entry_;
- const T empty_entry_;
const T *current_entry_;
- DISALLOW_COPY_AND_ASSIGN(CompactHashBiTable);
+ void operator=(const CompactHashBiTable<I, T, H, E, HS> &table); // disallow
};
-template <class I, class T, class H>
-const I CompactHashBiTable<I, T, H>::kEmptyKey = -1;
-template <class I, class T, class H>
-const I CompactHashBiTable<I, T, H>::kCurrentKey = -2;
+template <class I, class T, class H, class E, HSType HS>
+const I CompactHashBiTable<I, T, H, E, HS>::kCurrentKey = -1;
+
+template <class I, class T, class H, class E, HSType HS>
+const I CompactHashBiTable<I, T, H, E, HS>::kEmptyKey = -2;
+
+template <class I, class T, class H, class E, HSType HS>
+const I CompactHashBiTable<I, T, H, E, HS>::kDeletedKey = -3;
// An implementation using a vector for the entry to ID mapping.
@@ -196,7 +276,17 @@ const I CompactHashBiTable<I, T, H>::kCurrentKey = -2;
template <class I, class T, class FP>
class VectorBiTable {
public:
- explicit VectorBiTable(FP *fp = 0) : fp_(fp ? fp : new FP()) {}
+ // Reserves space for 'table_size' elements.
+ explicit VectorBiTable(FP *fp = 0, size_t table_size = 0)
+ : fp_(fp ? fp : new FP()) {
+ if (table_size)
+ id2entry_.reserve(table_size);
+ }
+
+ VectorBiTable(const VectorBiTable<I, T, FP> &table)
+ : fp_(table.fp_ ? new FP(*table.fp_) : 0),
+ fp2id_(table.fp2id_),
+ id2entry_(table.id2entry_) { }
~VectorBiTable() { delete fp_; }
@@ -227,7 +317,7 @@ class VectorBiTable {
vector<I> fp2id_;
vector<T> id2entry_;
- DISALLOW_COPY_AND_ASSIGN(VectorBiTable);
+ void operator=(const VectorBiTable<I, T, FP> &table); // disallow
};
@@ -235,20 +325,21 @@ class VectorBiTable {
// selecting functor S returns true for entries to be hashed in the
// vector. The fingerprinting functor FP returns a unique fingerprint
// for each entry to be hashed in the vector (these need to be
-// suitable for indexing in a vector). The hash functor H is used when
-// hashing entry into the compact hash table.
-template <class I, class T, class S, class FP, class H>
+// suitable for indexing in a vector). The hash functor H is used
+// when hashing entry into the compact hash table. If passed to the
+// constructor, ownership is given to this class.
+template <class I, class T, class S, class FP, class H, HSType HS = HS_DENSE>
class VectorHashBiTable {
public:
friend class HashFunc;
friend class HashEqual;
- VectorHashBiTable(S *s, FP *fp, H *h,
- size_t vector_size = 0,
- size_t entry_size = 0)
+ explicit VectorHashBiTable(S *s, FP *fp = 0, H *h = 0,
+ size_t vector_size = 0,
+ size_t entry_size = 0)
: selector_(s),
- fp_(fp),
- h_(h),
+ fp_(fp ? fp : new FP()),
+ h_(h ? h : new H()),
hash_func_(*this),
hash_equal_(*this),
keys_(0, hash_func_, hash_equal_) {
@@ -256,6 +347,18 @@ class VectorHashBiTable {
fp2id_.reserve(vector_size);
if (entry_size)
id2entry_.reserve(entry_size);
+ }
+
+ VectorHashBiTable(const VectorHashBiTable<I, T, S, FP, H, HS> &table)
+ : selector_(new S(table.s_)),
+ fp_(table.fp_ ? new FP(*table.fp_) : 0),
+ h_(table.h_ ? new H(*table.h_) : 0),
+ id2entry_(table.id2entry_),
+ fp2id_(table.fp2id_),
+ hash_func_(*this),
+ hash_equal_(*this),
+ keys_(table.keys_.size(), hash_func_, hash_equal_) {
+ keys_.insert(table.keys_.begin(), table.keys_.end());
}
~VectorHashBiTable() {
@@ -309,14 +412,20 @@ class VectorHashBiTable {
const H &Hash() const { return *h_; }
private:
- static const I kEmptyKey;
- static const I kCurrentKey;
+ static const I kCurrentKey; // -1
+ static const I kEmptyKey; // -2
class HashFunc {
public:
HashFunc(const VectorHashBiTable &ht) : ht_(&ht) {}
- size_t operator()(I k) const { return (*(ht_->h_))(ht_->Key2Entry(k)); }
+ size_t operator()(I k) const {
+ if (k >= kCurrentKey) {
+ return (*(ht_->h_))(ht_->Key2Entry(k));
+ } else {
+ return 0;
+ }
+ }
private:
const VectorHashBiTable *ht_;
};
@@ -326,53 +435,54 @@ class VectorHashBiTable {
HashEqual(const VectorHashBiTable &ht) : ht_(&ht) {}
bool operator()(I k1, I k2) const {
- return ht_->Key2Entry(k1) == ht_->Key2Entry(k2);
+ if (k1 >= kCurrentKey && k2 >= kCurrentKey) {
+ return ht_->Key2Entry(k1) == ht_->Key2Entry(k2);
+ } else {
+ return k1 == k2;
+ }
}
private:
const VectorHashBiTable *ht_;
};
- typedef unordered_set<I, HashFunc, HashEqual> KeyHashSet;
+ typedef HashSet<I, HashFunc, HashEqual, HS> KeyHashSet;
const T &Key2Entry(I k) const {
- if (k == kEmptyKey)
- return empty_entry_;
- else if (k == kCurrentKey)
+ if (k == kCurrentKey)
return *current_entry_;
else
return id2entry_[k];
}
-
S *selector_; // Returns true if entry hashed into vector
FP *fp_; // Fingerprint used when hashing entry into vector
H *h_; // Hash function used when hashing entry into hash_set
vector<T> id2entry_; // Maps state IDs to entry
- vector<I> fp2id_; // Maps entry fingerprints to IDs
+ vector<I> fp2id_; // Maps entry fingerprints to IDs
// Compact implementation of the hash table mapping entrys to
// state IDs using the hash function 'h_'
HashFunc hash_func_;
HashEqual hash_equal_;
KeyHashSet keys_;
- const T empty_entry_;
const T *current_entry_;
- DISALLOW_COPY_AND_ASSIGN(VectorHashBiTable);
+ // disallow
+ void operator=(const VectorHashBiTable<I, T, S, FP, H, HS> &table);
};
-template <class I, class T, class S, class FP, class H>
-const I VectorHashBiTable<I, T, S, FP, H>::kEmptyKey = -1;
+template <class I, class T, class S, class FP, class H, HSType HS>
+const I VectorHashBiTable<I, T, S, FP, H, HS>::kCurrentKey = -1;
-template <class I, class T, class S, class FP, class H>
-const I VectorHashBiTable<I, T, S, FP, H>::kCurrentKey = -2;
+template <class I, class T, class S, class FP, class H, HSType HS>
+const I VectorHashBiTable<I, T, S, FP, H, HS>::kEmptyKey = -3;
// An implementation using a hash map for the entry to ID
-// mapping. This version permits erasing of s. The entry T
-// must have == defined and its default constructor must produce a
-// entry that will never be seen. F is the hash function.
+// mapping. This version permits erasing of arbitrary states. The
+// entry T must have == defined and its default constructor must
+// produce a entry that will never be seen. F is the hash function.
template <class I, class T, class F>
class ErasableBiTable {
public:
@@ -413,7 +523,8 @@ class ErasableBiTable {
const T empty_entry_;
I first_; // I of first element in the deque;
- DISALLOW_COPY_AND_ASSIGN(ErasableBiTable);
+ // disallow
+ void operator=(const ErasableBiTable<I, T, F> &table); //disallow
};
} // namespace fst
diff --git a/src/include/fst/cache.h b/src/include/fst/cache.h
index 0177396..7c96fe1 100644
--- a/src/include/fst/cache.h
+++ b/src/include/fst/cache.h
@@ -89,14 +89,15 @@ struct DefaultCacheStateAllocator {
// CacheState below). This class is used to cache FST elements with
// the flags used to indicate what has been cached. Use HasStart()
// HasFinal(), and HasArcs() to determine if cached and SetStart(),
-// SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note you
-// must set the final weight even if the state is non-final to mark it as
-// cached. If the 'gc' option is 'false', cached items have the extent
-// of the FST - minimizing computation. If the 'gc' option is 'true',
-// garbage collection of states (not in use in an arc iterator) is
-// performed, in a rough approximation of LRU order, when 'gc_limit'
-// bytes is reached - controlling memory use. When 'gc_limit' is 0,
-// special optimizations apply - minimizing memory use.
+// SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note
+// you must set the final weight even if the state is non-final to
+// mark it as cached. If the 'gc' option is 'false', cached items have
+// the extent of the FST - minimizing computation. If the 'gc' option
+// is 'true', garbage collection of states (not in use in an arc
+// iterator and not 'protected') is performed, in a rough
+// approximation of LRU order, when 'gc_limit' bytes is reached -
+// controlling memory use. When 'gc_limit' is 0, special optimizations
+// apply - minimizing memory use.
template <class S, class C = DefaultCacheStateAllocator<S> >
class CacheBaseImpl : public VectorFstBaseImpl<S> {
@@ -111,8 +112,10 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
using FstImpl<Arc>::Properties;
using FstImpl<Arc>::SetProperties;
using VectorFstBaseImpl<State>::NumStates;
+ using VectorFstBaseImpl<State>::Start;
using VectorFstBaseImpl<State>::AddState;
using VectorFstBaseImpl<State>::SetState;
+ using VectorFstBaseImpl<State>::ReserveStates;
explicit CacheBaseImpl(C *allocator = 0)
: cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0),
@@ -120,27 +123,57 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
cache_gc_(FLAGS_fst_default_cache_gc), cache_size_(0),
cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit ||
FLAGS_fst_default_cache_gc_limit == 0 ?
- FLAGS_fst_default_cache_gc_limit : kMinCacheLimit) {
- allocator_ = allocator ? allocator : new C();
- }
+ FLAGS_fst_default_cache_gc_limit : kMinCacheLimit),
+ protect_(false) {
+ allocator_ = allocator ? allocator : new C();
+ }
explicit CacheBaseImpl(const CacheOptions &opts, C *allocator = 0)
: cache_start_(false), nknown_states_(0),
min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0),
cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ?
- opts.gc_limit : kMinCacheLimit) {
- allocator_ = allocator ? allocator : new C();
- }
+ opts.gc_limit : kMinCacheLimit),
+ protect_(false) {
+ allocator_ = allocator ? allocator : new C();
+ }
- // Preserve gc parameters, but initially cache nothing.
- CacheBaseImpl(const CacheBaseImpl &impl)
- : cache_start_(false), nknown_states_(0),
- min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
- cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0),
- cache_limit_(impl.cache_limit_) {
- allocator_ = new C();
+ // Preserve gc parameters. If preserve_cache true, also preserves
+ // cache data.
+ CacheBaseImpl(const CacheBaseImpl<S, C> &impl, bool preserve_cache = false)
+ : VectorFstBaseImpl<S>(), cache_start_(false), nknown_states_(0),
+ min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
+ cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0),
+ cache_limit_(impl.cache_limit_),
+ protect_(impl.protect_) {
+ allocator_ = new C();
+ if (preserve_cache) {
+ cache_start_ = impl.cache_start_;
+ nknown_states_ = impl.nknown_states_;
+ expanded_states_ = impl.expanded_states_;
+ min_unexpanded_state_id_ = impl.min_unexpanded_state_id_;
+ if (impl.cache_first_state_id_ != kNoStateId) {
+ cache_first_state_id_ = impl.cache_first_state_id_;
+ cache_first_state_ = allocator_->Allocate(cache_first_state_id_);
+ *cache_first_state_ = *impl.cache_first_state_;
}
+ cache_states_ = impl.cache_states_;
+ cache_size_ = impl.cache_size_;
+ ReserveStates(impl.NumStates());
+ for (StateId s = 0; s < impl.NumStates(); ++s) {
+ const S *state =
+ static_cast<const VectorFstBaseImpl<S> &>(impl).GetState(s);
+ if (state) {
+ S *copied_state = allocator_->Allocate(s);
+ *copied_state = *state;
+ AddState(copied_state);
+ } else {
+ AddState(0);
+ }
+ }
+ VectorFstBaseImpl<S>::SetStart(impl.Start());
+ }
+ }
~CacheBaseImpl() {
allocator_->Free(cache_first_state_, cache_first_state_id_);
@@ -174,49 +207,7 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
}
// Gets a state from its ID; add it if necessary.
- S *ExtendState(StateId s) {
- if (s == cache_first_state_id_) {
- return cache_first_state_; // Return 1st cached state
- } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) {
- cache_first_state_id_ = s; // Remember 1st cached state
- cache_first_state_ = allocator_->Allocate(s);
- return cache_first_state_;
- } else if (cache_first_state_id_ != kNoStateId &&
- cache_first_state_->ref_count == 0) {
- // With Default allocator, the Free and Allocate will reuse the same S*.
- allocator_->Free(cache_first_state_, cache_first_state_id_);
- cache_first_state_id_ = s;
- cache_first_state_ = allocator_->Allocate(s);
- return cache_first_state_; // Return 1st cached state
- } else {
- while (NumStates() <= s) // Add state to main cache
- AddState(0);
- if (!VectorFstBaseImpl<S>::GetState(s)) {
- SetState(s, allocator_->Allocate(s));
- if (cache_first_state_id_ != kNoStateId) { // Forget 1st cached state
- while (NumStates() <= cache_first_state_id_)
- AddState(0);
- SetState(cache_first_state_id_, cache_first_state_);
- if (cache_gc_) {
- cache_states_.push_back(cache_first_state_id_);
- cache_size_ += sizeof(S) +
- cache_first_state_->arcs.capacity() * sizeof(Arc);
- }
- cache_limit_ = kMinCacheLimit;
- cache_first_state_id_ = kNoStateId;
- cache_first_state_ = 0;
- }
- if (cache_gc_) {
- cache_states_.push_back(s);
- cache_size_ += sizeof(S);
- if (cache_size_ > cache_limit_)
- GC(s, false);
- }
- }
- S *state = VectorFstBaseImpl<S>::GetState(s);
- return state;
- }
- }
+ S *ExtendState(StateId s);
void SetStart(StateId s) {
VectorFstBaseImpl<S>::SetStart(s);
@@ -246,7 +237,8 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
const Arc *parc = state->arcs.empty() ? 0 : &(state->arcs.back());
SetProperties(AddArcProperties(Properties(), s, arc, parc));
state->flags |= kCacheModified;
- if (cache_gc_ && s != cache_first_state_id_) {
+ if (cache_gc_ && s != cache_first_state_id_ &&
+ !(state->flags & kCacheProtect)) {
cache_size_ += sizeof(Arc);
if (cache_size_ > cache_limit_)
GC(s, false);
@@ -278,7 +270,8 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
}
ExpandedState(s);
state->flags |= kCacheArcs | kCacheRecent | kCacheModified;
- if (cache_gc_ && s != cache_first_state_id_) {
+ if (cache_gc_ && s != cache_first_state_id_ &&
+ !(state->flags & kCacheProtect)) {
cache_size_ += arcs.capacity() * sizeof(Arc);
if (cache_size_ > cache_limit_)
GC(s, false);
@@ -300,18 +293,73 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
if (arcs[j].olabel == 0)
--state->noepsilons;
}
+
state->arcs.resize(arcs.size() - n);
SetProperties(DeleteArcsProperties(Properties()));
state->flags |= kCacheModified;
+ if (cache_gc_ && s != cache_first_state_id_ &&
+ !(state->flags & kCacheProtect)) {
+ cache_size_ -= n * sizeof(Arc);
+ }
}
void DeleteArcs(StateId s) {
S *state = ExtendState(s);
+ size_t n = state->arcs.size();
state->niepsilons = 0;
state->noepsilons = 0;
state->arcs.clear();
SetProperties(DeleteArcsProperties(Properties()));
state->flags |= kCacheModified;
+ if (cache_gc_ && s != cache_first_state_id_ &&
+ !(state->flags & kCacheProtect)) {
+ cache_size_ -= n * sizeof(Arc);
+ }
+ }
+
+ void DeleteStates(const vector<StateId> &dstates) {
+ size_t old_num_states = NumStates();
+ vector<StateId> newid(old_num_states, 0);
+ for (size_t i = 0; i < dstates.size(); ++i)
+ newid[dstates[i]] = kNoStateId;
+ StateId nstates = 0;
+ for (StateId s = 0; s < old_num_states; ++s) {
+ if (newid[s] != kNoStateId) {
+ newid[s] = nstates;
+ ++nstates;
+ }
+ }
+ // just for states_.resize(), does unnecessary walk.
+ VectorFstBaseImpl<S>::DeleteStates(dstates);
+ SetProperties(DeleteStatesProperties(Properties()));
+ // Update list of cached states.
+ typename list<StateId>::iterator siter = cache_states_.begin();
+ while (siter != cache_states_.end()) {
+ if (newid[*siter] != kNoStateId) {
+ *siter = newid[*siter];
+ ++siter;
+ } else {
+ cache_states_.erase(siter++);
+ }
+ }
+ }
+
+ void DeleteStates() {
+ cache_states_.clear();
+ allocator_->Free(cache_first_state_, cache_first_state_id_);
+ for (int s = 0; s < NumStates(); ++s) {
+ allocator_->Free(VectorFstBaseImpl<S>::GetState(s), s);
+ SetState(s, 0);
+ }
+ nknown_states_ = 0;
+ min_unexpanded_state_id_ = 0;
+ cache_first_state_id_ = kNoStateId;
+ cache_first_state_ = 0;
+ cache_size_ = 0;
+ cache_start_ = false;
+ VectorFstBaseImpl<State>::DeleteStates();
+ SetProperties(DeleteAllStatesProperties(Properties(),
+ kExpanded | kMutable));
}
// Is the start state cached?
@@ -390,48 +438,17 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
return min_unexpanded_state_id_;
}
- // Removes from cache_states_ and uncaches (not referenced-counted)
- // states that have not been accessed since the last GC until
- // cache_limit_/3 bytes are uncached. If that fails to free enough,
- // recurs uncaching recently visited states as well. If still
- // unable to free enough memory, then widens cache_limit_.
- void GC(StateId current, bool free_recent) {
- if (!cache_gc_)
- return;
- VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this
- << "), free recently cached = " << free_recent
- << ", cache size = " << cache_size_
- << ", cache limit = " << cache_limit_ << "\n";
- typename list<StateId>::iterator siter = cache_states_.begin();
+ // Removes from cache_states_ and uncaches (not referenced-counted
+ // or protected) states that have not been accessed since the last
+ // GC until at most cache_fraction * cache_limit_ bytes are cached.
+ // If that fails to free enough, recurs uncaching recently visited
+ // states as well. If still unable to free enough memory, then
+ // widens cache_limit_ to fulfill condition.
+ void GC(StateId current, bool free_recent, float cache_fraction = 0.666);
- size_t cache_target = (2 * cache_limit_)/3 + 1;
- while (siter != cache_states_.end()) {
- StateId s = *siter;
- S* state = VectorFstBaseImpl<S>::GetState(s);
- if (cache_size_ > cache_target && state->ref_count == 0 &&
- (free_recent || !(state->flags & kCacheRecent)) && s != current) {
- cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc);
- allocator_->Free(state, s);
- SetState(s, 0);
- cache_states_.erase(siter++);
- } else {
- state->flags &= ~kCacheRecent;
- ++siter;
- }
- }
- if (!free_recent && cache_size_ > cache_target) {
- GC(current, true);
- } else {
- while (cache_size_ > cache_target) {
- cache_limit_ *= 2;
- cache_target *= 2;
- }
- }
- VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this
- << "), free recently cached = " << free_recent
- << ", cache size = " << cache_size_
- << ", cache limit = " << cache_limit_ << "\n";
- }
+ // Setc/clears GC protection: if true, new states are protected
+ // from garbage collection.
+ void GCProtect(bool on) { protect_ = on; }
void ExpandedState(StateId s) {
if (s < min_unexpanded_state_id_)
@@ -441,26 +458,30 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
expanded_states_[s] = true;
}
+ C *GetAllocator() const {
+ return allocator_;
+ }
+
// Caching on/off switch, limit and size accessors.
bool GetCacheGc() const { return cache_gc_; }
size_t GetCacheLimit() const { return cache_limit_; }
size_t GetCacheSize() const { return cache_size_; }
private:
- static const size_t kMinCacheLimit = 8096; // Minimum (non-zero) cache limit
- static const uint32 kCacheFinal = 0x0001; // Final weight has been cached
- static const uint32 kCacheArcs = 0x0002; // Arcs have been cached
- static const uint32 kCacheRecent = 0x0004; // Mark as visited since GC
+ static const size_t kMinCacheLimit = 8096; // Minimum (non-zero) cache limit
+
+ static const uint32 kCacheFinal = 0x0001; // Final weight has been cached
+ static const uint32 kCacheArcs = 0x0002; // Arcs have been cached
+ static const uint32 kCacheRecent = 0x0004; // Mark as visited since GC
+ static const uint32 kCacheProtect = 0x0008; // Mark state as GC protected
public:
- static const uint32 kCacheModified = 0x0008; // Mark state as modified
+ static const uint32 kCacheModified = 0x0010; // Mark state as modified
static const uint32 kCacheFlags = kCacheFinal | kCacheArcs | kCacheRecent
- | kCacheModified;
-
- protected:
- C *allocator_; // used to allocate new states
+ | kCacheProtect | kCacheModified;
private:
+ C *allocator_; // used to allocate new states
mutable bool cache_start_; // Is the start state cached?
StateId nknown_states_; // # of known states
vector<bool> expanded_states_; // states that have been expanded
@@ -471,10 +492,113 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> {
bool cache_gc_; // enable GC
size_t cache_size_; // # of bytes cached
size_t cache_limit_; // # of bytes allowed before GC
+ bool protect_; // Protect new states from GC
- void operator=(const CacheBaseImpl<S> &impl); // disallow
+ void operator=(const CacheBaseImpl<S, C> &impl); // disallow
};
+// Gets a state from its ID; add it if necessary.
+template <class S, class C>
+S *CacheBaseImpl<S, C>::ExtendState(typename S::Arc::StateId s) {
+ // If 'protect_' true and a new state, protects from garbage collection.
+ if (s == cache_first_state_id_) {
+ return cache_first_state_; // Return 1st cached state
+ } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) {
+ cache_first_state_id_ = s; // Remember 1st cached state
+ cache_first_state_ = allocator_->Allocate(s);
+ if (protect_) cache_first_state_->flags |= kCacheProtect;
+ return cache_first_state_;
+ } else if (cache_first_state_id_ != kNoStateId &&
+ cache_first_state_->ref_count == 0 &&
+ !(cache_first_state_->flags & kCacheProtect)) {
+ // With Default allocator, the Free and Allocate will reuse the same S*.
+ allocator_->Free(cache_first_state_, cache_first_state_id_);
+ cache_first_state_id_ = s;
+ cache_first_state_ = allocator_->Allocate(s);
+ if (protect_) cache_first_state_->flags |= kCacheProtect;
+ return cache_first_state_; // Return 1st cached state
+ } else {
+ while (NumStates() <= s) // Add state to main cache
+ AddState(0);
+ S *state = VectorFstBaseImpl<S>::GetState(s);
+ if (!state) {
+ state = allocator_->Allocate(s);
+ if (protect_) state->flags |= kCacheProtect;
+ SetState(s, state);
+ if (cache_first_state_id_ != kNoStateId) { // Forget 1st cached state
+ while (NumStates() <= cache_first_state_id_)
+ AddState(0);
+ SetState(cache_first_state_id_, cache_first_state_);
+ if (cache_gc_ && !(cache_first_state_->flags & kCacheProtect)) {
+ cache_states_.push_back(cache_first_state_id_);
+ cache_size_ += sizeof(S) +
+ cache_first_state_->arcs.capacity() * sizeof(Arc);
+ }
+ cache_limit_ = kMinCacheLimit;
+ cache_first_state_id_ = kNoStateId;
+ cache_first_state_ = 0;
+ }
+ if (cache_gc_ && !protect_) {
+ cache_states_.push_back(s);
+ cache_size_ += sizeof(S);
+ if (cache_size_ > cache_limit_)
+ GC(s, false);
+ }
+ }
+ return state;
+ }
+}
+
+// Removes from cache_states_ and uncaches (not referenced-counted or
+// protected) states that have not been accessed since the last GC
+// until at most cache_fraction * cache_limit_ bytes are cached. If
+// that fails to free enough, recurs uncaching recently visited states
+// as well. If still unable to free enough memory, then widens cache_limit_
+// to fulfill condition.
+template <class S, class C>
+void CacheBaseImpl<S, C>::GC(typename S::Arc::StateId current,
+ bool free_recent, float cache_fraction) {
+ if (!cache_gc_)
+ return;
+ VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this
+ << "), free recently cached = " << free_recent
+ << ", cache size = " << cache_size_
+ << ", cache frac = " << cache_fraction
+ << ", cache limit = " << cache_limit_ << "\n";
+ typename list<StateId>::iterator siter = cache_states_.begin();
+
+ size_t cache_target = cache_fraction * cache_limit_;
+ while (siter != cache_states_.end()) {
+ StateId s = *siter;
+ S* state = VectorFstBaseImpl<S>::GetState(s);
+ if (cache_size_ > cache_target && state->ref_count == 0 &&
+ (free_recent || !(state->flags & kCacheRecent)) && s != current) {
+ cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc);
+ allocator_->Free(state, s);
+ SetState(s, 0);
+ cache_states_.erase(siter++);
+ } else {
+ state->flags &= ~kCacheRecent;
+ ++siter;
+ }
+ }
+ if (!free_recent && cache_size_ > cache_target) { // recurses on recent
+ GC(current, true);
+ } else if (cache_target > 0) { // widens cache limit
+ while (cache_size_ > cache_target) {
+ cache_limit_ *= 2;
+ cache_target *= 2;
+ }
+ } else if (cache_size_ > 0) {
+ FSTERROR() << "CacheImpl:GC: Unable to free all cached states";
+ }
+ VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this
+ << "), free recently cached = " << free_recent
+ << ", cache size = " << cache_size_
+ << ", cache frac = " << cache_fraction
+ << ", cache limit = " << cache_limit_ << "\n";
+}
+
template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheFinal;
template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheArcs;
template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheRecent;
@@ -516,7 +640,8 @@ class CacheImpl : public CacheBaseImpl< CacheState<A> > {
explicit CacheImpl(const CacheOptions &opts)
: CacheBaseImpl< CacheState<A> >(opts) {}
- CacheImpl(const CacheImpl<State> &impl) : CacheBaseImpl<State>(impl) {}
+ CacheImpl(const CacheImpl<A> &impl, bool preserve_cache = false)
+ : CacheBaseImpl<State>(impl, preserve_cache) {}
private:
void operator=(const CacheImpl<State> &impl); // disallow
@@ -536,12 +661,13 @@ class CacheStateIterator : public StateIteratorBase<typename F::Arc> {
typedef CacheBaseImpl<State> Impl;
CacheStateIterator(const F &fst, Impl *impl)
- : fst_(fst), impl_(impl), s_(0) {}
+ : fst_(fst), impl_(impl), s_(0) {
+ fst_.Start(); // force start state
+ }
bool Done() const {
if (s_ < impl_->NumKnownStates())
return false;
- fst_.Start(); // force start state
if (s_ < impl_->NumKnownStates())
return false;
for (StateId u = impl_->MinUnexpandedState();
diff --git a/src/include/fst/compact-fst.h b/src/include/fst/compact-fst.h
index 57c927e..6db3317 100644
--- a/src/include/fst/compact-fst.h
+++ b/src/include/fst/compact-fst.h
@@ -32,6 +32,7 @@ using std::vector;
#include <fst/cache.h>
#include <fst/expanded-fst.h>
#include <fst/fst-decl.h> // For optional argument declarations
+#include <fst/mapped-file.h>
#include <fst/matcher.h>
#include <fst/test-properties.h>
#include <fst/util.h>
@@ -134,7 +135,9 @@ class CompactFstData {
typedef U Unsigned;
CompactFstData()
- : states_(0),
+ : states_region_(0),
+ compacts_region_(0),
+ states_(0),
compacts_(0),
nstates_(0),
ncompacts_(0),
@@ -150,8 +153,14 @@ class CompactFstData {
const Compactor &compactor);
~CompactFstData() {
- delete[] states_;
- delete[] compacts_;
+ if (states_region_ == NULL) {
+ delete [] states_;
+ }
+ delete states_region_;
+ if (compacts_region_ == NULL) {
+ delete [] compacts_;
+ }
+ delete compacts_region_;
}
template <class Compactor>
@@ -175,10 +184,9 @@ class CompactFstData {
bool Error() const { return error_; }
- // Byte alignment for states and arcs in file format (version 1 only)
- static const int kFileAlign = 16;
-
private:
+ MappedFile *states_region_;
+ MappedFile *compacts_region_;
Unsigned *states_;
CompactElement *compacts_;
size_t nstates_;
@@ -190,13 +198,11 @@ class CompactFstData {
};
template <class E, class U>
-const int CompactFstData<E, U>::kFileAlign;
-
-
-template <class E, class U>
template <class A, class C>
CompactFstData<E, U>::CompactFstData(const Fst<A> &fst, const C &compactor)
- : states_(0),
+ : states_region_(0),
+ compacts_region_(0),
+ states_(0),
compacts_(0),
nstates_(0),
ncompacts_(0),
@@ -265,7 +271,9 @@ template <class Iterator, class C>
CompactFstData<E, U>::CompactFstData(const Iterator &begin,
const Iterator &end,
const C &compactor)
- : states_(0),
+ : states_region_(0),
+ compacts_region_(0),
+ states_(0),
compacts_(0),
nstates_(0),
ncompacts_(0),
@@ -361,42 +369,40 @@ CompactFstData<E, U> *CompactFstData<E, U>::Read(
data->narcs_ = hdr.NumArcs();
if (compactor.Size() == -1) {
- data->states_ = new Unsigned[data->nstates_ + 1];
- if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) &&
- !AlignInput(strm, kFileAlign)) {
+ if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) {
LOG(ERROR) << "CompactFst::Read: Alignment failed: " << opts.source;
delete data;
return 0;
}
- // TODO: memory map this
size_t b = (data->nstates_ + 1) * sizeof(Unsigned);
- strm.read(reinterpret_cast<char *>(data->states_), b);
- if (!strm) {
+ data->states_region_ = MappedFile::Map(&strm, opts, b);
+ if (!strm || data->states_region_ == NULL) {
LOG(ERROR) << "CompactFst::Read: Read failed: " << opts.source;
delete data;
return 0;
}
+ data->states_ = static_cast<Unsigned *>(
+ data->states_region_->mutable_data());
} else {
data->states_ = 0;
}
data->ncompacts_ = compactor.Size() == -1
? data->states_[data->nstates_]
: data->nstates_ * compactor.Size();
- data->compacts_ = new CompactElement[data->ncompacts_];
- // TODO: memory map this
- size_t b = data->ncompacts_ * sizeof(CompactElement);
- if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) &&
- !AlignInput(strm, kFileAlign)) {
+ if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) {
LOG(ERROR) << "CompactFst::Read: Alignment failed: " << opts.source;
delete data;
return 0;
}
- strm.read(reinterpret_cast<char *>(data->compacts_), b);
- if (!strm) {
+ size_t b = data->ncompacts_ * sizeof(CompactElement);
+ data->compacts_region_ = MappedFile::Map(&strm, opts, b);
+ if (!strm || data->compacts_region_ == NULL) {
LOG(ERROR) << "CompactFst::Read: Read failed: " << opts.source;
delete data;
return 0;
}
+ data->compacts_ = static_cast<CompactElement *>(
+ data->compacts_region_->mutable_data());
return data;
}
@@ -404,14 +410,14 @@ template<class E, class U>
bool CompactFstData<E, U>::Write(ostream &strm,
const FstWriteOptions &opts) const {
if (states_) {
- if (opts.align && !AlignOutput(strm, kFileAlign)) {
+ if (opts.align && !AlignOutput(strm)) {
LOG(ERROR) << "CompactFst::Write: Alignment failed: " << opts.source;
return false;
}
strm.write(reinterpret_cast<char *>(states_),
(nstates_ + 1) * sizeof(Unsigned));
}
- if (opts.align && !AlignOutput(strm, kFileAlign)) {
+ if (opts.align && !AlignOutput(strm)) {
LOG(ERROR) << "CompactFst::Write: Alignment failed: " << opts.source;
return false;
}
@@ -899,6 +905,17 @@ class CompactFst : public ImplToExpandedFst< CompactFstImpl<A, C, U> > {
ImplToFst< Impl, ExpandedFst<A> >::SetImpl(impl, own_impl);
}
+ // Use overloading to extract the type of the argument.
+ static Impl* GetImplIfCompactFst(const CompactFst<A, C, U> &compact_fst) {
+ return compact_fst.GetImpl();
+ }
+
+ // This does not give privileged treatment to subclasses of CompactFst.
+ template<typename NonCompactFst>
+ static Impl* GetImplIfCompactFst(const NonCompactFst& fst) {
+ return NULL;
+ }
+
void operator=(const CompactFst<A, C, U> &fst); // disallow
};
@@ -914,20 +931,16 @@ bool CompactFst<A, C, U>::WriteFst(const F &fst,
typedef U Unsigned;
typedef typename C::Element CompactElement;
typedef typename A::Weight Weight;
- static const int kFileAlign =
- CompactFstData<CompactElement, U>::kFileAlign;
int file_version = opts.align ?
CompactFstImpl<A, C, U>::kAlignedFileVersion :
CompactFstImpl<A, C, U>::kFileVersion;
size_t num_arcs = -1, num_states = -1, num_compacts = -1;
C first_pass_compactor = compactor;
- if (fst.Type() == CompactFst<A, C, U>().Type()) {
- const CompactFst<A, C, U> *compact_fst =
- reinterpret_cast<const CompactFst<A, C, U> *>(&fst);
- num_arcs = compact_fst->GetImpl()->Data()->NumArcs();
- num_states = compact_fst->GetImpl()->Data()->NumStates();
- num_compacts = compact_fst->GetImpl()->Data()->NumCompacts();
- first_pass_compactor = *compact_fst->GetImpl()->GetCompactor();
+ if (Impl* impl = GetImplIfCompactFst(fst)) {
+ num_arcs = impl->Data()->NumArcs();
+ num_states = impl->Data()->NumStates();
+ num_compacts = impl->Data()->NumCompacts();
+ first_pass_compactor = *impl->GetCompactor();
} else {
// A first pass is needed to compute the state of the compactor, which
// is saved ahead of the rest of the data structures. This unfortunately
@@ -971,7 +984,7 @@ bool CompactFst<A, C, U>::WriteFst(const F &fst,
&hdr);
first_pass_compactor.Write(strm);
if (first_pass_compactor.Size() == -1) {
- if (opts.align && !AlignOutput(strm, kFileAlign)) {
+ if (opts.align && !AlignOutput(strm)) {
LOG(ERROR) << "CompactFst::Write: Alignment failed: " << opts.source;
return false;
}
@@ -986,7 +999,7 @@ bool CompactFst<A, C, U>::WriteFst(const F &fst,
}
strm.write(reinterpret_cast<const char *>(&compacts), sizeof(compacts));
}
- if (opts.align && !AlignOutput(strm, kFileAlign)) {
+ if (opts.align && !AlignOutput(strm)) {
LOG(ERROR) << "Could not align file during write after writing states";
}
C second_pass_compactor = compactor;
diff --git a/src/include/fst/compose.h b/src/include/fst/compose.h
index dfdff0a..db5ea3a 100644
--- a/src/include/fst/compose.h
+++ b/src/include/fst/compose.h
@@ -122,7 +122,7 @@ class ComposeFstImplBase : public CacheImpl<A> {
ComposeFstImplBase(const Fst<A> &fst1, const Fst<A> &fst2,
const CacheOptions &opts)
- :CacheImpl<A>(opts) {
+ : CacheImpl<A>(opts) {
VLOG(2) << "ComposeFst(" << this << "): Begin";
SetType("compose");
@@ -137,7 +137,7 @@ class ComposeFstImplBase : public CacheImpl<A> {
}
ComposeFstImplBase(const ComposeFstImplBase<A> &impl)
- : CacheImpl<A>(impl) {
+ : CacheImpl<A>(impl, true) {
SetProperties(impl.Properties(), kCopyProperties);
SetInputSymbols(impl.InputSymbols());
SetOutputSymbols(impl.OutputSymbols());
@@ -275,6 +275,13 @@ class ComposeFstImpl : public ComposeFstImplBase<typename M1::Arc> {
OrderedExpand(s, fst2_, s2, fst1_, s1, matcher2_, true);
}
+ const FST1 &GetFst1() { return fst1_; }
+ const FST2 &GetFst2() { return fst2_; }
+ M1 *GetMatcher1() { return matcher1_; }
+ M2 *GetMatcher2() { return matcher2_; }
+ F *GetFilter() { return filter_; }
+ T *GetStateTable() { return state_table_; }
+
private:
// This does that actual matching of labels in the composition. The
// arguments are ordered so matching is called on state 'sa' of
diff --git a/src/include/fst/const-fst.h b/src/include/fst/const-fst.h
index 80efc8d..e6e85af 100644
--- a/src/include/fst/const-fst.h
+++ b/src/include/fst/const-fst.h
@@ -28,6 +28,7 @@ using std::vector;
#include <fst/expanded-fst.h>
#include <fst/fst-decl.h> // For optional argument declarations
+#include <fst/mapped-file.h>
#include <fst/test-properties.h>
#include <fst/util.h>
@@ -55,7 +56,8 @@ class ConstFstImpl : public FstImpl<A> {
typedef U Unsigned;
ConstFstImpl()
- : states_(0), arcs_(0), nstates_(0), narcs_(0), start_(kNoStateId) {
+ : states_region_(0), arcs_region_(0), states_(0), arcs_(0), nstates_(0),
+ narcs_(0), start_(kNoStateId) {
string type = "const";
if (sizeof(U) != sizeof(uint32)) {
string size;
@@ -69,8 +71,8 @@ class ConstFstImpl : public FstImpl<A> {
explicit ConstFstImpl(const Fst<A> &fst);
~ConstFstImpl() {
- delete[] states_;
- delete[] arcs_;
+ delete arcs_region_;
+ delete states_region_;
}
StateId Start() const { return start_; }
@@ -125,9 +127,9 @@ class ConstFstImpl : public FstImpl<A> {
static const int kAlignedFileVersion = 1;
// Minimum file format version supported
static const int kMinFileVersion = 1;
- // Byte alignment for states and arcs in file format (version 1 only)
- static const int kFileAlign = 16;
+ MappedFile *states_region_; // Mapped file for states
+ MappedFile *arcs_region_; // Mapped file for arcs
State *states_; // States represenation
A *arcs_; // Arcs representation
StateId nstates_; // Number of states
@@ -145,8 +147,6 @@ template <class A, class U>
const int ConstFstImpl<A, U>::kAlignedFileVersion;
template <class A, class U>
const int ConstFstImpl<A, U>::kMinFileVersion;
-template <class A, class U>
-const int ConstFstImpl<A, U>::kFileAlign;
template<class A, class U>
@@ -173,8 +173,10 @@ ConstFstImpl<A, U>::ConstFstImpl(const Fst<A> &fst) : nstates_(0), narcs_(0) {
aiter.Next())
++narcs_;
}
- states_ = new State[nstates_];
- arcs_ = new A[narcs_];
+ states_region_ = MappedFile::Allocate(nstates_ * sizeof(*states_));
+ arcs_region_ = MappedFile::Allocate(narcs_ * sizeof(*arcs_));
+ states_ = reinterpret_cast<State*>(states_region_->mutable_data());
+ arcs_ = reinterpret_cast<A*>(arcs_region_->mutable_data());
size_t pos = 0;
for (StateId s = 0; s < nstates_; ++s) {
states_[s].final = fst.Final(s);
@@ -210,39 +212,40 @@ ConstFstImpl<A, U> *ConstFstImpl<A, U>::Read(istream &strm,
impl->start_ = hdr.Start();
impl->nstates_ = hdr.NumStates();
impl->narcs_ = hdr.NumArcs();
- impl->states_ = new State[impl->nstates_];
- impl->arcs_ = new A[impl->narcs_];
// Ensures compatibility
if (hdr.Version() == kAlignedFileVersion)
hdr.SetFlags(hdr.GetFlags() | FstHeader::IS_ALIGNED);
- if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) &&
- !AlignInput(strm, kFileAlign)) {
+ if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) {
LOG(ERROR) << "ConstFst::Read: Alignment failed: " << opts.source;
delete impl;
return 0;
}
+
size_t b = impl->nstates_ * sizeof(typename ConstFstImpl<A, U>::State);
- strm.read(reinterpret_cast<char *>(impl->states_), b);
- if (!strm) {
+ impl->states_region_ = MappedFile::Map(&strm, opts, b);
+ if (!strm || impl->states_region_ == NULL) {
LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source;
delete impl;
return 0;
}
- if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) &&
- !AlignInput(strm, kFileAlign)) {
+ impl->states_ = reinterpret_cast<State*>(
+ impl->states_region_->mutable_data());
+ if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) {
LOG(ERROR) << "ConstFst::Read: Alignment failed: " << opts.source;
delete impl;
return 0;
}
+
b = impl->narcs_ * sizeof(A);
- strm.read(reinterpret_cast<char *>(impl->arcs_), b);
- if (!strm) {
+ impl->arcs_region_ = MappedFile::Map(&strm, opts, b);
+ if (!strm || impl->arcs_region_ == NULL) {
LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source;
delete impl;
return 0;
}
+ impl->arcs_ = reinterpret_cast<A*>(impl->arcs_region_->mutable_data());
return impl;
}
@@ -318,6 +321,17 @@ class ConstFst : public ImplToExpandedFst< ConstFstImpl<A, U> > {
ImplToFst< Impl, ExpandedFst<A> >::SetImpl(impl, own_impl);
}
+ // Use overloading to extract the type of the argument.
+ static Impl* GetImplIfConstFst(const ConstFst &const_fst) {
+ return const_fst.GetImpl();
+ }
+
+ // Note that this does not give privileged treatment to subtypes of ConstFst.
+ template<typename NonConstFst>
+ static Impl* GetImplIfConstFst(const NonConstFst& fst) {
+ return NULL;
+ }
+
void operator=(const ConstFst<A, U> &fst); // disallow
};
@@ -333,11 +347,9 @@ bool ConstFst<A, U>::WriteFst(const F &fst, ostream &strm,
size_t num_arcs = -1, num_states = -1;
size_t start_offset = 0;
bool update_header = true;
- if (fst.Type() == ConstFst<A, U>().Type()) {
- const ConstFst<A, U> *const_fst =
- reinterpret_cast<const ConstFst<A, U> *>(&fst);
- num_arcs = const_fst->GetImpl()->narcs_;
- num_states = const_fst->GetImpl()->nstates_;
+ if (Impl* impl = GetImplIfConstFst(fst)) {
+ num_arcs = impl->narcs_;
+ num_states = impl->nstates_;
update_header = false;
} else if ((start_offset = strm.tellp()) == -1) {
// precompute values needed for header when we cannot seek to rewrite it.
@@ -363,7 +375,7 @@ bool ConstFst<A, U>::WriteFst(const F &fst, ostream &strm,
ConstFstImpl<A, U>::kStaticProperties;
FstImpl<A>::WriteFstHeader(fst, strm, opts, file_version, type, properties,
&hdr);
- if (opts.align && !AlignOutput(strm, ConstFstImpl<A, U>::kFileAlign)) {
+ if (opts.align && !AlignOutput(strm)) {
LOG(ERROR) << "Could not align file during write after header";
return false;
}
@@ -381,7 +393,7 @@ bool ConstFst<A, U>::WriteFst(const F &fst, ostream &strm,
}
hdr.SetNumStates(states);
hdr.SetNumArcs(pos);
- if (opts.align && !AlignOutput(strm, ConstFstImpl<A, U>::kFileAlign)) {
+ if (opts.align && !AlignOutput(strm)) {
LOG(ERROR) << "Could not align file during write after writing states";
}
for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) {
diff --git a/src/include/fst/determinize.h b/src/include/fst/determinize.h
index a145e4a..9ff8723 100644
--- a/src/include/fst/determinize.h
+++ b/src/include/fst/determinize.h
@@ -33,9 +33,10 @@ using std::tr1::unordered_multimap;
#include <vector>
using std::vector;
+#include <fst/arc-map.h>
#include <fst/cache.h>
+#include <fst/bi-table.h>
#include <fst/factor-weight.h>
-#include <fst/arc-map.h>
#include <fst/prune.h>
#include <fst/test-properties.h>
@@ -108,24 +109,244 @@ class GallicCommonDivisor {
D weight_common_divisor_;
};
-// Options for finite-state transducer determinization.
+
+// Represents an element in a subset
+template <class A>
+struct DeterminizeElement {
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ DeterminizeElement() {}
+
+ DeterminizeElement(StateId s, Weight w) : state_id(s), weight(w) {}
+
+ bool operator==(const DeterminizeElement<A> & element) const {
+ return state_id == element.state_id && weight == element.weight;
+ }
+
+ bool operator<(const DeterminizeElement<A> & element) const {
+ return state_id < element.state_id ||
+ (state_id == element.state_id && weight == element.weight);
+ }
+
+ StateId state_id; // Input state Id
+ Weight weight; // Residual weight
+};
+
+
+//
+// DETERMINIZE FILTERS - these can be used in determinization to compute
+// transformations on the subsets prior to their being added as destination
+// states. The filter operates on a map between a label and the
+// corresponding destination subsets. The possibly modified map is
+// then used to construct the destination states for arcs exiting state 's'.
+// It must define the ordered map type LabelMap and have a default
+// and copy constructor.
+
+// A determinize filter that does not modify its input.
template <class Arc>
+struct IdentityDeterminizeFilter {
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Label Label;
+ typedef slist< DeterminizeElement<Arc> > Subset;
+ typedef map<Label, Subset*> LabelMap;
+
+ static uint64 Properties(uint64 props) { return props; }
+
+ void operator()(StateId s, LabelMap *label_map) {}
+};
+
+
+//
+// DETERMINIZATION STATE TABLES
+//
+// The determiziation state table has the form:
+//
+// template <class Arc>
+// class DeterminizeStateTable {
+// public:
+// typedef typename Arc::StateId StateId;
+// typedef DeterminizeElement<Arc> Element;
+// typedef slist<Element> Subset;
+//
+// // Required constuctor
+// DeterminizeStateTable();
+//
+// // Required copy constructor that does not copy state
+// DeterminizeStateTable(const DeterminizeStateTable<A,P> &table);
+//
+// // Lookup state ID by subset (not depending of the element order).
+// // If it doesn't exist, then add it. FindState takes
+// // ownership of the subset argument (so that it doesn't have to
+// // copy it if it creates a new state).
+// StateId FindState(Subset *subset);
+//
+// // Lookup subset by ID.
+// const Subset *FindSubset(StateId id) const;
+// };
+//
+
+// The default determinization state table based on the
+// compact hash bi-table.
+template <class Arc>
+class DefaultDeterminizeStateTable {
+ public:
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Label Label;
+ typedef typename Arc::Weight Weight;
+ typedef DeterminizeElement<Arc> Element;
+ typedef slist<Element> Subset;
+
+ explicit DefaultDeterminizeStateTable(size_t table_size = 0)
+ : table_size_(table_size),
+ subsets_(table_size_, new SubsetKey(), new SubsetEqual(&elements_)) { }
+
+ DefaultDeterminizeStateTable(const DefaultDeterminizeStateTable<Arc> &table)
+ : table_size_(table.table_size_),
+ subsets_(table_size_, new SubsetKey(), new SubsetEqual(&elements_)) { }
+
+ ~DefaultDeterminizeStateTable() {
+ for (StateId s = 0; s < subsets_.Size(); ++s)
+ delete subsets_.FindEntry(s);
+ }
+
+ // Finds the state corresponding to a subset. Only creates a new
+ // state if the subset is not found. FindState takes ownership of
+ // the subset argument (so that it doesn't have to copy it if it
+ // creates a new state).
+ StateId FindState(Subset *subset) {
+ StateId ns = subsets_.Size();
+ StateId s = subsets_.FindId(subset);
+ if (s != ns) delete subset; // subset found
+ return s;
+ }
+
+ const Subset* FindSubset(StateId s) { return subsets_.FindEntry(s); }
+
+ private:
+ // Comparison object for hashing Subset(s). Subsets are not sorted in this
+ // implementation, so ordering must not be assumed in the equivalence
+ // test.
+ class SubsetEqual {
+ public:
+ SubsetEqual() { // needed for compilation but should never be called
+ FSTERROR() << "SubsetEqual: default constructor not implemented";
+ }
+
+ // Constructor takes vector needed to check equality. See immediately
+ // below for constraints on it.
+ explicit SubsetEqual(vector<Element *> *elements)
+ : elements_(elements) {}
+
+ // At each call to operator(), the elements_ vector should contain
+ // only NULLs. When this operator returns, elements_ will still
+ // have this property.
+ bool operator()(Subset* subset1, Subset* subset2) const {
+ if (!subset1 && !subset2)
+ return true;
+ if ((subset1 && !subset2) || (!subset1 && subset2))
+ return false;
+
+ if (subset1->size() != subset2->size())
+ return false;
+
+ // Loads first subset elements in element vector.
+ for (typename Subset::iterator iter1 = subset1->begin();
+ iter1 != subset1->end();
+ ++iter1) {
+ Element &element1 = *iter1;
+ while (elements_->size() <= element1.state_id)
+ elements_->push_back(0);
+ (*elements_)[element1.state_id] = &element1;
+ }
+
+ // Checks second subset matches first via element vector.
+ for (typename Subset::iterator iter2 = subset2->begin();
+ iter2 != subset2->end();
+ ++iter2) {
+ Element &element2 = *iter2;
+ while (elements_->size() <= element2.state_id)
+ elements_->push_back(0);
+ Element *element1 = (*elements_)[element2.state_id];
+ if (!element1 || element1->weight != element2.weight) {
+ // Mismatch found. Resets element vector before returning false.
+ for (typename Subset::iterator iter1 = subset1->begin();
+ iter1 != subset1->end();
+ ++iter1)
+ (*elements_)[iter1->state_id] = 0;
+ return false;
+ } else {
+ (*elements_)[element2.state_id] = 0; // Clears entry
+ }
+ }
+ return true;
+ }
+ private:
+ vector<Element *> *elements_;
+ };
+
+ // Hash function for Subset to Fst states. Subset elements are not
+ // sorted in this implementation, so the hash must be invariant
+ // under subset reordering.
+ class SubsetKey {
+ public:
+ size_t operator()(const Subset* subset) const {
+ size_t hash = 0;
+ if (subset) {
+ for (typename Subset::const_iterator iter = subset->begin();
+ iter != subset->end();
+ ++iter) {
+ const Element &element = *iter;
+ int lshift = element.state_id % (CHAR_BIT * sizeof(size_t) - 1) + 1;
+ int rshift = CHAR_BIT * sizeof(size_t) - lshift;
+ size_t n = element.state_id;
+ hash ^= n << lshift ^ n >> rshift ^ element.weight.Hash();
+ }
+ }
+ return hash;
+ }
+ };
+
+ size_t table_size_;
+
+ typedef CompactHashBiTable<StateId, Subset *,
+ SubsetKey, SubsetEqual, HS_STL> SubsetTable;
+
+ SubsetTable subsets_;
+ vector<Element *> elements_;
+
+ void operator=(const DefaultDeterminizeStateTable<Arc> &); // disallow
+};
+
+// Options for finite-state transducer determinization templated on
+// the arc type, common divisor, the determinization filter and the
+// state table. DeterminizeFst takes ownership of the determinization
+// filter and state table if provided.
+template <class Arc,
+ class D = DefaultCommonDivisor<typename Arc::Weight>,
+ class F = IdentityDeterminizeFilter<Arc>,
+ class T = DefaultDeterminizeStateTable<Arc> >
struct DeterminizeFstOptions : CacheOptions {
typedef typename Arc::Label Label;
float delta; // Quantization delta for subset weights
Label subsequential_label; // Label used for residual final output
// when producing subsequential transducers.
+ F *filter; // Determinization filter
+ T *state_table; // Determinization state table
explicit DeterminizeFstOptions(const CacheOptions &opts,
- float del = kDelta,
- Label lab = 0)
- : CacheOptions(opts), delta(del), subsequential_label(lab) {}
-
- explicit DeterminizeFstOptions(float del = kDelta, Label lab = 0)
- : delta(del), subsequential_label(lab) {}
+ float del = kDelta, Label lab = 0,
+ F *filt = 0,
+ T *table = 0)
+ : CacheOptions(opts), delta(del), subsequential_label(lab),
+ filter(filt), state_table(table) {}
+
+ explicit DeterminizeFstOptions(float del = kDelta, Label lab = 0,
+ F *filt = 0, T *table = 0)
+ : delta(del), subsequential_label(lab), filter(filt),
+ state_table(table) {}
};
-
// Implementation of delayed DeterminizeFst. This base class is
// common to the variants that implement acceptor and transducer
// determinization.
@@ -149,15 +370,15 @@ class DeterminizeFstImplBase : public CacheImpl<A> {
typedef typename A::StateId StateId;
typedef CacheState<A> State;
+ template <class D, class F, class T>
DeterminizeFstImplBase(const Fst<A> &fst,
- const DeterminizeFstOptions<A> &opts)
+ const DeterminizeFstOptions<A, D, F, T> &opts)
: CacheImpl<A>(opts), fst_(fst.Copy()) {
SetType("determinize");
- uint64 props = fst.Properties(kFstProperties, false);
- SetProperties(DeterminizeProperties(props,
- opts.subsequential_label != 0),
- kCopyProperties);
-
+ uint64 iprops = fst.Properties(kFstProperties, false);
+ uint64 dprops = DeterminizeProperties(iprops,
+ opts.subsequential_label != 0);
+ SetProperties(F::Properties(dprops), kCopyProperties);
SetInputSymbols(fst.InputSymbols());
SetOutputSymbols(fst.OutputSymbols());
}
@@ -234,7 +455,7 @@ class DeterminizeFstImplBase : public CacheImpl<A> {
// Implementation of delayed determinization for weighted acceptors.
// It is templated on the arc type A and the common divisor D.
-template <class A, class D>
+template <class A, class D, class F, class T>
class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> {
public:
using FstImpl<A>::SetProperties;
@@ -244,27 +465,19 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> {
typedef typename A::Label Label;
typedef typename A::Weight Weight;
typedef typename A::StateId StateId;
-
- struct Element {
- Element() {}
-
- Element(StateId s, Weight w) : state_id(s), weight(w) {}
-
- StateId state_id; // Input state Id
- Weight weight; // Residual weight
- };
+ typedef DeterminizeElement<A> Element;
typedef slist<Element> Subset;
- typedef map<Label, Subset*> LabelMap;
+ typedef typename F::LabelMap LabelMap;
- DeterminizeFsaImpl(const Fst<A> &fst, D common_divisor,
+ DeterminizeFsaImpl(const Fst<A> &fst,
const vector<Weight> *in_dist, vector<Weight> *out_dist,
- const DeterminizeFstOptions<A> &opts)
+ const DeterminizeFstOptions<A, D, F, T> &opts)
: DeterminizeFstImplBase<A>(fst, opts),
delta_(opts.delta),
in_dist_(in_dist),
out_dist_(out_dist),
- common_divisor_(common_divisor),
- subset_hash_(0, SubsetKey(), SubsetEqual(&elements_)) {
+ filter_(opts.filter ? opts.filter : new F()),
+ state_table_(opts.state_table ? opts.state_table : new T()) {
if (!fst.Properties(kAcceptor, true)) {
FSTERROR() << "DeterminizeFst: argument not an acceptor";
SetProperties(kError, kError);
@@ -278,13 +491,13 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> {
out_dist_->clear();
}
- DeterminizeFsaImpl(const DeterminizeFsaImpl<A, D> &impl)
+ DeterminizeFsaImpl(const DeterminizeFsaImpl<A, D, F, T> &impl)
: DeterminizeFstImplBase<A>(impl),
delta_(impl.delta_),
in_dist_(0),
out_dist_(0),
- common_divisor_(impl.common_divisor_),
- subset_hash_(0, SubsetKey(), SubsetEqual(&elements_)) {
+ filter_(new F(*impl.filter_)),
+ state_table_(new T(*impl.state_table_)) {
if (impl.out_dist_) {
FSTERROR() << "DeterminizeFsaImpl: cannot copy with out_dist vector";
SetProperties(kError, kError);
@@ -292,12 +505,12 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> {
}
virtual ~DeterminizeFsaImpl() {
- for (int i = 0; i < subsets_.size(); ++i)
- delete subsets_[i];
+ delete filter_;
+ delete state_table_;
}
- virtual DeterminizeFsaImpl<A, D> *Copy() {
- return new DeterminizeFsaImpl<A, D>(*this);
+ virtual DeterminizeFsaImpl<A, D, F, T> *Copy() {
+ return new DeterminizeFsaImpl<A, D, F, T>(*this);
}
uint64 Properties() const { return Properties(kFstProperties); }
@@ -320,12 +533,12 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> {
}
virtual Weight ComputeFinal(StateId s) {
- Subset *subset = subsets_[s];
+ const Subset *subset = state_table_->FindSubset(s);
Weight final = Weight::Zero();
- for (typename Subset::iterator siter = subset->begin();
+ for (typename Subset::const_iterator siter = subset->begin();
siter != subset->end();
++siter) {
- Element &element = *siter;
+ const Element &element = *siter;
final = Plus(final, Times(element.weight,
GetFst().Final(element.state_id)));
if (!final.Member())
@@ -334,33 +547,9 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> {
return final;
}
- // Finds the state corresponding to a subset. Only creates a new state
- // if the subset is not found in the subset hash. FindState takes
- // ownership of the subset argument (so that it doesn't have to copy it
- // if it creates a new state).
- //
- // The method exploits the following device: all pairs stored in the
- // associative container subset_hash_ are of the form (subset,
- // id(subset) + 1), i.e. subset_hash_[subset] > 0 if subset has been
- // stored previously. For unassigned subsets, the call to
- // subset_hash_[subset] creates a new pair (subset, 0). As a result,
- // subset_hash_[subset] == 0 iff subset is new.
StateId FindState(Subset *subset) {
- StateId &assoc_value = subset_hash_[subset];
- if (assoc_value == 0) { // subset wasn't present; create new state
- StateId s = CreateState(subset);
- assoc_value = s + 1;
- return s;
- } else {
- delete subset;
- return assoc_value - 1; // NB: assoc_value = ID + 1
- }
- }
-
- StateId CreateState(Subset *subset) {
- StateId s = subsets_.size();
- subsets_.push_back(subset);
- if (in_dist_)
+ StateId s = state_table_->FindState(subset);
+ if (in_dist_ && out_dist_->size() <= s)
out_dist_->push_back(ComputeDistance(subset));
return s;
}
@@ -398,24 +587,35 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> {
// element weights include the input automaton label weights and the
// subsets may contain duplicate states.
void LabelSubsets(StateId s, LabelMap *label_map) {
- Subset *src_subset = subsets_[s];
+ const Subset *src_subset = state_table_->FindSubset(s);
- for (typename Subset::iterator siter = src_subset->begin();
+ for (typename Subset::const_iterator siter = src_subset->begin();
siter != src_subset->end();
++siter) {
- Element &src_element = *siter;
+ const Element &src_element = *siter;
for (ArcIterator< Fst<A> > aiter(GetFst(), src_element.state_id);
!aiter.Done();
aiter.Next()) {
const A &arc = aiter.Value();
Element dest_element(arc.nextstate,
Times(src_element.weight, arc.weight));
- Subset* &dest_subset = (*label_map)[arc.ilabel];
- if (dest_subset == 0)
+
+ // The LabelMap may be a e.g. multimap with more complex
+ // determinization filters, so we insert efficiently w/o using [].
+ typename LabelMap::iterator liter = label_map->lower_bound(arc.ilabel);
+ Subset* dest_subset;
+ if (liter == label_map->end() || liter->first != arc.ilabel) {
dest_subset = new Subset;
+ label_map->insert(liter, make_pair(arc.ilabel, dest_subset));
+ } else {
+ dest_subset = liter->second;
+ }
+
dest_subset->push_front(dest_element);
}
}
+ // Applies the determinization filter
+ (*filter_)(s, label_map);
}
// Adds an arc from state S to the destination state associated
@@ -469,98 +669,17 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> {
CacheImpl<A>::PushArc(s, arc);
}
- // Comparison object for hashing Subset(s). Subsets are not sorted in this
- // implementation, so ordering must not be assumed in the equivalence
- // test.
- class SubsetEqual {
- public:
- // Constructor takes vector needed to check equality. See immediately
- // below for constraints on it.
- explicit SubsetEqual(vector<Element *> *elements)
- : elements_(elements) {}
-
- // At each call to operator(), the elements_ vector should contain
- // only NULLs. When this operator returns, elements_ will still
- // have this property.
- bool operator()(Subset* subset1, Subset* subset2) const {
- if (subset1->size() != subset2->size())
- return false;
-
- // Loads first subset elements in element vector.
- for (typename Subset::iterator iter1 = subset1->begin();
- iter1 != subset1->end();
- ++iter1) {
- Element &element1 = *iter1;
- while (elements_->size() <= element1.state_id)
- elements_->push_back(0);
- (*elements_)[element1.state_id] = &element1;
- }
-
- // Checks second subset matches first via element vector.
- for (typename Subset::iterator iter2 = subset2->begin();
- iter2 != subset2->end();
- ++iter2) {
- Element &element2 = *iter2;
- while (elements_->size() <= element2.state_id)
- elements_->push_back(0);
- Element *element1 = (*elements_)[element2.state_id];
- if (!element1 || element1->weight != element2.weight) {
- // Mismatch found. Resets element vector before returning false.
- for (typename Subset::iterator iter1 = subset1->begin();
- iter1 != subset1->end();
- ++iter1)
- (*elements_)[iter1->state_id] = 0;
- return false;
- } else {
- (*elements_)[element2.state_id] = 0; // Clears entry
- }
- }
- return true;
- }
- private:
- vector<Element *> *elements_;
- };
-
- // Hash function for Subset to Fst states. Subset elements are not
- // sorted in this implementation, so the hash must be invariant
- // under subset reordering.
- class SubsetKey {
- public:
- size_t operator()(const Subset* subset) const {
- size_t hash = 0;
- for (typename Subset::const_iterator iter = subset->begin();
- iter != subset->end();
- ++iter) {
- const Element &element = *iter;
- int lshift = element.state_id % (CHAR_BIT * sizeof(size_t) - 1) + 1;
- int rshift = CHAR_BIT * sizeof(size_t) - lshift;
- size_t n = element.state_id;
- hash ^= n << lshift ^ n >> rshift ^ element.weight.Hash();
- }
- return hash;
- }
- };
-
float delta_; // Quantization delta for subset weights
const vector<Weight> *in_dist_; // Distance to final NFA states
vector<Weight> *out_dist_; // Distance to final DFA states
D common_divisor_;
+ F *filter_;
+ T *state_table_;
- // Used to test equivalence of subsets.
vector<Element *> elements_;
- // Maps from StateId to Subset.
- vector<Subset *> subsets_;
-
- // Hashes from Subset to its StateId in the output automaton.
- typedef unordered_map<Subset *, StateId, SubsetKey, SubsetEqual>
- SubsetHash;
-
- // Hashes from Label to Subsets corr. to destination states of current state.
- SubsetHash subset_hash_;
-
- void operator=(const DeterminizeFsaImpl<A, D> &); // disallow
+ void operator=(const DeterminizeFsaImpl<A, D, F, T> &); // disallow
};
@@ -569,7 +688,7 @@ class DeterminizeFsaImpl : public DeterminizeFstImplBase<A> {
// the Gallic semiring as an acceptor whose weights contain the output
// strings and using acceptor determinization above to determinize
// that acceptor.
-template <class A, StringType S>
+template <class A, StringType S, class D, class F, class T>
class DeterminizeFstImpl : public DeterminizeFstImplBase<A> {
public:
using FstImpl<A>::SetProperties;
@@ -588,17 +707,18 @@ class DeterminizeFstImpl : public DeterminizeFstImplBase<A> {
typedef ArcMapFst<A, ToArc, ToMapper> ToFst;
typedef ArcMapFst<ToArc, A, FromMapper> FromFst;
- typedef GallicCommonDivisor<Label, Weight, S> CommonDivisor;
+ typedef GallicCommonDivisor<Label, Weight, S, D> CommonDivisor;
typedef GallicFactor<Label, Weight, S> FactorIterator;
- DeterminizeFstImpl(const Fst<A> &fst, const DeterminizeFstOptions<A> &opts)
+ DeterminizeFstImpl(const Fst<A> &fst,
+ const DeterminizeFstOptions<A, D, F, T> &opts)
: DeterminizeFstImplBase<A>(fst, opts),
delta_(opts.delta),
subsequential_label_(opts.subsequential_label) {
Init(GetFst());
}
- DeterminizeFstImpl(const DeterminizeFstImpl<A, S> &impl)
+ DeterminizeFstImpl(const DeterminizeFstImpl<A, S, D, F, T> &impl)
: DeterminizeFstImplBase<A>(impl),
delta_(impl.delta_),
subsequential_label_(impl.subsequential_label_) {
@@ -607,8 +727,8 @@ class DeterminizeFstImpl : public DeterminizeFstImplBase<A> {
~DeterminizeFstImpl() { delete from_fst_; }
- virtual DeterminizeFstImpl<A, S> *Copy() {
- return new DeterminizeFstImpl<A, S>(*this);
+ virtual DeterminizeFstImpl<A, S, D, F, T> *Copy() {
+ return new DeterminizeFstImpl<A, S, D, F, T>(*this);
}
uint64 Properties() const { return Properties(kFstProperties); }
@@ -642,7 +762,7 @@ class DeterminizeFstImpl : public DeterminizeFstImplBase<A> {
Label subsequential_label_;
FromFst *from_fst_;
- void operator=(const DeterminizeFstImpl<A, S> &); // disallow
+ void operator=(const DeterminizeFstImpl<A, S, D, F, T> &); // disallow
};
@@ -673,7 +793,8 @@ class DeterminizeFst : public ImplToFst< DeterminizeFstImplBase<A> > {
public:
friend class ArcIterator< DeterminizeFst<A> >;
friend class StateIterator< DeterminizeFst<A> >;
- template <class B, StringType S> friend class DeterminizeFstImpl;
+ template <class B, StringType S, class D, class F, class T>
+ friend class DeterminizeFstImpl;
typedef A Arc;
typedef typename A::Weight Weight;
@@ -684,33 +805,47 @@ class DeterminizeFst : public ImplToFst< DeterminizeFstImplBase<A> > {
using ImplToFst<Impl>::SetImpl;
- explicit DeterminizeFst(
- const Fst<A> &fst,
- const DeterminizeFstOptions<A> &opts = DeterminizeFstOptions<A>()) {
+ explicit DeterminizeFst(const Fst<A> &fst) {
+ typedef DefaultCommonDivisor<Weight> D;
+ typedef IdentityDeterminizeFilter<A> F;
+ typedef DefaultDeterminizeStateTable<A> T;
+ DeterminizeFstOptions<A, D, F, T> opts;
if (fst.Properties(kAcceptor, true)) {
// Calls implementation for acceptors.
- typedef DefaultCommonDivisor<Weight> D;
- SetImpl(new DeterminizeFsaImpl<A, D>(fst, D(), 0, 0, opts));
+ SetImpl(new DeterminizeFsaImpl<A, D, F, T>(fst, 0, 0, opts));
} else {
// Calls implementation for transducers.
- SetImpl(new DeterminizeFstImpl<A, STRING_LEFT_RESTRICT>(fst, opts));
+ SetImpl(new
+ DeterminizeFstImpl<A, STRING_LEFT_RESTRICT, D, F, T>(fst, opts));
+ }
+ }
+
+ template <class D, class F, class T>
+ DeterminizeFst(const Fst<A> &fst,
+ const DeterminizeFstOptions<A, D, F, T> &opts) {
+ if (fst.Properties(kAcceptor, true)) {
+ // Calls implementation for acceptors.
+ SetImpl(new DeterminizeFsaImpl<A, D, F, T>(fst, 0, 0, opts));
+ } else {
+ // Calls implementation for transducers.
+ SetImpl(new
+ DeterminizeFstImpl<A, STRING_LEFT_RESTRICT, D, F, T>(fst, opts));
}
}
// This acceptor-only version additionally computes the distance to
// final states in the output if provided with those distances for the
// input. Useful for e.g. unique N-shortest paths.
- DeterminizeFst(
- const Fst<A> &fst,
- const vector<Weight> &in_dist, vector<Weight> *out_dist,
- const DeterminizeFstOptions<A> &opts = DeterminizeFstOptions<A>()) {
+ template <class D, class F, class T>
+ DeterminizeFst(const Fst<A> &fst,
+ const vector<Weight> *in_dist, vector<Weight> *out_dist,
+ const DeterminizeFstOptions<A, D, F, T> &opts) {
if (!fst.Properties(kAcceptor, true)) {
FSTERROR() << "DeterminizeFst:"
<< " distance to final states computed for acceptors only";
GetImpl()->SetProperties(kError, kError);
}
- typedef DefaultCommonDivisor<Weight> D;
- SetImpl(new DeterminizeFsaImpl<A, D>(fst, D(), &in_dist, out_dist, opts));
+ SetImpl(new DeterminizeFsaImpl<A, D, F, T>(fst, in_dist, out_dist, opts));
}
// See Fst<>::Copy() for doc.
@@ -733,14 +868,6 @@ class DeterminizeFst : public ImplToFst< DeterminizeFstImplBase<A> > {
}
private:
- // This private version is for passing the common divisor to
- // FSA determinization.
- template <class D>
- DeterminizeFst(const Fst<A> &fst, const D &common_div,
- const DeterminizeFstOptions<A> &opts)
- : ImplToFst<Impl>(
- new DeterminizeFsaImpl<A, D>(fst, common_div, 0, 0, opts)) {}
-
// Makes visible to friends.
Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
@@ -750,17 +877,18 @@ class DeterminizeFst : public ImplToFst< DeterminizeFstImplBase<A> > {
// Initialization of transducer determinization implementation. which
// is defined after DeterminizeFst since it calls it.
-template <class A, StringType S>
-void DeterminizeFstImpl<A, S>::Init(const Fst<A> &fst) {
+template <class A, StringType S, class D, class F, class T>
+void DeterminizeFstImpl<A, S, D, F, T>::Init(const Fst<A> &fst) {
// Mapper to an acceptor.
ToFst to_fst(fst, ToMapper());
- // Determinize acceptor.
+ // Determinizes acceptor.
// This recursive call terminates since it passes the common divisor
// to a private constructor.
CacheOptions copts(GetCacheGc(), GetCacheLimit());
- DeterminizeFstOptions<ToArc> dopts(copts, delta_);
- DeterminizeFst<ToArc> det_fsa(to_fst, CommonDivisor(), dopts);
+ DeterminizeFstOptions<ToArc, CommonDivisor> dopts(copts, delta_);
+ // Uses acceptor-only constructor to avoid template recursion
+ DeterminizeFst<ToArc> det_fsa(to_fst, 0, 0, dopts);
// Mapper back to transducer.
FactorWeightOptions<ToArc> fopts(CacheOptions(true, 0), delta_,
@@ -832,7 +960,7 @@ struct DeterminizeOptions {
// Determinizes a weighted transducer. This version writes the
// determinized Fst to an output MutableFst. The result will be an
-// equivalent FSt that has the property that no state has two
+// equivalent FST that has the property that no state has two
// transitions with the same input label. For this algorithm, epsilon
// transitions are treated as regular symbols (cf. RmEpsilon).
//
@@ -866,7 +994,7 @@ void Determinize(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
if (ifst.Properties(kAcceptor, false)) {
vector<Weight> idistance, odistance;
ShortestDistance(ifst, &idistance, true);
- DeterminizeFst<Arc> dfst(ifst, idistance, &odistance, nopts);
+ DeterminizeFst<Arc> dfst(ifst, &idistance, &odistance, nopts);
PruneOptions< Arc, AnyArcFilter<Arc> > popts(opts.weight_threshold,
opts.state_threshold,
AnyArcFilter<Arc>(),
diff --git a/src/include/fst/edit-fst.h b/src/include/fst/edit-fst.h
index 303cb24..bd33b9d 100644
--- a/src/include/fst/edit-fst.h
+++ b/src/include/fst/edit-fst.h
@@ -25,6 +25,10 @@ using std::vector;
#include <fst/cache.h>
+#include <tr1/unordered_map>
+using std::tr1::unordered_map;
+using std::tr1::unordered_multimap;
+
namespace fst {
// The EditFst class enables non-destructive edit operations on a wrapped
@@ -431,7 +435,8 @@ class EditFstImpl : public FstImpl<A> {
// A copy constructor for this implementation class, used to implement
// the Copy() method of the Fst interface.
EditFstImpl(const EditFstImpl &impl)
- : wrapped_(static_cast<WrappedFstT *>(impl.wrapped_->Copy(true))),
+ : FstImpl<A>(),
+ wrapped_(static_cast<WrappedFstT *>(impl.wrapped_->Copy(true))),
data_(impl.data_) {
data_->IncrRefCount();
SetProperties(impl.Properties());
diff --git a/src/include/fst/epsnormalize.h b/src/include/fst/epsnormalize.h
index 8187737..9d178b1 100644
--- a/src/include/fst/epsnormalize.h
+++ b/src/include/fst/epsnormalize.h
@@ -24,7 +24,6 @@
#include <tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
-#include <fst/slist.h>
#include <fst/factor-weight.h>
diff --git a/src/include/fst/equivalent.h b/src/include/fst/equivalent.h
index 7f8708a..e28fea1 100644
--- a/src/include/fst/equivalent.h
+++ b/src/include/fst/equivalent.h
@@ -23,6 +23,7 @@
#include <algorithm>
#include <deque>
+using std::deque;
#include <tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
diff --git a/src/include/fst/extensions/far/extract.h b/src/include/fst/extensions/far/extract.h
index d6f92ff..95866de 100644
--- a/src/include/fst/extensions/far/extract.h
+++ b/src/include/fst/extensions/far/extract.h
@@ -32,51 +32,106 @@ using std::vector;
namespace fst {
template<class Arc>
+inline void FarWriteFst(const Fst<Arc>* fst, string key,
+ string* okey, int* nrep,
+ const int32 &generate_filenames, int i,
+ const string &filename_prefix,
+ const string &filename_suffix) {
+ if (key == *okey)
+ ++*nrep;
+ else
+ *nrep = 0;
+
+ *okey = key;
+
+ string ofilename;
+ if (generate_filenames) {
+ ostringstream tmp;
+ tmp.width(generate_filenames);
+ tmp.fill('0');
+ tmp << i;
+ ofilename = tmp.str();
+ } else {
+ if (*nrep > 0) {
+ ostringstream tmp;
+ tmp << '.' << nrep;
+ key.append(tmp.str().data(), tmp.str().size());
+ }
+ ofilename = key;
+ }
+ fst->Write(filename_prefix + ofilename + filename_suffix);
+}
+
+template<class Arc>
void FarExtract(const vector<string> &ifilenames,
const int32 &generate_filenames,
- const string &begin_key,
- const string &end_key,
+ const string &keys,
+ const string &key_separator,
+ const string &range_delimiter,
const string &filename_prefix,
const string &filename_suffix) {
FarReader<Arc> *far_reader = FarReader<Arc>::Open(ifilenames);
if (!far_reader) return;
- if (!begin_key.empty())
- far_reader->Find(begin_key);
-
string okey;
int nrep = 0;
- for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) {
- string key = far_reader->GetKey();
- if (!end_key.empty() && end_key < key)
- break;
- const Fst<Arc> &fst = far_reader->GetFst();
-
- if (key == okey)
- ++nrep;
- else
- nrep = 0;
- okey = key;
-
- string ofilename;
- if (generate_filenames) {
- ostringstream tmp;
- tmp.width(generate_filenames);
- tmp.fill('0');
- tmp << i;
- ofilename = tmp.str();
- } else {
- if (nrep > 0) {
- ostringstream tmp;
- tmp << '.' << nrep;
- key.append(tmp.str().data(), tmp.str().size());
+ vector<char *> key_vector;
+ // User has specified a set of fsts to extract, where some of the "fsts" could
+ // be ranges.
+ if (!keys.empty()) {
+ char *keys_cstr = new char[keys.size()+1];
+ strcpy(keys_cstr, keys.c_str());
+ SplitToVector(keys_cstr, key_separator.c_str(), &key_vector, true);
+ int i = 0;
+ for (int k = 0; k < key_vector.size(); ++k, ++i) {
+ string key = string(key_vector[k]);
+ char *key_cstr = new char[key.size()+1];
+ strcpy(key_cstr, key.c_str());
+ vector<char *> range_vector;
+ SplitToVector(key_cstr, range_delimiter.c_str(), &range_vector, false);
+ if (range_vector.size() == 1) { // Not a range
+ if (!far_reader->Find(key)) {
+ LOG(ERROR) << "FarExtract: Cannot find key: " << key;
+ return;
+ }
+ const Fst<Arc> &fst = far_reader->GetFst();
+ FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i,
+ filename_prefix, filename_suffix);
+ } else if (range_vector.size() == 2) { // A legal range
+ string begin_key = string(range_vector[0]);
+ string end_key = string(range_vector[1]);
+ if (begin_key.empty() || end_key.empty()) {
+ LOG(ERROR) << "FarExtract: Illegal range specification: " << key;
+ return;
+ }
+ if (!far_reader->Find(begin_key)) {
+ LOG(ERROR) << "FarExtract: Cannot find key: " << begin_key;
+ return;
+ }
+ for ( ; !far_reader->Done(); far_reader->Next(), ++i) {
+ string ikey = far_reader->GetKey();
+ if (end_key < ikey) break;
+ const Fst<Arc> &fst = far_reader->GetFst();
+ FarWriteFst(&fst, ikey, &okey, &nrep, generate_filenames, i,
+ filename_prefix, filename_suffix);
+ }
+ } else {
+ LOG(ERROR) << "FarExtract: Illegal range specification: " << key;
+ return;
}
- ofilename = key;
+ delete key_cstr;
}
- fst.Write(filename_prefix + ofilename + filename_suffix);
+ delete keys_cstr;
+ return;
+ }
+ // Nothing specified: extract everything.
+ for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) {
+ string key = far_reader->GetKey();
+ const Fst<Arc> &fst = far_reader->GetFst();
+ FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i,
+ filename_prefix, filename_suffix);
}
-
return;
}
diff --git a/src/include/fst/extensions/far/far.h b/src/include/fst/extensions/far/far.h
index acce76e..737f1b8 100644
--- a/src/include/fst/extensions/far/far.h
+++ b/src/include/fst/extensions/far/far.h
@@ -273,13 +273,10 @@ FarWriter<A> *FarWriter<A>::Create(const string &filename, FarType type) {
return STListFarWriter<A>::Create(filename);
case FAR_STTABLE:
return STTableFarWriter<A>::Create(filename);
- break;
case FAR_STLIST:
return STListFarWriter<A>::Create(filename);
- break;
case FAR_FST:
return FstFarWriter<A>::Create(filename);
- break;
default:
LOG(ERROR) << "FarWriter::Create: unknown far type";
return 0;
diff --git a/src/include/fst/extensions/far/farscript.h b/src/include/fst/extensions/far/farscript.h
index 3a9c145..cfd9167 100644
--- a/src/include/fst/extensions/far/farscript.h
+++ b/src/include/fst/extensions/far/farscript.h
@@ -173,18 +173,22 @@ bool FarEqual(const string &filename1,
typedef args::Package<const vector<string> &, int32,
const string&, const string&, const string&,
- const string&> FarExtractArgs;
+ const string&, const string&> FarExtractArgs;
template<class Arc>
void FarExtract(FarExtractArgs *args) {
fst::FarExtract<Arc>(
- args->arg1, args->arg2, args->arg3, args->arg4, args->arg5, args->arg6);
+ args->arg1, args->arg2, args->arg3, args->arg4, args->arg5, args->arg6,
+ args->arg7);
}
void FarExtract(const vector<string> &ifilenames,
const string &arc_type,
- int32 generate_filenames, const string &begin_key,
- const string &end_key, const string &filename_prefix,
+ int32 generate_filenames,
+ const string &keys,
+ const string &key_separator,
+ const string &range_delimiter,
+ const string &filename_prefix,
const string &filename_suffix);
typedef args::Package<const vector<string> &, const string &,
diff --git a/src/include/fst/extensions/far/stlist.h b/src/include/fst/extensions/far/stlist.h
index 1cdc80c..ff3d98b 100644
--- a/src/include/fst/extensions/far/stlist.h
+++ b/src/include/fst/extensions/far/stlist.h
@@ -145,13 +145,13 @@ class STListReader {
ReadType(*streams_[i], &magic_number);
ReadType(*streams_[i], &file_version);
if (magic_number != kSTListMagicNumber) {
- FSTERROR() << "STListReader::STTableReader: wrong file type: "
+ FSTERROR() << "STListReader::STListReader: wrong file type: "
<< filenames[i];
error_ = true;
return;
}
if (file_version != kSTListFileVersion) {
- FSTERROR() << "STListReader::STTableReader: wrong file version: "
+ FSTERROR() << "STListReader::STListReader: wrong file version: "
<< filenames[i];
error_ = true;
return;
@@ -161,7 +161,7 @@ class STListReader {
if (!key.empty())
heap_.push(make_pair(key, i));
if (!*streams_[i]) {
- FSTERROR() << "STTableReader: error reading file: " << sources_[i];
+ FSTERROR() << "STListReader: error reading file: " << sources_[i];
error_ = true;
return;
}
@@ -170,7 +170,7 @@ class STListReader {
size_t current = heap_.top().second;
entry_ = entry_reader_(*streams_[current]);
if (!entry_ || !*streams_[current]) {
- FSTERROR() << "STTableReader: error reading entry for key: "
+ FSTERROR() << "STListReader: error reading entry for key: "
<< heap_.top().first << ", file: " << sources_[current];
error_ = true;
}
@@ -219,7 +219,7 @@ class STListReader {
heap_.pop();
ReadType(*(streams_[current]), &key);
if (!*streams_[current]) {
- FSTERROR() << "STTableReader: error reading file: "
+ FSTERROR() << "STListReader: error reading file: "
<< sources_[current];
error_ = true;
return;
@@ -233,7 +233,7 @@ class STListReader {
delete entry_;
entry_ = entry_reader_(*streams_[current]);
if (!entry_ || !*streams_[current]) {
- FSTERROR() << "STTableReader: error reading entry for key: "
+ FSTERROR() << "STListReader: error reading entry for key: "
<< heap_.top().first << ", file: " << sources_[current];
error_ = true;
}
@@ -267,8 +267,8 @@ class STListReader {
// String-type list header reading function template on the entry header
// type 'H' having a member function:
// Read(istream &strm, const string &filename);
-// Checks that 'filename' is an STTable and call the H::Read() on the last
-// entry in the STTable.
+// Checks that 'filename' is an STList and call the H::Read() on the last
+// entry in the STList.
// Does not support reading from stdin.
template <class H>
bool ReadSTListHeader(const string &filename, H *header) {
@@ -281,18 +281,18 @@ bool ReadSTListHeader(const string &filename, H *header) {
ReadType(strm, &magic_number);
ReadType(strm, &file_version);
if (magic_number != kSTListMagicNumber) {
- LOG(ERROR) << "ReadSTTableHeader: wrong file type: " << filename;
+ LOG(ERROR) << "ReadSTListHeader: wrong file type: " << filename;
return false;
}
if (file_version != kSTListFileVersion) {
- LOG(ERROR) << "ReadSTTableHeader: wrong file version: " << filename;
+ LOG(ERROR) << "ReadSTListHeader: wrong file version: " << filename;
return false;
}
string key;
ReadType(strm, &key);
header->Read(strm, filename + ":" + key);
if (!strm) {
- LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename;
+ LOG(ERROR) << "ReadSTListHeader: error reading file: " << filename;
return false;
}
return true;
diff --git a/src/include/fst/extensions/ngram/ngram-fst.h b/src/include/fst/extensions/ngram/ngram-fst.h
index eee664a..873ae6a 100644
--- a/src/include/fst/extensions/ngram/ngram-fst.h
+++ b/src/include/fst/extensions/ngram/ngram-fst.h
@@ -26,6 +26,7 @@ using std::vector;
#include <fst/compat.h>
#include <fst/fstlib.h>
+#include <fst/mapped-file.h>
#include <fst/extensions/ngram/bitmap-index.h>
// NgramFst implements a n-gram language model based upon the LOUDS data
@@ -76,7 +77,7 @@ class NGramFstImpl : public FstImpl<A> {
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
- NGramFstImpl() : data_(0), owned_(false) {
+ NGramFstImpl() : data_region_(0), data_(0), owned_(false) {
SetType("ngram");
SetInputSymbols(NULL);
SetOutputSymbols(NULL);
@@ -89,6 +90,7 @@ class NGramFstImpl : public FstImpl<A> {
if (owned_) {
delete [] data_;
}
+ delete data_region_;
}
static NGramFstImpl<A>* Read(istream &strm, // NOLINT
@@ -104,7 +106,8 @@ class NGramFstImpl : public FstImpl<A> {
strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures));
strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final));
size_t size = Storage(num_states, num_futures, num_final);
- char* data = new char[size];
+ MappedFile *data_region = MappedFile::Allocate(size);
+ char *data = reinterpret_cast<char *>(data_region->mutable_data());
// Copy num_states, num_futures and num_final back into data.
memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states));
memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures),
@@ -116,7 +119,7 @@ class NGramFstImpl : public FstImpl<A> {
delete impl;
return NULL;
}
- impl->Init(data, true /* owned */);
+ impl->Init(data, false, data_region);
return impl;
}
@@ -126,7 +129,7 @@ class NGramFstImpl : public FstImpl<A> {
hdr.SetStart(Start());
hdr.SetNumStates(num_states_);
WriteHeader(strm, opts, kFileVersion, &hdr);
- strm.write(data_, Storage(num_states_, num_futures_, num_final_));
+ strm.write(data_, StorageSize());
return strm;
}
@@ -223,11 +226,23 @@ class NGramFstImpl : public FstImpl<A> {
// Access to the underlying representation
const char* GetData(size_t* data_size) const {
- *data_size = Storage(num_states_, num_futures_, num_final_);
+ *data_size = StorageSize();
return data_;
}
- void Init(const char* data, bool owned);
+ void Init(const char* data, bool owned, MappedFile *file = 0);
+
+ const vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const {
+ SetInstFuture(s, inst);
+ SetInstContext(inst);
+ return inst->context_;
+ }
+
+ size_t StorageSize() const {
+ return Storage(num_states_, num_futures_, num_final_);
+ }
+
+ void GetStates(const vector<Label>& context, vector<StateId> *states) const;
private:
StateId Transition(const vector<Label> &context, Label future) const;
@@ -242,6 +257,7 @@ class NGramFstImpl : public FstImpl<A> {
// Minimum file format version supported.
static const int kMinFileVersion = 4;
+ MappedFile *data_region_;
const char* data_;
bool owned_; // True if we own data_
uint64 num_states_, num_futures_, num_final_;
@@ -261,7 +277,7 @@ class NGramFstImpl : public FstImpl<A> {
template<typename A>
NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
- : data_(0), owned_(false) {
+ : data_region_(0), data_(0), owned_(false) {
typedef A Arc;
typedef typename Arc::Label Label;
typedef typename Arc::Weight Weight;
@@ -286,12 +302,16 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
// epsilons.
StateId unigram = fst.Start();
while (1) {
- ArcIterator<Fst<A> > aiter(fst, unigram);
- if (aiter.Done()) {
- FSTERROR() << "Start state has no arcs";
+ if (unigram == kNoStateId) {
+ FSTERROR() << "Could not identify unigram state.";
SetProperties(kError, kError);
return;
}
+ ArcIterator<Fst<A> > aiter(fst, unigram);
+ if (aiter.Done()) {
+ LOG(WARNING) << "Unigram state " << unigram << " has no arcs.";
+ break;
+ }
if (aiter.Value().ilabel != 0) break;
unigram = aiter.Value().nextstate;
}
@@ -385,7 +405,8 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
Weight weight;
Label label = kNoLabel;
const size_t storage = Storage(num_states, num_futures, num_final);
- char* data = new char[storage];
+ MappedFile *data_region = MappedFile::Allocate(storage);
+ char *data = reinterpret_cast<char *>(data_region->mutable_data());
memset(data, 0, storage);
size_t offset = 0;
memcpy(data + offset, reinterpret_cast<char *>(&num_states),
@@ -482,14 +503,17 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
return;
}
- Init(data, true /* owned */);
+ Init(data, false, data_region);
}
template<typename A>
-inline void NGramFstImpl<A>::Init(const char* data, bool owned) {
+inline void NGramFstImpl<A>::Init(const char* data, bool owned,
+ MappedFile *data_region) {
if (owned_) {
delete [] data_;
}
+ delete data_region_;
+ data_region_ = data_region;
owned_ = owned;
data_ = data;
size_t offset = 0;
@@ -507,7 +531,7 @@ inline void NGramFstImpl<A>::Init(const char* data, bool owned) {
future_ = reinterpret_cast<const uint64*>(data_ + offset);
offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits);
final_ = reinterpret_cast<const uint64*>(data_ + offset);
- offset += BitmapIndex::StorageSize(num_states_ + 1) * sizeof(bits);
+ offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits);
context_words_ = reinterpret_cast<const Label*>(data_ + offset);
offset += (num_states_ + 1) * sizeof(*context_words_);
future_words_ = reinterpret_cast<const Label*>(data_ + offset);
@@ -538,10 +562,10 @@ inline void NGramFstImpl<A>::Init(const char* data, bool owned) {
template<typename A>
inline typename A::StateId NGramFstImpl<A>::Transition(
const vector<Label> &context, Label future) const {
- size_t num_children = root_num_children_;
const Label *children = root_children_;
- const Label *loc = lower_bound(children, children + num_children, future);
- if (loc == children + num_children || *loc != future) {
+ const Label *loc = lower_bound(children, children + root_num_children_,
+ future);
+ if (loc == children + root_num_children_ || *loc != future) {
return context_index_.Rank1(0);
}
size_t node = root_first_child_ + loc - children;
@@ -551,7 +575,6 @@ inline typename A::StateId NGramFstImpl<A>::Transition(
return context_index_.Rank1(node);
}
size_t last_child = context_index_.Select0(node_rank + 1) - 1;
- num_children = last_child - first_child + 1;
for (int word = context.size() - 1; word >= 0; --word) {
children = context_words_ + context_index_.Rank1(first_child);
loc = lower_bound(children, children + last_child - first_child + 1,
@@ -569,6 +592,42 @@ inline typename A::StateId NGramFstImpl<A>::Transition(
return context_index_.Rank1(node);
}
+template<typename A>
+inline void NGramFstImpl<A>::GetStates(
+ const vector<Label> &context,
+ vector<typename A::StateId>* states) const {
+ states->clear();
+ states->push_back(0);
+ typename vector<Label>::const_reverse_iterator cit = context.rbegin();
+ const Label *children = root_children_;
+ const Label *loc = lower_bound(children, children + root_num_children_, *cit);
+ if (loc == children + root_num_children_ || *loc != *cit) return;
+ size_t node = root_first_child_ + loc - children;
+ states->push_back(context_index_.Rank1(node));
+ if (context.size() == 1) return;
+ size_t node_rank = context_index_.Rank1(node);
+ size_t first_child = context_index_.Select0(node_rank) + 1;
+ ++cit;
+ if (context_index_.Get(first_child) != false) {
+ size_t last_child = context_index_.Select0(node_rank + 1) - 1;
+ while (cit != context.rend()) {
+ children = context_words_ + context_index_.Rank1(first_child);
+ loc = lower_bound(children, children + last_child - first_child + 1,
+ *cit);
+ if (loc == children + last_child - first_child + 1 || *loc != *cit) {
+ break;
+ }
+ ++cit;
+ node = first_child + loc - children;
+ states->push_back(context_index_.Rank1(node));
+ node_rank = context_index_.Rank1(node);
+ first_child = context_index_.Select0(node_rank) + 1;
+ if (context_index_.Get(first_child) == false) break;
+ last_child = context_index_.Select0(node_rank + 1) - 1;
+ }
+ }
+}
+
/*****************************************************************************/
template<class A>
class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
@@ -597,7 +656,7 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
// Non-standard constructor to initialize NGramFst directly from data.
NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) {
- GetImpl()->Init(data, owned);
+ GetImpl()->Init(data, owned, NULL);
}
// Get method that gets the data associated with Init().
@@ -605,6 +664,16 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
return GetImpl()->GetData(data_size);
}
+ const vector<Label> GetContext(StateId s) const {
+ return GetImpl()->GetContext(s, &inst_);
+ }
+
+ // Consumes as much as possible of context from right to left, returns the
+ // the states corresponding to the increasingly conditioned input sequence.
+ void GetStates(const vector<Label>& context, vector<StateId> *state) const {
+ return GetImpl()->GetStates(context, state);
+ }
+
virtual size_t NumArcs(StateId s) const {
return GetImpl()->NumArcs(s, &inst_);
}
@@ -650,6 +719,10 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
return new NGramFstMatcher<A>(*this, match_type);
}
+ size_t StorageSize() const {
+ return GetImpl()->StorageSize();
+ }
+
private:
explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {}
diff --git a/src/include/fst/extensions/pdt/compose.h b/src/include/fst/extensions/pdt/compose.h
index 364d76f..c856c6d 100644
--- a/src/include/fst/extensions/pdt/compose.h
+++ b/src/include/fst/extensions/pdt/compose.h
@@ -21,82 +21,469 @@
#ifndef FST_EXTENSIONS_PDT_COMPOSE_H__
#define FST_EXTENSIONS_PDT_COMPOSE_H__
+#include <list>
+
+#include <fst/extensions/pdt/pdt.h>
#include <fst/compose.h>
namespace fst {
+// Return paren arcs for Find(kNoLabel).
+const uint32 kParenList = 0x00000001;
+
+// Return a kNolabel loop for Find(paren).
+const uint32 kParenLoop = 0x00000002;
+
+// This class is a matcher that treats parens as multi-epsilon labels.
+// It is most efficient if the parens are in a range non-overlapping with
+// the non-paren labels.
+template <class F>
+class ParenMatcher {
+ public:
+ typedef SortedMatcher<F> M;
+ typedef typename M::FST FST;
+ typedef typename M::Arc Arc;
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Label Label;
+ typedef typename Arc::Weight Weight;
+
+ ParenMatcher(const FST &fst, MatchType match_type,
+ uint32 flags = (kParenLoop | kParenList))
+ : matcher_(fst, match_type),
+ match_type_(match_type),
+ flags_(flags) {
+ if (match_type == MATCH_INPUT) {
+ loop_.ilabel = kNoLabel;
+ loop_.olabel = 0;
+ } else {
+ loop_.ilabel = 0;
+ loop_.olabel = kNoLabel;
+ }
+ loop_.weight = Weight::One();
+ loop_.nextstate = kNoStateId;
+ }
+
+ ParenMatcher(const ParenMatcher<F> &matcher, bool safe = false)
+ : matcher_(matcher.matcher_, safe),
+ match_type_(matcher.match_type_),
+ flags_(matcher.flags_),
+ open_parens_(matcher.open_parens_),
+ close_parens_(matcher.close_parens_),
+ loop_(matcher.loop_) {
+ loop_.nextstate = kNoStateId;
+ }
+
+ ParenMatcher<F> *Copy(bool safe = false) const {
+ return new ParenMatcher<F>(*this, safe);
+ }
+
+ MatchType Type(bool test) const { return matcher_.Type(test); }
+
+ void SetState(StateId s) {
+ matcher_.SetState(s);
+ loop_.nextstate = s;
+ }
+
+ bool Find(Label match_label);
+
+ bool Done() const {
+ return done_;
+ }
+
+ const Arc& Value() const {
+ return paren_loop_ ? loop_ : matcher_.Value();
+ }
+
+ void Next();
+
+ const FST &GetFst() const { return matcher_.GetFst(); }
+
+ uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
+
+ uint32 Flags() const { return matcher_.Flags(); }
+
+ void AddOpenParen(Label label) {
+ if (label == 0) {
+ FSTERROR() << "ParenMatcher: Bad open paren label: 0";
+ } else {
+ open_parens_.Insert(label);
+ }
+ }
+
+ void AddCloseParen(Label label) {
+ if (label == 0) {
+ FSTERROR() << "ParenMatcher: Bad close paren label: 0";
+ } else {
+ close_parens_.Insert(label);
+ }
+ }
+
+ void RemoveOpenParen(Label label) {
+ if (label == 0) {
+ FSTERROR() << "ParenMatcher: Bad open paren label: 0";
+ } else {
+ open_parens_.Erase(label);
+ }
+ }
+
+ void RemoveCloseParen(Label label) {
+ if (label == 0) {
+ FSTERROR() << "ParenMatcher: Bad close paren label: 0";
+ } else {
+ close_parens_.Erase(label);
+ }
+ }
+
+ void ClearOpenParens() {
+ open_parens_.Clear();
+ }
+
+ void ClearCloseParens() {
+ close_parens_.Clear();
+ }
+
+ bool IsOpenParen(Label label) const {
+ return open_parens_.Member(label);
+ }
+
+ bool IsCloseParen(Label label) const {
+ return close_parens_.Member(label);
+ }
+
+ private:
+ // Advances matcher to next open paren if it exists, returning true.
+ // O.w. returns false.
+ bool NextOpenParen();
+
+ // Advances matcher to next open paren if it exists, returning true.
+ // O.w. returns false.
+ bool NextCloseParen();
+
+ M matcher_;
+ MatchType match_type_; // Type of match to perform
+ uint32 flags_;
+
+ // open paren label set
+ CompactSet<Label, kNoLabel> open_parens_;
+
+ // close paren label set
+ CompactSet<Label, kNoLabel> close_parens_;
+
+
+ bool open_paren_list_; // Matching open paren list
+ bool close_paren_list_; // Matching close paren list
+ bool paren_loop_; // Current arc is the implicit paren loop
+ mutable Arc loop_; // For non-consuming symbols
+ bool done_; // Matching done
+
+ void operator=(const ParenMatcher<F> &); // Disallow
+};
+
+template <class M> inline
+bool ParenMatcher<M>::Find(Label match_label) {
+ open_paren_list_ = false;
+ close_paren_list_ = false;
+ paren_loop_ = false;
+ done_ = false;
+
+ // Returns all parenthesis arcs
+ if (match_label == kNoLabel && (flags_ & kParenList)) {
+ if (open_parens_.LowerBound() != kNoLabel) {
+ matcher_.LowerBound(open_parens_.LowerBound());
+ open_paren_list_ = NextOpenParen();
+ if (open_paren_list_) return true;
+ }
+ if (close_parens_.LowerBound() != kNoLabel) {
+ matcher_.LowerBound(close_parens_.LowerBound());
+ close_paren_list_ = NextCloseParen();
+ if (close_paren_list_) return true;
+ }
+ }
+
+ // Returns 'implicit' paren loop
+ if (match_label > 0 && (flags_ & kParenLoop) &&
+ (IsOpenParen(match_label) || IsCloseParen(match_label))) {
+ paren_loop_ = true;
+ return true;
+ }
+
+ // Returns all other labels
+ if (matcher_.Find(match_label))
+ return true;
+
+ done_ = true;
+ return false;
+}
+
+template <class F> inline
+void ParenMatcher<F>::Next() {
+ if (paren_loop_) {
+ paren_loop_ = false;
+ done_ = true;
+ } else if (open_paren_list_) {
+ matcher_.Next();
+ open_paren_list_ = NextOpenParen();
+ if (open_paren_list_) return;
+
+ if (close_parens_.LowerBound() != kNoLabel) {
+ matcher_.LowerBound(close_parens_.LowerBound());
+ close_paren_list_ = NextCloseParen();
+ if (close_paren_list_) return;
+ }
+ done_ = !matcher_.Find(kNoLabel);
+ } else if (close_paren_list_) {
+ matcher_.Next();
+ close_paren_list_ = NextCloseParen();
+ if (close_paren_list_) return;
+ done_ = !matcher_.Find(kNoLabel);
+ } else {
+ matcher_.Next();
+ done_ = matcher_.Done();
+ }
+}
+
+// Advances matcher to next open paren if it exists, returning true.
+// O.w. returns false.
+template <class F> inline
+bool ParenMatcher<F>::NextOpenParen() {
+ for (; !matcher_.Done(); matcher_.Next()) {
+ Label label = match_type_ == MATCH_INPUT ?
+ matcher_.Value().ilabel : matcher_.Value().olabel;
+ if (label > open_parens_.UpperBound())
+ return false;
+ if (IsOpenParen(label))
+ return true;
+ }
+ return false;
+}
+
+// Advances matcher to next close paren if it exists, returning true.
+// O.w. returns false.
+template <class F> inline
+bool ParenMatcher<F>::NextCloseParen() {
+ for (; !matcher_.Done(); matcher_.Next()) {
+ Label label = match_type_ == MATCH_INPUT ?
+ matcher_.Value().ilabel : matcher_.Value().olabel;
+ if (label > close_parens_.UpperBound())
+ return false;
+ if (IsCloseParen(label))
+ return true;
+ }
+ return false;
+}
+
+
+template <class F>
+class ParenFilter {
+ public:
+ typedef typename F::FST1 FST1;
+ typedef typename F::FST2 FST2;
+ typedef typename F::Arc Arc;
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Label Label;
+ typedef typename Arc::Weight Weight;
+ typedef typename F::Matcher1 Matcher1;
+ typedef typename F::Matcher2 Matcher2;
+ typedef typename F::FilterState FilterState1;
+ typedef StateId StackId;
+ typedef PdtStack<StackId, Label> ParenStack;
+ typedef IntegerFilterState<StackId> FilterState2;
+ typedef PairFilterState<FilterState1, FilterState2> FilterState;
+ typedef ParenFilter<F> Filter;
+
+ ParenFilter(const FST1 &fst1, const FST2 &fst2,
+ Matcher1 *matcher1 = 0, Matcher2 *matcher2 = 0,
+ const vector<pair<Label, Label> > *parens = 0,
+ bool expand = false, bool keep_parens = true)
+ : filter_(fst1, fst2, matcher1, matcher2),
+ parens_(parens ? *parens : vector<pair<Label, Label> >()),
+ expand_(expand),
+ keep_parens_(keep_parens),
+ f_(FilterState::NoState()),
+ stack_(parens_),
+ paren_id_(-1) {
+ if (parens) {
+ for (size_t i = 0; i < parens->size(); ++i) {
+ const pair<Label, Label> &p = (*parens)[i];
+ parens_.push_back(p);
+ GetMatcher1()->AddOpenParen(p.first);
+ GetMatcher2()->AddOpenParen(p.first);
+ if (!expand_) {
+ GetMatcher1()->AddCloseParen(p.second);
+ GetMatcher2()->AddCloseParen(p.second);
+ }
+ }
+ }
+ }
+
+ ParenFilter(const Filter &filter, bool safe = false)
+ : filter_(filter.filter_, safe),
+ parens_(filter.parens_),
+ expand_(filter.expand_),
+ keep_parens_(filter.keep_parens_),
+ f_(FilterState::NoState()),
+ stack_(filter.parens_),
+ paren_id_(-1) { }
+
+ FilterState Start() const {
+ return FilterState(filter_.Start(), FilterState2(0));
+ }
+
+ void SetState(StateId s1, StateId s2, const FilterState &f) {
+ f_ = f;
+ filter_.SetState(s1, s2, f_.GetState1());
+ if (!expand_)
+ return;
+
+ ssize_t paren_id = stack_.Top(f.GetState2().GetState());
+ if (paren_id != paren_id_) {
+ if (paren_id_ != -1) {
+ GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second);
+ GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second);
+ }
+ paren_id_ = paren_id;
+ if (paren_id_ != -1) {
+ GetMatcher1()->AddCloseParen(parens_[paren_id_].second);
+ GetMatcher2()->AddCloseParen(parens_[paren_id_].second);
+ }
+ }
+ }
+
+ FilterState FilterArc(Arc *arc1, Arc *arc2) const {
+ FilterState1 f1 = filter_.FilterArc(arc1, arc2);
+ const FilterState2 &f2 = f_.GetState2();
+ if (f1 == FilterState1::NoState())
+ return FilterState::NoState();
+
+ if (arc1->olabel == kNoLabel && arc2->ilabel) { // arc2 parentheses
+ if (keep_parens_) {
+ arc1->ilabel = arc2->ilabel;
+ } else if (arc2->ilabel) {
+ arc2->olabel = arc1->ilabel;
+ }
+ return FilterParen(arc2->ilabel, f1, f2);
+ } else if (arc2->ilabel == kNoLabel && arc1->olabel) { // arc1 parentheses
+ if (keep_parens_) {
+ arc2->olabel = arc1->olabel;
+ } else {
+ arc1->ilabel = arc2->olabel;
+ }
+ return FilterParen(arc1->olabel, f1, f2);
+ } else {
+ return FilterState(f1, f2);
+ }
+ }
+
+ void FilterFinal(Weight *w1, Weight *w2) const {
+ if (f_.GetState2().GetState() != 0)
+ *w1 = Weight::Zero();
+ filter_.FilterFinal(w1, w2);
+ }
+
+ // Return resp matchers. Ownership stays with filter.
+ Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); }
+ Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); }
+
+ uint64 Properties(uint64 iprops) const {
+ uint64 oprops = filter_.Properties(iprops);
+ return oprops & kILabelInvariantProperties & kOLabelInvariantProperties;
+ }
+
+ private:
+ const FilterState FilterParen(Label label, const FilterState1 &f1,
+ const FilterState2 &f2) const {
+ if (!expand_)
+ return FilterState(f1, f2);
+
+ StackId stack_id = stack_.Find(f2.GetState(), label);
+ if (stack_id < 0) {
+ return FilterState::NoState();
+ } else {
+ return FilterState(f1, FilterState2(stack_id));
+ }
+ }
+
+ F filter_;
+ vector<pair<Label, Label> > parens_;
+ bool expand_; // Expands to FST
+ bool keep_parens_; // Retains parentheses in output
+ FilterState f_; // Current filter state
+ mutable ParenStack stack_;
+ ssize_t paren_id_;
+};
+
// Class to setup composition options for PDT composition.
// Default is for the PDT as the first composition argument.
template <class Arc, bool left_pdt = true>
-class PdtComposeOptions : public
+class PdtComposeFstOptions : public
ComposeFstOptions<Arc,
- MultiEpsMatcher< Matcher<Fst<Arc> > >,
- MultiEpsFilter<AltSequenceComposeFilter<
- MultiEpsMatcher<
- Matcher<Fst<Arc> > > > > > {
+ ParenMatcher< Fst<Arc> >,
+ ParenFilter<AltSequenceComposeFilter<
+ ParenMatcher< Fst<Arc> > > > > {
public:
typedef typename Arc::Label Label;
- typedef MultiEpsMatcher< Matcher<Fst<Arc> > > PdtMatcher;
- typedef MultiEpsFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter;
+ typedef ParenMatcher< Fst<Arc> > PdtMatcher;
+ typedef ParenFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter;
typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
using COptions::matcher1;
using COptions::matcher2;
using COptions::filter;
- PdtComposeOptions(const Fst<Arc> &ifst1,
+ PdtComposeFstOptions(const Fst<Arc> &ifst1,
const vector<pair<Label, Label> > &parens,
- const Fst<Arc> &ifst2) {
- matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsList);
- matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsLoop);
-
- // Treat parens as multi-epsilons when composing.
- for (size_t i = 0; i < parens.size(); ++i) {
- matcher1->AddMultiEpsLabel(parens[i].first);
- matcher1->AddMultiEpsLabel(parens[i].second);
- matcher2->AddMultiEpsLabel(parens[i].first);
- matcher2->AddMultiEpsLabel(parens[i].second);
- }
+ const Fst<Arc> &ifst2, bool expand = false,
+ bool keep_parens = true) {
+ matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenList);
+ matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenLoop);
- filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true);
+ filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
+ expand, keep_parens);
}
};
// Class to setup composition options for PDT with FST composition.
// Specialization is for the FST as the first composition argument.
template <class Arc>
-class PdtComposeOptions<Arc, false> : public
+class PdtComposeFstOptions<Arc, false> : public
ComposeFstOptions<Arc,
- MultiEpsMatcher< Matcher<Fst<Arc> > >,
- MultiEpsFilter<SequenceComposeFilter<
- MultiEpsMatcher<
- Matcher<Fst<Arc> > > > > > {
+ ParenMatcher< Fst<Arc> >,
+ ParenFilter<SequenceComposeFilter<
+ ParenMatcher< Fst<Arc> > > > > {
public:
typedef typename Arc::Label Label;
- typedef MultiEpsMatcher< Matcher<Fst<Arc> > > PdtMatcher;
- typedef MultiEpsFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter;
+ typedef ParenMatcher< Fst<Arc> > PdtMatcher;
+ typedef ParenFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter;
typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
using COptions::matcher1;
using COptions::matcher2;
using COptions::filter;
- PdtComposeOptions(const Fst<Arc> &ifst1,
- const Fst<Arc> &ifst2,
- const vector<pair<Label, Label> > &parens) {
- matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsLoop);
- matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsList);
-
- // Treat parens as multi-epsilons when composing.
- for (size_t i = 0; i < parens.size(); ++i) {
- matcher1->AddMultiEpsLabel(parens[i].first);
- matcher1->AddMultiEpsLabel(parens[i].second);
- matcher2->AddMultiEpsLabel(parens[i].first);
- matcher2->AddMultiEpsLabel(parens[i].second);
- }
+ PdtComposeFstOptions(const Fst<Arc> &ifst1,
+ const Fst<Arc> &ifst2,
+ const vector<pair<Label, Label> > &parens,
+ bool expand = false, bool keep_parens = true) {
+ matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenLoop);
+ matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenList);
- filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true);
+ filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
+ expand, keep_parens);
}
};
+enum PdtComposeFilter {
+ PAREN_FILTER, // Bar-Hillel construction; keeps parentheses
+ EXPAND_FILTER, // Bar-Hillel + expansion; removes parentheses
+ EXPAND_PAREN_FILTER, // Bar-Hillel + expansion; keeps parentheses
+};
+
+struct PdtComposeOptions {
+ bool connect; // Connect output
+ PdtComposeFilter filter_type; // Which pre-defined filter to use
+
+ explicit PdtComposeOptions(bool c, PdtComposeFilter ft = PAREN_FILTER)
+ : connect(c), filter_type(ft) {}
+ PdtComposeOptions() : connect(true), filter_type(PAREN_FILTER) {}
+};
// Composes pushdown transducer (PDT) encoded as an FST (1st arg) and
// an FST (2nd arg) with the result also a PDT encoded as an Fst. (3rd arg).
@@ -110,16 +497,17 @@ void Compose(const Fst<Arc> &ifst1,
typename Arc::Label> > &parens,
const Fst<Arc> &ifst2,
MutableFst<Arc> *ofst,
- const ComposeOptions &opts = ComposeOptions()) {
-
- PdtComposeOptions<Arc, true> copts(ifst1, parens, ifst2);
+ const PdtComposeOptions &opts = PdtComposeOptions()) {
+ bool expand = opts.filter_type != PAREN_FILTER;
+ bool keep_parens = opts.filter_type != EXPAND_FILTER;
+ PdtComposeFstOptions<Arc, true> copts(ifst1, parens, ifst2,
+ expand, keep_parens);
copts.gc_limit = 0;
*ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
if (opts.connect)
Connect(ofst);
}
-
// Composes an FST (1st arg) and pushdown transducer (PDT) encoded as
// an FST (2nd arg) with the result also a PDT encoded as an Fst (3rd arg).
// In the PDTs, some transitions are labeled with open or close
@@ -132,9 +520,11 @@ void Compose(const Fst<Arc> &ifst1,
const vector<pair<typename Arc::Label,
typename Arc::Label> > &parens,
MutableFst<Arc> *ofst,
- const ComposeOptions &opts = ComposeOptions()) {
-
- PdtComposeOptions<Arc, false> copts(ifst1, ifst2, parens);
+ const PdtComposeOptions &opts = PdtComposeOptions()) {
+ bool expand = opts.filter_type != PAREN_FILTER;
+ bool keep_parens = opts.filter_type != EXPAND_FILTER;
+ PdtComposeFstOptions<Arc, false> copts(ifst1, ifst2, parens,
+ expand, keep_parens);
copts.gc_limit = 0;
*ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
if (opts.connect)
diff --git a/src/include/fst/extensions/pdt/pdt.h b/src/include/fst/extensions/pdt/pdt.h
index 6649f55..c56afbd 100644
--- a/src/include/fst/extensions/pdt/pdt.h
+++ b/src/include/fst/extensions/pdt/pdt.h
@@ -27,6 +27,7 @@ using std::tr1::unordered_multimap;
#include <map>
#include <set>
+#include <fst/compat.h>
#include <fst/state-table.h>
#include <fst/fst.h>
diff --git a/src/include/fst/extensions/pdt/pdtscript.h b/src/include/fst/extensions/pdt/pdtscript.h
index c2a1cf4..84bb27e 100644
--- a/src/include/fst/extensions/pdt/pdtscript.h
+++ b/src/include/fst/extensions/pdt/pdtscript.h
@@ -48,7 +48,7 @@ typedef args::Package<const FstClass &,
const FstClass &,
const vector<pair<int64, int64> >&,
MutableFstClass *,
- const ComposeOptions &,
+ const PdtComposeOptions &,
bool> PdtComposeArgs;
template<class Arc>
@@ -76,7 +76,7 @@ void PdtCompose(const FstClass & ifst1,
const FstClass & ifst2,
const vector<pair<int64, int64> > &parens,
MutableFstClass *ofst,
- const ComposeOptions &copts,
+ const PdtComposeOptions &copts,
bool left_pdt);
// PDT EXPAND
diff --git a/src/include/fst/extensions/pdt/replace.h b/src/include/fst/extensions/pdt/replace.h
index a85d0fe..9081400 100644
--- a/src/include/fst/extensions/pdt/replace.h
+++ b/src/include/fst/extensions/pdt/replace.h
@@ -21,6 +21,10 @@
#ifndef FST_EXTENSIONS_PDT_REPLACE_H__
#define FST_EXTENSIONS_PDT_REPLACE_H__
+#include <tr1/unordered_map>
+using std::tr1::unordered_map;
+using std::tr1::unordered_multimap;
+
#include <fst/replace.h>
namespace fst {
@@ -62,11 +66,14 @@ void Replace(const vector<pair<typename Arc::Label,
label2id[ifst_array[i].first] = i;
Label max_label = kNoLabel;
+ size_t max_non_term_count = 0;
- deque<size_t> non_term_queue; // Queue of non-terminals to replace
- unordered_set<Label> non_term_set; // Set of non-terminals to replace
+ // Queue of non-terminals to replace
+ deque<size_t> non_term_queue;
+ // Map of non-terminals to replace to count
+ unordered_map<Label, size_t> non_term_map;
non_term_queue.push_back(root);
- non_term_set.insert(root);
+ non_term_map[root] = 1;;
// PDT state corr. to ith replace FST start state.
vector<StateId> fst_start(ifst_array.size(), kNoLabel);
@@ -107,10 +114,11 @@ void Replace(const vector<pair<typename Arc::Label,
size_t nfst_id = it->second;
if (ifst_array[nfst_id].second->Start() == -1)
continue;
- if (non_term_set.count(arc.olabel) == 0) {
+ size_t count = non_term_map[arc.olabel]++;
+ if (count == 0)
non_term_queue.push_back(arc.olabel);
- non_term_set.insert(arc.olabel);
- }
+ if (count > max_non_term_count)
+ max_non_term_count = count;
}
arc.nextstate += soff;
ofst->AddArc(os, arc);
@@ -134,7 +142,8 @@ void Replace(const vector<pair<typename Arc::Label,
// # of parenthesis pairs per fst.
vector<size_t> nparens(ifst_array.size(), 0);
// Initial open parenthesis label
- Label first_paren = max_label + 1;
+ Label first_open_paren = max_label + 1;
+ Label first_close_paren = max_label + max_non_term_count + 1;
for (StateIterator< Fst<Arc> > siter(*ofst);
!siter.Done(); siter.Next()) {
@@ -158,8 +167,8 @@ void Replace(const vector<pair<typename Arc::Label,
close_paren = (*parens)[paren_id].second;
} else {
size_t paren_id = nparens[nfst_id]++;
- open_paren = first_paren + 2 * paren_id;
- close_paren = open_paren + 1;
+ open_paren = first_open_paren + paren_id;
+ close_paren = first_close_paren + paren_id;
paren_map[paren_key] = paren_id;
if (paren_id >= parens->size())
parens->push_back(make_pair(open_paren, close_paren));
diff --git a/src/include/fst/factor-weight.h b/src/include/fst/factor-weight.h
index 97440e1..685155c 100644
--- a/src/include/fst/factor-weight.h
+++ b/src/include/fst/factor-weight.h
@@ -25,7 +25,6 @@
#include <tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
-#include <fst/slist.h>
#include <string>
#include <utility>
using std::pair; using std::make_pair;
diff --git a/src/include/fst/fst-decl.h b/src/include/fst/fst-decl.h
index 0e2cdf1..f27ded8 100644
--- a/src/include/fst/fst-decl.h
+++ b/src/include/fst/fst-decl.h
@@ -56,7 +56,6 @@ template <class A> class ClosureFst;
template <class A> class ComposeFst;
template <class A> class ConcatFst;
template <class A> class DeterminizeFst;
-template <class A> class DeterminizeFst;
template <class A> class DifferenceFst;
template <class A> class IntersectFst;
template <class A> class InvertFst;
diff --git a/src/include/fst/fst.h b/src/include/fst/fst.h
index dd11e4f..150fc4e 100644
--- a/src/include/fst/fst.h
+++ b/src/include/fst/fst.h
@@ -53,6 +53,11 @@ template <class A> class ArcIteratorData;
template <class A> class MatcherBase;
struct FstReadOptions {
+ // FileReadMode(s) are advisory, there are many conditions than prevent a
+ // file from being mapped, READ mode will be selected in these cases with
+ // a warning indicating why it was chosen.
+ enum FileReadMode { READ, MAP };
+
string source; // Where you're reading from
const FstHeader *header; // Pointer to Fst header. If non-zero, use
// this info (don't read a stream header)
@@ -60,19 +65,20 @@ struct FstReadOptions {
// this info (read and skip stream isymbols)
const SymbolTable* osymbols; // Pointer to output symbols. If non-zero, use
// this info (read and skip stream osymbols)
+ FileReadMode mode; // Read or map files (advisory, if possible)
- explicit FstReadOptions(const string& src = "<unspecfied>",
+ explicit FstReadOptions(const string& src = "<unspecified>",
const FstHeader *hdr = 0,
const SymbolTable* isym = 0,
- const SymbolTable* osym = 0)
- : source(src), header(hdr), isymbols(isym), osymbols(osym) {}
+ const SymbolTable* osym = 0);
explicit FstReadOptions(const string& src,
const SymbolTable* isym,
- const SymbolTable* osym = 0)
- : source(src), header(0), isymbols(isym), osymbols(osym) {}
-};
+ const SymbolTable* osym = 0);
+ // Helper function to convert strings FileReadModes into their enum value.
+ static FileReadMode ReadMode(const string &mode);
+};
struct FstWriteOptions {
string source; // Where you're writing to
diff --git a/src/include/fst/interval-set.h b/src/include/fst/interval-set.h
index cf6ac54..c4362f2 100644
--- a/src/include/fst/interval-set.h
+++ b/src/include/fst/interval-set.h
@@ -81,12 +81,12 @@ class IntervalSet {
const vector<Interval> *Intervals() const { return &intervals_; }
- const bool Empty() const { return intervals_.empty(); }
+ bool Empty() const { return intervals_.empty(); }
- const T Size() const { return intervals_.size(); }
+ T Size() const { return intervals_.size(); }
// Number of points in the intervals (undefined if not normalized).
- const T Count() const { return count_; }
+ T Count() const { return count_; }
void Clear() {
intervals_.clear();
diff --git a/src/include/fst/mapped-file.h b/src/include/fst/mapped-file.h
new file mode 100644
index 0000000..d61bc14
--- /dev/null
+++ b/src/include/fst/mapped-file.h
@@ -0,0 +1,83 @@
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: sorenj@google.com (Jeffrey Sorensen)
+
+#ifndef FST_LIB_MAPPED_FILE_H_
+#define FST_LIB_MAPPED_FILE_H_
+
+#include <unistd.h>
+#include <sys/mman.h>
+
+#include <fst/fst.h>
+#include <iostream>
+#include <fstream>
+#include <sstream>
+
+DECLARE_int32(fst_arch_alignment); // defined in mapped-file.h
+
+namespace fst {
+
+// A memory region is a simple abstraction for allocated memory or data from
+// mmap'ed files. If mmap equals NULL, then data represents an owned region of
+// size bytes. Otherwise, mmap and size refer to the mapping and data is a
+// casted pointer to a region contained within [mmap, mmap + size).
+// If size is 0, then mmap refers and data refer to a block of memory managed
+// externally by some other allocator.
+struct MemoryRegion {
+ void *data;
+ void *mmap;
+ size_t size;
+};
+
+class MappedFile {
+ public:
+ virtual ~MappedFile();
+
+ void* mutable_data() const {
+ return reinterpret_cast<void*>(region_.data);
+ }
+
+ const void* data() const {
+ return reinterpret_cast<void*>(region_.data);
+ }
+
+ // Returns a MappedFile object that contains the contents of the input
+ // stream s starting from the current file position with size bytes.
+ // The file name must also be provided in the FstReadOptions as opts.source
+ // or else mapping will fail. If mapping is not possible, then a MappedFile
+ // object with a new[]'ed block of memory will be created.
+ static MappedFile* Map(istream* s, const FstReadOptions& opts, size_t size);
+
+ // Creates a MappedFile object with a new[]'ed block of memory of size.
+ // RECOMMENDED FOR INTERNAL USE ONLY, may change in future releases.
+ static MappedFile* Allocate(size_t size);
+
+ // Creates a MappedFile object pointing to a borrowed reference to data.
+ // This block of memory is not owned by the MappedFile object and will not
+ // be freed.
+ // RECOMMENDED FOR INTERNAL USE ONLY, may change in future releases.
+ static MappedFile* Borrow(void *data);
+
+ static const int kArchAlignment;
+
+ private:
+ explicit MappedFile(const MemoryRegion &region);
+
+ MemoryRegion region_;
+ DISALLOW_COPY_AND_ASSIGN(MappedFile);
+};
+} // namespace fst
+
+#endif // FST_LIB_MAPPED_FILE_H_
diff --git a/src/include/fst/matcher.h b/src/include/fst/matcher.h
index 5ab3d26..89ed9be 100644
--- a/src/include/fst/matcher.h
+++ b/src/include/fst/matcher.h
@@ -75,6 +75,8 @@ namespace fst {
// bool Done() const; // No more matches.
// const A& Value() const; // Current arc (when !Done)
// void Next(); // Advance to next arc (when !Done)
+// // Initially and after SetState() the iterator methods
+// // have undefined behavior until Find() is called.
//
// // Return matcher FST.
// const F& GetFst() const;
@@ -223,13 +225,44 @@ class SortedMatcher : public MatcherBase<typename F::Arc> {
loop_.nextstate = s;
}
- bool Find(Label match_label);
+ bool Find(Label match_label) {
+ exact_match_ = true;
+ if (error_) {
+ current_loop_ = false;
+ match_label_ = kNoLabel;
+ return false;
+ }
+ current_loop_ = match_label == 0;
+ match_label_ = match_label == kNoLabel ? 0 : match_label;
+ if (Search()) {
+ return true;
+ } else {
+ return current_loop_;
+ }
+ }
+
+ // Positions matcher to the first position where inserting
+ // match_label would maintain the sort order.
+ void LowerBound(Label match_label) {
+ exact_match_ = false;
+ current_loop_ = false;
+ if (error_) {
+ match_label_ = kNoLabel;
+ return;
+ }
+ match_label_ = match_label;
+ Search();
+ }
+ // After Find(), returns false if no more exact matches.
+ // After LowerBound(), returns false if no more arcs.
bool Done() const {
if (current_loop_)
return false;
if (aiter_->Done())
return true;
+ if (!exact_match_)
+ return false;
aiter_->SetFlags(
match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
kArcValueFlags);
@@ -261,6 +294,8 @@ class SortedMatcher : public MatcherBase<typename F::Arc> {
return outprops;
}
+ size_t Position() const { return aiter_ ? aiter_->Position() : 0; }
+
private:
virtual void SetState_(StateId s) { SetState(s); }
virtual bool Find_(Label label) { return Find(label); }
@@ -268,6 +303,8 @@ class SortedMatcher : public MatcherBase<typename F::Arc> {
virtual const Arc& Value_() const { return Value(); }
virtual void Next_() { Next(); }
+ bool Search();
+
const F *fst_;
StateId s_; // Current state
ArcIterator<F> *aiter_; // Iterator for current state
@@ -277,20 +314,16 @@ class SortedMatcher : public MatcherBase<typename F::Arc> {
size_t narcs_; // Current state arc count
Arc loop_; // For non-consuming symbols
bool current_loop_; // Current arc is the implicit loop
+ bool exact_match_; // Exact match or lower bound?
bool error_; // Error encountered
void operator=(const SortedMatcher<F> &); // Disallow
};
+// Returns true iff match to match_label_. Positions arc iterator at
+// lower bound regardless.
template <class F> inline
-bool SortedMatcher<F>::Find(Label match_label) {
- if (error_) {
- current_loop_ = false;
- match_label_ = kNoLabel;
- return false;
- }
- current_loop_ = match_label == 0;
- match_label_ = match_label == kNoLabel ? 0 : match_label;
+bool SortedMatcher<F>::Search() {
aiter_->SetFlags(
match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
kArcValueFlags);
@@ -321,7 +354,8 @@ bool SortedMatcher<F>::Find(Label match_label) {
return true;
}
}
- return current_loop_;
+ aiter_->Seek(low);
+ return false;
} else {
// Linear search for match.
for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) {
@@ -333,7 +367,7 @@ bool SortedMatcher<F>::Find(Label match_label) {
if (label > match_label_)
break;
}
- return current_loop_;
+ return false;
}
}
@@ -1047,16 +1081,19 @@ class MultiEpsMatcher {
}
}
+ void RemoveMultiEpsLabel(Label label) {
+ if (label == 0) {
+ FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
+ } else {
+ multi_eps_labels_.Erase(label);
+ }
+ }
+
void ClearMultiEpsLabels() {
multi_eps_labels_.Clear();
}
private:
- // Specialized for 'set' - log lookup
- bool IsMultiEps(const set<Label> &multi_eps_labels, Label label) const {
- return multi_eps_labels.Find(label) != multi_eps_labels.end();
- }
-
M *matcher_;
uint32 flags_;
bool own_matcher_; // Does this class delete the matcher?
diff --git a/src/include/fst/queue.h b/src/include/fst/queue.h
index e31f087..95a082d 100644
--- a/src/include/fst/queue.h
+++ b/src/include/fst/queue.h
@@ -23,6 +23,7 @@
#define FST_LIB_QUEUE_H__
#include <deque>
+using std::deque;
#include <vector>
using std::vector;
@@ -791,13 +792,13 @@ struct TrivialStateEquivClass {
};
-// Pruning queue discipline: Enqueues a state 's' only when its
-// shortest distance (so far), as specified by 'distance', is less
-// than (as specified by 'comp') the shortest distance Times() the
-// 'threshold' to any state in the same equivalence class, as
-// specified by the function object 'class_func'. The underlying
-// queue discipline is specified by 'queue'. The ownership of 'queue'
-// is given to this class.
+// Distance-based pruning queue discipline: Enqueues a state 's'
+// only when its shortest distance (so far), as specified by
+// 'distance', is less than (as specified by 'comp') the shortest
+// distance Times() the 'threshold' to any state in the same
+// equivalence class, as specified by the function object
+// 'class_func'. The underlying queue discipline is specified by
+// 'queue'. The ownership of 'queue' is given to this class.
template <typename Q, typename L, typename C>
class PruneQueue : public QueueBase<typename Q::StateId> {
public:
@@ -884,6 +885,54 @@ class NaturalPruneQueue :
};
+// Filter-based pruning queue discipline: Enqueues a state 's' only
+// if allowed by the filter, specified by the function object 'state_filter'.
+// The underlying queue discipline is specified by 'queue'. The ownership
+// of 'queue' is given to this class.
+template <typename Q, typename F>
+class FilterQueue : public QueueBase<typename Q::StateId> {
+ public:
+ typedef typename Q::StateId StateId;
+
+ FilterQueue(Q *queue, const F &state_filter)
+ : QueueBase<StateId>(OTHER_QUEUE),
+ queue_(queue),
+ state_filter_(state_filter) {}
+
+ ~FilterQueue() { delete queue_; }
+
+ StateId Head() const { return queue_->Head(); }
+
+ // Enqueues only if allowed by state filter.
+ void Enqueue(StateId s) {
+ if (state_filter_(s)) {
+ queue_->Enqueue(s);
+ }
+ }
+
+ void Dequeue() { queue_->Dequeue(); }
+
+ void Update(StateId s) {}
+ bool Empty() const { return queue_->Empty(); }
+ void Clear() { queue_->Clear(); }
+
+ private:
+ // This allows base-class virtual access to non-virtual derived-
+ // class members of the same name. It makes the derived class more
+ // efficient to use but unsafe to further derive.
+ virtual StateId Head_() const { return Head(); }
+ virtual void Enqueue_(StateId s) { Enqueue(s); }
+ virtual void Dequeue_() { Dequeue(); }
+ virtual void Update_(StateId s) { Update(s); }
+ virtual bool Empty_() const { return Empty(); }
+ virtual void Clear_() { return Clear(); }
+
+ Q *queue_;
+ const F &state_filter_; // Filter to prune states
+
+ DISALLOW_COPY_AND_ASSIGN(FilterQueue);
+};
+
} // namespace fst
#endif
diff --git a/src/include/fst/relabel.h b/src/include/fst/relabel.h
index 685d42a..dc675b6 100644
--- a/src/include/fst/relabel.h
+++ b/src/include/fst/relabel.h
@@ -34,6 +34,10 @@ using std::vector;
#include <fst/test-properties.h>
+#include <tr1/unordered_map>
+using std::tr1::unordered_map;
+using std::tr1::unordered_multimap;
+
namespace fst {
//
diff --git a/src/include/fst/script/convert.h b/src/include/fst/script/convert.h
index 2c70a70..4a3ce6b 100644
--- a/src/include/fst/script/convert.h
+++ b/src/include/fst/script/convert.h
@@ -34,7 +34,7 @@ void Convert(ConvertArgs *args) {
const string &new_type = args->args.arg2;
Fst<Arc> *result = Convert(fst, new_type);
- args->retval = new FstClass(result);
+ args->retval = new FstClass(*result);
delete result;
}
diff --git a/src/include/fst/script/disambiguate.h b/src/include/fst/script/disambiguate.h
new file mode 100644
index 0000000..e42a9c2
--- /dev/null
+++ b/src/include/fst/script/disambiguate.h
@@ -0,0 +1,68 @@
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Jake Ratkiewicz)
+
+#ifndef FST_SCRIPT_DISAMBIGUATE_H_
+#define FST_SCRIPT_DISAMBIGUATE_H_
+
+#include <fst/disambiguate.h>
+#include <fst/script/arg-packs.h>
+#include <fst/script/fst-class.h>
+#include <fst/script/weight-class.h>
+
+namespace fst {
+namespace script {
+
+struct DisambiguateOptions {
+ float delta;
+ WeightClass weight_threshold;
+ int64 state_threshold;
+ int64 subsequential_label;
+
+ explicit DisambiguateOptions(float d = fst::kDelta,
+ WeightClass w =
+ fst::script::WeightClass::Zero(),
+ int64 n = fst::kNoStateId, int64 l = 0)
+ : delta(d), weight_threshold(w), state_threshold(n),
+ subsequential_label(l) {}
+};
+
+typedef args::Package<const FstClass&, MutableFstClass*,
+ const DisambiguateOptions &> DisambiguateArgs;
+
+template<class Arc>
+void Disambiguate(DisambiguateArgs *args) {
+ const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>());
+ MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>();
+ const DisambiguateOptions &opts = args->arg3;
+
+ fst::DisambiguateOptions<Arc> detargs;
+ detargs.delta = opts.delta;
+ detargs.weight_threshold =
+ *(opts.weight_threshold.GetWeight<typename Arc::Weight>());
+ detargs.state_threshold = opts.state_threshold;
+ detargs.subsequential_label = opts.subsequential_label;
+
+ Disambiguate(ifst, ofst, detargs);
+}
+
+void Disambiguate(const FstClass &ifst, MutableFstClass *ofst,
+ const DisambiguateOptions &opts =
+ fst::script::DisambiguateOptions());
+
+} // namespace script
+} // namespace fst
+
+#endif // FST_SCRIPT_DISAMBIGUATE_H_
diff --git a/src/include/fst/script/fst-class.h b/src/include/fst/script/fst-class.h
index a820c1c..fe2cf53 100644
--- a/src/include/fst/script/fst-class.h
+++ b/src/include/fst/script/fst-class.h
@@ -52,8 +52,8 @@ class FstClassBase {
virtual const string &WeightType() const = 0;
virtual const SymbolTable *InputSymbols() const = 0;
virtual const SymbolTable *OutputSymbols() const = 0;
- virtual void Write(const string& fname) const = 0;
- virtual void Write(ostream &ostr, const FstWriteOptions &opts) const = 0;
+ virtual bool Write(const string& fname) const = 0;
+ virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const = 0;
virtual uint64 Properties(uint64 mask, bool test) const = 0;
virtual ~FstClassBase() { }
};
@@ -82,6 +82,8 @@ class FstClassImpl : public FstClassImplBase {
bool should_own = false) :
impl_(should_own ? impl : impl->Copy()) { }
+ explicit FstClassImpl(const Fst<Arc> &impl) : impl_(impl.Copy()) { }
+
virtual const string &ArcType() const {
return Arc::Type();
}
@@ -112,12 +114,12 @@ class FstClassImpl : public FstClassImplBase {
static_cast<MutableFst<Arc> *>(impl_)->SetOutputSymbols(os);
}
- virtual void Write(const string &fname) const {
- impl_->Write(fname);
+ virtual bool Write(const string &fname) const {
+ return impl_->Write(fname);
}
- virtual void Write(ostream &ostr, const FstWriteOptions &opts) const {
- impl_->Write(ostr, opts);
+ virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const {
+ return impl_->Write(ostr, opts);
}
virtual uint64 Properties(uint64 mask, bool test) const {
@@ -166,10 +168,10 @@ class FstClass : public FstClassBase {
}
template<class Arc>
- explicit FstClass(Fst<Arc> *fst) : impl_(new FstClassImpl<Arc>(fst)) {
+ explicit FstClass(const Fst<Arc> &fst) : impl_(new FstClassImpl<Arc>(fst)) {
}
- explicit FstClass(const FstClass &other) : impl_(other.impl_->Copy()) { }
+ FstClass(const FstClass &other) : impl_(other.impl_->Copy()) { }
FstClass &operator=(const FstClass &other) {
delete impl_;
@@ -201,12 +203,12 @@ class FstClass : public FstClassBase {
return impl_->WeightType();
}
- virtual void Write(const string &fname) const {
- impl_->Write(fname);
+ virtual bool Write(const string &fname) const {
+ return impl_->Write(fname);
}
- virtual void Write(ostream &ostr, const FstWriteOptions &opts) const {
- impl_->Write(ostr, opts);
+ virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const {
+ return impl_->Write(ostr, opts);
}
virtual uint64 Properties(uint64 mask, bool test) const {
@@ -253,7 +255,7 @@ class FstClass : public FstClassBase {
if (!u) {
return 0;
} else {
- FstClassT *r = new FstClassT(u);
+ FstClassT *r = new FstClassT(*u);
delete u;
return r;
}
@@ -276,7 +278,7 @@ class FstClass : public FstClassBase {
class MutableFstClass : public FstClass {
public:
template<class Arc>
- explicit MutableFstClass(MutableFst<Arc> *fst) :
+ explicit MutableFstClass(const MutableFst<Arc> &fst) :
FstClass(fst) { }
template<class Arc>
@@ -294,18 +296,18 @@ class MutableFstClass : public FstClass {
if (!mfst) {
return 0;
} else {
- MutableFstClass *retval = new MutableFstClass(mfst);
+ MutableFstClass *retval = new MutableFstClass(*mfst);
delete mfst;
return retval;
}
}
- virtual void Write(const string &fname) const {
- GetImpl()->Write(fname);
+ virtual bool Write(const string &fname) const {
+ return GetImpl()->Write(fname);
}
- virtual void Write(ostream &ostr, const FstWriteOptions &opts) const {
- GetImpl()->Write(ostr, opts);
+ virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const {
+ return GetImpl()->Write(ostr, opts);
}
static MutableFstClass *Read(const string &fname, bool convert = false);
@@ -344,7 +346,7 @@ class VectorFstClass : public MutableFstClass {
explicit VectorFstClass(const string &arc_type);
template<class Arc>
- explicit VectorFstClass(VectorFst<Arc> *fst) :
+ explicit VectorFstClass(const VectorFst<Arc> &fst) :
MutableFstClass(fst) { }
template<class Arc>
@@ -354,7 +356,7 @@ class VectorFstClass : public MutableFstClass {
if (!vfst) {
return 0;
} else {
- VectorFstClass *retval = new VectorFstClass(vfst);
+ VectorFstClass *retval = new VectorFstClass(*vfst);
delete vfst;
return retval;
}
diff --git a/src/include/fst/script/map.h b/src/include/fst/script/map.h
index 2332074..3caaa9f 100644
--- a/src/include/fst/script/map.h
+++ b/src/include/fst/script/map.h
@@ -59,46 +59,54 @@ void Map(MapArgs *args) {
float delta = args->args.arg3;
typename Arc::Weight w = *(args->args.arg4.GetWeight<typename Arc::Weight>());
+ Fst<Arc> *fst = NULL;
+ Fst<LogArc> *lfst = NULL;
+ Fst<Log64Arc> *l64fst = NULL;
+ Fst<StdArc> *sfst = NULL;
if (map_type == ARC_SUM_MAPPER) {
- args->retval = new FstClass(
- script::StateMap(ifst, ArcSumMapper<Arc>(ifst)));
+ args->retval = new FstClass(*(fst =
+ script::StateMap(ifst, ArcSumMapper<Arc>(ifst))));
} else if (map_type == IDENTITY_MAPPER) {
- args->retval = new FstClass(
- script::ArcMap(ifst, IdentityArcMapper<Arc>()));
+ args->retval = new FstClass(*(fst =
+ script::ArcMap(ifst, IdentityArcMapper<Arc>())));
} else if (map_type == INVERT_MAPPER) {
- args->retval = new FstClass(
- script::ArcMap(ifst, InvertWeightMapper<Arc>()));
+ args->retval = new FstClass(*(fst =
+ script::ArcMap(ifst, InvertWeightMapper<Arc>())));
} else if (map_type == PLUS_MAPPER) {
- args->retval = new FstClass(
- script::ArcMap(ifst, PlusMapper<Arc>(w)));
+ args->retval = new FstClass(*(fst =
+ script::ArcMap(ifst, PlusMapper<Arc>(w))));
} else if (map_type == QUANTIZE_MAPPER) {
- args->retval = new FstClass(
- script::ArcMap(ifst, QuantizeMapper<Arc>(delta)));
+ args->retval = new FstClass(*(fst =
+ script::ArcMap(ifst, QuantizeMapper<Arc>(delta))));
} else if (map_type == RMWEIGHT_MAPPER) {
- args->retval = new FstClass(
- script::ArcMap(ifst, RmWeightMapper<Arc>()));
+ args->retval = new FstClass(*(fst =
+ script::ArcMap(ifst, RmWeightMapper<Arc>())));
} else if (map_type == SUPERFINAL_MAPPER) {
- args->retval = new FstClass(
- script::ArcMap(ifst, SuperFinalMapper<Arc>()));
+ args->retval = new FstClass(*(fst =
+ script::ArcMap(ifst, SuperFinalMapper<Arc>())));
} else if (map_type == TIMES_MAPPER) {
- args->retval = new FstClass(
- script::ArcMap(ifst, TimesMapper<Arc>(w)));
+ args->retval = new FstClass(*(fst =
+ script::ArcMap(ifst, TimesMapper<Arc>(w))));
} else if (map_type == TO_LOG_MAPPER) {
- args->retval = new FstClass(
- script::ArcMap(ifst, WeightConvertMapper<Arc, LogArc>()));
+ args->retval = new FstClass(*(lfst =
+ script::ArcMap(ifst, WeightConvertMapper<Arc, LogArc>())));
} else if (map_type == TO_LOG64_MAPPER) {
- args->retval = new FstClass(
- script::ArcMap(ifst, WeightConvertMapper<Arc, Log64Arc>()));
+ args->retval = new FstClass(*(l64fst =
+ script::ArcMap(ifst, WeightConvertMapper<Arc, Log64Arc>())));
} else if (map_type == TO_STD_MAPPER) {
- args->retval = new FstClass(
- script::ArcMap(ifst, WeightConvertMapper<Arc, StdArc>()));
+ args->retval = new FstClass(*(sfst =
+ script::ArcMap(ifst, WeightConvertMapper<Arc, StdArc>())));
} else {
FSTERROR() << "Error: unknown/unsupported mapper type: "
<< map_type;
VectorFst<Arc> *ofst = new VectorFst<Arc>;
ofst->SetProperties(kError, kError);
- args->retval = new FstClass(ofst);
+ args->retval = new FstClass(*(fst =ofst));
}
+ delete sfst;
+ delete l64fst;
+ delete lfst;
+ delete fst;
}
diff --git a/src/include/fst/script/shortest-distance.h b/src/include/fst/script/shortest-distance.h
index 5fc2976..39c5045 100644
--- a/src/include/fst/script/shortest-distance.h
+++ b/src/include/fst/script/shortest-distance.h
@@ -175,11 +175,11 @@ void ShortestDistance(ShortestDistanceArgs1 *args) {
return;
case FIFO_QUEUE:
- ShortestDistanceHelper<Arc, FifoQueue<StateId> >(args);
+ ShortestDistanceHelper<Arc, FifoQueue<StateId> >(args);
return;
case LIFO_QUEUE:
- ShortestDistanceHelper<Arc, LifoQueue<StateId> >(args);
+ ShortestDistanceHelper<Arc, LifoQueue<StateId> >(args);
return;
case SHORTEST_FIRST_QUEUE:
@@ -188,11 +188,11 @@ void ShortestDistance(ShortestDistanceArgs1 *args) {
return;
case STATE_ORDER_QUEUE:
- ShortestDistanceHelper<Arc, StateOrderQueue<StateId> >(args);
+ ShortestDistanceHelper<Arc, StateOrderQueue<StateId> >(args);
return;
case TOP_ORDER_QUEUE:
- ShortestDistanceHelper<Arc, TopOrderQueue<StateId> >(args);
+ ShortestDistanceHelper<Arc, TopOrderQueue<StateId> >(args);
return;
}
}
diff --git a/src/include/fst/script/weight-class.h b/src/include/fst/script/weight-class.h
index 228216d..b9f7ddf 100644
--- a/src/include/fst/script/weight-class.h
+++ b/src/include/fst/script/weight-class.h
@@ -128,6 +128,13 @@ class WeightClass {
return w;
}
+ const string &Type() const {
+ if (impl_) return impl_->Type();
+ static const string no_type = "none";
+ return no_type;
+ }
+
+
~WeightClass() { if (impl_) delete impl_; }
private:
enum ElementType { ZERO, ONE, OTHER };
diff --git a/src/include/fst/shortest-distance.h b/src/include/fst/shortest-distance.h
index 9320c4c..ec47a14 100644
--- a/src/include/fst/shortest-distance.h
+++ b/src/include/fst/shortest-distance.h
@@ -22,6 +22,7 @@
#define FST_LIB_SHORTEST_DISTANCE_H__
#include <deque>
+using std::deque;
#include <vector>
using std::vector;
diff --git a/src/include/fst/shortest-path.h b/src/include/fst/shortest-path.h
index f12970c..9cd13d9 100644
--- a/src/include/fst/shortest-path.h
+++ b/src/include/fst/shortest-path.h
@@ -458,7 +458,7 @@ void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
} else {
vector<Weight> ddistance;
DeterminizeFstOptions<ReverseArc> dopts(opts.delta);
- DeterminizeFst<ReverseArc> dfst(rfst, *distance, &ddistance, dopts);
+ DeterminizeFst<ReverseArc> dfst(rfst, distance, &ddistance, dopts);
NShortestPath(dfst, ofst, ddistance, n, opts.delta,
opts.weight_threshold, opts.state_threshold);
}
diff --git a/src/include/fst/state-map.h b/src/include/fst/state-map.h
index 2c59e1d..0e65d74 100644
--- a/src/include/fst/state-map.h
+++ b/src/include/fst/state-map.h
@@ -191,8 +191,6 @@ class StateMapFstImpl : public CacheImpl<B> {
using FstImpl<B>::SetInputSymbols;
using FstImpl<B>::SetOutputSymbols;
- using VectorFstBaseImpl<typename CacheImpl<B>::State>::NumStates;
-
using CacheImpl<B>::PushArc;
using CacheImpl<B>::HasArcs;
using CacheImpl<B>::HasFinal;
@@ -535,7 +533,7 @@ class ArcUniqueMapper {
explicit ArcUniqueMapper(const Fst<A> &fst) : fst_(fst), i_(0) {}
// Allows updating Fst argument; pass only if changed.
- ArcUniqueMapper(const ArcSumMapper<A> &mapper,
+ ArcUniqueMapper(const ArcUniqueMapper<A> &mapper,
const Fst<A> *fst = 0)
: fst_(fst ? *fst : mapper.fst_), i_(0) {}
diff --git a/src/include/fst/state-table.h b/src/include/fst/state-table.h
index 7d863a0..d8107a1 100644
--- a/src/include/fst/state-table.h
+++ b/src/include/fst/state-table.h
@@ -22,6 +22,7 @@
#define FST_LIB_STATE_TABLE_H__
#include <deque>
+using std::deque;
#include <vector>
using std::vector;
@@ -58,14 +59,14 @@ namespace fst {
// struct StateTuple {
// typedef typename S StateId;
//
-// // Required constructor.
+// // Required constructors.
// StateTuple();
+// StateTuple(const StateTuple &);
// };
// An implementation using a hash map for the tuple to state ID mapping.
-// The state tuple T must have == defined and the default constructor
-// must produce a tuple that will never be seen. H is the hash function.
+// The state tuple T must have == defined. H is the hash function.
template <class T, class H>
class HashStateTable : public HashBiTable<typename T::StateId, T, H> {
public:
@@ -76,15 +77,18 @@ class HashStateTable : public HashBiTable<typename T::StateId, T, H> {
using HashBiTable<StateId, T, H>::Size;
HashStateTable() : HashBiTable<StateId, T, H>() {}
+
+ // Reserves space for table_size elements.
+ explicit HashStateTable(size_t table_size)
+ : HashBiTable<StateId, T, H>(table_size) {}
+
StateId FindState(const StateTuple &tuple) { return FindId(tuple); }
const StateTuple &Tuple(StateId s) const { return FindEntry(s); }
};
-// An implementation using a hash set for the tuple to state ID
-// mapping. The state tuple T must have == defined and the default
-// constructor must produce a tuple that will never be seen. H is the
-// hash function.
+// An implementation using a hash map for the tuple to state ID mapping.
+// The state tuple T must have == defined. H is the hash function.
template <class T, class H>
class CompactHashStateTable
: public CompactHashBiTable<typename T::StateId, T, H> {
@@ -97,7 +101,7 @@ class CompactHashStateTable
CompactHashStateTable() : CompactHashBiTable<StateId, T, H>() {}
- // Reserves space for table_size elements.
+ // Reserves space for 'table_size' elements.
explicit CompactHashStateTable(size_t table_size)
: CompactHashBiTable<StateId, T, H>(table_size) {}
@@ -122,7 +126,10 @@ class VectorStateTable
using VectorBiTable<StateId, T, FP>::Size;
using VectorBiTable<StateId, T, FP>::Fingerprint;
- explicit VectorStateTable(FP *fp = 0) : VectorBiTable<StateId, T, FP>(fp) {}
+ // Reserves space for 'table_size' elements.
+ explicit VectorStateTable(FP *fp = 0, size_t table_size = 0)
+ : VectorBiTable<StateId, T, FP>(fp, table_size) {}
+
StateId FindState(const StateTuple &tuple) { return FindId(tuple); }
const StateTuple &Tuple(StateId s) const { return FindEntry(s); }
};
@@ -268,7 +275,9 @@ class GenericComposeStateTable : public H {
GenericComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2) {}
- GenericComposeStateTable(const GenericComposeStateTable<A, F> &table) {}
+ // Reserves space for 'table_size' elements.
+ GenericComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2,
+ size_t table_size) : H(table_size) {}
bool Error() const { return false; }
@@ -342,17 +351,18 @@ VectorStateTable<ComposeStateTuple<typename A::StateId, F>,
typedef typename A::StateId StateId;
typedef F FilterState;
typedef ComposeStateTuple<StateId, F> StateTuple;
+ typedef VectorStateTable<StateTuple,
+ ComposeFingerprint<StateId, F> > StateTable;
- ProductComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2)
- : VectorStateTable<ComposeStateTuple<StateId, F>,
- ComposeFingerprint<StateId, F> >
- (new ComposeFingerprint<StateId, F>(CountStates(fst1),
- CountStates(fst2))) { }
+ // Reserves space for 'table_size' elements.
+ ProductComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2,
+ size_t table_size = 0)
+ : StateTable(new ComposeFingerprint<StateId, F>(CountStates(fst1),
+ CountStates(fst2)),
+ table_size) {}
ProductComposeStateTable(const ProductComposeStateTable<A, F> &table)
- : VectorStateTable<ComposeStateTuple<StateId, F>,
- ComposeFingerprint<StateId, F> >
- (new ComposeFingerprint<StateId, F>(table.Fingerprint())) {}
+ : StateTable(new ComposeFingerprint<StateId, F>(table.Fingerprint())) {}
bool Error() const { return false; }
@@ -375,6 +385,8 @@ VectorStateTable<ComposeStateTuple<typename A::StateId, F>,
typedef typename A::StateId StateId;
typedef F FilterState;
typedef ComposeStateTuple<StateId, F> StateTuple;
+ typedef VectorStateTable<StateTuple,
+ ComposeState1Fingerprint<StateId, F> > StateTable;
StringDetComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2)
: error_(false) {
@@ -389,7 +401,7 @@ VectorStateTable<ComposeStateTuple<typename A::StateId, F>,
}
StringDetComposeStateTable(const StringDetComposeStateTable<A, F> &table)
- : error_(table.error_) {}
+ : StateTable(table), error_(table.error_) {}
bool Error() const { return error_; }
@@ -409,12 +421,14 @@ VectorStateTable<ComposeStateTuple<typename A::StateId, F>,
template <typename A, typename F>
class DetStringComposeStateTable : public
VectorStateTable<ComposeStateTuple<typename A::StateId, F>,
- ComposeState1Fingerprint<typename A::StateId, F> > {
+ ComposeState2Fingerprint<typename A::StateId, F> > {
public:
typedef A Arc;
typedef typename A::StateId StateId;
typedef F FilterState;
typedef ComposeStateTuple<StateId, F> StateTuple;
+ typedef VectorStateTable<StateTuple,
+ ComposeState2Fingerprint<StateId, F> > StateTable;
DetStringComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2)
:error_(false) {
@@ -429,7 +443,7 @@ VectorStateTable<ComposeStateTuple<typename A::StateId, F>,
}
DetStringComposeStateTable(const DetStringComposeStateTable<A, F> &table)
- : error_(table.error_) {}
+ : StateTable(table), error_(table.error_) {}
bool Error() const { return error_; }
@@ -456,8 +470,6 @@ ErasableStateTable<ComposeStateTuple<typename A::StateId, F>,
ErasableComposeStateTable(const Fst<A> &fst1, const Fst<A> &fst2) {}
- ErasableComposeStateTable(const ErasableComposeStateTable<A, F> &table) {}
-
bool Error() const { return false; }
private:
diff --git a/src/include/fst/string.h b/src/include/fst/string.h
index d51182e..9eaf7a3 100644
--- a/src/include/fst/string.h
+++ b/src/include/fst/string.h
@@ -57,6 +57,15 @@ class StringCompiler {
return true;
}
+ template <class F>
+ bool operator()(const string &s, F *fst, Weight w) const {
+ vector<Label> labels;
+ if (!ConvertStringToLabels(s, &labels))
+ return false;
+ Compile(labels, fst, w);
+ return true;
+ }
+
private:
bool ConvertStringToLabels(const string &str, vector<Label> *labels) const {
labels->clear();
@@ -83,22 +92,35 @@ class StringCompiler {
return true;
}
- void Compile(const vector<Label> &labels, MutableFst<A> *fst) const {
+ void Compile(const vector<Label> &labels, MutableFst<A> *fst,
+ const Weight &weight = Weight::One()) const {
fst->DeleteStates();
while (fst->NumStates() <= labels.size())
fst->AddState();
for (size_t i = 0; i < labels.size(); ++i)
fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
fst->SetStart(0);
- fst->SetFinal(labels.size(), Weight::One());
+ fst->SetFinal(labels.size(), weight);
}
template <class Unsigned>
- void Compile(const vector<Label> &labels, CompactFst<A, StringCompactor<A>,
- Unsigned> *fst) const {
+ void Compile(const vector<Label> &labels,
+ CompactFst<A, StringCompactor<A>, Unsigned> *fst) const {
fst->SetCompactElements(labels.begin(), labels.end());
}
+ template <class Unsigned>
+ void Compile(const vector<Label> &labels,
+ CompactFst<A, WeightedStringCompactor<A>, Unsigned> *fst,
+ const Weight &weight = Weight::One()) const {
+ vector<pair<Label, Weight> > compacts;
+ compacts.reserve(labels.size());
+ for (size_t i = 0; i < labels.size(); ++i)
+ compacts.push_back(make_pair(labels[i], Weight::One()));
+ compacts.back().second = weight;
+ fst->SetCompactElements(compacts.begin(), compacts.end());
+ }
+
bool ConvertSymbolToLabel(const char *s, Label* output) const {
int64 n;
if (syms_) {
@@ -167,6 +189,7 @@ class StringPrinter {
}
*output = sstrm.str();
} else if (token_type_ == BYTE) {
+ output->reserve(labels_.size());
for (size_t i = 0; i < labels_.size(); ++i) {
output->push_back(labels_[i]);
}
diff --git a/src/include/fst/util.h b/src/include/fst/util.h
index 1f6046b..4eb8fba 100644
--- a/src/include/fst/util.h
+++ b/src/include/fst/util.h
@@ -268,17 +268,17 @@ void WeightToStr(Weight w, string *s) {
s->append(strm.str().data(), strm.str().size());
}
-// Utilities for reading/writing label pairs
+// Utilities for reading/writing integer pairs (typically labels)
// Returns true on success
-template <typename Label>
-bool ReadLabelPairs(const string& filename,
- vector<pair<Label, Label> >* pairs,
+template <typename I>
+bool ReadIntPairs(const string& filename,
+ vector<pair<I, I> >* pairs,
bool allow_negative = false) {
ifstream strm(filename.c_str());
if (!strm) {
- LOG(ERROR) << "ReadLabelPairs: Can't open file: " << filename;
+ LOG(ERROR) << "ReadIntPairs: Can't open file: " << filename;
return false;
}
@@ -291,33 +291,34 @@ bool ReadLabelPairs(const string& filename,
++nline;
vector<char *> col;
SplitToVector(line, "\n\t ", &col, true);
- if (col.size() == 0 || col[0][0] == '\0') // empty line
+ // empty line or comment?
+ if (col.size() == 0 || col[0][0] == '\0' || col[0][0] == '#')
continue;
if (col.size() != 2) {
- LOG(ERROR) << "ReadLabelPairs: Bad number of columns, "
+ LOG(ERROR) << "ReadIntPairs: Bad number of columns, "
<< "file = " << filename << ", line = " << nline;
return false;
}
bool err;
- Label frmlabel = StrToInt64(col[0], filename, nline, allow_negative, &err);
+ I i1 = StrToInt64(col[0], filename, nline, allow_negative, &err);
if (err) return false;
- Label tolabel = StrToInt64(col[1], filename, nline, allow_negative, &err);
+ I i2 = StrToInt64(col[1], filename, nline, allow_negative, &err);
if (err) return false;
- pairs->push_back(make_pair(frmlabel, tolabel));
+ pairs->push_back(make_pair(i1, i2));
}
return true;
}
// Returns true on success
-template <typename Label>
-bool WriteLabelPairs(const string& filename,
- const vector<pair<Label, Label> >& pairs) {
+template <typename I>
+bool WriteIntPairs(const string& filename,
+ const vector<pair<I, I> >& pairs) {
ostream *strm = &cout;
if (!filename.empty()) {
strm = new ofstream(filename.c_str());
if (!*strm) {
- LOG(ERROR) << "WriteLabelPairs: Can't open file: " << filename;
+ LOG(ERROR) << "WriteIntPairs: Can't open file: " << filename;
return false;
}
}
@@ -326,7 +327,7 @@ bool WriteLabelPairs(const string& filename,
*strm << pairs[n].first << "\t" << pairs[n].second << "\n";
if (!*strm) {
- LOG(ERROR) << "WriteLabelPairs: Write failed: "
+ LOG(ERROR) << "WriteIntPairs: Write failed: "
<< (filename.empty() ? "standard output" : filename);
return false;
}
@@ -335,6 +336,21 @@ bool WriteLabelPairs(const string& filename,
return true;
}
+// Utilities for reading/writing label pairs
+
+template <typename Label>
+bool ReadLabelPairs(const string& filename,
+ vector<pair<Label, Label> >* pairs,
+ bool allow_negative = false) {
+ return ReadIntPairs(filename, pairs, allow_negative);
+}
+
+template <typename Label>
+bool WriteLabelPairs(const string& filename,
+ vector<pair<Label, Label> >& pairs) {
+ return WriteIntPairs(filename, pairs);
+}
+
// Utilities for converting a type name to a legal C symbol.
void ConvertToLegalCSymbol(string *s);
@@ -344,8 +360,8 @@ void ConvertToLegalCSymbol(string *s);
// UTILITIES FOR STREAM I/O
//
-bool AlignInput(istream &strm, int align);
-bool AlignOutput(ostream &strm, int align);
+bool AlignInput(istream &strm);
+bool AlignOutput(ostream &strm);
//
// UTILITIES FOR PROTOCOL BUFFER I/O
@@ -380,6 +396,17 @@ public:
max_key_ = key;
}
+ void Erase(Key key) {
+ set_.erase(key);
+ if (set_.empty()) {
+ min_key_ = max_key_ = NoKey;
+ } else if (key == min_key_) {
+ ++min_key_;
+ } else if (key == max_key_) {
+ --max_key_;
+ }
+ }
+
void Clear() {
set_.clear();
min_key_ = max_key_ = NoKey;
@@ -393,10 +420,26 @@ public:
return set_.find(key);
}
+ bool Member(Key key) const {
+ if (min_key_ == NoKey || key < min_key_ || max_key_ < key) {
+ return false; // out of range
+ } else if (min_key_ != NoKey && max_key_ + 1 == min_key_ + set_.size()) {
+ return true; // dense range
+ } else {
+ return set_.find(key) != set_.end();
+ }
+ }
+
const_iterator Begin() const { return set_.begin(); }
const_iterator End() const { return set_.end(); }
+ // All stored keys are greater than or equal to this value.
+ Key LowerBound() const { return min_key_; }
+
+ // All stored keys are less than or equal to this value.
+ Key UpperBound() const { return max_key_; }
+
private:
set<Key> set_;
Key min_key_;
diff --git a/src/include/fst/visit.h b/src/include/fst/visit.h
index a02d86a..5f5059a 100644
--- a/src/include/fst/visit.h
+++ b/src/include/fst/visit.h
@@ -238,18 +238,28 @@ class CopyVisitor {
};
-// Visits input FST up to a state limit following queue order.
+// Visits input FST up to a state limit following queue order. If
+// 'access_only' is true, aborts on visiting first state not
+// accessible from the initial state.
template <class A>
class PartialVisitor {
public:
typedef A Arc;
typedef typename A::StateId StateId;
- explicit PartialVisitor(StateId maxvisit) : maxvisit_(maxvisit) {}
+ explicit PartialVisitor(StateId maxvisit, bool access_only = false)
+ : maxvisit_(maxvisit),
+ access_only_(access_only),
+ start_(kNoStateId) {}
- void InitVisit(const Fst<A> &ifst) { nvisit_ = 0; }
+ void InitVisit(const Fst<A> &ifst) {
+ nvisit_ = 0;
+ start_ = ifst.Start();
+ }
- bool InitState(StateId s, StateId) {
+ bool InitState(StateId s, StateId root) {
+ if (access_only_ && root != start_)
+ return false;
++nvisit_;
return nvisit_ <= maxvisit_;
}
@@ -262,7 +272,10 @@ class PartialVisitor {
private:
StateId maxvisit_;
+ bool access_only_;
StateId nvisit_;
+ StateId start_;
+
};