diff options
Diffstat (limited to 'src/include/fst/string.h')
-rw-r--r-- | src/include/fst/string.h | 31 |
1 files changed, 27 insertions, 4 deletions
diff --git a/src/include/fst/string.h b/src/include/fst/string.h index d51182e..9eaf7a3 100644 --- a/src/include/fst/string.h +++ b/src/include/fst/string.h @@ -57,6 +57,15 @@ class StringCompiler { return true; } + template <class F> + bool operator()(const string &s, F *fst, Weight w) const { + vector<Label> labels; + if (!ConvertStringToLabels(s, &labels)) + return false; + Compile(labels, fst, w); + return true; + } + private: bool ConvertStringToLabels(const string &str, vector<Label> *labels) const { labels->clear(); @@ -83,22 +92,35 @@ class StringCompiler { return true; } - void Compile(const vector<Label> &labels, MutableFst<A> *fst) const { + void Compile(const vector<Label> &labels, MutableFst<A> *fst, + const Weight &weight = Weight::One()) const { fst->DeleteStates(); while (fst->NumStates() <= labels.size()) fst->AddState(); for (size_t i = 0; i < labels.size(); ++i) fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1)); fst->SetStart(0); - fst->SetFinal(labels.size(), Weight::One()); + fst->SetFinal(labels.size(), weight); } template <class Unsigned> - void Compile(const vector<Label> &labels, CompactFst<A, StringCompactor<A>, - Unsigned> *fst) const { + void Compile(const vector<Label> &labels, + CompactFst<A, StringCompactor<A>, Unsigned> *fst) const { fst->SetCompactElements(labels.begin(), labels.end()); } + template <class Unsigned> + void Compile(const vector<Label> &labels, + CompactFst<A, WeightedStringCompactor<A>, Unsigned> *fst, + const Weight &weight = Weight::One()) const { + vector<pair<Label, Weight> > compacts; + compacts.reserve(labels.size()); + for (size_t i = 0; i < labels.size(); ++i) + compacts.push_back(make_pair(labels[i], Weight::One())); + compacts.back().second = weight; + fst->SetCompactElements(compacts.begin(), compacts.end()); + } + bool ConvertSymbolToLabel(const char *s, Label* output) const { int64 n; if (syms_) { @@ -167,6 +189,7 @@ class StringPrinter { } *output = sstrm.str(); } else if (token_type_ == BYTE) { + output->reserve(labels_.size()); for (size_t i = 0; i < labels_.size(); ++i) { output->push_back(labels_[i]); } |