aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/extensions/pdt/compose.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/extensions/pdt/compose.h')
-rw-r--r--src/include/fst/extensions/pdt/compose.h486
1 files changed, 438 insertions, 48 deletions
diff --git a/src/include/fst/extensions/pdt/compose.h b/src/include/fst/extensions/pdt/compose.h
index 364d76f..c856c6d 100644
--- a/src/include/fst/extensions/pdt/compose.h
+++ b/src/include/fst/extensions/pdt/compose.h
@@ -21,82 +21,469 @@
#ifndef FST_EXTENSIONS_PDT_COMPOSE_H__
#define FST_EXTENSIONS_PDT_COMPOSE_H__
+#include <list>
+
+#include <fst/extensions/pdt/pdt.h>
#include <fst/compose.h>
namespace fst {
+// Return paren arcs for Find(kNoLabel).
+const uint32 kParenList = 0x00000001;
+
+// Return a kNolabel loop for Find(paren).
+const uint32 kParenLoop = 0x00000002;
+
+// This class is a matcher that treats parens as multi-epsilon labels.
+// It is most efficient if the parens are in a range non-overlapping with
+// the non-paren labels.
+template <class F>
+class ParenMatcher {
+ public:
+ typedef SortedMatcher<F> M;
+ typedef typename M::FST FST;
+ typedef typename M::Arc Arc;
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Label Label;
+ typedef typename Arc::Weight Weight;
+
+ ParenMatcher(const FST &fst, MatchType match_type,
+ uint32 flags = (kParenLoop | kParenList))
+ : matcher_(fst, match_type),
+ match_type_(match_type),
+ flags_(flags) {
+ if (match_type == MATCH_INPUT) {
+ loop_.ilabel = kNoLabel;
+ loop_.olabel = 0;
+ } else {
+ loop_.ilabel = 0;
+ loop_.olabel = kNoLabel;
+ }
+ loop_.weight = Weight::One();
+ loop_.nextstate = kNoStateId;
+ }
+
+ ParenMatcher(const ParenMatcher<F> &matcher, bool safe = false)
+ : matcher_(matcher.matcher_, safe),
+ match_type_(matcher.match_type_),
+ flags_(matcher.flags_),
+ open_parens_(matcher.open_parens_),
+ close_parens_(matcher.close_parens_),
+ loop_(matcher.loop_) {
+ loop_.nextstate = kNoStateId;
+ }
+
+ ParenMatcher<F> *Copy(bool safe = false) const {
+ return new ParenMatcher<F>(*this, safe);
+ }
+
+ MatchType Type(bool test) const { return matcher_.Type(test); }
+
+ void SetState(StateId s) {
+ matcher_.SetState(s);
+ loop_.nextstate = s;
+ }
+
+ bool Find(Label match_label);
+
+ bool Done() const {
+ return done_;
+ }
+
+ const Arc& Value() const {
+ return paren_loop_ ? loop_ : matcher_.Value();
+ }
+
+ void Next();
+
+ const FST &GetFst() const { return matcher_.GetFst(); }
+
+ uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
+
+ uint32 Flags() const { return matcher_.Flags(); }
+
+ void AddOpenParen(Label label) {
+ if (label == 0) {
+ FSTERROR() << "ParenMatcher: Bad open paren label: 0";
+ } else {
+ open_parens_.Insert(label);
+ }
+ }
+
+ void AddCloseParen(Label label) {
+ if (label == 0) {
+ FSTERROR() << "ParenMatcher: Bad close paren label: 0";
+ } else {
+ close_parens_.Insert(label);
+ }
+ }
+
+ void RemoveOpenParen(Label label) {
+ if (label == 0) {
+ FSTERROR() << "ParenMatcher: Bad open paren label: 0";
+ } else {
+ open_parens_.Erase(label);
+ }
+ }
+
+ void RemoveCloseParen(Label label) {
+ if (label == 0) {
+ FSTERROR() << "ParenMatcher: Bad close paren label: 0";
+ } else {
+ close_parens_.Erase(label);
+ }
+ }
+
+ void ClearOpenParens() {
+ open_parens_.Clear();
+ }
+
+ void ClearCloseParens() {
+ close_parens_.Clear();
+ }
+
+ bool IsOpenParen(Label label) const {
+ return open_parens_.Member(label);
+ }
+
+ bool IsCloseParen(Label label) const {
+ return close_parens_.Member(label);
+ }
+
+ private:
+ // Advances matcher to next open paren if it exists, returning true.
+ // O.w. returns false.
+ bool NextOpenParen();
+
+ // Advances matcher to next open paren if it exists, returning true.
+ // O.w. returns false.
+ bool NextCloseParen();
+
+ M matcher_;
+ MatchType match_type_; // Type of match to perform
+ uint32 flags_;
+
+ // open paren label set
+ CompactSet<Label, kNoLabel> open_parens_;
+
+ // close paren label set
+ CompactSet<Label, kNoLabel> close_parens_;
+
+
+ bool open_paren_list_; // Matching open paren list
+ bool close_paren_list_; // Matching close paren list
+ bool paren_loop_; // Current arc is the implicit paren loop
+ mutable Arc loop_; // For non-consuming symbols
+ bool done_; // Matching done
+
+ void operator=(const ParenMatcher<F> &); // Disallow
+};
+
+template <class M> inline
+bool ParenMatcher<M>::Find(Label match_label) {
+ open_paren_list_ = false;
+ close_paren_list_ = false;
+ paren_loop_ = false;
+ done_ = false;
+
+ // Returns all parenthesis arcs
+ if (match_label == kNoLabel && (flags_ & kParenList)) {
+ if (open_parens_.LowerBound() != kNoLabel) {
+ matcher_.LowerBound(open_parens_.LowerBound());
+ open_paren_list_ = NextOpenParen();
+ if (open_paren_list_) return true;
+ }
+ if (close_parens_.LowerBound() != kNoLabel) {
+ matcher_.LowerBound(close_parens_.LowerBound());
+ close_paren_list_ = NextCloseParen();
+ if (close_paren_list_) return true;
+ }
+ }
+
+ // Returns 'implicit' paren loop
+ if (match_label > 0 && (flags_ & kParenLoop) &&
+ (IsOpenParen(match_label) || IsCloseParen(match_label))) {
+ paren_loop_ = true;
+ return true;
+ }
+
+ // Returns all other labels
+ if (matcher_.Find(match_label))
+ return true;
+
+ done_ = true;
+ return false;
+}
+
+template <class F> inline
+void ParenMatcher<F>::Next() {
+ if (paren_loop_) {
+ paren_loop_ = false;
+ done_ = true;
+ } else if (open_paren_list_) {
+ matcher_.Next();
+ open_paren_list_ = NextOpenParen();
+ if (open_paren_list_) return;
+
+ if (close_parens_.LowerBound() != kNoLabel) {
+ matcher_.LowerBound(close_parens_.LowerBound());
+ close_paren_list_ = NextCloseParen();
+ if (close_paren_list_) return;
+ }
+ done_ = !matcher_.Find(kNoLabel);
+ } else if (close_paren_list_) {
+ matcher_.Next();
+ close_paren_list_ = NextCloseParen();
+ if (close_paren_list_) return;
+ done_ = !matcher_.Find(kNoLabel);
+ } else {
+ matcher_.Next();
+ done_ = matcher_.Done();
+ }
+}
+
+// Advances matcher to next open paren if it exists, returning true.
+// O.w. returns false.
+template <class F> inline
+bool ParenMatcher<F>::NextOpenParen() {
+ for (; !matcher_.Done(); matcher_.Next()) {
+ Label label = match_type_ == MATCH_INPUT ?
+ matcher_.Value().ilabel : matcher_.Value().olabel;
+ if (label > open_parens_.UpperBound())
+ return false;
+ if (IsOpenParen(label))
+ return true;
+ }
+ return false;
+}
+
+// Advances matcher to next close paren if it exists, returning true.
+// O.w. returns false.
+template <class F> inline
+bool ParenMatcher<F>::NextCloseParen() {
+ for (; !matcher_.Done(); matcher_.Next()) {
+ Label label = match_type_ == MATCH_INPUT ?
+ matcher_.Value().ilabel : matcher_.Value().olabel;
+ if (label > close_parens_.UpperBound())
+ return false;
+ if (IsCloseParen(label))
+ return true;
+ }
+ return false;
+}
+
+
+template <class F>
+class ParenFilter {
+ public:
+ typedef typename F::FST1 FST1;
+ typedef typename F::FST2 FST2;
+ typedef typename F::Arc Arc;
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Label Label;
+ typedef typename Arc::Weight Weight;
+ typedef typename F::Matcher1 Matcher1;
+ typedef typename F::Matcher2 Matcher2;
+ typedef typename F::FilterState FilterState1;
+ typedef StateId StackId;
+ typedef PdtStack<StackId, Label> ParenStack;
+ typedef IntegerFilterState<StackId> FilterState2;
+ typedef PairFilterState<FilterState1, FilterState2> FilterState;
+ typedef ParenFilter<F> Filter;
+
+ ParenFilter(const FST1 &fst1, const FST2 &fst2,
+ Matcher1 *matcher1 = 0, Matcher2 *matcher2 = 0,
+ const vector<pair<Label, Label> > *parens = 0,
+ bool expand = false, bool keep_parens = true)
+ : filter_(fst1, fst2, matcher1, matcher2),
+ parens_(parens ? *parens : vector<pair<Label, Label> >()),
+ expand_(expand),
+ keep_parens_(keep_parens),
+ f_(FilterState::NoState()),
+ stack_(parens_),
+ paren_id_(-1) {
+ if (parens) {
+ for (size_t i = 0; i < parens->size(); ++i) {
+ const pair<Label, Label> &p = (*parens)[i];
+ parens_.push_back(p);
+ GetMatcher1()->AddOpenParen(p.first);
+ GetMatcher2()->AddOpenParen(p.first);
+ if (!expand_) {
+ GetMatcher1()->AddCloseParen(p.second);
+ GetMatcher2()->AddCloseParen(p.second);
+ }
+ }
+ }
+ }
+
+ ParenFilter(const Filter &filter, bool safe = false)
+ : filter_(filter.filter_, safe),
+ parens_(filter.parens_),
+ expand_(filter.expand_),
+ keep_parens_(filter.keep_parens_),
+ f_(FilterState::NoState()),
+ stack_(filter.parens_),
+ paren_id_(-1) { }
+
+ FilterState Start() const {
+ return FilterState(filter_.Start(), FilterState2(0));
+ }
+
+ void SetState(StateId s1, StateId s2, const FilterState &f) {
+ f_ = f;
+ filter_.SetState(s1, s2, f_.GetState1());
+ if (!expand_)
+ return;
+
+ ssize_t paren_id = stack_.Top(f.GetState2().GetState());
+ if (paren_id != paren_id_) {
+ if (paren_id_ != -1) {
+ GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second);
+ GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second);
+ }
+ paren_id_ = paren_id;
+ if (paren_id_ != -1) {
+ GetMatcher1()->AddCloseParen(parens_[paren_id_].second);
+ GetMatcher2()->AddCloseParen(parens_[paren_id_].second);
+ }
+ }
+ }
+
+ FilterState FilterArc(Arc *arc1, Arc *arc2) const {
+ FilterState1 f1 = filter_.FilterArc(arc1, arc2);
+ const FilterState2 &f2 = f_.GetState2();
+ if (f1 == FilterState1::NoState())
+ return FilterState::NoState();
+
+ if (arc1->olabel == kNoLabel && arc2->ilabel) { // arc2 parentheses
+ if (keep_parens_) {
+ arc1->ilabel = arc2->ilabel;
+ } else if (arc2->ilabel) {
+ arc2->olabel = arc1->ilabel;
+ }
+ return FilterParen(arc2->ilabel, f1, f2);
+ } else if (arc2->ilabel == kNoLabel && arc1->olabel) { // arc1 parentheses
+ if (keep_parens_) {
+ arc2->olabel = arc1->olabel;
+ } else {
+ arc1->ilabel = arc2->olabel;
+ }
+ return FilterParen(arc1->olabel, f1, f2);
+ } else {
+ return FilterState(f1, f2);
+ }
+ }
+
+ void FilterFinal(Weight *w1, Weight *w2) const {
+ if (f_.GetState2().GetState() != 0)
+ *w1 = Weight::Zero();
+ filter_.FilterFinal(w1, w2);
+ }
+
+ // Return resp matchers. Ownership stays with filter.
+ Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); }
+ Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); }
+
+ uint64 Properties(uint64 iprops) const {
+ uint64 oprops = filter_.Properties(iprops);
+ return oprops & kILabelInvariantProperties & kOLabelInvariantProperties;
+ }
+
+ private:
+ const FilterState FilterParen(Label label, const FilterState1 &f1,
+ const FilterState2 &f2) const {
+ if (!expand_)
+ return FilterState(f1, f2);
+
+ StackId stack_id = stack_.Find(f2.GetState(), label);
+ if (stack_id < 0) {
+ return FilterState::NoState();
+ } else {
+ return FilterState(f1, FilterState2(stack_id));
+ }
+ }
+
+ F filter_;
+ vector<pair<Label, Label> > parens_;
+ bool expand_; // Expands to FST
+ bool keep_parens_; // Retains parentheses in output
+ FilterState f_; // Current filter state
+ mutable ParenStack stack_;
+ ssize_t paren_id_;
+};
+
// Class to setup composition options for PDT composition.
// Default is for the PDT as the first composition argument.
template <class Arc, bool left_pdt = true>
-class PdtComposeOptions : public
+class PdtComposeFstOptions : public
ComposeFstOptions<Arc,
- MultiEpsMatcher< Matcher<Fst<Arc> > >,
- MultiEpsFilter<AltSequenceComposeFilter<
- MultiEpsMatcher<
- Matcher<Fst<Arc> > > > > > {
+ ParenMatcher< Fst<Arc> >,
+ ParenFilter<AltSequenceComposeFilter<
+ ParenMatcher< Fst<Arc> > > > > {
public:
typedef typename Arc::Label Label;
- typedef MultiEpsMatcher< Matcher<Fst<Arc> > > PdtMatcher;
- typedef MultiEpsFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter;
+ typedef ParenMatcher< Fst<Arc> > PdtMatcher;
+ typedef ParenFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter;
typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
using COptions::matcher1;
using COptions::matcher2;
using COptions::filter;
- PdtComposeOptions(const Fst<Arc> &ifst1,
+ PdtComposeFstOptions(const Fst<Arc> &ifst1,
const vector<pair<Label, Label> > &parens,
- const Fst<Arc> &ifst2) {
- matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsList);
- matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsLoop);
-
- // Treat parens as multi-epsilons when composing.
- for (size_t i = 0; i < parens.size(); ++i) {
- matcher1->AddMultiEpsLabel(parens[i].first);
- matcher1->AddMultiEpsLabel(parens[i].second);
- matcher2->AddMultiEpsLabel(parens[i].first);
- matcher2->AddMultiEpsLabel(parens[i].second);
- }
+ const Fst<Arc> &ifst2, bool expand = false,
+ bool keep_parens = true) {
+ matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenList);
+ matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenLoop);
- filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true);
+ filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
+ expand, keep_parens);
}
};
// Class to setup composition options for PDT with FST composition.
// Specialization is for the FST as the first composition argument.
template <class Arc>
-class PdtComposeOptions<Arc, false> : public
+class PdtComposeFstOptions<Arc, false> : public
ComposeFstOptions<Arc,
- MultiEpsMatcher< Matcher<Fst<Arc> > >,
- MultiEpsFilter<SequenceComposeFilter<
- MultiEpsMatcher<
- Matcher<Fst<Arc> > > > > > {
+ ParenMatcher< Fst<Arc> >,
+ ParenFilter<SequenceComposeFilter<
+ ParenMatcher< Fst<Arc> > > > > {
public:
typedef typename Arc::Label Label;
- typedef MultiEpsMatcher< Matcher<Fst<Arc> > > PdtMatcher;
- typedef MultiEpsFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter;
+ typedef ParenMatcher< Fst<Arc> > PdtMatcher;
+ typedef ParenFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter;
typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
using COptions::matcher1;
using COptions::matcher2;
using COptions::filter;
- PdtComposeOptions(const Fst<Arc> &ifst1,
- const Fst<Arc> &ifst2,
- const vector<pair<Label, Label> > &parens) {
- matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsLoop);
- matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsList);
-
- // Treat parens as multi-epsilons when composing.
- for (size_t i = 0; i < parens.size(); ++i) {
- matcher1->AddMultiEpsLabel(parens[i].first);
- matcher1->AddMultiEpsLabel(parens[i].second);
- matcher2->AddMultiEpsLabel(parens[i].first);
- matcher2->AddMultiEpsLabel(parens[i].second);
- }
+ PdtComposeFstOptions(const Fst<Arc> &ifst1,
+ const Fst<Arc> &ifst2,
+ const vector<pair<Label, Label> > &parens,
+ bool expand = false, bool keep_parens = true) {
+ matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenLoop);
+ matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenList);
- filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true);
+ filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
+ expand, keep_parens);
}
};
+enum PdtComposeFilter {
+ PAREN_FILTER, // Bar-Hillel construction; keeps parentheses
+ EXPAND_FILTER, // Bar-Hillel + expansion; removes parentheses
+ EXPAND_PAREN_FILTER, // Bar-Hillel + expansion; keeps parentheses
+};
+
+struct PdtComposeOptions {
+ bool connect; // Connect output
+ PdtComposeFilter filter_type; // Which pre-defined filter to use
+
+ explicit PdtComposeOptions(bool c, PdtComposeFilter ft = PAREN_FILTER)
+ : connect(c), filter_type(ft) {}
+ PdtComposeOptions() : connect(true), filter_type(PAREN_FILTER) {}
+};
// Composes pushdown transducer (PDT) encoded as an FST (1st arg) and
// an FST (2nd arg) with the result also a PDT encoded as an Fst. (3rd arg).
@@ -110,16 +497,17 @@ void Compose(const Fst<Arc> &ifst1,
typename Arc::Label> > &parens,
const Fst<Arc> &ifst2,
MutableFst<Arc> *ofst,
- const ComposeOptions &opts = ComposeOptions()) {
-
- PdtComposeOptions<Arc, true> copts(ifst1, parens, ifst2);
+ const PdtComposeOptions &opts = PdtComposeOptions()) {
+ bool expand = opts.filter_type != PAREN_FILTER;
+ bool keep_parens = opts.filter_type != EXPAND_FILTER;
+ PdtComposeFstOptions<Arc, true> copts(ifst1, parens, ifst2,
+ expand, keep_parens);
copts.gc_limit = 0;
*ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
if (opts.connect)
Connect(ofst);
}
-
// Composes an FST (1st arg) and pushdown transducer (PDT) encoded as
// an FST (2nd arg) with the result also a PDT encoded as an Fst (3rd arg).
// In the PDTs, some transitions are labeled with open or close
@@ -132,9 +520,11 @@ void Compose(const Fst<Arc> &ifst1,
const vector<pair<typename Arc::Label,
typename Arc::Label> > &parens,
MutableFst<Arc> *ofst,
- const ComposeOptions &opts = ComposeOptions()) {
-
- PdtComposeOptions<Arc, false> copts(ifst1, ifst2, parens);
+ const PdtComposeOptions &opts = PdtComposeOptions()) {
+ bool expand = opts.filter_type != PAREN_FILTER;
+ bool keep_parens = opts.filter_type != EXPAND_FILTER;
+ PdtComposeFstOptions<Arc, false> copts(ifst1, ifst2, parens,
+ expand, keep_parens);
copts.gc_limit = 0;
*ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
if (opts.connect)