aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/extensions
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/extensions')
-rw-r--r--src/include/fst/extensions/far/extract.h119
-rw-r--r--src/include/fst/extensions/far/far.h3
-rw-r--r--src/include/fst/extensions/far/farscript.h12
-rw-r--r--src/include/fst/extensions/far/stlist.h22
-rw-r--r--src/include/fst/extensions/ngram/ngram-fst.h111
-rw-r--r--src/include/fst/extensions/pdt/compose.h486
-rw-r--r--src/include/fst/extensions/pdt/pdt.h1
-rw-r--r--src/include/fst/extensions/pdt/pdtscript.h4
-rw-r--r--src/include/fst/extensions/pdt/replace.h27
9 files changed, 657 insertions, 128 deletions
diff --git a/src/include/fst/extensions/far/extract.h b/src/include/fst/extensions/far/extract.h
index d6f92ff..95866de 100644
--- a/src/include/fst/extensions/far/extract.h
+++ b/src/include/fst/extensions/far/extract.h
@@ -32,51 +32,106 @@ using std::vector;
namespace fst {
template<class Arc>
+inline void FarWriteFst(const Fst<Arc>* fst, string key,
+ string* okey, int* nrep,
+ const int32 &generate_filenames, int i,
+ const string &filename_prefix,
+ const string &filename_suffix) {
+ if (key == *okey)
+ ++*nrep;
+ else
+ *nrep = 0;
+
+ *okey = key;
+
+ string ofilename;
+ if (generate_filenames) {
+ ostringstream tmp;
+ tmp.width(generate_filenames);
+ tmp.fill('0');
+ tmp << i;
+ ofilename = tmp.str();
+ } else {
+ if (*nrep > 0) {
+ ostringstream tmp;
+ tmp << '.' << nrep;
+ key.append(tmp.str().data(), tmp.str().size());
+ }
+ ofilename = key;
+ }
+ fst->Write(filename_prefix + ofilename + filename_suffix);
+}
+
+template<class Arc>
void FarExtract(const vector<string> &ifilenames,
const int32 &generate_filenames,
- const string &begin_key,
- const string &end_key,
+ const string &keys,
+ const string &key_separator,
+ const string &range_delimiter,
const string &filename_prefix,
const string &filename_suffix) {
FarReader<Arc> *far_reader = FarReader<Arc>::Open(ifilenames);
if (!far_reader) return;
- if (!begin_key.empty())
- far_reader->Find(begin_key);
-
string okey;
int nrep = 0;
- for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) {
- string key = far_reader->GetKey();
- if (!end_key.empty() && end_key < key)
- break;
- const Fst<Arc> &fst = far_reader->GetFst();
-
- if (key == okey)
- ++nrep;
- else
- nrep = 0;
- okey = key;
-
- string ofilename;
- if (generate_filenames) {
- ostringstream tmp;
- tmp.width(generate_filenames);
- tmp.fill('0');
- tmp << i;
- ofilename = tmp.str();
- } else {
- if (nrep > 0) {
- ostringstream tmp;
- tmp << '.' << nrep;
- key.append(tmp.str().data(), tmp.str().size());
+ vector<char *> key_vector;
+ // User has specified a set of fsts to extract, where some of the "fsts" could
+ // be ranges.
+ if (!keys.empty()) {
+ char *keys_cstr = new char[keys.size()+1];
+ strcpy(keys_cstr, keys.c_str());
+ SplitToVector(keys_cstr, key_separator.c_str(), &key_vector, true);
+ int i = 0;
+ for (int k = 0; k < key_vector.size(); ++k, ++i) {
+ string key = string(key_vector[k]);
+ char *key_cstr = new char[key.size()+1];
+ strcpy(key_cstr, key.c_str());
+ vector<char *> range_vector;
+ SplitToVector(key_cstr, range_delimiter.c_str(), &range_vector, false);
+ if (range_vector.size() == 1) { // Not a range
+ if (!far_reader->Find(key)) {
+ LOG(ERROR) << "FarExtract: Cannot find key: " << key;
+ return;
+ }
+ const Fst<Arc> &fst = far_reader->GetFst();
+ FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i,
+ filename_prefix, filename_suffix);
+ } else if (range_vector.size() == 2) { // A legal range
+ string begin_key = string(range_vector[0]);
+ string end_key = string(range_vector[1]);
+ if (begin_key.empty() || end_key.empty()) {
+ LOG(ERROR) << "FarExtract: Illegal range specification: " << key;
+ return;
+ }
+ if (!far_reader->Find(begin_key)) {
+ LOG(ERROR) << "FarExtract: Cannot find key: " << begin_key;
+ return;
+ }
+ for ( ; !far_reader->Done(); far_reader->Next(), ++i) {
+ string ikey = far_reader->GetKey();
+ if (end_key < ikey) break;
+ const Fst<Arc> &fst = far_reader->GetFst();
+ FarWriteFst(&fst, ikey, &okey, &nrep, generate_filenames, i,
+ filename_prefix, filename_suffix);
+ }
+ } else {
+ LOG(ERROR) << "FarExtract: Illegal range specification: " << key;
+ return;
}
- ofilename = key;
+ delete key_cstr;
}
- fst.Write(filename_prefix + ofilename + filename_suffix);
+ delete keys_cstr;
+ return;
+ }
+ // Nothing specified: extract everything.
+ for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) {
+ string key = far_reader->GetKey();
+ const Fst<Arc> &fst = far_reader->GetFst();
+ FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i,
+ filename_prefix, filename_suffix);
}
-
return;
}
diff --git a/src/include/fst/extensions/far/far.h b/src/include/fst/extensions/far/far.h
index acce76e..737f1b8 100644
--- a/src/include/fst/extensions/far/far.h
+++ b/src/include/fst/extensions/far/far.h
@@ -273,13 +273,10 @@ FarWriter<A> *FarWriter<A>::Create(const string &filename, FarType type) {
return STListFarWriter<A>::Create(filename);
case FAR_STTABLE:
return STTableFarWriter<A>::Create(filename);
- break;
case FAR_STLIST:
return STListFarWriter<A>::Create(filename);
- break;
case FAR_FST:
return FstFarWriter<A>::Create(filename);
- break;
default:
LOG(ERROR) << "FarWriter::Create: unknown far type";
return 0;
diff --git a/src/include/fst/extensions/far/farscript.h b/src/include/fst/extensions/far/farscript.h
index 3a9c145..cfd9167 100644
--- a/src/include/fst/extensions/far/farscript.h
+++ b/src/include/fst/extensions/far/farscript.h
@@ -173,18 +173,22 @@ bool FarEqual(const string &filename1,
typedef args::Package<const vector<string> &, int32,
const string&, const string&, const string&,
- const string&> FarExtractArgs;
+ const string&, const string&> FarExtractArgs;
template<class Arc>
void FarExtract(FarExtractArgs *args) {
fst::FarExtract<Arc>(
- args->arg1, args->arg2, args->arg3, args->arg4, args->arg5, args->arg6);
+ args->arg1, args->arg2, args->arg3, args->arg4, args->arg5, args->arg6,
+ args->arg7);
}
void FarExtract(const vector<string> &ifilenames,
const string &arc_type,
- int32 generate_filenames, const string &begin_key,
- const string &end_key, const string &filename_prefix,
+ int32 generate_filenames,
+ const string &keys,
+ const string &key_separator,
+ const string &range_delimiter,
+ const string &filename_prefix,
const string &filename_suffix);
typedef args::Package<const vector<string> &, const string &,
diff --git a/src/include/fst/extensions/far/stlist.h b/src/include/fst/extensions/far/stlist.h
index 1cdc80c..ff3d98b 100644
--- a/src/include/fst/extensions/far/stlist.h
+++ b/src/include/fst/extensions/far/stlist.h
@@ -145,13 +145,13 @@ class STListReader {
ReadType(*streams_[i], &magic_number);
ReadType(*streams_[i], &file_version);
if (magic_number != kSTListMagicNumber) {
- FSTERROR() << "STListReader::STTableReader: wrong file type: "
+ FSTERROR() << "STListReader::STListReader: wrong file type: "
<< filenames[i];
error_ = true;
return;
}
if (file_version != kSTListFileVersion) {
- FSTERROR() << "STListReader::STTableReader: wrong file version: "
+ FSTERROR() << "STListReader::STListReader: wrong file version: "
<< filenames[i];
error_ = true;
return;
@@ -161,7 +161,7 @@ class STListReader {
if (!key.empty())
heap_.push(make_pair(key, i));
if (!*streams_[i]) {
- FSTERROR() << "STTableReader: error reading file: " << sources_[i];
+ FSTERROR() << "STListReader: error reading file: " << sources_[i];
error_ = true;
return;
}
@@ -170,7 +170,7 @@ class STListReader {
size_t current = heap_.top().second;
entry_ = entry_reader_(*streams_[current]);
if (!entry_ || !*streams_[current]) {
- FSTERROR() << "STTableReader: error reading entry for key: "
+ FSTERROR() << "STListReader: error reading entry for key: "
<< heap_.top().first << ", file: " << sources_[current];
error_ = true;
}
@@ -219,7 +219,7 @@ class STListReader {
heap_.pop();
ReadType(*(streams_[current]), &key);
if (!*streams_[current]) {
- FSTERROR() << "STTableReader: error reading file: "
+ FSTERROR() << "STListReader: error reading file: "
<< sources_[current];
error_ = true;
return;
@@ -233,7 +233,7 @@ class STListReader {
delete entry_;
entry_ = entry_reader_(*streams_[current]);
if (!entry_ || !*streams_[current]) {
- FSTERROR() << "STTableReader: error reading entry for key: "
+ FSTERROR() << "STListReader: error reading entry for key: "
<< heap_.top().first << ", file: " << sources_[current];
error_ = true;
}
@@ -267,8 +267,8 @@ class STListReader {
// String-type list header reading function template on the entry header
// type 'H' having a member function:
// Read(istream &strm, const string &filename);
-// Checks that 'filename' is an STTable and call the H::Read() on the last
-// entry in the STTable.
+// Checks that 'filename' is an STList and call the H::Read() on the last
+// entry in the STList.
// Does not support reading from stdin.
template <class H>
bool ReadSTListHeader(const string &filename, H *header) {
@@ -281,18 +281,18 @@ bool ReadSTListHeader(const string &filename, H *header) {
ReadType(strm, &magic_number);
ReadType(strm, &file_version);
if (magic_number != kSTListMagicNumber) {
- LOG(ERROR) << "ReadSTTableHeader: wrong file type: " << filename;
+ LOG(ERROR) << "ReadSTListHeader: wrong file type: " << filename;
return false;
}
if (file_version != kSTListFileVersion) {
- LOG(ERROR) << "ReadSTTableHeader: wrong file version: " << filename;
+ LOG(ERROR) << "ReadSTListHeader: wrong file version: " << filename;
return false;
}
string key;
ReadType(strm, &key);
header->Read(strm, filename + ":" + key);
if (!strm) {
- LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename;
+ LOG(ERROR) << "ReadSTListHeader: error reading file: " << filename;
return false;
}
return true;
diff --git a/src/include/fst/extensions/ngram/ngram-fst.h b/src/include/fst/extensions/ngram/ngram-fst.h
index eee664a..873ae6a 100644
--- a/src/include/fst/extensions/ngram/ngram-fst.h
+++ b/src/include/fst/extensions/ngram/ngram-fst.h
@@ -26,6 +26,7 @@ using std::vector;
#include <fst/compat.h>
#include <fst/fstlib.h>
+#include <fst/mapped-file.h>
#include <fst/extensions/ngram/bitmap-index.h>
// NgramFst implements a n-gram language model based upon the LOUDS data
@@ -76,7 +77,7 @@ class NGramFstImpl : public FstImpl<A> {
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
- NGramFstImpl() : data_(0), owned_(false) {
+ NGramFstImpl() : data_region_(0), data_(0), owned_(false) {
SetType("ngram");
SetInputSymbols(NULL);
SetOutputSymbols(NULL);
@@ -89,6 +90,7 @@ class NGramFstImpl : public FstImpl<A> {
if (owned_) {
delete [] data_;
}
+ delete data_region_;
}
static NGramFstImpl<A>* Read(istream &strm, // NOLINT
@@ -104,7 +106,8 @@ class NGramFstImpl : public FstImpl<A> {
strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures));
strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final));
size_t size = Storage(num_states, num_futures, num_final);
- char* data = new char[size];
+ MappedFile *data_region = MappedFile::Allocate(size);
+ char *data = reinterpret_cast<char *>(data_region->mutable_data());
// Copy num_states, num_futures and num_final back into data.
memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states));
memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures),
@@ -116,7 +119,7 @@ class NGramFstImpl : public FstImpl<A> {
delete impl;
return NULL;
}
- impl->Init(data, true /* owned */);
+ impl->Init(data, false, data_region);
return impl;
}
@@ -126,7 +129,7 @@ class NGramFstImpl : public FstImpl<A> {
hdr.SetStart(Start());
hdr.SetNumStates(num_states_);
WriteHeader(strm, opts, kFileVersion, &hdr);
- strm.write(data_, Storage(num_states_, num_futures_, num_final_));
+ strm.write(data_, StorageSize());
return strm;
}
@@ -223,11 +226,23 @@ class NGramFstImpl : public FstImpl<A> {
// Access to the underlying representation
const char* GetData(size_t* data_size) const {
- *data_size = Storage(num_states_, num_futures_, num_final_);
+ *data_size = StorageSize();
return data_;
}
- void Init(const char* data, bool owned);
+ void Init(const char* data, bool owned, MappedFile *file = 0);
+
+ const vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const {
+ SetInstFuture(s, inst);
+ SetInstContext(inst);
+ return inst->context_;
+ }
+
+ size_t StorageSize() const {
+ return Storage(num_states_, num_futures_, num_final_);
+ }
+
+ void GetStates(const vector<Label>& context, vector<StateId> *states) const;
private:
StateId Transition(const vector<Label> &context, Label future) const;
@@ -242,6 +257,7 @@ class NGramFstImpl : public FstImpl<A> {
// Minimum file format version supported.
static const int kMinFileVersion = 4;
+ MappedFile *data_region_;
const char* data_;
bool owned_; // True if we own data_
uint64 num_states_, num_futures_, num_final_;
@@ -261,7 +277,7 @@ class NGramFstImpl : public FstImpl<A> {
template<typename A>
NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
- : data_(0), owned_(false) {
+ : data_region_(0), data_(0), owned_(false) {
typedef A Arc;
typedef typename Arc::Label Label;
typedef typename Arc::Weight Weight;
@@ -286,12 +302,16 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
// epsilons.
StateId unigram = fst.Start();
while (1) {
- ArcIterator<Fst<A> > aiter(fst, unigram);
- if (aiter.Done()) {
- FSTERROR() << "Start state has no arcs";
+ if (unigram == kNoStateId) {
+ FSTERROR() << "Could not identify unigram state.";
SetProperties(kError, kError);
return;
}
+ ArcIterator<Fst<A> > aiter(fst, unigram);
+ if (aiter.Done()) {
+ LOG(WARNING) << "Unigram state " << unigram << " has no arcs.";
+ break;
+ }
if (aiter.Value().ilabel != 0) break;
unigram = aiter.Value().nextstate;
}
@@ -385,7 +405,8 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
Weight weight;
Label label = kNoLabel;
const size_t storage = Storage(num_states, num_futures, num_final);
- char* data = new char[storage];
+ MappedFile *data_region = MappedFile::Allocate(storage);
+ char *data = reinterpret_cast<char *>(data_region->mutable_data());
memset(data, 0, storage);
size_t offset = 0;
memcpy(data + offset, reinterpret_cast<char *>(&num_states),
@@ -482,14 +503,17 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
return;
}
- Init(data, true /* owned */);
+ Init(data, false, data_region);
}
template<typename A>
-inline void NGramFstImpl<A>::Init(const char* data, bool owned) {
+inline void NGramFstImpl<A>::Init(const char* data, bool owned,
+ MappedFile *data_region) {
if (owned_) {
delete [] data_;
}
+ delete data_region_;
+ data_region_ = data_region;
owned_ = owned;
data_ = data;
size_t offset = 0;
@@ -507,7 +531,7 @@ inline void NGramFstImpl<A>::Init(const char* data, bool owned) {
future_ = reinterpret_cast<const uint64*>(data_ + offset);
offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits);
final_ = reinterpret_cast<const uint64*>(data_ + offset);
- offset += BitmapIndex::StorageSize(num_states_ + 1) * sizeof(bits);
+ offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits);
context_words_ = reinterpret_cast<const Label*>(data_ + offset);
offset += (num_states_ + 1) * sizeof(*context_words_);
future_words_ = reinterpret_cast<const Label*>(data_ + offset);
@@ -538,10 +562,10 @@ inline void NGramFstImpl<A>::Init(const char* data, bool owned) {
template<typename A>
inline typename A::StateId NGramFstImpl<A>::Transition(
const vector<Label> &context, Label future) const {
- size_t num_children = root_num_children_;
const Label *children = root_children_;
- const Label *loc = lower_bound(children, children + num_children, future);
- if (loc == children + num_children || *loc != future) {
+ const Label *loc = lower_bound(children, children + root_num_children_,
+ future);
+ if (loc == children + root_num_children_ || *loc != future) {
return context_index_.Rank1(0);
}
size_t node = root_first_child_ + loc - children;
@@ -551,7 +575,6 @@ inline typename A::StateId NGramFstImpl<A>::Transition(
return context_index_.Rank1(node);
}
size_t last_child = context_index_.Select0(node_rank + 1) - 1;
- num_children = last_child - first_child + 1;
for (int word = context.size() - 1; word >= 0; --word) {
children = context_words_ + context_index_.Rank1(first_child);
loc = lower_bound(children, children + last_child - first_child + 1,
@@ -569,6 +592,42 @@ inline typename A::StateId NGramFstImpl<A>::Transition(
return context_index_.Rank1(node);
}
+template<typename A>
+inline void NGramFstImpl<A>::GetStates(
+ const vector<Label> &context,
+ vector<typename A::StateId>* states) const {
+ states->clear();
+ states->push_back(0);
+ typename vector<Label>::const_reverse_iterator cit = context.rbegin();
+ const Label *children = root_children_;
+ const Label *loc = lower_bound(children, children + root_num_children_, *cit);
+ if (loc == children + root_num_children_ || *loc != *cit) return;
+ size_t node = root_first_child_ + loc - children;
+ states->push_back(context_index_.Rank1(node));
+ if (context.size() == 1) return;
+ size_t node_rank = context_index_.Rank1(node);
+ size_t first_child = context_index_.Select0(node_rank) + 1;
+ ++cit;
+ if (context_index_.Get(first_child) != false) {
+ size_t last_child = context_index_.Select0(node_rank + 1) - 1;
+ while (cit != context.rend()) {
+ children = context_words_ + context_index_.Rank1(first_child);
+ loc = lower_bound(children, children + last_child - first_child + 1,
+ *cit);
+ if (loc == children + last_child - first_child + 1 || *loc != *cit) {
+ break;
+ }
+ ++cit;
+ node = first_child + loc - children;
+ states->push_back(context_index_.Rank1(node));
+ node_rank = context_index_.Rank1(node);
+ first_child = context_index_.Select0(node_rank) + 1;
+ if (context_index_.Get(first_child) == false) break;
+ last_child = context_index_.Select0(node_rank + 1) - 1;
+ }
+ }
+}
+
/*****************************************************************************/
template<class A>
class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
@@ -597,7 +656,7 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
// Non-standard constructor to initialize NGramFst directly from data.
NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) {
- GetImpl()->Init(data, owned);
+ GetImpl()->Init(data, owned, NULL);
}
// Get method that gets the data associated with Init().
@@ -605,6 +664,16 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
return GetImpl()->GetData(data_size);
}
+ const vector<Label> GetContext(StateId s) const {
+ return GetImpl()->GetContext(s, &inst_);
+ }
+
+ // Consumes as much as possible of context from right to left, returns the
+ // the states corresponding to the increasingly conditioned input sequence.
+ void GetStates(const vector<Label>& context, vector<StateId> *state) const {
+ return GetImpl()->GetStates(context, state);
+ }
+
virtual size_t NumArcs(StateId s) const {
return GetImpl()->NumArcs(s, &inst_);
}
@@ -650,6 +719,10 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
return new NGramFstMatcher<A>(*this, match_type);
}
+ size_t StorageSize() const {
+ return GetImpl()->StorageSize();
+ }
+
private:
explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {}
diff --git a/src/include/fst/extensions/pdt/compose.h b/src/include/fst/extensions/pdt/compose.h
index 364d76f..c856c6d 100644
--- a/src/include/fst/extensions/pdt/compose.h
+++ b/src/include/fst/extensions/pdt/compose.h
@@ -21,82 +21,469 @@
#ifndef FST_EXTENSIONS_PDT_COMPOSE_H__
#define FST_EXTENSIONS_PDT_COMPOSE_H__
+#include <list>
+
+#include <fst/extensions/pdt/pdt.h>
#include <fst/compose.h>
namespace fst {
+// Return paren arcs for Find(kNoLabel).
+const uint32 kParenList = 0x00000001;
+
+// Return a kNolabel loop for Find(paren).
+const uint32 kParenLoop = 0x00000002;
+
+// This class is a matcher that treats parens as multi-epsilon labels.
+// It is most efficient if the parens are in a range non-overlapping with
+// the non-paren labels.
+template <class F>
+class ParenMatcher {
+ public:
+ typedef SortedMatcher<F> M;
+ typedef typename M::FST FST;
+ typedef typename M::Arc Arc;
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Label Label;
+ typedef typename Arc::Weight Weight;
+
+ ParenMatcher(const FST &fst, MatchType match_type,
+ uint32 flags = (kParenLoop | kParenList))
+ : matcher_(fst, match_type),
+ match_type_(match_type),
+ flags_(flags) {
+ if (match_type == MATCH_INPUT) {
+ loop_.ilabel = kNoLabel;
+ loop_.olabel = 0;
+ } else {
+ loop_.ilabel = 0;
+ loop_.olabel = kNoLabel;
+ }
+ loop_.weight = Weight::One();
+ loop_.nextstate = kNoStateId;
+ }
+
+ ParenMatcher(const ParenMatcher<F> &matcher, bool safe = false)
+ : matcher_(matcher.matcher_, safe),
+ match_type_(matcher.match_type_),
+ flags_(matcher.flags_),
+ open_parens_(matcher.open_parens_),
+ close_parens_(matcher.close_parens_),
+ loop_(matcher.loop_) {
+ loop_.nextstate = kNoStateId;
+ }
+
+ ParenMatcher<F> *Copy(bool safe = false) const {
+ return new ParenMatcher<F>(*this, safe);
+ }
+
+ MatchType Type(bool test) const { return matcher_.Type(test); }
+
+ void SetState(StateId s) {
+ matcher_.SetState(s);
+ loop_.nextstate = s;
+ }
+
+ bool Find(Label match_label);
+
+ bool Done() const {
+ return done_;
+ }
+
+ const Arc& Value() const {
+ return paren_loop_ ? loop_ : matcher_.Value();
+ }
+
+ void Next();
+
+ const FST &GetFst() const { return matcher_.GetFst(); }
+
+ uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
+
+ uint32 Flags() const { return matcher_.Flags(); }
+
+ void AddOpenParen(Label label) {
+ if (label == 0) {
+ FSTERROR() << "ParenMatcher: Bad open paren label: 0";
+ } else {
+ open_parens_.Insert(label);
+ }
+ }
+
+ void AddCloseParen(Label label) {
+ if (label == 0) {
+ FSTERROR() << "ParenMatcher: Bad close paren label: 0";
+ } else {
+ close_parens_.Insert(label);
+ }
+ }
+
+ void RemoveOpenParen(Label label) {
+ if (label == 0) {
+ FSTERROR() << "ParenMatcher: Bad open paren label: 0";
+ } else {
+ open_parens_.Erase(label);
+ }
+ }
+
+ void RemoveCloseParen(Label label) {
+ if (label == 0) {
+ FSTERROR() << "ParenMatcher: Bad close paren label: 0";
+ } else {
+ close_parens_.Erase(label);
+ }
+ }
+
+ void ClearOpenParens() {
+ open_parens_.Clear();
+ }
+
+ void ClearCloseParens() {
+ close_parens_.Clear();
+ }
+
+ bool IsOpenParen(Label label) const {
+ return open_parens_.Member(label);
+ }
+
+ bool IsCloseParen(Label label) const {
+ return close_parens_.Member(label);
+ }
+
+ private:
+ // Advances matcher to next open paren if it exists, returning true.
+ // O.w. returns false.
+ bool NextOpenParen();
+
+ // Advances matcher to next open paren if it exists, returning true.
+ // O.w. returns false.
+ bool NextCloseParen();
+
+ M matcher_;
+ MatchType match_type_; // Type of match to perform
+ uint32 flags_;
+
+ // open paren label set
+ CompactSet<Label, kNoLabel> open_parens_;
+
+ // close paren label set
+ CompactSet<Label, kNoLabel> close_parens_;
+
+
+ bool open_paren_list_; // Matching open paren list
+ bool close_paren_list_; // Matching close paren list
+ bool paren_loop_; // Current arc is the implicit paren loop
+ mutable Arc loop_; // For non-consuming symbols
+ bool done_; // Matching done
+
+ void operator=(const ParenMatcher<F> &); // Disallow
+};
+
+template <class M> inline
+bool ParenMatcher<M>::Find(Label match_label) {
+ open_paren_list_ = false;
+ close_paren_list_ = false;
+ paren_loop_ = false;
+ done_ = false;
+
+ // Returns all parenthesis arcs
+ if (match_label == kNoLabel && (flags_ & kParenList)) {
+ if (open_parens_.LowerBound() != kNoLabel) {
+ matcher_.LowerBound(open_parens_.LowerBound());
+ open_paren_list_ = NextOpenParen();
+ if (open_paren_list_) return true;
+ }
+ if (close_parens_.LowerBound() != kNoLabel) {
+ matcher_.LowerBound(close_parens_.LowerBound());
+ close_paren_list_ = NextCloseParen();
+ if (close_paren_list_) return true;
+ }
+ }
+
+ // Returns 'implicit' paren loop
+ if (match_label > 0 && (flags_ & kParenLoop) &&
+ (IsOpenParen(match_label) || IsCloseParen(match_label))) {
+ paren_loop_ = true;
+ return true;
+ }
+
+ // Returns all other labels
+ if (matcher_.Find(match_label))
+ return true;
+
+ done_ = true;
+ return false;
+}
+
+template <class F> inline
+void ParenMatcher<F>::Next() {
+ if (paren_loop_) {
+ paren_loop_ = false;
+ done_ = true;
+ } else if (open_paren_list_) {
+ matcher_.Next();
+ open_paren_list_ = NextOpenParen();
+ if (open_paren_list_) return;
+
+ if (close_parens_.LowerBound() != kNoLabel) {
+ matcher_.LowerBound(close_parens_.LowerBound());
+ close_paren_list_ = NextCloseParen();
+ if (close_paren_list_) return;
+ }
+ done_ = !matcher_.Find(kNoLabel);
+ } else if (close_paren_list_) {
+ matcher_.Next();
+ close_paren_list_ = NextCloseParen();
+ if (close_paren_list_) return;
+ done_ = !matcher_.Find(kNoLabel);
+ } else {
+ matcher_.Next();
+ done_ = matcher_.Done();
+ }
+}
+
+// Advances matcher to next open paren if it exists, returning true.
+// O.w. returns false.
+template <class F> inline
+bool ParenMatcher<F>::NextOpenParen() {
+ for (; !matcher_.Done(); matcher_.Next()) {
+ Label label = match_type_ == MATCH_INPUT ?
+ matcher_.Value().ilabel : matcher_.Value().olabel;
+ if (label > open_parens_.UpperBound())
+ return false;
+ if (IsOpenParen(label))
+ return true;
+ }
+ return false;
+}
+
+// Advances matcher to next close paren if it exists, returning true.
+// O.w. returns false.
+template <class F> inline
+bool ParenMatcher<F>::NextCloseParen() {
+ for (; !matcher_.Done(); matcher_.Next()) {
+ Label label = match_type_ == MATCH_INPUT ?
+ matcher_.Value().ilabel : matcher_.Value().olabel;
+ if (label > close_parens_.UpperBound())
+ return false;
+ if (IsCloseParen(label))
+ return true;
+ }
+ return false;
+}
+
+
+template <class F>
+class ParenFilter {
+ public:
+ typedef typename F::FST1 FST1;
+ typedef typename F::FST2 FST2;
+ typedef typename F::Arc Arc;
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Label Label;
+ typedef typename Arc::Weight Weight;
+ typedef typename F::Matcher1 Matcher1;
+ typedef typename F::Matcher2 Matcher2;
+ typedef typename F::FilterState FilterState1;
+ typedef StateId StackId;
+ typedef PdtStack<StackId, Label> ParenStack;
+ typedef IntegerFilterState<StackId> FilterState2;
+ typedef PairFilterState<FilterState1, FilterState2> FilterState;
+ typedef ParenFilter<F> Filter;
+
+ ParenFilter(const FST1 &fst1, const FST2 &fst2,
+ Matcher1 *matcher1 = 0, Matcher2 *matcher2 = 0,
+ const vector<pair<Label, Label> > *parens = 0,
+ bool expand = false, bool keep_parens = true)
+ : filter_(fst1, fst2, matcher1, matcher2),
+ parens_(parens ? *parens : vector<pair<Label, Label> >()),
+ expand_(expand),
+ keep_parens_(keep_parens),
+ f_(FilterState::NoState()),
+ stack_(parens_),
+ paren_id_(-1) {
+ if (parens) {
+ for (size_t i = 0; i < parens->size(); ++i) {
+ const pair<Label, Label> &p = (*parens)[i];
+ parens_.push_back(p);
+ GetMatcher1()->AddOpenParen(p.first);
+ GetMatcher2()->AddOpenParen(p.first);
+ if (!expand_) {
+ GetMatcher1()->AddCloseParen(p.second);
+ GetMatcher2()->AddCloseParen(p.second);
+ }
+ }
+ }
+ }
+
+ ParenFilter(const Filter &filter, bool safe = false)
+ : filter_(filter.filter_, safe),
+ parens_(filter.parens_),
+ expand_(filter.expand_),
+ keep_parens_(filter.keep_parens_),
+ f_(FilterState::NoState()),
+ stack_(filter.parens_),
+ paren_id_(-1) { }
+
+ FilterState Start() const {
+ return FilterState(filter_.Start(), FilterState2(0));
+ }
+
+ void SetState(StateId s1, StateId s2, const FilterState &f) {
+ f_ = f;
+ filter_.SetState(s1, s2, f_.GetState1());
+ if (!expand_)
+ return;
+
+ ssize_t paren_id = stack_.Top(f.GetState2().GetState());
+ if (paren_id != paren_id_) {
+ if (paren_id_ != -1) {
+ GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second);
+ GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second);
+ }
+ paren_id_ = paren_id;
+ if (paren_id_ != -1) {
+ GetMatcher1()->AddCloseParen(parens_[paren_id_].second);
+ GetMatcher2()->AddCloseParen(parens_[paren_id_].second);
+ }
+ }
+ }
+
+ FilterState FilterArc(Arc *arc1, Arc *arc2) const {
+ FilterState1 f1 = filter_.FilterArc(arc1, arc2);
+ const FilterState2 &f2 = f_.GetState2();
+ if (f1 == FilterState1::NoState())
+ return FilterState::NoState();
+
+ if (arc1->olabel == kNoLabel && arc2->ilabel) { // arc2 parentheses
+ if (keep_parens_) {
+ arc1->ilabel = arc2->ilabel;
+ } else if (arc2->ilabel) {
+ arc2->olabel = arc1->ilabel;
+ }
+ return FilterParen(arc2->ilabel, f1, f2);
+ } else if (arc2->ilabel == kNoLabel && arc1->olabel) { // arc1 parentheses
+ if (keep_parens_) {
+ arc2->olabel = arc1->olabel;
+ } else {
+ arc1->ilabel = arc2->olabel;
+ }
+ return FilterParen(arc1->olabel, f1, f2);
+ } else {
+ return FilterState(f1, f2);
+ }
+ }
+
+ void FilterFinal(Weight *w1, Weight *w2) const {
+ if (f_.GetState2().GetState() != 0)
+ *w1 = Weight::Zero();
+ filter_.FilterFinal(w1, w2);
+ }
+
+ // Return resp matchers. Ownership stays with filter.
+ Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); }
+ Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); }
+
+ uint64 Properties(uint64 iprops) const {
+ uint64 oprops = filter_.Properties(iprops);
+ return oprops & kILabelInvariantProperties & kOLabelInvariantProperties;
+ }
+
+ private:
+ const FilterState FilterParen(Label label, const FilterState1 &f1,
+ const FilterState2 &f2) const {
+ if (!expand_)
+ return FilterState(f1, f2);
+
+ StackId stack_id = stack_.Find(f2.GetState(), label);
+ if (stack_id < 0) {
+ return FilterState::NoState();
+ } else {
+ return FilterState(f1, FilterState2(stack_id));
+ }
+ }
+
+ F filter_;
+ vector<pair<Label, Label> > parens_;
+ bool expand_; // Expands to FST
+ bool keep_parens_; // Retains parentheses in output
+ FilterState f_; // Current filter state
+ mutable ParenStack stack_;
+ ssize_t paren_id_;
+};
+
// Class to setup composition options for PDT composition.
// Default is for the PDT as the first composition argument.
template <class Arc, bool left_pdt = true>
-class PdtComposeOptions : public
+class PdtComposeFstOptions : public
ComposeFstOptions<Arc,
- MultiEpsMatcher< Matcher<Fst<Arc> > >,
- MultiEpsFilter<AltSequenceComposeFilter<
- MultiEpsMatcher<
- Matcher<Fst<Arc> > > > > > {
+ ParenMatcher< Fst<Arc> >,
+ ParenFilter<AltSequenceComposeFilter<
+ ParenMatcher< Fst<Arc> > > > > {
public:
typedef typename Arc::Label Label;
- typedef MultiEpsMatcher< Matcher<Fst<Arc> > > PdtMatcher;
- typedef MultiEpsFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter;
+ typedef ParenMatcher< Fst<Arc> > PdtMatcher;
+ typedef ParenFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter;
typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
using COptions::matcher1;
using COptions::matcher2;
using COptions::filter;
- PdtComposeOptions(const Fst<Arc> &ifst1,
+ PdtComposeFstOptions(const Fst<Arc> &ifst1,
const vector<pair<Label, Label> > &parens,
- const Fst<Arc> &ifst2) {
- matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsList);
- matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsLoop);
-
- // Treat parens as multi-epsilons when composing.
- for (size_t i = 0; i < parens.size(); ++i) {
- matcher1->AddMultiEpsLabel(parens[i].first);
- matcher1->AddMultiEpsLabel(parens[i].second);
- matcher2->AddMultiEpsLabel(parens[i].first);
- matcher2->AddMultiEpsLabel(parens[i].second);
- }
+ const Fst<Arc> &ifst2, bool expand = false,
+ bool keep_parens = true) {
+ matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenList);
+ matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenLoop);
- filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true);
+ filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
+ expand, keep_parens);
}
};
// Class to setup composition options for PDT with FST composition.
// Specialization is for the FST as the first composition argument.
template <class Arc>
-class PdtComposeOptions<Arc, false> : public
+class PdtComposeFstOptions<Arc, false> : public
ComposeFstOptions<Arc,
- MultiEpsMatcher< Matcher<Fst<Arc> > >,
- MultiEpsFilter<SequenceComposeFilter<
- MultiEpsMatcher<
- Matcher<Fst<Arc> > > > > > {
+ ParenMatcher< Fst<Arc> >,
+ ParenFilter<SequenceComposeFilter<
+ ParenMatcher< Fst<Arc> > > > > {
public:
typedef typename Arc::Label Label;
- typedef MultiEpsMatcher< Matcher<Fst<Arc> > > PdtMatcher;
- typedef MultiEpsFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter;
+ typedef ParenMatcher< Fst<Arc> > PdtMatcher;
+ typedef ParenFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter;
typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
using COptions::matcher1;
using COptions::matcher2;
using COptions::filter;
- PdtComposeOptions(const Fst<Arc> &ifst1,
- const Fst<Arc> &ifst2,
- const vector<pair<Label, Label> > &parens) {
- matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsLoop);
- matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsList);
-
- // Treat parens as multi-epsilons when composing.
- for (size_t i = 0; i < parens.size(); ++i) {
- matcher1->AddMultiEpsLabel(parens[i].first);
- matcher1->AddMultiEpsLabel(parens[i].second);
- matcher2->AddMultiEpsLabel(parens[i].first);
- matcher2->AddMultiEpsLabel(parens[i].second);
- }
+ PdtComposeFstOptions(const Fst<Arc> &ifst1,
+ const Fst<Arc> &ifst2,
+ const vector<pair<Label, Label> > &parens,
+ bool expand = false, bool keep_parens = true) {
+ matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenLoop);
+ matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenList);
- filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true);
+ filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
+ expand, keep_parens);
}
};
+enum PdtComposeFilter {
+ PAREN_FILTER, // Bar-Hillel construction; keeps parentheses
+ EXPAND_FILTER, // Bar-Hillel + expansion; removes parentheses
+ EXPAND_PAREN_FILTER, // Bar-Hillel + expansion; keeps parentheses
+};
+
+struct PdtComposeOptions {
+ bool connect; // Connect output
+ PdtComposeFilter filter_type; // Which pre-defined filter to use
+
+ explicit PdtComposeOptions(bool c, PdtComposeFilter ft = PAREN_FILTER)
+ : connect(c), filter_type(ft) {}
+ PdtComposeOptions() : connect(true), filter_type(PAREN_FILTER) {}
+};
// Composes pushdown transducer (PDT) encoded as an FST (1st arg) and
// an FST (2nd arg) with the result also a PDT encoded as an Fst. (3rd arg).
@@ -110,16 +497,17 @@ void Compose(const Fst<Arc> &ifst1,
typename Arc::Label> > &parens,
const Fst<Arc> &ifst2,
MutableFst<Arc> *ofst,
- const ComposeOptions &opts = ComposeOptions()) {
-
- PdtComposeOptions<Arc, true> copts(ifst1, parens, ifst2);
+ const PdtComposeOptions &opts = PdtComposeOptions()) {
+ bool expand = opts.filter_type != PAREN_FILTER;
+ bool keep_parens = opts.filter_type != EXPAND_FILTER;
+ PdtComposeFstOptions<Arc, true> copts(ifst1, parens, ifst2,
+ expand, keep_parens);
copts.gc_limit = 0;
*ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
if (opts.connect)
Connect(ofst);
}
-
// Composes an FST (1st arg) and pushdown transducer (PDT) encoded as
// an FST (2nd arg) with the result also a PDT encoded as an Fst (3rd arg).
// In the PDTs, some transitions are labeled with open or close
@@ -132,9 +520,11 @@ void Compose(const Fst<Arc> &ifst1,
const vector<pair<typename Arc::Label,
typename Arc::Label> > &parens,
MutableFst<Arc> *ofst,
- const ComposeOptions &opts = ComposeOptions()) {
-
- PdtComposeOptions<Arc, false> copts(ifst1, ifst2, parens);
+ const PdtComposeOptions &opts = PdtComposeOptions()) {
+ bool expand = opts.filter_type != PAREN_FILTER;
+ bool keep_parens = opts.filter_type != EXPAND_FILTER;
+ PdtComposeFstOptions<Arc, false> copts(ifst1, ifst2, parens,
+ expand, keep_parens);
copts.gc_limit = 0;
*ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
if (opts.connect)
diff --git a/src/include/fst/extensions/pdt/pdt.h b/src/include/fst/extensions/pdt/pdt.h
index 6649f55..c56afbd 100644
--- a/src/include/fst/extensions/pdt/pdt.h
+++ b/src/include/fst/extensions/pdt/pdt.h
@@ -27,6 +27,7 @@ using std::tr1::unordered_multimap;
#include <map>
#include <set>
+#include <fst/compat.h>
#include <fst/state-table.h>
#include <fst/fst.h>
diff --git a/src/include/fst/extensions/pdt/pdtscript.h b/src/include/fst/extensions/pdt/pdtscript.h
index c2a1cf4..84bb27e 100644
--- a/src/include/fst/extensions/pdt/pdtscript.h
+++ b/src/include/fst/extensions/pdt/pdtscript.h
@@ -48,7 +48,7 @@ typedef args::Package<const FstClass &,
const FstClass &,
const vector<pair<int64, int64> >&,
MutableFstClass *,
- const ComposeOptions &,
+ const PdtComposeOptions &,
bool> PdtComposeArgs;
template<class Arc>
@@ -76,7 +76,7 @@ void PdtCompose(const FstClass & ifst1,
const FstClass & ifst2,
const vector<pair<int64, int64> > &parens,
MutableFstClass *ofst,
- const ComposeOptions &copts,
+ const PdtComposeOptions &copts,
bool left_pdt);
// PDT EXPAND
diff --git a/src/include/fst/extensions/pdt/replace.h b/src/include/fst/extensions/pdt/replace.h
index a85d0fe..9081400 100644
--- a/src/include/fst/extensions/pdt/replace.h
+++ b/src/include/fst/extensions/pdt/replace.h
@@ -21,6 +21,10 @@
#ifndef FST_EXTENSIONS_PDT_REPLACE_H__
#define FST_EXTENSIONS_PDT_REPLACE_H__
+#include <tr1/unordered_map>
+using std::tr1::unordered_map;
+using std::tr1::unordered_multimap;
+
#include <fst/replace.h>
namespace fst {
@@ -62,11 +66,14 @@ void Replace(const vector<pair<typename Arc::Label,
label2id[ifst_array[i].first] = i;
Label max_label = kNoLabel;
+ size_t max_non_term_count = 0;
- deque<size_t> non_term_queue; // Queue of non-terminals to replace
- unordered_set<Label> non_term_set; // Set of non-terminals to replace
+ // Queue of non-terminals to replace
+ deque<size_t> non_term_queue;
+ // Map of non-terminals to replace to count
+ unordered_map<Label, size_t> non_term_map;
non_term_queue.push_back(root);
- non_term_set.insert(root);
+ non_term_map[root] = 1;;
// PDT state corr. to ith replace FST start state.
vector<StateId> fst_start(ifst_array.size(), kNoLabel);
@@ -107,10 +114,11 @@ void Replace(const vector<pair<typename Arc::Label,
size_t nfst_id = it->second;
if (ifst_array[nfst_id].second->Start() == -1)
continue;
- if (non_term_set.count(arc.olabel) == 0) {
+ size_t count = non_term_map[arc.olabel]++;
+ if (count == 0)
non_term_queue.push_back(arc.olabel);
- non_term_set.insert(arc.olabel);
- }
+ if (count > max_non_term_count)
+ max_non_term_count = count;
}
arc.nextstate += soff;
ofst->AddArc(os, arc);
@@ -134,7 +142,8 @@ void Replace(const vector<pair<typename Arc::Label,
// # of parenthesis pairs per fst.
vector<size_t> nparens(ifst_array.size(), 0);
// Initial open parenthesis label
- Label first_paren = max_label + 1;
+ Label first_open_paren = max_label + 1;
+ Label first_close_paren = max_label + max_non_term_count + 1;
for (StateIterator< Fst<Arc> > siter(*ofst);
!siter.Done(); siter.Next()) {
@@ -158,8 +167,8 @@ void Replace(const vector<pair<typename Arc::Label,
close_paren = (*parens)[paren_id].second;
} else {
size_t paren_id = nparens[nfst_id]++;
- open_paren = first_paren + 2 * paren_id;
- close_paren = open_paren + 1;
+ open_paren = first_open_paren + paren_id;
+ close_paren = first_close_paren + paren_id;
paren_map[paren_key] = paren_id;
if (paren_id >= parens->size())
parens->push_back(make_pair(open_paren, close_paren));