diff options
author | Przemyslaw Szczepaniak <pszczepaniak@google.com> | 2013-03-04 11:30:34 +0000 |
---|---|---|
committer | Przemyslaw Szczepaniak <pszczepaniak@google.com> | 2013-03-04 11:30:34 +0000 |
commit | 5bf56ba7027cd5f22ff52d0138893f7a585135fb (patch) | |
tree | 19e17fc79b8873e66f211276d4dd169c480cede1 /src/include/fst/extensions/ngram/ngram-fst.h | |
parent | 3da1eb108d36da35333b2d655202791af854996b (diff) | |
parent | 5b6dc79427b8f7eeb6a7ff68034ab8548ce670ea (diff) | |
download | openfst-kitkat-cts-dev.tar.gz |
Merge remote-tracking branch 'goog/ics-ub-google-tts' into jb-mr2-devandroid-sdk-4.4.2_r1.0.1android-sdk-4.4.2_r1android-cts-4.4_r4android-cts-4.4_r1android-4.4w_r1android-4.4_r1.2.0.1android-4.4_r1.2android-4.4_r1.1.0.1android-4.4_r1.1android-4.4_r1.0.1android-4.4_r1android-4.4_r0.9android-4.4_r0.8android-4.4_r0.7android-4.4.4_r2.0.1android-4.4.4_r2android-4.4.4_r1.0.1android-4.4.4_r1android-4.4.3_r1.1.0.1android-4.4.3_r1.1android-4.4.3_r1.0.1android-4.4.3_r1android-4.4.2_r2.0.1android-4.4.2_r2android-4.4.2_r1.0.1android-4.4.2_r1android-4.4.1_r1.0.1android-4.4.1_r1android-4.3_r3.1android-4.3_r3android-4.3_r2.3android-4.3_r2.2android-4.3_r2.1android-4.3_r2android-4.3_r1.1android-4.3_r1android-4.3_r0.9.1android-4.3_r0.9android-4.3.1_r1tools_r22.2kitkat-wearkitkat-releasekitkat-mr2.2-releasekitkat-mr2.1-releasekitkat-mr2-releasekitkat-mr1.1-releasekitkat-mr1-releasekitkat-devkitkat-cts-releasekitkat-cts-devjb-mr2.0.0-releasejb-mr2.0-releasejb-mr2-releasejb-mr2-devidea133-weekly-releaseidea133
Diffstat (limited to 'src/include/fst/extensions/ngram/ngram-fst.h')
-rw-r--r-- | src/include/fst/extensions/ngram/ngram-fst.h | 111 |
1 files changed, 92 insertions, 19 deletions
diff --git a/src/include/fst/extensions/ngram/ngram-fst.h b/src/include/fst/extensions/ngram/ngram-fst.h index eee664a..873ae6a 100644 --- a/src/include/fst/extensions/ngram/ngram-fst.h +++ b/src/include/fst/extensions/ngram/ngram-fst.h @@ -26,6 +26,7 @@ using std::vector; #include <fst/compat.h> #include <fst/fstlib.h> +#include <fst/mapped-file.h> #include <fst/extensions/ngram/bitmap-index.h> // NgramFst implements a n-gram language model based upon the LOUDS data @@ -76,7 +77,7 @@ class NGramFstImpl : public FstImpl<A> { typedef typename A::StateId StateId; typedef typename A::Weight Weight; - NGramFstImpl() : data_(0), owned_(false) { + NGramFstImpl() : data_region_(0), data_(0), owned_(false) { SetType("ngram"); SetInputSymbols(NULL); SetOutputSymbols(NULL); @@ -89,6 +90,7 @@ class NGramFstImpl : public FstImpl<A> { if (owned_) { delete [] data_; } + delete data_region_; } static NGramFstImpl<A>* Read(istream &strm, // NOLINT @@ -104,7 +106,8 @@ class NGramFstImpl : public FstImpl<A> { strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures)); strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final)); size_t size = Storage(num_states, num_futures, num_final); - char* data = new char[size]; + MappedFile *data_region = MappedFile::Allocate(size); + char *data = reinterpret_cast<char *>(data_region->mutable_data()); // Copy num_states, num_futures and num_final back into data. memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states)); memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures), @@ -116,7 +119,7 @@ class NGramFstImpl : public FstImpl<A> { delete impl; return NULL; } - impl->Init(data, true /* owned */); + impl->Init(data, false, data_region); return impl; } @@ -126,7 +129,7 @@ class NGramFstImpl : public FstImpl<A> { hdr.SetStart(Start()); hdr.SetNumStates(num_states_); WriteHeader(strm, opts, kFileVersion, &hdr); - strm.write(data_, Storage(num_states_, num_futures_, num_final_)); + strm.write(data_, StorageSize()); return strm; } @@ -223,11 +226,23 @@ class NGramFstImpl : public FstImpl<A> { // Access to the underlying representation const char* GetData(size_t* data_size) const { - *data_size = Storage(num_states_, num_futures_, num_final_); + *data_size = StorageSize(); return data_; } - void Init(const char* data, bool owned); + void Init(const char* data, bool owned, MappedFile *file = 0); + + const vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const { + SetInstFuture(s, inst); + SetInstContext(inst); + return inst->context_; + } + + size_t StorageSize() const { + return Storage(num_states_, num_futures_, num_final_); + } + + void GetStates(const vector<Label>& context, vector<StateId> *states) const; private: StateId Transition(const vector<Label> &context, Label future) const; @@ -242,6 +257,7 @@ class NGramFstImpl : public FstImpl<A> { // Minimum file format version supported. static const int kMinFileVersion = 4; + MappedFile *data_region_; const char* data_; bool owned_; // True if we own data_ uint64 num_states_, num_futures_, num_final_; @@ -261,7 +277,7 @@ class NGramFstImpl : public FstImpl<A> { template<typename A> NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) - : data_(0), owned_(false) { + : data_region_(0), data_(0), owned_(false) { typedef A Arc; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; @@ -286,12 +302,16 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) // epsilons. StateId unigram = fst.Start(); while (1) { - ArcIterator<Fst<A> > aiter(fst, unigram); - if (aiter.Done()) { - FSTERROR() << "Start state has no arcs"; + if (unigram == kNoStateId) { + FSTERROR() << "Could not identify unigram state."; SetProperties(kError, kError); return; } + ArcIterator<Fst<A> > aiter(fst, unigram); + if (aiter.Done()) { + LOG(WARNING) << "Unigram state " << unigram << " has no arcs."; + break; + } if (aiter.Value().ilabel != 0) break; unigram = aiter.Value().nextstate; } @@ -385,7 +405,8 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) Weight weight; Label label = kNoLabel; const size_t storage = Storage(num_states, num_futures, num_final); - char* data = new char[storage]; + MappedFile *data_region = MappedFile::Allocate(storage); + char *data = reinterpret_cast<char *>(data_region->mutable_data()); memset(data, 0, storage); size_t offset = 0; memcpy(data + offset, reinterpret_cast<char *>(&num_states), @@ -482,14 +503,17 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) return; } - Init(data, true /* owned */); + Init(data, false, data_region); } template<typename A> -inline void NGramFstImpl<A>::Init(const char* data, bool owned) { +inline void NGramFstImpl<A>::Init(const char* data, bool owned, + MappedFile *data_region) { if (owned_) { delete [] data_; } + delete data_region_; + data_region_ = data_region; owned_ = owned; data_ = data; size_t offset = 0; @@ -507,7 +531,7 @@ inline void NGramFstImpl<A>::Init(const char* data, bool owned) { future_ = reinterpret_cast<const uint64*>(data_ + offset); offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits); final_ = reinterpret_cast<const uint64*>(data_ + offset); - offset += BitmapIndex::StorageSize(num_states_ + 1) * sizeof(bits); + offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits); context_words_ = reinterpret_cast<const Label*>(data_ + offset); offset += (num_states_ + 1) * sizeof(*context_words_); future_words_ = reinterpret_cast<const Label*>(data_ + offset); @@ -538,10 +562,10 @@ inline void NGramFstImpl<A>::Init(const char* data, bool owned) { template<typename A> inline typename A::StateId NGramFstImpl<A>::Transition( const vector<Label> &context, Label future) const { - size_t num_children = root_num_children_; const Label *children = root_children_; - const Label *loc = lower_bound(children, children + num_children, future); - if (loc == children + num_children || *loc != future) { + const Label *loc = lower_bound(children, children + root_num_children_, + future); + if (loc == children + root_num_children_ || *loc != future) { return context_index_.Rank1(0); } size_t node = root_first_child_ + loc - children; @@ -551,7 +575,6 @@ inline typename A::StateId NGramFstImpl<A>::Transition( return context_index_.Rank1(node); } size_t last_child = context_index_.Select0(node_rank + 1) - 1; - num_children = last_child - first_child + 1; for (int word = context.size() - 1; word >= 0; --word) { children = context_words_ + context_index_.Rank1(first_child); loc = lower_bound(children, children + last_child - first_child + 1, @@ -569,6 +592,42 @@ inline typename A::StateId NGramFstImpl<A>::Transition( return context_index_.Rank1(node); } +template<typename A> +inline void NGramFstImpl<A>::GetStates( + const vector<Label> &context, + vector<typename A::StateId>* states) const { + states->clear(); + states->push_back(0); + typename vector<Label>::const_reverse_iterator cit = context.rbegin(); + const Label *children = root_children_; + const Label *loc = lower_bound(children, children + root_num_children_, *cit); + if (loc == children + root_num_children_ || *loc != *cit) return; + size_t node = root_first_child_ + loc - children; + states->push_back(context_index_.Rank1(node)); + if (context.size() == 1) return; + size_t node_rank = context_index_.Rank1(node); + size_t first_child = context_index_.Select0(node_rank) + 1; + ++cit; + if (context_index_.Get(first_child) != false) { + size_t last_child = context_index_.Select0(node_rank + 1) - 1; + while (cit != context.rend()) { + children = context_words_ + context_index_.Rank1(first_child); + loc = lower_bound(children, children + last_child - first_child + 1, + *cit); + if (loc == children + last_child - first_child + 1 || *loc != *cit) { + break; + } + ++cit; + node = first_child + loc - children; + states->push_back(context_index_.Rank1(node)); + node_rank = context_index_.Rank1(node); + first_child = context_index_.Select0(node_rank) + 1; + if (context_index_.Get(first_child) == false) break; + last_child = context_index_.Select0(node_rank + 1) - 1; + } + } +} + /*****************************************************************************/ template<class A> class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { @@ -597,7 +656,7 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { // Non-standard constructor to initialize NGramFst directly from data. NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) { - GetImpl()->Init(data, owned); + GetImpl()->Init(data, owned, NULL); } // Get method that gets the data associated with Init(). @@ -605,6 +664,16 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { return GetImpl()->GetData(data_size); } + const vector<Label> GetContext(StateId s) const { + return GetImpl()->GetContext(s, &inst_); + } + + // Consumes as much as possible of context from right to left, returns the + // the states corresponding to the increasingly conditioned input sequence. + void GetStates(const vector<Label>& context, vector<StateId> *state) const { + return GetImpl()->GetStates(context, state); + } + virtual size_t NumArcs(StateId s) const { return GetImpl()->NumArcs(s, &inst_); } @@ -650,6 +719,10 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { return new NGramFstMatcher<A>(*this, match_type); } + size_t StorageSize() const { + return GetImpl()->StorageSize(); + } + private: explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {} |