From 5b6dc79427b8f7eeb6a7ff68034ab8548ce670ea Mon Sep 17 00:00:00 2001 From: Alexander Gutkin Date: Thu, 28 Feb 2013 00:24:20 +0000 Subject: Bumped OpenFST implementation to openfst-1.3.3-CL41851770. Updated OpenFST implementation to the most recent version used by Greco3 (corresponds to nlp::fst exported at Perforce CL 41851770). In particular this version has an improved PDT support. Change-Id: I5aadfc962297eef73922c67e7d57866f11ee7d81 --- src/include/fst/extensions/pdt/compose.h | 486 ++++++++++++++++++++++++++--- src/include/fst/extensions/pdt/pdt.h | 1 + src/include/fst/extensions/pdt/pdtscript.h | 4 +- src/include/fst/extensions/pdt/replace.h | 27 +- 4 files changed, 459 insertions(+), 59 deletions(-) (limited to 'src/include/fst/extensions/pdt') diff --git a/src/include/fst/extensions/pdt/compose.h b/src/include/fst/extensions/pdt/compose.h index 364d76f..c856c6d 100644 --- a/src/include/fst/extensions/pdt/compose.h +++ b/src/include/fst/extensions/pdt/compose.h @@ -21,82 +21,469 @@ #ifndef FST_EXTENSIONS_PDT_COMPOSE_H__ #define FST_EXTENSIONS_PDT_COMPOSE_H__ +#include + +#include #include namespace fst { +// Return paren arcs for Find(kNoLabel). +const uint32 kParenList = 0x00000001; + +// Return a kNolabel loop for Find(paren). +const uint32 kParenLoop = 0x00000002; + +// This class is a matcher that treats parens as multi-epsilon labels. +// It is most efficient if the parens are in a range non-overlapping with +// the non-paren labels. +template +class ParenMatcher { + public: + typedef SortedMatcher M; + typedef typename M::FST FST; + typedef typename M::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + ParenMatcher(const FST &fst, MatchType match_type, + uint32 flags = (kParenLoop | kParenList)) + : matcher_(fst, match_type), + match_type_(match_type), + flags_(flags) { + if (match_type == MATCH_INPUT) { + loop_.ilabel = kNoLabel; + loop_.olabel = 0; + } else { + loop_.ilabel = 0; + loop_.olabel = kNoLabel; + } + loop_.weight = Weight::One(); + loop_.nextstate = kNoStateId; + } + + ParenMatcher(const ParenMatcher &matcher, bool safe = false) + : matcher_(matcher.matcher_, safe), + match_type_(matcher.match_type_), + flags_(matcher.flags_), + open_parens_(matcher.open_parens_), + close_parens_(matcher.close_parens_), + loop_(matcher.loop_) { + loop_.nextstate = kNoStateId; + } + + ParenMatcher *Copy(bool safe = false) const { + return new ParenMatcher(*this, safe); + } + + MatchType Type(bool test) const { return matcher_.Type(test); } + + void SetState(StateId s) { + matcher_.SetState(s); + loop_.nextstate = s; + } + + bool Find(Label match_label); + + bool Done() const { + return done_; + } + + const Arc& Value() const { + return paren_loop_ ? loop_ : matcher_.Value(); + } + + void Next(); + + const FST &GetFst() const { return matcher_.GetFst(); } + + uint64 Properties(uint64 props) const { return matcher_.Properties(props); } + + uint32 Flags() const { return matcher_.Flags(); } + + void AddOpenParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad open paren label: 0"; + } else { + open_parens_.Insert(label); + } + } + + void AddCloseParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad close paren label: 0"; + } else { + close_parens_.Insert(label); + } + } + + void RemoveOpenParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad open paren label: 0"; + } else { + open_parens_.Erase(label); + } + } + + void RemoveCloseParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad close paren label: 0"; + } else { + close_parens_.Erase(label); + } + } + + void ClearOpenParens() { + open_parens_.Clear(); + } + + void ClearCloseParens() { + close_parens_.Clear(); + } + + bool IsOpenParen(Label label) const { + return open_parens_.Member(label); + } + + bool IsCloseParen(Label label) const { + return close_parens_.Member(label); + } + + private: + // Advances matcher to next open paren if it exists, returning true. + // O.w. returns false. + bool NextOpenParen(); + + // Advances matcher to next open paren if it exists, returning true. + // O.w. returns false. + bool NextCloseParen(); + + M matcher_; + MatchType match_type_; // Type of match to perform + uint32 flags_; + + // open paren label set + CompactSet open_parens_; + + // close paren label set + CompactSet close_parens_; + + + bool open_paren_list_; // Matching open paren list + bool close_paren_list_; // Matching close paren list + bool paren_loop_; // Current arc is the implicit paren loop + mutable Arc loop_; // For non-consuming symbols + bool done_; // Matching done + + void operator=(const ParenMatcher &); // Disallow +}; + +template inline +bool ParenMatcher::Find(Label match_label) { + open_paren_list_ = false; + close_paren_list_ = false; + paren_loop_ = false; + done_ = false; + + // Returns all parenthesis arcs + if (match_label == kNoLabel && (flags_ & kParenList)) { + if (open_parens_.LowerBound() != kNoLabel) { + matcher_.LowerBound(open_parens_.LowerBound()); + open_paren_list_ = NextOpenParen(); + if (open_paren_list_) return true; + } + if (close_parens_.LowerBound() != kNoLabel) { + matcher_.LowerBound(close_parens_.LowerBound()); + close_paren_list_ = NextCloseParen(); + if (close_paren_list_) return true; + } + } + + // Returns 'implicit' paren loop + if (match_label > 0 && (flags_ & kParenLoop) && + (IsOpenParen(match_label) || IsCloseParen(match_label))) { + paren_loop_ = true; + return true; + } + + // Returns all other labels + if (matcher_.Find(match_label)) + return true; + + done_ = true; + return false; +} + +template inline +void ParenMatcher::Next() { + if (paren_loop_) { + paren_loop_ = false; + done_ = true; + } else if (open_paren_list_) { + matcher_.Next(); + open_paren_list_ = NextOpenParen(); + if (open_paren_list_) return; + + if (close_parens_.LowerBound() != kNoLabel) { + matcher_.LowerBound(close_parens_.LowerBound()); + close_paren_list_ = NextCloseParen(); + if (close_paren_list_) return; + } + done_ = !matcher_.Find(kNoLabel); + } else if (close_paren_list_) { + matcher_.Next(); + close_paren_list_ = NextCloseParen(); + if (close_paren_list_) return; + done_ = !matcher_.Find(kNoLabel); + } else { + matcher_.Next(); + done_ = matcher_.Done(); + } +} + +// Advances matcher to next open paren if it exists, returning true. +// O.w. returns false. +template inline +bool ParenMatcher::NextOpenParen() { + for (; !matcher_.Done(); matcher_.Next()) { + Label label = match_type_ == MATCH_INPUT ? + matcher_.Value().ilabel : matcher_.Value().olabel; + if (label > open_parens_.UpperBound()) + return false; + if (IsOpenParen(label)) + return true; + } + return false; +} + +// Advances matcher to next close paren if it exists, returning true. +// O.w. returns false. +template inline +bool ParenMatcher::NextCloseParen() { + for (; !matcher_.Done(); matcher_.Next()) { + Label label = match_type_ == MATCH_INPUT ? + matcher_.Value().ilabel : matcher_.Value().olabel; + if (label > close_parens_.UpperBound()) + return false; + if (IsCloseParen(label)) + return true; + } + return false; +} + + +template +class ParenFilter { + public: + typedef typename F::FST1 FST1; + typedef typename F::FST2 FST2; + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef typename F::Matcher1 Matcher1; + typedef typename F::Matcher2 Matcher2; + typedef typename F::FilterState FilterState1; + typedef StateId StackId; + typedef PdtStack ParenStack; + typedef IntegerFilterState FilterState2; + typedef PairFilterState FilterState; + typedef ParenFilter Filter; + + ParenFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = 0, Matcher2 *matcher2 = 0, + const vector > *parens = 0, + bool expand = false, bool keep_parens = true) + : filter_(fst1, fst2, matcher1, matcher2), + parens_(parens ? *parens : vector >()), + expand_(expand), + keep_parens_(keep_parens), + f_(FilterState::NoState()), + stack_(parens_), + paren_id_(-1) { + if (parens) { + for (size_t i = 0; i < parens->size(); ++i) { + const pair &p = (*parens)[i]; + parens_.push_back(p); + GetMatcher1()->AddOpenParen(p.first); + GetMatcher2()->AddOpenParen(p.first); + if (!expand_) { + GetMatcher1()->AddCloseParen(p.second); + GetMatcher2()->AddCloseParen(p.second); + } + } + } + } + + ParenFilter(const Filter &filter, bool safe = false) + : filter_(filter.filter_, safe), + parens_(filter.parens_), + expand_(filter.expand_), + keep_parens_(filter.keep_parens_), + f_(FilterState::NoState()), + stack_(filter.parens_), + paren_id_(-1) { } + + FilterState Start() const { + return FilterState(filter_.Start(), FilterState2(0)); + } + + void SetState(StateId s1, StateId s2, const FilterState &f) { + f_ = f; + filter_.SetState(s1, s2, f_.GetState1()); + if (!expand_) + return; + + ssize_t paren_id = stack_.Top(f.GetState2().GetState()); + if (paren_id != paren_id_) { + if (paren_id_ != -1) { + GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second); + GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second); + } + paren_id_ = paren_id; + if (paren_id_ != -1) { + GetMatcher1()->AddCloseParen(parens_[paren_id_].second); + GetMatcher2()->AddCloseParen(parens_[paren_id_].second); + } + } + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + FilterState1 f1 = filter_.FilterArc(arc1, arc2); + const FilterState2 &f2 = f_.GetState2(); + if (f1 == FilterState1::NoState()) + return FilterState::NoState(); + + if (arc1->olabel == kNoLabel && arc2->ilabel) { // arc2 parentheses + if (keep_parens_) { + arc1->ilabel = arc2->ilabel; + } else if (arc2->ilabel) { + arc2->olabel = arc1->ilabel; + } + return FilterParen(arc2->ilabel, f1, f2); + } else if (arc2->ilabel == kNoLabel && arc1->olabel) { // arc1 parentheses + if (keep_parens_) { + arc2->olabel = arc1->olabel; + } else { + arc1->ilabel = arc2->olabel; + } + return FilterParen(arc1->olabel, f1, f2); + } else { + return FilterState(f1, f2); + } + } + + void FilterFinal(Weight *w1, Weight *w2) const { + if (f_.GetState2().GetState() != 0) + *w1 = Weight::Zero(); + filter_.FilterFinal(w1, w2); + } + + // Return resp matchers. Ownership stays with filter. + Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); } + Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); } + + uint64 Properties(uint64 iprops) const { + uint64 oprops = filter_.Properties(iprops); + return oprops & kILabelInvariantProperties & kOLabelInvariantProperties; + } + + private: + const FilterState FilterParen(Label label, const FilterState1 &f1, + const FilterState2 &f2) const { + if (!expand_) + return FilterState(f1, f2); + + StackId stack_id = stack_.Find(f2.GetState(), label); + if (stack_id < 0) { + return FilterState::NoState(); + } else { + return FilterState(f1, FilterState2(stack_id)); + } + } + + F filter_; + vector > parens_; + bool expand_; // Expands to FST + bool keep_parens_; // Retains parentheses in output + FilterState f_; // Current filter state + mutable ParenStack stack_; + ssize_t paren_id_; +}; + // Class to setup composition options for PDT composition. // Default is for the PDT as the first composition argument. template -class PdtComposeOptions : public +class PdtComposeFstOptions : public ComposeFstOptions > >, - MultiEpsFilter > > > > > { + ParenMatcher< Fst >, + ParenFilter > > > > { public: typedef typename Arc::Label Label; - typedef MultiEpsMatcher< Matcher > > PdtMatcher; - typedef MultiEpsFilter > PdtFilter; + typedef ParenMatcher< Fst > PdtMatcher; + typedef ParenFilter > PdtFilter; typedef ComposeFstOptions COptions; using COptions::matcher1; using COptions::matcher2; using COptions::filter; - PdtComposeOptions(const Fst &ifst1, + PdtComposeFstOptions(const Fst &ifst1, const vector > &parens, - const Fst &ifst2) { - matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsList); - matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsLoop); - - // Treat parens as multi-epsilons when composing. - for (size_t i = 0; i < parens.size(); ++i) { - matcher1->AddMultiEpsLabel(parens[i].first); - matcher1->AddMultiEpsLabel(parens[i].second); - matcher2->AddMultiEpsLabel(parens[i].first); - matcher2->AddMultiEpsLabel(parens[i].second); - } + const Fst &ifst2, bool expand = false, + bool keep_parens = true) { + matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenList); + matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenLoop); - filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true); + filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, + expand, keep_parens); } }; // Class to setup composition options for PDT with FST composition. // Specialization is for the FST as the first composition argument. template -class PdtComposeOptions : public +class PdtComposeFstOptions : public ComposeFstOptions > >, - MultiEpsFilter > > > > > { + ParenMatcher< Fst >, + ParenFilter > > > > { public: typedef typename Arc::Label Label; - typedef MultiEpsMatcher< Matcher > > PdtMatcher; - typedef MultiEpsFilter > PdtFilter; + typedef ParenMatcher< Fst > PdtMatcher; + typedef ParenFilter > PdtFilter; typedef ComposeFstOptions COptions; using COptions::matcher1; using COptions::matcher2; using COptions::filter; - PdtComposeOptions(const Fst &ifst1, - const Fst &ifst2, - const vector > &parens) { - matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsLoop); - matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsList); - - // Treat parens as multi-epsilons when composing. - for (size_t i = 0; i < parens.size(); ++i) { - matcher1->AddMultiEpsLabel(parens[i].first); - matcher1->AddMultiEpsLabel(parens[i].second); - matcher2->AddMultiEpsLabel(parens[i].first); - matcher2->AddMultiEpsLabel(parens[i].second); - } + PdtComposeFstOptions(const Fst &ifst1, + const Fst &ifst2, + const vector > &parens, + bool expand = false, bool keep_parens = true) { + matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenLoop); + matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenList); - filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true); + filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, + expand, keep_parens); } }; +enum PdtComposeFilter { + PAREN_FILTER, // Bar-Hillel construction; keeps parentheses + EXPAND_FILTER, // Bar-Hillel + expansion; removes parentheses + EXPAND_PAREN_FILTER, // Bar-Hillel + expansion; keeps parentheses +}; + +struct PdtComposeOptions { + bool connect; // Connect output + PdtComposeFilter filter_type; // Which pre-defined filter to use + + explicit PdtComposeOptions(bool c, PdtComposeFilter ft = PAREN_FILTER) + : connect(c), filter_type(ft) {} + PdtComposeOptions() : connect(true), filter_type(PAREN_FILTER) {} +}; // Composes pushdown transducer (PDT) encoded as an FST (1st arg) and // an FST (2nd arg) with the result also a PDT encoded as an Fst. (3rd arg). @@ -110,16 +497,17 @@ void Compose(const Fst &ifst1, typename Arc::Label> > &parens, const Fst &ifst2, MutableFst *ofst, - const ComposeOptions &opts = ComposeOptions()) { - - PdtComposeOptions copts(ifst1, parens, ifst2); + const PdtComposeOptions &opts = PdtComposeOptions()) { + bool expand = opts.filter_type != PAREN_FILTER; + bool keep_parens = opts.filter_type != EXPAND_FILTER; + PdtComposeFstOptions copts(ifst1, parens, ifst2, + expand, keep_parens); copts.gc_limit = 0; *ofst = ComposeFst(ifst1, ifst2, copts); if (opts.connect) Connect(ofst); } - // Composes an FST (1st arg) and pushdown transducer (PDT) encoded as // an FST (2nd arg) with the result also a PDT encoded as an Fst (3rd arg). // In the PDTs, some transitions are labeled with open or close @@ -132,9 +520,11 @@ void Compose(const Fst &ifst1, const vector > &parens, MutableFst *ofst, - const ComposeOptions &opts = ComposeOptions()) { - - PdtComposeOptions copts(ifst1, ifst2, parens); + const PdtComposeOptions &opts = PdtComposeOptions()) { + bool expand = opts.filter_type != PAREN_FILTER; + bool keep_parens = opts.filter_type != EXPAND_FILTER; + PdtComposeFstOptions copts(ifst1, ifst2, parens, + expand, keep_parens); copts.gc_limit = 0; *ofst = ComposeFst(ifst1, ifst2, copts); if (opts.connect) diff --git a/src/include/fst/extensions/pdt/pdt.h b/src/include/fst/extensions/pdt/pdt.h index 6649f55..c56afbd 100644 --- a/src/include/fst/extensions/pdt/pdt.h +++ b/src/include/fst/extensions/pdt/pdt.h @@ -27,6 +27,7 @@ using std::tr1::unordered_multimap; #include #include +#include #include #include diff --git a/src/include/fst/extensions/pdt/pdtscript.h b/src/include/fst/extensions/pdt/pdtscript.h index c2a1cf4..84bb27e 100644 --- a/src/include/fst/extensions/pdt/pdtscript.h +++ b/src/include/fst/extensions/pdt/pdtscript.h @@ -48,7 +48,7 @@ typedef args::Package >&, MutableFstClass *, - const ComposeOptions &, + const PdtComposeOptions &, bool> PdtComposeArgs; template @@ -76,7 +76,7 @@ void PdtCompose(const FstClass & ifst1, const FstClass & ifst2, const vector > &parens, MutableFstClass *ofst, - const ComposeOptions &copts, + const PdtComposeOptions &copts, bool left_pdt); // PDT EXPAND diff --git a/src/include/fst/extensions/pdt/replace.h b/src/include/fst/extensions/pdt/replace.h index a85d0fe..9081400 100644 --- a/src/include/fst/extensions/pdt/replace.h +++ b/src/include/fst/extensions/pdt/replace.h @@ -21,6 +21,10 @@ #ifndef FST_EXTENSIONS_PDT_REPLACE_H__ #define FST_EXTENSIONS_PDT_REPLACE_H__ +#include +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; + #include namespace fst { @@ -62,11 +66,14 @@ void Replace(const vector non_term_queue; // Queue of non-terminals to replace - unordered_set