diff options
Diffstat (limited to 'src/include/fst/extensions')
-rw-r--r-- | src/include/fst/extensions/far/extract.h | 119 | ||||
-rw-r--r-- | src/include/fst/extensions/far/far.h | 3 | ||||
-rw-r--r-- | src/include/fst/extensions/far/farscript.h | 12 | ||||
-rw-r--r-- | src/include/fst/extensions/far/stlist.h | 22 | ||||
-rw-r--r-- | src/include/fst/extensions/ngram/ngram-fst.h | 111 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/compose.h | 486 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/pdt.h | 1 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/pdtscript.h | 4 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/replace.h | 27 |
9 files changed, 657 insertions, 128 deletions
diff --git a/src/include/fst/extensions/far/extract.h b/src/include/fst/extensions/far/extract.h index d6f92ff..95866de 100644 --- a/src/include/fst/extensions/far/extract.h +++ b/src/include/fst/extensions/far/extract.h @@ -32,51 +32,106 @@ using std::vector; namespace fst { template<class Arc> +inline void FarWriteFst(const Fst<Arc>* fst, string key, + string* okey, int* nrep, + const int32 &generate_filenames, int i, + const string &filename_prefix, + const string &filename_suffix) { + if (key == *okey) + ++*nrep; + else + *nrep = 0; + + *okey = key; + + string ofilename; + if (generate_filenames) { + ostringstream tmp; + tmp.width(generate_filenames); + tmp.fill('0'); + tmp << i; + ofilename = tmp.str(); + } else { + if (*nrep > 0) { + ostringstream tmp; + tmp << '.' << nrep; + key.append(tmp.str().data(), tmp.str().size()); + } + ofilename = key; + } + fst->Write(filename_prefix + ofilename + filename_suffix); +} + +template<class Arc> void FarExtract(const vector<string> &ifilenames, const int32 &generate_filenames, - const string &begin_key, - const string &end_key, + const string &keys, + const string &key_separator, + const string &range_delimiter, const string &filename_prefix, const string &filename_suffix) { FarReader<Arc> *far_reader = FarReader<Arc>::Open(ifilenames); if (!far_reader) return; - if (!begin_key.empty()) - far_reader->Find(begin_key); - string okey; int nrep = 0; - for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) { - string key = far_reader->GetKey(); - if (!end_key.empty() && end_key < key) - break; - const Fst<Arc> &fst = far_reader->GetFst(); - - if (key == okey) - ++nrep; - else - nrep = 0; - okey = key; - - string ofilename; - if (generate_filenames) { - ostringstream tmp; - tmp.width(generate_filenames); - tmp.fill('0'); - tmp << i; - ofilename = tmp.str(); - } else { - if (nrep > 0) { - ostringstream tmp; - tmp << '.' << nrep; - key.append(tmp.str().data(), tmp.str().size()); + vector<char *> key_vector; + // User has specified a set of fsts to extract, where some of the "fsts" could + // be ranges. + if (!keys.empty()) { + char *keys_cstr = new char[keys.size()+1]; + strcpy(keys_cstr, keys.c_str()); + SplitToVector(keys_cstr, key_separator.c_str(), &key_vector, true); + int i = 0; + for (int k = 0; k < key_vector.size(); ++k, ++i) { + string key = string(key_vector[k]); + char *key_cstr = new char[key.size()+1]; + strcpy(key_cstr, key.c_str()); + vector<char *> range_vector; + SplitToVector(key_cstr, range_delimiter.c_str(), &range_vector, false); + if (range_vector.size() == 1) { // Not a range + if (!far_reader->Find(key)) { + LOG(ERROR) << "FarExtract: Cannot find key: " << key; + return; + } + const Fst<Arc> &fst = far_reader->GetFst(); + FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i, + filename_prefix, filename_suffix); + } else if (range_vector.size() == 2) { // A legal range + string begin_key = string(range_vector[0]); + string end_key = string(range_vector[1]); + if (begin_key.empty() || end_key.empty()) { + LOG(ERROR) << "FarExtract: Illegal range specification: " << key; + return; + } + if (!far_reader->Find(begin_key)) { + LOG(ERROR) << "FarExtract: Cannot find key: " << begin_key; + return; + } + for ( ; !far_reader->Done(); far_reader->Next(), ++i) { + string ikey = far_reader->GetKey(); + if (end_key < ikey) break; + const Fst<Arc> &fst = far_reader->GetFst(); + FarWriteFst(&fst, ikey, &okey, &nrep, generate_filenames, i, + filename_prefix, filename_suffix); + } + } else { + LOG(ERROR) << "FarExtract: Illegal range specification: " << key; + return; } - ofilename = key; + delete key_cstr; } - fst.Write(filename_prefix + ofilename + filename_suffix); + delete keys_cstr; + return; + } + // Nothing specified: extract everything. + for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) { + string key = far_reader->GetKey(); + const Fst<Arc> &fst = far_reader->GetFst(); + FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i, + filename_prefix, filename_suffix); } - return; } diff --git a/src/include/fst/extensions/far/far.h b/src/include/fst/extensions/far/far.h index acce76e..737f1b8 100644 --- a/src/include/fst/extensions/far/far.h +++ b/src/include/fst/extensions/far/far.h @@ -273,13 +273,10 @@ FarWriter<A> *FarWriter<A>::Create(const string &filename, FarType type) { return STListFarWriter<A>::Create(filename); case FAR_STTABLE: return STTableFarWriter<A>::Create(filename); - break; case FAR_STLIST: return STListFarWriter<A>::Create(filename); - break; case FAR_FST: return FstFarWriter<A>::Create(filename); - break; default: LOG(ERROR) << "FarWriter::Create: unknown far type"; return 0; diff --git a/src/include/fst/extensions/far/farscript.h b/src/include/fst/extensions/far/farscript.h index 3a9c145..cfd9167 100644 --- a/src/include/fst/extensions/far/farscript.h +++ b/src/include/fst/extensions/far/farscript.h @@ -173,18 +173,22 @@ bool FarEqual(const string &filename1, typedef args::Package<const vector<string> &, int32, const string&, const string&, const string&, - const string&> FarExtractArgs; + const string&, const string&> FarExtractArgs; template<class Arc> void FarExtract(FarExtractArgs *args) { fst::FarExtract<Arc>( - args->arg1, args->arg2, args->arg3, args->arg4, args->arg5, args->arg6); + args->arg1, args->arg2, args->arg3, args->arg4, args->arg5, args->arg6, + args->arg7); } void FarExtract(const vector<string> &ifilenames, const string &arc_type, - int32 generate_filenames, const string &begin_key, - const string &end_key, const string &filename_prefix, + int32 generate_filenames, + const string &keys, + const string &key_separator, + const string &range_delimiter, + const string &filename_prefix, const string &filename_suffix); typedef args::Package<const vector<string> &, const string &, diff --git a/src/include/fst/extensions/far/stlist.h b/src/include/fst/extensions/far/stlist.h index 1cdc80c..ff3d98b 100644 --- a/src/include/fst/extensions/far/stlist.h +++ b/src/include/fst/extensions/far/stlist.h @@ -145,13 +145,13 @@ class STListReader { ReadType(*streams_[i], &magic_number); ReadType(*streams_[i], &file_version); if (magic_number != kSTListMagicNumber) { - FSTERROR() << "STListReader::STTableReader: wrong file type: " + FSTERROR() << "STListReader::STListReader: wrong file type: " << filenames[i]; error_ = true; return; } if (file_version != kSTListFileVersion) { - FSTERROR() << "STListReader::STTableReader: wrong file version: " + FSTERROR() << "STListReader::STListReader: wrong file version: " << filenames[i]; error_ = true; return; @@ -161,7 +161,7 @@ class STListReader { if (!key.empty()) heap_.push(make_pair(key, i)); if (!*streams_[i]) { - FSTERROR() << "STTableReader: error reading file: " << sources_[i]; + FSTERROR() << "STListReader: error reading file: " << sources_[i]; error_ = true; return; } @@ -170,7 +170,7 @@ class STListReader { size_t current = heap_.top().second; entry_ = entry_reader_(*streams_[current]); if (!entry_ || !*streams_[current]) { - FSTERROR() << "STTableReader: error reading entry for key: " + FSTERROR() << "STListReader: error reading entry for key: " << heap_.top().first << ", file: " << sources_[current]; error_ = true; } @@ -219,7 +219,7 @@ class STListReader { heap_.pop(); ReadType(*(streams_[current]), &key); if (!*streams_[current]) { - FSTERROR() << "STTableReader: error reading file: " + FSTERROR() << "STListReader: error reading file: " << sources_[current]; error_ = true; return; @@ -233,7 +233,7 @@ class STListReader { delete entry_; entry_ = entry_reader_(*streams_[current]); if (!entry_ || !*streams_[current]) { - FSTERROR() << "STTableReader: error reading entry for key: " + FSTERROR() << "STListReader: error reading entry for key: " << heap_.top().first << ", file: " << sources_[current]; error_ = true; } @@ -267,8 +267,8 @@ class STListReader { // String-type list header reading function template on the entry header // type 'H' having a member function: // Read(istream &strm, const string &filename); -// Checks that 'filename' is an STTable and call the H::Read() on the last -// entry in the STTable. +// Checks that 'filename' is an STList and call the H::Read() on the last +// entry in the STList. // Does not support reading from stdin. template <class H> bool ReadSTListHeader(const string &filename, H *header) { @@ -281,18 +281,18 @@ bool ReadSTListHeader(const string &filename, H *header) { ReadType(strm, &magic_number); ReadType(strm, &file_version); if (magic_number != kSTListMagicNumber) { - LOG(ERROR) << "ReadSTTableHeader: wrong file type: " << filename; + LOG(ERROR) << "ReadSTListHeader: wrong file type: " << filename; return false; } if (file_version != kSTListFileVersion) { - LOG(ERROR) << "ReadSTTableHeader: wrong file version: " << filename; + LOG(ERROR) << "ReadSTListHeader: wrong file version: " << filename; return false; } string key; ReadType(strm, &key); header->Read(strm, filename + ":" + key); if (!strm) { - LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename; + LOG(ERROR) << "ReadSTListHeader: error reading file: " << filename; return false; } return true; diff --git a/src/include/fst/extensions/ngram/ngram-fst.h b/src/include/fst/extensions/ngram/ngram-fst.h index eee664a..873ae6a 100644 --- a/src/include/fst/extensions/ngram/ngram-fst.h +++ b/src/include/fst/extensions/ngram/ngram-fst.h @@ -26,6 +26,7 @@ using std::vector; #include <fst/compat.h> #include <fst/fstlib.h> +#include <fst/mapped-file.h> #include <fst/extensions/ngram/bitmap-index.h> // NgramFst implements a n-gram language model based upon the LOUDS data @@ -76,7 +77,7 @@ class NGramFstImpl : public FstImpl<A> { typedef typename A::StateId StateId; typedef typename A::Weight Weight; - NGramFstImpl() : data_(0), owned_(false) { + NGramFstImpl() : data_region_(0), data_(0), owned_(false) { SetType("ngram"); SetInputSymbols(NULL); SetOutputSymbols(NULL); @@ -89,6 +90,7 @@ class NGramFstImpl : public FstImpl<A> { if (owned_) { delete [] data_; } + delete data_region_; } static NGramFstImpl<A>* Read(istream &strm, // NOLINT @@ -104,7 +106,8 @@ class NGramFstImpl : public FstImpl<A> { strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures)); strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final)); size_t size = Storage(num_states, num_futures, num_final); - char* data = new char[size]; + MappedFile *data_region = MappedFile::Allocate(size); + char *data = reinterpret_cast<char *>(data_region->mutable_data()); // Copy num_states, num_futures and num_final back into data. memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states)); memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures), @@ -116,7 +119,7 @@ class NGramFstImpl : public FstImpl<A> { delete impl; return NULL; } - impl->Init(data, true /* owned */); + impl->Init(data, false, data_region); return impl; } @@ -126,7 +129,7 @@ class NGramFstImpl : public FstImpl<A> { hdr.SetStart(Start()); hdr.SetNumStates(num_states_); WriteHeader(strm, opts, kFileVersion, &hdr); - strm.write(data_, Storage(num_states_, num_futures_, num_final_)); + strm.write(data_, StorageSize()); return strm; } @@ -223,11 +226,23 @@ class NGramFstImpl : public FstImpl<A> { // Access to the underlying representation const char* GetData(size_t* data_size) const { - *data_size = Storage(num_states_, num_futures_, num_final_); + *data_size = StorageSize(); return data_; } - void Init(const char* data, bool owned); + void Init(const char* data, bool owned, MappedFile *file = 0); + + const vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const { + SetInstFuture(s, inst); + SetInstContext(inst); + return inst->context_; + } + + size_t StorageSize() const { + return Storage(num_states_, num_futures_, num_final_); + } + + void GetStates(const vector<Label>& context, vector<StateId> *states) const; private: StateId Transition(const vector<Label> &context, Label future) const; @@ -242,6 +257,7 @@ class NGramFstImpl : public FstImpl<A> { // Minimum file format version supported. static const int kMinFileVersion = 4; + MappedFile *data_region_; const char* data_; bool owned_; // True if we own data_ uint64 num_states_, num_futures_, num_final_; @@ -261,7 +277,7 @@ class NGramFstImpl : public FstImpl<A> { template<typename A> NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) - : data_(0), owned_(false) { + : data_region_(0), data_(0), owned_(false) { typedef A Arc; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; @@ -286,12 +302,16 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) // epsilons. StateId unigram = fst.Start(); while (1) { - ArcIterator<Fst<A> > aiter(fst, unigram); - if (aiter.Done()) { - FSTERROR() << "Start state has no arcs"; + if (unigram == kNoStateId) { + FSTERROR() << "Could not identify unigram state."; SetProperties(kError, kError); return; } + ArcIterator<Fst<A> > aiter(fst, unigram); + if (aiter.Done()) { + LOG(WARNING) << "Unigram state " << unigram << " has no arcs."; + break; + } if (aiter.Value().ilabel != 0) break; unigram = aiter.Value().nextstate; } @@ -385,7 +405,8 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) Weight weight; Label label = kNoLabel; const size_t storage = Storage(num_states, num_futures, num_final); - char* data = new char[storage]; + MappedFile *data_region = MappedFile::Allocate(storage); + char *data = reinterpret_cast<char *>(data_region->mutable_data()); memset(data, 0, storage); size_t offset = 0; memcpy(data + offset, reinterpret_cast<char *>(&num_states), @@ -482,14 +503,17 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) return; } - Init(data, true /* owned */); + Init(data, false, data_region); } template<typename A> -inline void NGramFstImpl<A>::Init(const char* data, bool owned) { +inline void NGramFstImpl<A>::Init(const char* data, bool owned, + MappedFile *data_region) { if (owned_) { delete [] data_; } + delete data_region_; + data_region_ = data_region; owned_ = owned; data_ = data; size_t offset = 0; @@ -507,7 +531,7 @@ inline void NGramFstImpl<A>::Init(const char* data, bool owned) { future_ = reinterpret_cast<const uint64*>(data_ + offset); offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits); final_ = reinterpret_cast<const uint64*>(data_ + offset); - offset += BitmapIndex::StorageSize(num_states_ + 1) * sizeof(bits); + offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits); context_words_ = reinterpret_cast<const Label*>(data_ + offset); offset += (num_states_ + 1) * sizeof(*context_words_); future_words_ = reinterpret_cast<const Label*>(data_ + offset); @@ -538,10 +562,10 @@ inline void NGramFstImpl<A>::Init(const char* data, bool owned) { template<typename A> inline typename A::StateId NGramFstImpl<A>::Transition( const vector<Label> &context, Label future) const { - size_t num_children = root_num_children_; const Label *children = root_children_; - const Label *loc = lower_bound(children, children + num_children, future); - if (loc == children + num_children || *loc != future) { + const Label *loc = lower_bound(children, children + root_num_children_, + future); + if (loc == children + root_num_children_ || *loc != future) { return context_index_.Rank1(0); } size_t node = root_first_child_ + loc - children; @@ -551,7 +575,6 @@ inline typename A::StateId NGramFstImpl<A>::Transition( return context_index_.Rank1(node); } size_t last_child = context_index_.Select0(node_rank + 1) - 1; - num_children = last_child - first_child + 1; for (int word = context.size() - 1; word >= 0; --word) { children = context_words_ + context_index_.Rank1(first_child); loc = lower_bound(children, children + last_child - first_child + 1, @@ -569,6 +592,42 @@ inline typename A::StateId NGramFstImpl<A>::Transition( return context_index_.Rank1(node); } +template<typename A> +inline void NGramFstImpl<A>::GetStates( + const vector<Label> &context, + vector<typename A::StateId>* states) const { + states->clear(); + states->push_back(0); + typename vector<Label>::const_reverse_iterator cit = context.rbegin(); + const Label *children = root_children_; + const Label *loc = lower_bound(children, children + root_num_children_, *cit); + if (loc == children + root_num_children_ || *loc != *cit) return; + size_t node = root_first_child_ + loc - children; + states->push_back(context_index_.Rank1(node)); + if (context.size() == 1) return; + size_t node_rank = context_index_.Rank1(node); + size_t first_child = context_index_.Select0(node_rank) + 1; + ++cit; + if (context_index_.Get(first_child) != false) { + size_t last_child = context_index_.Select0(node_rank + 1) - 1; + while (cit != context.rend()) { + children = context_words_ + context_index_.Rank1(first_child); + loc = lower_bound(children, children + last_child - first_child + 1, + *cit); + if (loc == children + last_child - first_child + 1 || *loc != *cit) { + break; + } + ++cit; + node = first_child + loc - children; + states->push_back(context_index_.Rank1(node)); + node_rank = context_index_.Rank1(node); + first_child = context_index_.Select0(node_rank) + 1; + if (context_index_.Get(first_child) == false) break; + last_child = context_index_.Select0(node_rank + 1) - 1; + } + } +} + /*****************************************************************************/ template<class A> class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { @@ -597,7 +656,7 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { // Non-standard constructor to initialize NGramFst directly from data. NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) { - GetImpl()->Init(data, owned); + GetImpl()->Init(data, owned, NULL); } // Get method that gets the data associated with Init(). @@ -605,6 +664,16 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { return GetImpl()->GetData(data_size); } + const vector<Label> GetContext(StateId s) const { + return GetImpl()->GetContext(s, &inst_); + } + + // Consumes as much as possible of context from right to left, returns the + // the states corresponding to the increasingly conditioned input sequence. + void GetStates(const vector<Label>& context, vector<StateId> *state) const { + return GetImpl()->GetStates(context, state); + } + virtual size_t NumArcs(StateId s) const { return GetImpl()->NumArcs(s, &inst_); } @@ -650,6 +719,10 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { return new NGramFstMatcher<A>(*this, match_type); } + size_t StorageSize() const { + return GetImpl()->StorageSize(); + } + private: explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {} diff --git a/src/include/fst/extensions/pdt/compose.h b/src/include/fst/extensions/pdt/compose.h index 364d76f..c856c6d 100644 --- a/src/include/fst/extensions/pdt/compose.h +++ b/src/include/fst/extensions/pdt/compose.h @@ -21,82 +21,469 @@ #ifndef FST_EXTENSIONS_PDT_COMPOSE_H__ #define FST_EXTENSIONS_PDT_COMPOSE_H__ +#include <list> + +#include <fst/extensions/pdt/pdt.h> #include <fst/compose.h> namespace fst { +// Return paren arcs for Find(kNoLabel). +const uint32 kParenList = 0x00000001; + +// Return a kNolabel loop for Find(paren). +const uint32 kParenLoop = 0x00000002; + +// This class is a matcher that treats parens as multi-epsilon labels. +// It is most efficient if the parens are in a range non-overlapping with +// the non-paren labels. +template <class F> +class ParenMatcher { + public: + typedef SortedMatcher<F> M; + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + ParenMatcher(const FST &fst, MatchType match_type, + uint32 flags = (kParenLoop | kParenList)) + : matcher_(fst, match_type), + match_type_(match_type), + flags_(flags) { + if (match_type == MATCH_INPUT) { + loop_.ilabel = kNoLabel; + loop_.olabel = 0; + } else { + loop_.ilabel = 0; + loop_.olabel = kNoLabel; + } + loop_.weight = Weight::One(); + loop_.nextstate = kNoStateId; + } + + ParenMatcher(const ParenMatcher<F> &matcher, bool safe = false) + : matcher_(matcher.matcher_, safe), + match_type_(matcher.match_type_), + flags_(matcher.flags_), + open_parens_(matcher.open_parens_), + close_parens_(matcher.close_parens_), + loop_(matcher.loop_) { + loop_.nextstate = kNoStateId; + } + + ParenMatcher<F> *Copy(bool safe = false) const { + return new ParenMatcher<F>(*this, safe); + } + + MatchType Type(bool test) const { return matcher_.Type(test); } + + void SetState(StateId s) { + matcher_.SetState(s); + loop_.nextstate = s; + } + + bool Find(Label match_label); + + bool Done() const { + return done_; + } + + const Arc& Value() const { + return paren_loop_ ? loop_ : matcher_.Value(); + } + + void Next(); + + const FST &GetFst() const { return matcher_.GetFst(); } + + uint64 Properties(uint64 props) const { return matcher_.Properties(props); } + + uint32 Flags() const { return matcher_.Flags(); } + + void AddOpenParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad open paren label: 0"; + } else { + open_parens_.Insert(label); + } + } + + void AddCloseParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad close paren label: 0"; + } else { + close_parens_.Insert(label); + } + } + + void RemoveOpenParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad open paren label: 0"; + } else { + open_parens_.Erase(label); + } + } + + void RemoveCloseParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad close paren label: 0"; + } else { + close_parens_.Erase(label); + } + } + + void ClearOpenParens() { + open_parens_.Clear(); + } + + void ClearCloseParens() { + close_parens_.Clear(); + } + + bool IsOpenParen(Label label) const { + return open_parens_.Member(label); + } + + bool IsCloseParen(Label label) const { + return close_parens_.Member(label); + } + + private: + // Advances matcher to next open paren if it exists, returning true. + // O.w. returns false. + bool NextOpenParen(); + + // Advances matcher to next open paren if it exists, returning true. + // O.w. returns false. + bool NextCloseParen(); + + M matcher_; + MatchType match_type_; // Type of match to perform + uint32 flags_; + + // open paren label set + CompactSet<Label, kNoLabel> open_parens_; + + // close paren label set + CompactSet<Label, kNoLabel> close_parens_; + + + bool open_paren_list_; // Matching open paren list + bool close_paren_list_; // Matching close paren list + bool paren_loop_; // Current arc is the implicit paren loop + mutable Arc loop_; // For non-consuming symbols + bool done_; // Matching done + + void operator=(const ParenMatcher<F> &); // Disallow +}; + +template <class M> inline +bool ParenMatcher<M>::Find(Label match_label) { + open_paren_list_ = false; + close_paren_list_ = false; + paren_loop_ = false; + done_ = false; + + // Returns all parenthesis arcs + if (match_label == kNoLabel && (flags_ & kParenList)) { + if (open_parens_.LowerBound() != kNoLabel) { + matcher_.LowerBound(open_parens_.LowerBound()); + open_paren_list_ = NextOpenParen(); + if (open_paren_list_) return true; + } + if (close_parens_.LowerBound() != kNoLabel) { + matcher_.LowerBound(close_parens_.LowerBound()); + close_paren_list_ = NextCloseParen(); + if (close_paren_list_) return true; + } + } + + // Returns 'implicit' paren loop + if (match_label > 0 && (flags_ & kParenLoop) && + (IsOpenParen(match_label) || IsCloseParen(match_label))) { + paren_loop_ = true; + return true; + } + + // Returns all other labels + if (matcher_.Find(match_label)) + return true; + + done_ = true; + return false; +} + +template <class F> inline +void ParenMatcher<F>::Next() { + if (paren_loop_) { + paren_loop_ = false; + done_ = true; + } else if (open_paren_list_) { + matcher_.Next(); + open_paren_list_ = NextOpenParen(); + if (open_paren_list_) return; + + if (close_parens_.LowerBound() != kNoLabel) { + matcher_.LowerBound(close_parens_.LowerBound()); + close_paren_list_ = NextCloseParen(); + if (close_paren_list_) return; + } + done_ = !matcher_.Find(kNoLabel); + } else if (close_paren_list_) { + matcher_.Next(); + close_paren_list_ = NextCloseParen(); + if (close_paren_list_) return; + done_ = !matcher_.Find(kNoLabel); + } else { + matcher_.Next(); + done_ = matcher_.Done(); + } +} + +// Advances matcher to next open paren if it exists, returning true. +// O.w. returns false. +template <class F> inline +bool ParenMatcher<F>::NextOpenParen() { + for (; !matcher_.Done(); matcher_.Next()) { + Label label = match_type_ == MATCH_INPUT ? + matcher_.Value().ilabel : matcher_.Value().olabel; + if (label > open_parens_.UpperBound()) + return false; + if (IsOpenParen(label)) + return true; + } + return false; +} + +// Advances matcher to next close paren if it exists, returning true. +// O.w. returns false. +template <class F> inline +bool ParenMatcher<F>::NextCloseParen() { + for (; !matcher_.Done(); matcher_.Next()) { + Label label = match_type_ == MATCH_INPUT ? + matcher_.Value().ilabel : matcher_.Value().olabel; + if (label > close_parens_.UpperBound()) + return false; + if (IsCloseParen(label)) + return true; + } + return false; +} + + +template <class F> +class ParenFilter { + public: + typedef typename F::FST1 FST1; + typedef typename F::FST2 FST2; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef typename F::Matcher1 Matcher1; + typedef typename F::Matcher2 Matcher2; + typedef typename F::FilterState FilterState1; + typedef StateId StackId; + typedef PdtStack<StackId, Label> ParenStack; + typedef IntegerFilterState<StackId> FilterState2; + typedef PairFilterState<FilterState1, FilterState2> FilterState; + typedef ParenFilter<F> Filter; + + ParenFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = 0, Matcher2 *matcher2 = 0, + const vector<pair<Label, Label> > *parens = 0, + bool expand = false, bool keep_parens = true) + : filter_(fst1, fst2, matcher1, matcher2), + parens_(parens ? *parens : vector<pair<Label, Label> >()), + expand_(expand), + keep_parens_(keep_parens), + f_(FilterState::NoState()), + stack_(parens_), + paren_id_(-1) { + if (parens) { + for (size_t i = 0; i < parens->size(); ++i) { + const pair<Label, Label> &p = (*parens)[i]; + parens_.push_back(p); + GetMatcher1()->AddOpenParen(p.first); + GetMatcher2()->AddOpenParen(p.first); + if (!expand_) { + GetMatcher1()->AddCloseParen(p.second); + GetMatcher2()->AddCloseParen(p.second); + } + } + } + } + + ParenFilter(const Filter &filter, bool safe = false) + : filter_(filter.filter_, safe), + parens_(filter.parens_), + expand_(filter.expand_), + keep_parens_(filter.keep_parens_), + f_(FilterState::NoState()), + stack_(filter.parens_), + paren_id_(-1) { } + + FilterState Start() const { + return FilterState(filter_.Start(), FilterState2(0)); + } + + void SetState(StateId s1, StateId s2, const FilterState &f) { + f_ = f; + filter_.SetState(s1, s2, f_.GetState1()); + if (!expand_) + return; + + ssize_t paren_id = stack_.Top(f.GetState2().GetState()); + if (paren_id != paren_id_) { + if (paren_id_ != -1) { + GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second); + GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second); + } + paren_id_ = paren_id; + if (paren_id_ != -1) { + GetMatcher1()->AddCloseParen(parens_[paren_id_].second); + GetMatcher2()->AddCloseParen(parens_[paren_id_].second); + } + } + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + FilterState1 f1 = filter_.FilterArc(arc1, arc2); + const FilterState2 &f2 = f_.GetState2(); + if (f1 == FilterState1::NoState()) + return FilterState::NoState(); + + if (arc1->olabel == kNoLabel && arc2->ilabel) { // arc2 parentheses + if (keep_parens_) { + arc1->ilabel = arc2->ilabel; + } else if (arc2->ilabel) { + arc2->olabel = arc1->ilabel; + } + return FilterParen(arc2->ilabel, f1, f2); + } else if (arc2->ilabel == kNoLabel && arc1->olabel) { // arc1 parentheses + if (keep_parens_) { + arc2->olabel = arc1->olabel; + } else { + arc1->ilabel = arc2->olabel; + } + return FilterParen(arc1->olabel, f1, f2); + } else { + return FilterState(f1, f2); + } + } + + void FilterFinal(Weight *w1, Weight *w2) const { + if (f_.GetState2().GetState() != 0) + *w1 = Weight::Zero(); + filter_.FilterFinal(w1, w2); + } + + // Return resp matchers. Ownership stays with filter. + Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); } + Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); } + + uint64 Properties(uint64 iprops) const { + uint64 oprops = filter_.Properties(iprops); + return oprops & kILabelInvariantProperties & kOLabelInvariantProperties; + } + + private: + const FilterState FilterParen(Label label, const FilterState1 &f1, + const FilterState2 &f2) const { + if (!expand_) + return FilterState(f1, f2); + + StackId stack_id = stack_.Find(f2.GetState(), label); + if (stack_id < 0) { + return FilterState::NoState(); + } else { + return FilterState(f1, FilterState2(stack_id)); + } + } + + F filter_; + vector<pair<Label, Label> > parens_; + bool expand_; // Expands to FST + bool keep_parens_; // Retains parentheses in output + FilterState f_; // Current filter state + mutable ParenStack stack_; + ssize_t paren_id_; +}; + // Class to setup composition options for PDT composition. // Default is for the PDT as the first composition argument. template <class Arc, bool left_pdt = true> -class PdtComposeOptions : public +class PdtComposeFstOptions : public ComposeFstOptions<Arc, - MultiEpsMatcher< Matcher<Fst<Arc> > >, - MultiEpsFilter<AltSequenceComposeFilter< - MultiEpsMatcher< - Matcher<Fst<Arc> > > > > > { + ParenMatcher< Fst<Arc> >, + ParenFilter<AltSequenceComposeFilter< + ParenMatcher< Fst<Arc> > > > > { public: typedef typename Arc::Label Label; - typedef MultiEpsMatcher< Matcher<Fst<Arc> > > PdtMatcher; - typedef MultiEpsFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter; + typedef ParenMatcher< Fst<Arc> > PdtMatcher; + typedef ParenFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter; typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions; using COptions::matcher1; using COptions::matcher2; using COptions::filter; - PdtComposeOptions(const Fst<Arc> &ifst1, + PdtComposeFstOptions(const Fst<Arc> &ifst1, const vector<pair<Label, Label> > &parens, - const Fst<Arc> &ifst2) { - matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsList); - matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsLoop); - - // Treat parens as multi-epsilons when composing. - for (size_t i = 0; i < parens.size(); ++i) { - matcher1->AddMultiEpsLabel(parens[i].first); - matcher1->AddMultiEpsLabel(parens[i].second); - matcher2->AddMultiEpsLabel(parens[i].first); - matcher2->AddMultiEpsLabel(parens[i].second); - } + const Fst<Arc> &ifst2, bool expand = false, + bool keep_parens = true) { + matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenList); + matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenLoop); - filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true); + filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, + expand, keep_parens); } }; // Class to setup composition options for PDT with FST composition. // Specialization is for the FST as the first composition argument. template <class Arc> -class PdtComposeOptions<Arc, false> : public +class PdtComposeFstOptions<Arc, false> : public ComposeFstOptions<Arc, - MultiEpsMatcher< Matcher<Fst<Arc> > >, - MultiEpsFilter<SequenceComposeFilter< - MultiEpsMatcher< - Matcher<Fst<Arc> > > > > > { + ParenMatcher< Fst<Arc> >, + ParenFilter<SequenceComposeFilter< + ParenMatcher< Fst<Arc> > > > > { public: typedef typename Arc::Label Label; - typedef MultiEpsMatcher< Matcher<Fst<Arc> > > PdtMatcher; - typedef MultiEpsFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter; + typedef ParenMatcher< Fst<Arc> > PdtMatcher; + typedef ParenFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter; typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions; using COptions::matcher1; using COptions::matcher2; using COptions::filter; - PdtComposeOptions(const Fst<Arc> &ifst1, - const Fst<Arc> &ifst2, - const vector<pair<Label, Label> > &parens) { - matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsLoop); - matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsList); - - // Treat parens as multi-epsilons when composing. - for (size_t i = 0; i < parens.size(); ++i) { - matcher1->AddMultiEpsLabel(parens[i].first); - matcher1->AddMultiEpsLabel(parens[i].second); - matcher2->AddMultiEpsLabel(parens[i].first); - matcher2->AddMultiEpsLabel(parens[i].second); - } + PdtComposeFstOptions(const Fst<Arc> &ifst1, + const Fst<Arc> &ifst2, + const vector<pair<Label, Label> > &parens, + bool expand = false, bool keep_parens = true) { + matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenLoop); + matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenList); - filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true); + filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, + expand, keep_parens); } }; +enum PdtComposeFilter { + PAREN_FILTER, // Bar-Hillel construction; keeps parentheses + EXPAND_FILTER, // Bar-Hillel + expansion; removes parentheses + EXPAND_PAREN_FILTER, // Bar-Hillel + expansion; keeps parentheses +}; + +struct PdtComposeOptions { + bool connect; // Connect output + PdtComposeFilter filter_type; // Which pre-defined filter to use + + explicit PdtComposeOptions(bool c, PdtComposeFilter ft = PAREN_FILTER) + : connect(c), filter_type(ft) {} + PdtComposeOptions() : connect(true), filter_type(PAREN_FILTER) {} +}; // Composes pushdown transducer (PDT) encoded as an FST (1st arg) and // an FST (2nd arg) with the result also a PDT encoded as an Fst. (3rd arg). @@ -110,16 +497,17 @@ void Compose(const Fst<Arc> &ifst1, typename Arc::Label> > &parens, const Fst<Arc> &ifst2, MutableFst<Arc> *ofst, - const ComposeOptions &opts = ComposeOptions()) { - - PdtComposeOptions<Arc, true> copts(ifst1, parens, ifst2); + const PdtComposeOptions &opts = PdtComposeOptions()) { + bool expand = opts.filter_type != PAREN_FILTER; + bool keep_parens = opts.filter_type != EXPAND_FILTER; + PdtComposeFstOptions<Arc, true> copts(ifst1, parens, ifst2, + expand, keep_parens); copts.gc_limit = 0; *ofst = ComposeFst<Arc>(ifst1, ifst2, copts); if (opts.connect) Connect(ofst); } - // Composes an FST (1st arg) and pushdown transducer (PDT) encoded as // an FST (2nd arg) with the result also a PDT encoded as an Fst (3rd arg). // In the PDTs, some transitions are labeled with open or close @@ -132,9 +520,11 @@ void Compose(const Fst<Arc> &ifst1, const vector<pair<typename Arc::Label, typename Arc::Label> > &parens, MutableFst<Arc> *ofst, - const ComposeOptions &opts = ComposeOptions()) { - - PdtComposeOptions<Arc, false> copts(ifst1, ifst2, parens); + const PdtComposeOptions &opts = PdtComposeOptions()) { + bool expand = opts.filter_type != PAREN_FILTER; + bool keep_parens = opts.filter_type != EXPAND_FILTER; + PdtComposeFstOptions<Arc, false> copts(ifst1, ifst2, parens, + expand, keep_parens); copts.gc_limit = 0; *ofst = ComposeFst<Arc>(ifst1, ifst2, copts); if (opts.connect) diff --git a/src/include/fst/extensions/pdt/pdt.h b/src/include/fst/extensions/pdt/pdt.h index 6649f55..c56afbd 100644 --- a/src/include/fst/extensions/pdt/pdt.h +++ b/src/include/fst/extensions/pdt/pdt.h @@ -27,6 +27,7 @@ using std::tr1::unordered_multimap; #include <map> #include <set> +#include <fst/compat.h> #include <fst/state-table.h> #include <fst/fst.h> diff --git a/src/include/fst/extensions/pdt/pdtscript.h b/src/include/fst/extensions/pdt/pdtscript.h index c2a1cf4..84bb27e 100644 --- a/src/include/fst/extensions/pdt/pdtscript.h +++ b/src/include/fst/extensions/pdt/pdtscript.h @@ -48,7 +48,7 @@ typedef args::Package<const FstClass &, const FstClass &, const vector<pair<int64, int64> >&, MutableFstClass *, - const ComposeOptions &, + const PdtComposeOptions &, bool> PdtComposeArgs; template<class Arc> @@ -76,7 +76,7 @@ void PdtCompose(const FstClass & ifst1, const FstClass & ifst2, const vector<pair<int64, int64> > &parens, MutableFstClass *ofst, - const ComposeOptions &copts, + const PdtComposeOptions &copts, bool left_pdt); // PDT EXPAND diff --git a/src/include/fst/extensions/pdt/replace.h b/src/include/fst/extensions/pdt/replace.h index a85d0fe..9081400 100644 --- a/src/include/fst/extensions/pdt/replace.h +++ b/src/include/fst/extensions/pdt/replace.h @@ -21,6 +21,10 @@ #ifndef FST_EXTENSIONS_PDT_REPLACE_H__ #define FST_EXTENSIONS_PDT_REPLACE_H__ +#include <tr1/unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; + #include <fst/replace.h> namespace fst { @@ -62,11 +66,14 @@ void Replace(const vector<pair<typename Arc::Label, label2id[ifst_array[i].first] = i; Label max_label = kNoLabel; + size_t max_non_term_count = 0; - deque<size_t> non_term_queue; // Queue of non-terminals to replace - unordered_set<Label> non_term_set; // Set of non-terminals to replace + // Queue of non-terminals to replace + deque<size_t> non_term_queue; + // Map of non-terminals to replace to count + unordered_map<Label, size_t> non_term_map; non_term_queue.push_back(root); - non_term_set.insert(root); + non_term_map[root] = 1;; // PDT state corr. to ith replace FST start state. vector<StateId> fst_start(ifst_array.size(), kNoLabel); @@ -107,10 +114,11 @@ void Replace(const vector<pair<typename Arc::Label, size_t nfst_id = it->second; if (ifst_array[nfst_id].second->Start() == -1) continue; - if (non_term_set.count(arc.olabel) == 0) { + size_t count = non_term_map[arc.olabel]++; + if (count == 0) non_term_queue.push_back(arc.olabel); - non_term_set.insert(arc.olabel); - } + if (count > max_non_term_count) + max_non_term_count = count; } arc.nextstate += soff; ofst->AddArc(os, arc); @@ -134,7 +142,8 @@ void Replace(const vector<pair<typename Arc::Label, // # of parenthesis pairs per fst. vector<size_t> nparens(ifst_array.size(), 0); // Initial open parenthesis label - Label first_paren = max_label + 1; + Label first_open_paren = max_label + 1; + Label first_close_paren = max_label + max_non_term_count + 1; for (StateIterator< Fst<Arc> > siter(*ofst); !siter.Done(); siter.Next()) { @@ -158,8 +167,8 @@ void Replace(const vector<pair<typename Arc::Label, close_paren = (*parens)[paren_id].second; } else { size_t paren_id = nparens[nfst_id]++; - open_paren = first_paren + 2 * paren_id; - close_paren = open_paren + 1; + open_paren = first_open_paren + paren_id; + close_paren = first_close_paren + paren_id; paren_map[paren_key] = paren_id; if (paren_id >= parens->size()) parens->push_back(make_pair(open_paren, close_paren)); |