aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/extensions/pdt
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/extensions/pdt')
-rw-r--r--src/include/fst/extensions/pdt/collection.h122
-rw-r--r--src/include/fst/extensions/pdt/compose.h146
-rw-r--r--src/include/fst/extensions/pdt/expand.h975
-rw-r--r--src/include/fst/extensions/pdt/info.h175
-rw-r--r--src/include/fst/extensions/pdt/paren.h496
-rw-r--r--src/include/fst/extensions/pdt/pdt.h212
-rw-r--r--src/include/fst/extensions/pdt/pdtlib.h30
-rw-r--r--src/include/fst/extensions/pdt/pdtscript.h284
-rw-r--r--src/include/fst/extensions/pdt/replace.h192
-rw-r--r--src/include/fst/extensions/pdt/reverse.h58
-rw-r--r--src/include/fst/extensions/pdt/shortest-path.h790
11 files changed, 3480 insertions, 0 deletions
diff --git a/src/include/fst/extensions/pdt/collection.h b/src/include/fst/extensions/pdt/collection.h
new file mode 100644
index 0000000..26be504
--- /dev/null
+++ b/src/include/fst/extensions/pdt/collection.h
@@ -0,0 +1,122 @@
+// collection.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Class to store a collection of sets with elements of type T.
+
+#ifndef FST_EXTENSIONS_PDT_COLLECTION_H__
+#define FST_EXTENSIONS_PDT_COLLECTION_H__
+
+#include <algorithm>
+#include <vector>
+using std::vector;
+
+#include <fst/bi-table.h>
+
+namespace fst {
+
+// Stores a collection of non-empty sets with elements of type T. A
+// default constructor, equality ==, a total order <, and an STL-style
+// hash class must be defined on the elements. Provides signed
+// integer ID (of type I) of each unique set. The IDs are allocated
+// starting from 0 in order.
+template <class I, class T>
+class Collection {
+ public:
+ struct Node { // Trie node
+ I node_id; // Root is kNoNodeId;
+ T element;
+
+ Node() : node_id(kNoNodeId), element(T()) {}
+ Node(I i, const T &t) : node_id(i), element(t) {}
+
+ bool operator==(const Node& n) const {
+ return n.node_id == node_id && n.element == element;
+ }
+ };
+
+ struct NodeHash {
+ size_t operator()(const Node &n) const {
+ return n.node_id + hash_(n.element) * kPrime;
+ }
+ };
+
+ typedef CompactHashBiTable<I, Node, NodeHash> NodeTable;
+
+ class SetIterator {
+ public:
+ SetIterator(I id, Node node, NodeTable *node_table)
+ :id_(id), node_(node), node_table_(node_table) {}
+
+ bool Done() const { return id_ == kNoNodeId; }
+
+ const T &Element() const { return node_.element; }
+
+ void Next() {
+ id_ = node_.node_id;
+ if (id_ != kNoNodeId)
+ node_ = node_table_->FindEntry(id_);
+ }
+
+ private:
+ I id_; // Iterator set node id
+ Node node_; // Iterator set node
+ NodeTable *node_table_;
+ };
+
+ Collection() {}
+
+ // Lookups integer ID from set. If it doesn't exist, then adds it.
+ // Set elements should be in strict order (and therefore unique).
+ I FindId(const vector<T> &set) {
+ I node_id = kNoNodeId;
+ for (ssize_t i = set.size() - 1; i >= 0; --i) {
+ Node node(node_id, set[i]);
+ node_id = node_table_.FindId(node);
+ }
+ return node_id;
+ }
+
+ // Finds set given integer ID. Returns true if ID corresponds
+ // to set. Use iterators below to traverse result.
+ SetIterator FindSet(I id) {
+ if (id < 0 && id >= node_table_.Size()) {
+ return SetIterator(kNoNodeId, Node(kNoNodeId, T()), &node_table_);
+ } else {
+ return SetIterator(id, node_table_.FindEntry(id), &node_table_);
+ }
+ }
+
+ private:
+ static const I kNoNodeId;
+ static const size_t kPrime;
+ static std::tr1::hash<T> hash_;
+
+ NodeTable node_table_;
+
+ DISALLOW_COPY_AND_ASSIGN(Collection);
+};
+
+template<class I, class T> const I Collection<I, T>::kNoNodeId = -1;
+
+template <class I, class T> const size_t Collection<I, T>::kPrime = 7853;
+
+template <class I, class T> std::tr1::hash<T> Collection<I, T>::hash_;
+
+} // namespace fst
+
+#endif // FST_EXTENSIONS_PDT_COLLECTION_H__
diff --git a/src/include/fst/extensions/pdt/compose.h b/src/include/fst/extensions/pdt/compose.h
new file mode 100644
index 0000000..364d76f
--- /dev/null
+++ b/src/include/fst/extensions/pdt/compose.h
@@ -0,0 +1,146 @@
+// compose.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Compose a PDT and an FST.
+
+#ifndef FST_EXTENSIONS_PDT_COMPOSE_H__
+#define FST_EXTENSIONS_PDT_COMPOSE_H__
+
+#include <fst/compose.h>
+
+namespace fst {
+
+// Class to setup composition options for PDT composition.
+// Default is for the PDT as the first composition argument.
+template <class Arc, bool left_pdt = true>
+class PdtComposeOptions : public
+ComposeFstOptions<Arc,
+ MultiEpsMatcher< Matcher<Fst<Arc> > >,
+ MultiEpsFilter<AltSequenceComposeFilter<
+ MultiEpsMatcher<
+ Matcher<Fst<Arc> > > > > > {
+ public:
+ typedef typename Arc::Label Label;
+ typedef MultiEpsMatcher< Matcher<Fst<Arc> > > PdtMatcher;
+ typedef MultiEpsFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter;
+ typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
+ using COptions::matcher1;
+ using COptions::matcher2;
+ using COptions::filter;
+
+ PdtComposeOptions(const Fst<Arc> &ifst1,
+ const vector<pair<Label, Label> > &parens,
+ const Fst<Arc> &ifst2) {
+ matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsList);
+ matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsLoop);
+
+ // Treat parens as multi-epsilons when composing.
+ for (size_t i = 0; i < parens.size(); ++i) {
+ matcher1->AddMultiEpsLabel(parens[i].first);
+ matcher1->AddMultiEpsLabel(parens[i].second);
+ matcher2->AddMultiEpsLabel(parens[i].first);
+ matcher2->AddMultiEpsLabel(parens[i].second);
+ }
+
+ filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true);
+ }
+};
+
+// Class to setup composition options for PDT with FST composition.
+// Specialization is for the FST as the first composition argument.
+template <class Arc>
+class PdtComposeOptions<Arc, false> : public
+ComposeFstOptions<Arc,
+ MultiEpsMatcher< Matcher<Fst<Arc> > >,
+ MultiEpsFilter<SequenceComposeFilter<
+ MultiEpsMatcher<
+ Matcher<Fst<Arc> > > > > > {
+ public:
+ typedef typename Arc::Label Label;
+ typedef MultiEpsMatcher< Matcher<Fst<Arc> > > PdtMatcher;
+ typedef MultiEpsFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter;
+ typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
+ using COptions::matcher1;
+ using COptions::matcher2;
+ using COptions::filter;
+
+ PdtComposeOptions(const Fst<Arc> &ifst1,
+ const Fst<Arc> &ifst2,
+ const vector<pair<Label, Label> > &parens) {
+ matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsLoop);
+ matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsList);
+
+ // Treat parens as multi-epsilons when composing.
+ for (size_t i = 0; i < parens.size(); ++i) {
+ matcher1->AddMultiEpsLabel(parens[i].first);
+ matcher1->AddMultiEpsLabel(parens[i].second);
+ matcher2->AddMultiEpsLabel(parens[i].first);
+ matcher2->AddMultiEpsLabel(parens[i].second);
+ }
+
+ filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true);
+ }
+};
+
+
+// Composes pushdown transducer (PDT) encoded as an FST (1st arg) and
+// an FST (2nd arg) with the result also a PDT encoded as an Fst. (3rd arg).
+// In the PDTs, some transitions are labeled with open or close
+// parentheses. To be interpreted as a PDT, the parens must balance on
+// a path (see PdtExpand()). The open-close parenthesis label pairs
+// are passed in 'parens'.
+template <class Arc>
+void Compose(const Fst<Arc> &ifst1,
+ const vector<pair<typename Arc::Label,
+ typename Arc::Label> > &parens,
+ const Fst<Arc> &ifst2,
+ MutableFst<Arc> *ofst,
+ const ComposeOptions &opts = ComposeOptions()) {
+
+ PdtComposeOptions<Arc, true> copts(ifst1, parens, ifst2);
+ copts.gc_limit = 0;
+ *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
+ if (opts.connect)
+ Connect(ofst);
+}
+
+
+// Composes an FST (1st arg) and pushdown transducer (PDT) encoded as
+// an FST (2nd arg) with the result also a PDT encoded as an Fst (3rd arg).
+// In the PDTs, some transitions are labeled with open or close
+// parentheses. To be interpreted as a PDT, the parens must balance on
+// a path (see ExpandFst()). The open-close parenthesis label pairs
+// are passed in 'parens'.
+template <class Arc>
+void Compose(const Fst<Arc> &ifst1,
+ const Fst<Arc> &ifst2,
+ const vector<pair<typename Arc::Label,
+ typename Arc::Label> > &parens,
+ MutableFst<Arc> *ofst,
+ const ComposeOptions &opts = ComposeOptions()) {
+
+ PdtComposeOptions<Arc, false> copts(ifst1, ifst2, parens);
+ copts.gc_limit = 0;
+ *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
+ if (opts.connect)
+ Connect(ofst);
+}
+
+} // namespace fst
+
+#endif // FST_EXTENSIONS_PDT_COMPOSE_H__
diff --git a/src/include/fst/extensions/pdt/expand.h b/src/include/fst/extensions/pdt/expand.h
new file mode 100644
index 0000000..f464403
--- /dev/null
+++ b/src/include/fst/extensions/pdt/expand.h
@@ -0,0 +1,975 @@
+// expand.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Expand a PDT to an FST.
+
+#ifndef FST_EXTENSIONS_PDT_EXPAND_H__
+#define FST_EXTENSIONS_PDT_EXPAND_H__
+
+#include <vector>
+using std::vector;
+
+#include <fst/extensions/pdt/pdt.h>
+#include <fst/extensions/pdt/paren.h>
+#include <fst/extensions/pdt/shortest-path.h>
+#include <fst/extensions/pdt/reverse.h>
+#include <fst/cache.h>
+#include <fst/mutable-fst.h>
+#include <fst/queue.h>
+#include <fst/state-table.h>
+#include <fst/test-properties.h>
+
+namespace fst {
+
+template <class Arc>
+struct ExpandFstOptions : public CacheOptions {
+ bool keep_parentheses;
+ PdtStack<typename Arc::StateId, typename Arc::Label> *stack;
+ PdtStateTable<typename Arc::StateId, typename Arc::StateId> *state_table;
+
+ ExpandFstOptions(
+ const CacheOptions &opts = CacheOptions(),
+ bool kp = false,
+ PdtStack<typename Arc::StateId, typename Arc::Label> *s = 0,
+ PdtStateTable<typename Arc::StateId, typename Arc::StateId> *st = 0)
+ : CacheOptions(opts), keep_parentheses(kp), stack(s), state_table(st) {}
+};
+
+// Properties for an expanded PDT.
+inline uint64 ExpandProperties(uint64 inprops) {
+ return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted);
+}
+
+
+// Implementation class for ExpandFst
+template <class A>
+class ExpandFstImpl
+ : public CacheImpl<A> {
+ public:
+ using FstImpl<A>::SetType;
+ using FstImpl<A>::SetProperties;
+ using FstImpl<A>::Properties;
+ using FstImpl<A>::SetInputSymbols;
+ using FstImpl<A>::SetOutputSymbols;
+
+ using CacheBaseImpl< CacheState<A> >::PushArc;
+ using CacheBaseImpl< CacheState<A> >::HasArcs;
+ using CacheBaseImpl< CacheState<A> >::HasFinal;
+ using CacheBaseImpl< CacheState<A> >::HasStart;
+ using CacheBaseImpl< CacheState<A> >::SetArcs;
+ using CacheBaseImpl< CacheState<A> >::SetFinal;
+ using CacheBaseImpl< CacheState<A> >::SetStart;
+
+ typedef A Arc;
+ typedef typename A::Label Label;
+ typedef typename A::Weight Weight;
+ typedef typename A::StateId StateId;
+ typedef StateId StackId;
+ typedef PdtStateTuple<StateId, StackId> StateTuple;
+
+ ExpandFstImpl(const Fst<A> &fst,
+ const vector<pair<typename Arc::Label,
+ typename Arc::Label> > &parens,
+ const ExpandFstOptions<A> &opts)
+ : CacheImpl<A>(opts), fst_(fst.Copy()),
+ stack_(opts.stack ? opts.stack: new PdtStack<StateId, Label>(parens)),
+ state_table_(opts.state_table ? opts.state_table :
+ new PdtStateTable<StateId, StackId>()),
+ own_stack_(opts.stack == 0), own_state_table_(opts.state_table == 0),
+ keep_parentheses_(opts.keep_parentheses) {
+ SetType("expand");
+
+ uint64 props = fst.Properties(kFstProperties, false);
+ SetProperties(ExpandProperties(props), kCopyProperties);
+
+ SetInputSymbols(fst.InputSymbols());
+ SetOutputSymbols(fst.OutputSymbols());
+ }
+
+ ExpandFstImpl(const ExpandFstImpl &impl)
+ : CacheImpl<A>(impl),
+ fst_(impl.fst_->Copy(true)),
+ stack_(new PdtStack<StateId, Label>(*impl.stack_)),
+ state_table_(new PdtStateTable<StateId, StackId>()),
+ own_stack_(true), own_state_table_(true),
+ keep_parentheses_(impl.keep_parentheses_) {
+ SetType("expand");
+ SetProperties(impl.Properties(), kCopyProperties);
+ SetInputSymbols(impl.InputSymbols());
+ SetOutputSymbols(impl.OutputSymbols());
+ }
+
+ ~ExpandFstImpl() {
+ delete fst_;
+ if (own_stack_)
+ delete stack_;
+ if (own_state_table_)
+ delete state_table_;
+ }
+
+ StateId Start() {
+ if (!HasStart()) {
+ StateId s = fst_->Start();
+ if (s == kNoStateId)
+ return kNoStateId;
+ StateTuple tuple(s, 0);
+ StateId start = state_table_->FindState(tuple);
+ SetStart(start);
+ }
+ return CacheImpl<A>::Start();
+ }
+
+ Weight Final(StateId s) {
+ if (!HasFinal(s)) {
+ const StateTuple &tuple = state_table_->Tuple(s);
+ Weight w = fst_->Final(tuple.state_id);
+ if (w != Weight::Zero() && tuple.stack_id == 0)
+ SetFinal(s, w);
+ else
+ SetFinal(s, Weight::Zero());
+ }
+ return CacheImpl<A>::Final(s);
+ }
+
+ size_t NumArcs(StateId s) {
+ if (!HasArcs(s)) {
+ ExpandState(s);
+ }
+ return CacheImpl<A>::NumArcs(s);
+ }
+
+ size_t NumInputEpsilons(StateId s) {
+ if (!HasArcs(s))
+ ExpandState(s);
+ return CacheImpl<A>::NumInputEpsilons(s);
+ }
+
+ size_t NumOutputEpsilons(StateId s) {
+ if (!HasArcs(s))
+ ExpandState(s);
+ return CacheImpl<A>::NumOutputEpsilons(s);
+ }
+
+ void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
+ if (!HasArcs(s))
+ ExpandState(s);
+ CacheImpl<A>::InitArcIterator(s, data);
+ }
+
+ // Computes the outgoing transitions from a state, creating new destination
+ // states as needed.
+ void ExpandState(StateId s) {
+ StateTuple tuple = state_table_->Tuple(s);
+ for (ArcIterator< Fst<A> > aiter(*fst_, tuple.state_id);
+ !aiter.Done(); aiter.Next()) {
+ Arc arc = aiter.Value();
+ StackId stack_id = stack_->Find(tuple.stack_id, arc.ilabel);
+ if (stack_id == -1) {
+ // Non-matching close parenthesis
+ continue;
+ } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) {
+ // Stack push/pop
+ arc.ilabel = arc.olabel = 0;
+ }
+
+ StateTuple ntuple(arc.nextstate, stack_id);
+ arc.nextstate = state_table_->FindState(ntuple);
+ PushArc(s, arc);
+ }
+ SetArcs(s);
+ }
+
+ const PdtStack<StackId, Label> &GetStack() const { return *stack_; }
+
+ const PdtStateTable<StateId, StackId> &GetStateTable() const {
+ return *state_table_;
+ }
+
+ private:
+ const Fst<A> *fst_;
+
+ PdtStack<StackId, Label> *stack_;
+ PdtStateTable<StateId, StackId> *state_table_;
+ bool own_stack_;
+ bool own_state_table_;
+ bool keep_parentheses_;
+
+ void operator=(const ExpandFstImpl<A> &); // disallow
+};
+
+// Expands a pushdown transducer (PDT) encoded as an FST into an FST.
+// This version is a delayed Fst. In the PDT, some transitions are
+// labeled with open or close parentheses. To be interpreted as a PDT,
+// the parens must balance on a path. The open-close parenthesis label
+// pairs are passed in 'parens'. The expansion enforces the
+// parenthesis constraints. The PDT must be expandable as an FST.
+//
+// This class attaches interface to implementation and handles
+// reference counting, delegating most methods to ImplToFst.
+template <class A>
+class ExpandFst : public ImplToFst< ExpandFstImpl<A> > {
+ public:
+ friend class ArcIterator< ExpandFst<A> >;
+ friend class StateIterator< ExpandFst<A> >;
+
+ typedef A Arc;
+ typedef typename A::Label Label;
+ typedef typename A::Weight Weight;
+ typedef typename A::StateId StateId;
+ typedef StateId StackId;
+ typedef CacheState<A> State;
+ typedef ExpandFstImpl<A> Impl;
+
+ ExpandFst(const Fst<A> &fst,
+ const vector<pair<typename Arc::Label,
+ typename Arc::Label> > &parens)
+ : ImplToFst<Impl>(new Impl(fst, parens, ExpandFstOptions<A>())) {}
+
+ ExpandFst(const Fst<A> &fst,
+ const vector<pair<typename Arc::Label,
+ typename Arc::Label> > &parens,
+ const ExpandFstOptions<A> &opts)
+ : ImplToFst<Impl>(new Impl(fst, parens, opts)) {}
+
+ // See Fst<>::Copy() for doc.
+ ExpandFst(const ExpandFst<A> &fst, bool safe = false)
+ : ImplToFst<Impl>(fst, safe) {}
+
+ // Get a copy of this ExpandFst. See Fst<>::Copy() for further doc.
+ virtual ExpandFst<A> *Copy(bool safe = false) const {
+ return new ExpandFst<A>(*this, safe);
+ }
+
+ virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
+
+ virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
+ GetImpl()->InitArcIterator(s, data);
+ }
+
+ const PdtStack<StackId, Label> &GetStack() const {
+ return GetImpl()->GetStack();
+ }
+
+ const PdtStateTable<StateId, StackId> &GetStateTable() const {
+ return GetImpl()->GetStateTable();
+ }
+
+ private:
+ // Makes visible to friends.
+ Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
+
+ void operator=(const ExpandFst<A> &fst); // Disallow
+};
+
+
+// Specialization for ExpandFst.
+template<class A>
+class StateIterator< ExpandFst<A> >
+ : public CacheStateIterator< ExpandFst<A> > {
+ public:
+ explicit StateIterator(const ExpandFst<A> &fst)
+ : CacheStateIterator< ExpandFst<A> >(fst, fst.GetImpl()) {}
+};
+
+
+// Specialization for ExpandFst.
+template <class A>
+class ArcIterator< ExpandFst<A> >
+ : public CacheArcIterator< ExpandFst<A> > {
+ public:
+ typedef typename A::StateId StateId;
+
+ ArcIterator(const ExpandFst<A> &fst, StateId s)
+ : CacheArcIterator< ExpandFst<A> >(fst.GetImpl(), s) {
+ if (!fst.GetImpl()->HasArcs(s))
+ fst.GetImpl()->ExpandState(s);
+ }
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(ArcIterator);
+};
+
+
+template <class A> inline
+void ExpandFst<A>::InitStateIterator(StateIteratorData<A> *data) const
+{
+ data->base = new StateIterator< ExpandFst<A> >(*this);
+}
+
+//
+// PrunedExpand Class
+//
+
+// Prunes the delayed expansion of a pushdown transducer (PDT) encoded
+// as an FST into an FST. In the PDT, some transitions are labeled
+// with open or close parentheses. To be interpreted as a PDT, the
+// parens must balance on a path. The open-close parenthesis label
+// pairs are passed in 'parens'. The expansion enforces the
+// parenthesis constraints.
+//
+// The algorithm works by visiting the delayed ExpandFst using a
+// shortest-stack first queue discipline and relies on the
+// shortest-distance information computed using a reverse
+// shortest-path call to perform the pruning.
+//
+// The algorithm maintains the same state ordering between the ExpandFst
+// being visited 'efst_' and the result of pruning written into the
+// MutableFst 'ofst_' to improve readability of the code.
+//
+template <class A>
+class PrunedExpand {
+ public:
+ typedef A Arc;
+ typedef typename A::Label Label;
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+ typedef StateId StackId;
+ typedef PdtStack<StackId, Label> Stack;
+ typedef PdtStateTable<StateId, StackId> StateTable;
+ typedef typename PdtBalanceData<Arc>::SetIterator SetIterator;
+
+ // Constructor taking as input a PDT specified by 'ifst' and 'parens'.
+ // 'keep_parentheses' specifies whether parentheses are replaced by
+ // epsilons or not during the expansion. 'opts' is the cache options
+ // used to instantiate the underlying ExpandFst.
+ PrunedExpand(const Fst<A> &ifst,
+ const vector<pair<Label, Label> > &parens,
+ bool keep_parentheses = false,
+ const CacheOptions &opts = CacheOptions())
+ : ifst_(ifst.Copy()),
+ keep_parentheses_(keep_parentheses),
+ stack_(parens),
+ efst_(ifst, parens,
+ ExpandFstOptions<Arc>(opts, true, &stack_, &state_table_)),
+ queue_(state_table_, stack_, stack_length_, distance_, fdistance_) {
+ Reverse(*ifst_, parens, &rfst_);
+ VectorFst<Arc> path;
+ reverse_shortest_path_ = new SP(
+ rfst_, parens,
+ PdtShortestPathOptions<A, FifoQueue<StateId> >(true, false));
+ reverse_shortest_path_->ShortestPath(&path);
+ balance_data_ = reverse_shortest_path_->GetBalanceData()->Reverse(
+ rfst_.NumStates(), 10, -1);
+
+ InitCloseParenMultimap(parens);
+ }
+
+ ~PrunedExpand() {
+ delete ifst_;
+ delete reverse_shortest_path_;
+ delete balance_data_;
+ }
+
+ // Expands and prunes with weight threshold 'threshold' the input PDT.
+ // Writes the result in 'ofst'.
+ void Expand(MutableFst<A> *ofst, const Weight &threshold);
+
+ private:
+ static const uint8 kEnqueued;
+ static const uint8 kExpanded;
+ static const uint8 kSourceState;
+
+ // Comparison functor used by the queue:
+ // 1. states corresponding to shortest stack first,
+ // 2. among stacks of the same length, reverse lexicographic order is used,
+ // 3. among states with the same stack, shortest-first order is used.
+ class StackCompare {
+ public:
+ StackCompare(const StateTable &st,
+ const Stack &s, const vector<StackId> &sl,
+ const vector<Weight> &d, const vector<Weight> &fd)
+ : state_table_(st), stack_(s), stack_length_(sl),
+ distance_(d), fdistance_(fd) {}
+
+ bool operator()(StateId s1, StateId s2) const {
+ StackId si1 = state_table_.Tuple(s1).stack_id;
+ StackId si2 = state_table_.Tuple(s2).stack_id;
+ if (stack_length_[si1] < stack_length_[si2])
+ return true;
+ if (stack_length_[si1] > stack_length_[si2])
+ return false;
+ // If stack id equal, use A*
+ if (si1 == si2) {
+ Weight w1 = (s1 < distance_.size()) && (s1 < fdistance_.size()) ?
+ Times(distance_[s1], fdistance_[s1]) : Weight::Zero();
+ Weight w2 = (s2 < distance_.size()) && (s2 < fdistance_.size()) ?
+ Times(distance_[s2], fdistance_[s2]) : Weight::Zero();
+ return less_(w1, w2);
+ }
+ // If lenghts are equal, use reverse lexico.
+ for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) {
+ if (stack_.Top(si1) < stack_.Top(si2)) return true;
+ if (stack_.Top(si1) > stack_.Top(si2)) return false;
+ }
+ return false;
+ }
+
+ private:
+ const StateTable &state_table_;
+ const Stack &stack_;
+ const vector<StackId> &stack_length_;
+ const vector<Weight> &distance_;
+ const vector<Weight> &fdistance_;
+ NaturalLess<Weight> less_;
+ };
+
+ class ShortestStackFirstQueue
+ : public ShortestFirstQueue<StateId, StackCompare> {
+ public:
+ ShortestStackFirstQueue(
+ const PdtStateTable<StateId, StackId> &st,
+ const Stack &s,
+ const vector<StackId> &sl,
+ const vector<Weight> &d, const vector<Weight> &fd)
+ : ShortestFirstQueue<StateId, StackCompare>(
+ StackCompare(st, s, sl, d, fd)) {}
+ };
+
+
+ void InitCloseParenMultimap(const vector<pair<Label, Label> > &parens);
+ Weight DistanceToDest(StateId state, StateId source) const;
+ uint8 Flags(StateId s) const;
+ void SetFlags(StateId s, uint8 flags, uint8 mask);
+ Weight Distance(StateId s) const;
+ void SetDistance(StateId s, Weight w);
+ Weight FinalDistance(StateId s) const;
+ void SetFinalDistance(StateId s, Weight w);
+ StateId SourceState(StateId s) const;
+ void SetSourceState(StateId s, StateId p);
+ void AddStateAndEnqueue(StateId s);
+ void Relax(StateId s, const A &arc, Weight w);
+ bool PruneArc(StateId s, const A &arc);
+ void ProcStart();
+ void ProcFinal(StateId s);
+ bool ProcNonParen(StateId s, const A &arc, bool add_arc);
+ bool ProcOpenParen(StateId s, const A &arc, StackId si, StackId nsi);
+ bool ProcCloseParen(StateId s, const A &arc);
+ void ProcDestStates(StateId s, StackId si);
+
+ Fst<A> *ifst_; // Input PDT
+ VectorFst<Arc> rfst_; // Reversed PDT
+ bool keep_parentheses_; // Keep parentheses in ofst?
+ StateTable state_table_; // State table for efst_
+ Stack stack_; // Stack trie
+ ExpandFst<Arc> efst_; // Expanded PDT
+ vector<StackId> stack_length_; // Length of stack for given stack id
+ vector<Weight> distance_; // Distance from initial state in efst_/ofst
+ vector<Weight> fdistance_; // Distance to final states in efst_/ofst
+ ShortestStackFirstQueue queue_; // Queue used to visit efst_
+ vector<uint8> flags_; // Status flags for states in efst_/ofst
+ vector<StateId> sources_; // PDT source state for each expanded state
+
+ typedef PdtShortestPath<Arc, FifoQueue<StateId> > SP;
+ typedef typename SP::CloseParenMultimap ParenMultimap;
+ SP *reverse_shortest_path_; // Shortest path for rfst_
+ PdtBalanceData<Arc> *balance_data_; // Not owned by shortest_path_
+ ParenMultimap close_paren_multimap_; // Maps open paren arcs to
+ // balancing close paren arcs.
+
+ MutableFst<Arc> *ofst_; // Output fst
+ Weight limit_; // Weight limit
+
+ typedef unordered_map<StateId, Weight> DestMap;
+ DestMap dest_map_;
+ StackId current_stack_id_;
+ // 'current_stack_id_' is the stack id of the states currently at the top
+ // of queue, i.e., the states currently being popped and processed.
+ // 'dest_map_' maps a state 's' in 'ifst_' that is the source
+ // of a close parentheses matching the top of 'current_stack_id_; to
+ // the shortest-distance from '(s, current_stack_id_)' to the final
+ // states in 'efst_'.
+ ssize_t current_paren_id_; // Paren id at top of current stack
+ ssize_t cached_stack_id_;
+ StateId cached_source_;
+ slist<pair<StateId, Weight> > cached_dest_list_;
+ // 'cached_dest_list_' contains the set of pair of destination
+ // states and weight to final states for source state
+ // 'cached_source_' and paren id 'cached_paren_id': the set of
+ // source state of a close parenthesis with paren id
+ // 'cached_paren_id' balancing an incoming open parenthesis with
+ // paren id 'cached_paren_id' in state 'cached_source_'.
+
+ NaturalLess<Weight> less_;
+};
+
+template <class A> const uint8 PrunedExpand<A>::kEnqueued = 0x01;
+template <class A> const uint8 PrunedExpand<A>::kExpanded = 0x02;
+template <class A> const uint8 PrunedExpand<A>::kSourceState = 0x04;
+
+
+// Initializes close paren multimap, mapping pairs (s,paren_id) to
+// all the arcs out of s labeled with close parenthese for paren_id.
+template <class A>
+void PrunedExpand<A>::InitCloseParenMultimap(
+ const vector<pair<Label, Label> > &parens) {
+ unordered_map<Label, Label> paren_id_map;
+ for (Label i = 0; i < parens.size(); ++i) {
+ const pair<Label, Label> &p = parens[i];
+ paren_id_map[p.first] = i;
+ paren_id_map[p.second] = i;
+ }
+
+ for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) {
+ StateId s = siter.Value();
+ for (ArcIterator<Fst<Arc> > aiter(*ifst_, s);
+ !aiter.Done(); aiter.Next()) {
+ const Arc &arc = aiter.Value();
+ typename unordered_map<Label, Label>::const_iterator pit
+ = paren_id_map.find(arc.ilabel);
+ if (pit == paren_id_map.end()) continue;
+ if (arc.ilabel == parens[pit->second].second) { // Close paren
+ ParenState<Arc> paren_state(pit->second, s);
+ close_paren_multimap_.insert(make_pair(paren_state, arc));
+ }
+ }
+ }
+}
+
+
+// Returns the weight of the shortest balanced path from 'source' to 'dest'
+// in 'ifst_', 'dest' must be the source state of a close paren arc.
+template <class A>
+typename A::Weight PrunedExpand<A>::DistanceToDest(StateId source,
+ StateId dest) const {
+ typename SP::SearchState s(source + 1, dest + 1);
+ VLOG(2) << "D(" << source << ", " << dest << ") ="
+ << reverse_shortest_path_->GetShortestPathData().Distance(s);
+ return reverse_shortest_path_->GetShortestPathData().Distance(s);
+}
+
+// Returns the flags for state 's' in 'ofst_'.
+template <class A>
+uint8 PrunedExpand<A>::Flags(StateId s) const {
+ return s < flags_.size() ? flags_[s] : 0;
+}
+
+// Modifies the flags for state 's' in 'ofst_'.
+template <class A>
+void PrunedExpand<A>::SetFlags(StateId s, uint8 flags, uint8 mask) {
+ while (flags_.size() <= s) flags_.push_back(0);
+ flags_[s] &= ~mask;
+ flags_[s] |= flags & mask;
+}
+
+
+// Returns the shortest distance from the initial state to 's' in 'ofst_'.
+template <class A>
+typename A::Weight PrunedExpand<A>::Distance(StateId s) const {
+ return s < distance_.size() ? distance_[s] : Weight::Zero();
+}
+
+// Sets the shortest distance from the initial state to 's' in 'ofst_' to 'w'.
+template <class A>
+void PrunedExpand<A>::SetDistance(StateId s, Weight w) {
+ while (distance_.size() <= s ) distance_.push_back(Weight::Zero());
+ distance_[s] = w;
+}
+
+
+// Returns the shortest distance from 's' to the final states in 'ofst_'.
+template <class A>
+typename A::Weight PrunedExpand<A>::FinalDistance(StateId s) const {
+ return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
+}
+
+// Sets the shortest distance from 's' to the final states in 'ofst_' to 'w'.
+template <class A>
+void PrunedExpand<A>::SetFinalDistance(StateId s, Weight w) {
+ while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero());
+ fdistance_[s] = w;
+}
+
+// Returns the PDT "source" state of state 's' in 'ofst_'.
+template <class A>
+typename A::StateId PrunedExpand<A>::SourceState(StateId s) const {
+ return s < sources_.size() ? sources_[s] : kNoStateId;
+}
+
+// Sets the PDT "source" state of state 's' in 'ofst_' to state 'p' in 'ifst_'.
+template <class A>
+void PrunedExpand<A>::SetSourceState(StateId s, StateId p) {
+ while (sources_.size() <= s) sources_.push_back(kNoStateId);
+ sources_[s] = p;
+}
+
+// Adds state 's' of 'efst_' to 'ofst_' and inserts it in the queue,
+// modifying the flags for 's' accordingly.
+template <class A>
+void PrunedExpand<A>::AddStateAndEnqueue(StateId s) {
+ if (!(Flags(s) & (kEnqueued | kExpanded))) {
+ while (ofst_->NumStates() <= s) ofst_->AddState();
+ queue_.Enqueue(s);
+ SetFlags(s, kEnqueued, kEnqueued);
+ } else if (Flags(s) & kEnqueued) {
+ queue_.Update(s);
+ }
+ // TODO(allauzen): Check everything is fine when kExpanded?
+}
+
+// Relaxes arc 'arc' out of state 's' in 'ofst_':
+// * if the distance to 's' times the weight of 'arc' is smaller than
+// the currently stored distance for 'arc.nextstate',
+// updates 'Distance(arc.nextstate)' with new estimate;
+// * if 'fd' is less than the currently stored distance from 'arc.nextstate'
+// to the final state, updates with new estimate.
+template <class A>
+void PrunedExpand<A>::Relax(StateId s, const A &arc, Weight fd) {
+ Weight nd = Times(Distance(s), arc.weight);
+ if (less_(nd, Distance(arc.nextstate))) {
+ SetDistance(arc.nextstate, nd);
+ SetSourceState(arc.nextstate, SourceState(s));
+ }
+ if (less_(fd, FinalDistance(arc.nextstate)))
+ SetFinalDistance(arc.nextstate, fd);
+ VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to "
+ << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate)
+ << ", nd = " << nd;
+}
+
+// Returns 'true' if the arc 'arc' out of state 's' in 'efst_' needs to
+// be pruned.
+template <class A>
+bool PrunedExpand<A>::PruneArc(StateId s, const A &arc) {
+ VLOG(2) << "Prune ?";
+ Weight fd = Weight::Zero();
+
+ if ((cached_source_ != SourceState(s)) ||
+ (cached_stack_id_ != current_stack_id_)) {
+ cached_source_ = SourceState(s);
+ cached_stack_id_ = current_stack_id_;
+ cached_dest_list_.clear();
+ if (cached_source_ != ifst_->Start()) {
+ for (SetIterator set_iter =
+ balance_data_->Find(current_paren_id_, cached_source_);
+ !set_iter.Done(); set_iter.Next()) {
+ StateId dest = set_iter.Element();
+ typename DestMap::const_iterator iter = dest_map_.find(dest);
+ cached_dest_list_.push_front(*iter);
+ }
+ } else {
+ // TODO(allauzen): queue discipline should prevent this never
+ // from happening; replace by a check.
+ cached_dest_list_.push_front(
+ make_pair(rfst_.Start() -1, Weight::One()));
+ }
+ }
+
+ for (typename slist<pair<StateId, Weight> >::const_iterator iter =
+ cached_dest_list_.begin();
+ iter != cached_dest_list_.end();
+ ++iter) {
+ fd = Plus(fd,
+ Times(DistanceToDest(state_table_.Tuple(arc.nextstate).state_id,
+ iter->first),
+ iter->second));
+ }
+ Relax(s, arc, fd);
+ Weight w = Times(Distance(s), Times(arc.weight, fd));
+ return less_(limit_, w);
+}
+
+// Adds start state of 'efst_' to 'ofst_', enqueues it and initializes
+// the distance data structures.
+template <class A>
+void PrunedExpand<A>::ProcStart() {
+ StateId s = efst_.Start();
+ AddStateAndEnqueue(s);
+ ofst_->SetStart(s);
+ SetSourceState(s, ifst_->Start());
+
+ current_stack_id_ = 0;
+ current_paren_id_ = -1;
+ stack_length_.push_back(0);
+ dest_map_[rfst_.Start() - 1] = Weight::One(); // not needed
+
+ cached_source_ = ifst_->Start();
+ cached_stack_id_ = 0;
+ cached_dest_list_.push_front(
+ make_pair(rfst_.Start() -1, Weight::One()));
+
+ PdtStateTuple<StateId, StackId> tuple(rfst_.Start() - 1, 0);
+ SetFinalDistance(state_table_.FindState(tuple), Weight::One());
+ SetDistance(s, Weight::One());
+ SetFinalDistance(s, DistanceToDest(ifst_->Start(), rfst_.Start() - 1));
+ VLOG(2) << DistanceToDest(ifst_->Start(), rfst_.Start() - 1);
+}
+
+// Makes 's' final in 'ofst_' if shortest accepting path ending in 's'
+// is below threshold.
+template <class A>
+void PrunedExpand<A>::ProcFinal(StateId s) {
+ Weight final = efst_.Final(s);
+ if ((final == Weight::Zero()) || less_(limit_, Times(Distance(s), final)))
+ return;
+ ofst_->SetFinal(s, final);
+}
+
+// Returns true when arc (or meta-arc) 'arc' out of 's' in 'efst_' is
+// below the threshold. When 'add_arc' is true, 'arc' is added to 'ofst_'.
+template <class A>
+bool PrunedExpand<A>::ProcNonParen(StateId s, const A &arc, bool add_arc) {
+ VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate
+ << ", " << arc.ilabel << ":" << arc.olabel << " / " << arc.weight
+ << ", add_arc = " << (add_arc ? "true" : "false");
+ if (PruneArc(s, arc)) return false;
+ if(add_arc) ofst_->AddArc(s, arc);
+ AddStateAndEnqueue(arc.nextstate);
+ return true;
+}
+
+// Processes an open paren arc 'arc' out of state 's' in 'ofst_'.
+// When 'arc' is labeled with an open paren,
+// 1. considers each (shortest) balanced path starting in 's' by
+// taking 'arc' and ending by a close paren balancing the open
+// paren of 'arc' as a meta-arc, processes and prunes each meta-arc
+// as a non-paren arc, inserting its destination to the queue;
+// 2. if at least one of these meta-arcs has not been pruned,
+// adds the destination of 'arc' to 'ofst_' as a new source state
+// for the stack id 'nsi' and inserts it in the queue.
+template <class A>
+bool PrunedExpand<A>::ProcOpenParen(StateId s, const A &arc, StackId si,
+ StackId nsi) {
+ // Update the stack lenght when needed: |nsi| = |si| + 1.
+ while (stack_length_.size() <= nsi) stack_length_.push_back(-1);
+ if (stack_length_[nsi] == -1)
+ stack_length_[nsi] = stack_length_[si] + 1;
+
+ StateId ns = arc.nextstate;
+ VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id
+ << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")";
+ bool proc_arc = false;
+ Weight fd = Weight::Zero();
+ ssize_t paren_id = stack_.ParenId(arc.ilabel);
+ slist<StateId> sources;
+ for (SetIterator set_iter =
+ balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id);
+ !set_iter.Done(); set_iter.Next()) {
+ sources.push_front(set_iter.Element());
+ }
+ for (typename slist<StateId>::const_iterator sources_iter = sources.begin();
+ sources_iter != sources.end();
+ ++ sources_iter) {
+ StateId source = *sources_iter;
+ VLOG(2) << "Close paren source: " << source;
+ ParenState<Arc> paren_state(paren_id, source);
+ for (typename ParenMultimap::const_iterator iter =
+ close_paren_multimap_.find(paren_state);
+ iter != close_paren_multimap_.end() && paren_state == iter->first;
+ ++iter) {
+ Arc meta_arc = iter->second;
+ PdtStateTuple<StateId, StackId> tuple(meta_arc.nextstate, si);
+ meta_arc.nextstate = state_table_.FindState(tuple);
+ VLOG(2) << state_table_.Tuple(ns).state_id << ", " << source;
+ VLOG(2) << "Meta arc weight = " << arc.weight << " Times "
+ << DistanceToDest(state_table_.Tuple(ns).state_id, source)
+ << " Times " << meta_arc.weight;
+ meta_arc.weight = Times(
+ arc.weight,
+ Times(DistanceToDest(state_table_.Tuple(ns).state_id, source),
+ meta_arc.weight));
+ proc_arc |= ProcNonParen(s, meta_arc, false);
+ fd = Plus(fd, Times(
+ Times(
+ DistanceToDest(state_table_.Tuple(ns).state_id, source),
+ iter->second.weight),
+ FinalDistance(meta_arc.nextstate)));
+ }
+ }
+ if (proc_arc) {
+ VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate;
+ ofst_->AddArc(
+ s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
+ AddStateAndEnqueue(arc.nextstate);
+ Weight nd = Times(Distance(s), arc.weight);
+ if(less_(nd, Distance(arc.nextstate)))
+ SetDistance(arc.nextstate, nd);
+ // FinalDistance not necessary for source state since pruning
+ // decided using the meta-arcs above. But this is a problem with
+ // A*, hence:
+ if (less_(fd, FinalDistance(arc.nextstate)))
+ SetFinalDistance(arc.nextstate, fd);
+ SetFlags(arc.nextstate, kSourceState, kSourceState);
+ }
+ return proc_arc;
+}
+
+// Checks that shortest path through close paren arc in 'efst_' is
+// below threshold, if so adds it to 'ofst_'.
+template <class A>
+bool PrunedExpand<A>::ProcCloseParen(StateId s, const A &arc) {
+ Weight w = Times(Distance(s),
+ Times(arc.weight, FinalDistance(arc.nextstate)));
+ if (less_(limit_, w))
+ return false;
+ ofst_->AddArc(
+ s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
+ return true;
+}
+
+// When 's' in 'ofst_' is a source state for stack id 'si', identifies
+// all the corresponding possible destination states, that is, all the
+// states in 'ifst_' that have an outgoing close paren arc balancing
+// the incoming open paren taken to get to 's', and for each such
+// state 't', computes the shortest distance from (t, si) to the final
+// states in 'ofst_'. Stores this information in 'dest_map_'.
+template <class A>
+void PrunedExpand<A>::ProcDestStates(StateId s, StackId si) {
+ if (!(Flags(s) & kSourceState)) return;
+ if (si != current_stack_id_) {
+ dest_map_.clear();
+ current_stack_id_ = si;
+ current_paren_id_ = stack_.Top(current_stack_id_);
+ VLOG(2) << "StackID " << si << " dequeued for first time";
+ }
+ // TODO(allauzen): clean up source state business; rename current function to
+ // ProcSourceState.
+ SetSourceState(s, state_table_.Tuple(s).state_id);
+
+ ssize_t paren_id = stack_.Top(si);
+ for (SetIterator set_iter =
+ balance_data_->Find(paren_id, state_table_.Tuple(s).state_id);
+ !set_iter.Done(); set_iter.Next()) {
+ StateId dest_state = set_iter.Element();
+ if (dest_map_.find(dest_state) != dest_map_.end())
+ continue;
+ Weight dest_weight = Weight::Zero();
+ ParenState<Arc> paren_state(paren_id, dest_state);
+ for (typename ParenMultimap::const_iterator iter =
+ close_paren_multimap_.find(paren_state);
+ iter != close_paren_multimap_.end() && paren_state == iter->first;
+ ++iter) {
+ const Arc &arc = iter->second;
+ PdtStateTuple<StateId, StackId> tuple(arc.nextstate, stack_.Pop(si));
+ dest_weight = Plus(dest_weight,
+ Times(arc.weight,
+ FinalDistance(state_table_.FindState(tuple))));
+ }
+ dest_map_[dest_state] = dest_weight;
+ VLOG(2) << "State " << dest_state << " is a dest state for stack id "
+ << si << " with weight " << dest_weight;
+ }
+}
+
+// Expands and prunes with weight threshold 'threshold' the input PDT.
+// Writes the result in 'ofst'.
+template <class A>
+void PrunedExpand<A>::Expand(
+ MutableFst<A> *ofst, const typename A::Weight &threshold) {
+ ofst_ = ofst;
+ ofst_->DeleteStates();
+ ofst_->SetInputSymbols(ifst_->InputSymbols());
+ ofst_->SetOutputSymbols(ifst_->OutputSymbols());
+
+ limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold);
+ flags_.clear();
+
+ ProcStart();
+
+ while (!queue_.Empty()) {
+ StateId s = queue_.Head();
+ queue_.Dequeue();
+ SetFlags(s, kExpanded, kExpanded | kEnqueued);
+ VLOG(2) << s << " dequeued!";
+
+ ProcFinal(s);
+ StackId stack_id = state_table_.Tuple(s).stack_id;
+ ProcDestStates(s, stack_id);
+
+ for (ArcIterator<ExpandFst<Arc> > aiter(efst_, s);
+ !aiter.Done();
+ aiter.Next()) {
+ Arc arc = aiter.Value();
+ StackId nextstack_id = state_table_.Tuple(arc.nextstate).stack_id;
+ if (stack_id == nextstack_id)
+ ProcNonParen(s, arc, true);
+ else if (stack_id == stack_.Pop(nextstack_id))
+ ProcOpenParen(s, arc, stack_id, nextstack_id);
+ else
+ ProcCloseParen(s, arc);
+ }
+ VLOG(2) << "d[" << s << "] = " << Distance(s)
+ << ", fd[" << s << "] = " << FinalDistance(s);
+ }
+}
+
+//
+// Expand() Functions
+//
+
+template <class Arc>
+struct ExpandOptions {
+ bool connect;
+ bool keep_parentheses;
+ typename Arc::Weight weight_threshold;
+
+ ExpandOptions(bool c = true, bool k = false,
+ typename Arc::Weight w = Arc::Weight::Zero())
+ : connect(c), keep_parentheses(k), weight_threshold(w) {}
+};
+
+// Expands a pushdown transducer (PDT) encoded as an FST into an FST.
+// This version writes the expanded PDT result to a MutableFst.
+// In the PDT, some transitions are labeled with open or close
+// parentheses. To be interpreted as a PDT, the parens must balance on
+// a path. The open-close parenthesis label pairs are passed in
+// 'parens'. The expansion enforces the parenthesis constraints. The
+// PDT must be expandable as an FST.
+template <class Arc>
+void Expand(
+ const Fst<Arc> &ifst,
+ const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
+ MutableFst<Arc> *ofst,
+ const ExpandOptions<Arc> &opts) {
+ typedef typename Arc::Label Label;
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Weight Weight;
+ typedef typename ExpandFst<Arc>::StackId StackId;
+
+ ExpandFstOptions<Arc> eopts;
+ eopts.gc_limit = 0;
+ if (opts.weight_threshold == Weight::Zero()) {
+ eopts.keep_parentheses = opts.keep_parentheses;
+ *ofst = ExpandFst<Arc>(ifst, parens, eopts);
+ } else {
+ PrunedExpand<Arc> pruned_expand(ifst, parens, opts.keep_parentheses);
+ pruned_expand.Expand(ofst, opts.weight_threshold);
+ }
+
+ if (opts.connect)
+ Connect(ofst);
+}
+
+// Expands a pushdown transducer (PDT) encoded as an FST into an FST.
+// This version writes the expanded PDT result to a MutableFst.
+// In the PDT, some transitions are labeled with open or close
+// parentheses. To be interpreted as a PDT, the parens must balance on
+// a path. The open-close parenthesis label pairs are passed in
+// 'parens'. The expansion enforces the parenthesis constraints. The
+// PDT must be expandable as an FST.
+template<class Arc>
+void Expand(
+ const Fst<Arc> &ifst,
+ const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
+ MutableFst<Arc> *ofst,
+ bool connect = true, bool keep_parentheses = false) {
+ Expand(ifst, parens, ofst, ExpandOptions<Arc>(connect, keep_parentheses));
+}
+
+} // namespace fst
+
+#endif // FST_EXTENSIONS_PDT_EXPAND_H__
diff --git a/src/include/fst/extensions/pdt/info.h b/src/include/fst/extensions/pdt/info.h
new file mode 100644
index 0000000..ef9a860
--- /dev/null
+++ b/src/include/fst/extensions/pdt/info.h
@@ -0,0 +1,175 @@
+// info.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Prints information about a PDT.
+
+#ifndef FST_EXTENSIONS_PDT_INFO_H__
+#define FST_EXTENSIONS_PDT_INFO_H__
+
+#include <unordered_map>
+using std::tr1::unordered_map;
+using std::tr1::unordered_multimap;
+#include <tr1/unordered_set>
+using std::tr1::unordered_set;
+using std::tr1::unordered_multiset;
+#include <vector>
+using std::vector;
+
+#include <fst/fst.h>
+#include <fst/extensions/pdt/pdt.h>
+
+namespace fst {
+
+// Compute various information about PDTs, helper class for pdtinfo.cc.
+template <class A> class PdtInfo {
+public:
+ typedef A Arc;
+ typedef typename A::StateId StateId;
+ typedef typename A::Label Label;
+ typedef typename A::Weight Weight;
+
+ PdtInfo(const Fst<A> &fst,
+ const vector<pair<typename A::Label,
+ typename A::Label> > &parens);
+
+ const string& FstType() const { return fst_type_; }
+ const string& ArcType() const { return A::Type(); }
+
+ int64 NumStates() const { return nstates_; }
+ int64 NumArcs() const { return narcs_; }
+ int64 NumOpenParens() const { return nopen_parens_; }
+ int64 NumCloseParens() const { return nclose_parens_; }
+ int64 NumUniqueOpenParens() const { return nuniq_open_parens_; }
+ int64 NumUniqueCloseParens() const { return nuniq_close_parens_; }
+ int64 NumOpenParenStates() const { return nopen_paren_states_; }
+ int64 NumCloseParenStates() const { return nclose_paren_states_; }
+
+ private:
+ string fst_type_;
+ int64 nstates_;
+ int64 narcs_;
+ int64 nopen_parens_;
+ int64 nclose_parens_;
+ int64 nuniq_open_parens_;
+ int64 nuniq_close_parens_;
+ int64 nopen_paren_states_;
+ int64 nclose_paren_states_;
+
+ DISALLOW_COPY_AND_ASSIGN(PdtInfo);
+};
+
+template <class A>
+PdtInfo<A>::PdtInfo(const Fst<A> &fst,
+ const vector<pair<typename A::Label,
+ typename A::Label> > &parens)
+ : fst_type_(fst.Type()),
+ nstates_(0),
+ narcs_(0),
+ nopen_parens_(0),
+ nclose_parens_(0),
+ nuniq_open_parens_(0),
+ nuniq_close_parens_(0),
+ nopen_paren_states_(0),
+ nclose_paren_states_(0) {
+ unordered_map<Label, size_t> paren_map;
+ unordered_set<Label> paren_set;
+ unordered_set<StateId> open_paren_state_set;
+ unordered_set<StateId> close_paren_state_set;
+
+ for (size_t i = 0; i < parens.size(); ++i) {
+ const pair<Label, Label> &p = parens[i];
+ paren_map[p.first] = i;
+ paren_map[p.second] = i;
+ }
+
+ for (StateIterator< Fst<A> > siter(fst);
+ !siter.Done();
+ siter.Next()) {
+ ++nstates_;
+ StateId s = siter.Value();
+ for (ArcIterator< Fst<A> > aiter(fst, s);
+ !aiter.Done();
+ aiter.Next()) {
+ const A &arc = aiter.Value();
+ ++narcs_;
+ typename unordered_map<Label, size_t>::const_iterator pit
+ = paren_map.find(arc.ilabel);
+ if (pit != paren_map.end()) {
+ Label open_paren = parens[pit->second].first;
+ Label close_paren = parens[pit->second].second;
+ if (arc.ilabel == open_paren) {
+ ++nopen_parens_;
+ if (!paren_set.count(open_paren)) {
+ ++nuniq_open_parens_;
+ paren_set.insert(open_paren);
+ }
+ if (!open_paren_state_set.count(arc.nextstate)) {
+ ++nopen_paren_states_;
+ open_paren_state_set.insert(arc.nextstate);
+ }
+ } else {
+ ++nclose_parens_;
+ if (!paren_set.count(close_paren)) {
+ ++nuniq_close_parens_;
+ paren_set.insert(close_paren);
+ }
+ if (!close_paren_state_set.count(s)) {
+ ++nclose_paren_states_;
+ close_paren_state_set.insert(s);
+ }
+
+ }
+ }
+ }
+ }
+}
+
+
+template <class A>
+void PrintPdtInfo(const PdtInfo<A> &pdtinfo) {
+ ios_base::fmtflags old = cout.setf(ios::left);
+ cout.width(50);
+ cout << "fst type" << pdtinfo.FstType().c_str() << endl;
+ cout.width(50);
+ cout << "arc type" << pdtinfo.ArcType().c_str() << endl;
+ cout.width(50);
+ cout << "# of states" << pdtinfo.NumStates() << endl;
+ cout.width(50);
+ cout << "# of arcs" << pdtinfo.NumArcs() << endl;
+ cout.width(50);
+ cout << "# of open parentheses" << pdtinfo.NumOpenParens() << endl;
+ cout.width(50);
+ cout << "# of close parentheses" << pdtinfo.NumCloseParens() << endl;
+ cout.width(50);
+ cout << "# of unique open parentheses"
+ << pdtinfo.NumUniqueOpenParens() << endl;
+ cout.width(50);
+ cout << "# of unique close parentheses"
+ << pdtinfo.NumUniqueCloseParens() << endl;
+ cout.width(50);
+ cout << "# of open parenthesis dest. states"
+ << pdtinfo.NumOpenParenStates() << endl;
+ cout.width(50);
+ cout << "# of close parenthesis source states"
+ << pdtinfo.NumCloseParenStates() << endl;
+ cout.setf(old);
+}
+
+} // namespace fst
+
+#endif // FST_EXTENSIONS_PDT_INFO_H__
diff --git a/src/include/fst/extensions/pdt/paren.h b/src/include/fst/extensions/pdt/paren.h
new file mode 100644
index 0000000..7b9887f
--- /dev/null
+++ b/src/include/fst/extensions/pdt/paren.h
@@ -0,0 +1,496 @@
+// paren.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// Common classes for PDT parentheses
+
+// \file
+
+#ifndef FST_EXTENSIONS_PDT_PAREN_H_
+#define FST_EXTENSIONS_PDT_PAREN_H_
+
+#include <algorithm>
+#include <unordered_map>
+using std::tr1::unordered_map;
+using std::tr1::unordered_multimap;
+#include <tr1/unordered_set>
+using std::tr1::unordered_set;
+using std::tr1::unordered_multiset;
+#include <set>
+
+#include <fst/extensions/pdt/pdt.h>
+#include <fst/extensions/pdt/collection.h>
+#include <fst/fst.h>
+#include <fst/dfs-visit.h>
+
+
+namespace fst {
+
+//
+// ParenState: Pair of an open (close) parenthesis and
+// its destination (source) state.
+//
+
+template <class A>
+class ParenState {
+ public:
+ typedef typename A::Label Label;
+ typedef typename A::StateId StateId;
+
+ struct Hash {
+ size_t operator()(const ParenState<A> &p) const {
+ return p.paren_id + p.state_id * kPrime;
+ }
+ };
+
+ Label paren_id; // ID of open (close) paren
+ StateId state_id; // destination (source) state of open (close) paren
+
+ ParenState() : paren_id(kNoLabel), state_id(kNoStateId) {}
+
+ ParenState(Label p, StateId s) : paren_id(p), state_id(s) {}
+
+ bool operator==(const ParenState<A> &p) const {
+ if (&p == this)
+ return true;
+ return p.paren_id == this->paren_id && p.state_id == this->state_id;
+ }
+
+ bool operator!=(const ParenState<A> &p) const { return !(p == *this); }
+
+ bool operator<(const ParenState<A> &p) const {
+ return paren_id < this->paren.id ||
+ (p.paren_id == this->paren.id && p.state_id < this->state_id);
+ }
+
+ private:
+ static const size_t kPrime;
+};
+
+template <class A>
+const size_t ParenState<A>::kPrime = 7853;
+
+
+// Creates an FST-style iterator from STL map and iterator.
+template <class M>
+class MapIterator {
+ public:
+ typedef typename M::const_iterator StlIterator;
+ typedef typename M::value_type PairType;
+ typedef typename PairType::second_type ValueType;
+
+ MapIterator(const M &m, StlIterator iter)
+ : map_(m), begin_(iter), iter_(iter) {}
+
+ bool Done() const {
+ return iter_ == map_.end() || iter_->first != begin_->first;
+ }
+
+ ValueType Value() const { return iter_->second; }
+ void Next() { ++iter_; }
+ void Reset() { iter_ = begin_; }
+
+ private:
+ const M &map_;
+ StlIterator begin_;
+ StlIterator iter_;
+};
+
+//
+// PdtParenReachable: Provides various parenthesis reachability information
+// on a PDT.
+//
+
+template <class A>
+class PdtParenReachable {
+ public:
+ typedef typename A::StateId StateId;
+ typedef typename A::Label Label;
+ public:
+ // Maps from state ID to reachable paren IDs from (to) that state.
+ typedef unordered_multimap<StateId, Label> ParenMultiMap;
+
+ // Maps from paren ID and state ID to reachable state set ID
+ typedef unordered_map<ParenState<A>, ssize_t,
+ typename ParenState<A>::Hash> StateSetMap;
+
+ // Maps from paren ID and state ID to arcs exiting that state with that
+ // Label.
+ typedef unordered_multimap<ParenState<A>, A,
+ typename ParenState<A>::Hash> ParenArcMultiMap;
+
+ typedef MapIterator<ParenMultiMap> ParenIterator;
+
+ typedef MapIterator<ParenArcMultiMap> ParenArcIterator;
+
+ typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
+
+ // Computes close (open) parenthesis reachabilty information for
+ // a PDT with bounded stack.
+ PdtParenReachable(const Fst<A> &fst,
+ const vector<pair<Label, Label> > &parens, bool close)
+ : fst_(fst),
+ parens_(parens),
+ close_(close) {
+ for (Label i = 0; i < parens.size(); ++i) {
+ const pair<Label, Label> &p = parens[i];
+ paren_id_map_[p.first] = i;
+ paren_id_map_[p.second] = i;
+ }
+
+ if (close_) {
+ StateId start = fst.Start();
+ if (start == kNoStateId)
+ return;
+ DFSearch(start, start);
+ } else {
+ FSTERROR() << "PdtParenReachable: open paren info not implemented";
+ }
+ }
+
+ // Given a state ID, returns an iterator over paren IDs
+ // for close (open) parens reachable from that state along balanced
+ // paths.
+ ParenIterator FindParens(StateId s) const {
+ return ParenIterator(paren_multimap_, paren_multimap_.find(s));
+ }
+
+ // Given a paren ID and a state ID s, returns an iterator over
+ // states that can be reached along balanced paths from (to) s that
+ // have have close (open) parentheses matching the paren ID exiting
+ // (entering) those states.
+ SetIterator FindStates(Label paren_id, StateId s) const {
+ ParenState<A> paren_state(paren_id, s);
+ typename StateSetMap::const_iterator id_it = set_map_.find(paren_state);
+ if (id_it == set_map_.end()) {
+ return state_sets_.FindSet(-1);
+ } else {
+ return state_sets_.FindSet(id_it->second);
+ }
+ }
+
+ // Given a paren Id and a state ID s, return an iterator over
+ // arcs that exit (enter) s and are labeled with a close (open)
+ // parenthesis matching the paren ID.
+ ParenArcIterator FindParenArcs(Label paren_id, StateId s) const {
+ ParenState<A> paren_state(paren_id, s);
+ return ParenArcIterator(paren_arc_multimap_,
+ paren_arc_multimap_.find(paren_state));
+ }
+
+ private:
+ // DFS that gathers paren and state set information.
+ // Bool returns false when cycle detected.
+ bool DFSearch(StateId s, StateId start);
+
+ // Unions state sets together gathered by the DFS.
+ void ComputeStateSet(StateId s);
+
+ // Gather state set(s) from state 'nexts'.
+ void UpdateStateSet(StateId nexts, set<Label> *paren_set,
+ vector< set<StateId> > *state_sets) const;
+
+ const Fst<A> &fst_;
+ const vector<pair<Label, Label> > &parens_; // Paren ID -> Labels
+ bool close_; // Close/open paren info?
+ unordered_map<Label, Label> paren_id_map_; // Paren labels -> ID
+ ParenMultiMap paren_multimap_; // Paren reachability
+ ParenArcMultiMap paren_arc_multimap_; // Paren Arcs
+ vector<char> state_color_; // DFS state
+ mutable Collection<ssize_t, StateId> state_sets_; // Reachable states -> ID
+ StateSetMap set_map_; // ID -> Reachable states
+ DISALLOW_COPY_AND_ASSIGN(PdtParenReachable);
+};
+
+// DFS that gathers paren and state set information.
+template <class A>
+bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) {
+ if (s >= state_color_.size())
+ state_color_.resize(s + 1, kDfsWhite);
+
+ if (state_color_[s] == kDfsBlack)
+ return true;
+
+ if (state_color_[s] == kDfsGrey)
+ return false;
+
+ state_color_[s] = kDfsGrey;
+
+ for (ArcIterator<Fst<A> > aiter(fst_, s);
+ !aiter.Done();
+ aiter.Next()) {
+ const A &arc = aiter.Value();
+
+ typename unordered_map<Label, Label>::const_iterator pit
+ = paren_id_map_.find(arc.ilabel);
+ if (pit != paren_id_map_.end()) { // paren?
+ Label paren_id = pit->second;
+ if (arc.ilabel == parens_[paren_id].first) { // open paren
+ DFSearch(arc.nextstate, arc.nextstate);
+ for (SetIterator set_iter = FindStates(paren_id, arc.nextstate);
+ !set_iter.Done(); set_iter.Next()) {
+ for (ParenArcIterator paren_arc_iter =
+ FindParenArcs(paren_id, set_iter.Element());
+ !paren_arc_iter.Done();
+ paren_arc_iter.Next()) {
+ const A &cparc = paren_arc_iter.Value();
+ DFSearch(cparc.nextstate, start);
+ }
+ }
+ }
+ } else { // non-paren
+ if(!DFSearch(arc.nextstate, start)) {
+ FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
+ return true;
+ }
+ }
+ }
+ ComputeStateSet(s);
+ state_color_[s] = kDfsBlack;
+ return true;
+}
+
+// Unions state sets together gathered by the DFS.
+template <class A>
+void PdtParenReachable<A>::ComputeStateSet(StateId s) {
+ set<Label> paren_set;
+ vector< set<StateId> > state_sets(parens_.size());
+ for (ArcIterator< Fst<A> > aiter(fst_, s);
+ !aiter.Done();
+ aiter.Next()) {
+ const A &arc = aiter.Value();
+
+ typename unordered_map<Label, Label>::const_iterator pit
+ = paren_id_map_.find(arc.ilabel);
+ if (pit != paren_id_map_.end()) { // paren?
+ Label paren_id = pit->second;
+ if (arc.ilabel == parens_[paren_id].first) { // open paren
+ for (SetIterator set_iter =
+ FindStates(paren_id, arc.nextstate);
+ !set_iter.Done(); set_iter.Next()) {
+ for (ParenArcIterator paren_arc_iter =
+ FindParenArcs(paren_id, set_iter.Element());
+ !paren_arc_iter.Done();
+ paren_arc_iter.Next()) {
+ const A &cparc = paren_arc_iter.Value();
+ UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
+ }
+ }
+ } else { // close paren
+ paren_set.insert(paren_id);
+ state_sets[paren_id].insert(s);
+ ParenState<A> paren_state(paren_id, s);
+ paren_arc_multimap_.insert(make_pair(paren_state, arc));
+ }
+ } else { // non-paren
+ UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
+ }
+ }
+
+ vector<StateId> state_set;
+ for (typename set<Label>::iterator paren_iter = paren_set.begin();
+ paren_iter != paren_set.end(); ++paren_iter) {
+ state_set.clear();
+ Label paren_id = *paren_iter;
+ paren_multimap_.insert(make_pair(s, paren_id));
+ for (typename set<StateId>::iterator state_iter
+ = state_sets[paren_id].begin();
+ state_iter != state_sets[paren_id].end();
+ ++state_iter) {
+ state_set.push_back(*state_iter);
+ }
+ ParenState<A> paren_state(paren_id, s);
+ set_map_[paren_state] = state_sets_.FindId(state_set);
+ }
+}
+
+// Gather state set(s) from state 'nexts'.
+template <class A>
+void PdtParenReachable<A>::UpdateStateSet(
+ StateId nexts, set<Label> *paren_set,
+ vector< set<StateId> > *state_sets) const {
+ for(ParenIterator paren_iter = FindParens(nexts);
+ !paren_iter.Done(); paren_iter.Next()) {
+ Label paren_id = paren_iter.Value();
+ paren_set->insert(paren_id);
+ for (SetIterator set_iter = FindStates(paren_id, nexts);
+ !set_iter.Done(); set_iter.Next()) {
+ (*state_sets)[paren_id].insert(set_iter.Element());
+ }
+ }
+}
+
+
+// Store balancing parenthesis data for a PDT. Allows on-the-fly
+// construction (e.g. in PdtShortestPath) unlike PdtParenReachable above.
+template <class A>
+class PdtBalanceData {
+ public:
+ typedef typename A::StateId StateId;
+ typedef typename A::Label Label;
+
+ // Hash set for open parens
+ typedef unordered_set<ParenState<A>, typename ParenState<A>::Hash> OpenParenSet;
+
+ // Maps from open paren destination state to parenthesis ID.
+ typedef unordered_multimap<StateId, Label> OpenParenMap;
+
+ // Maps from open paren state to source states of matching close parens
+ typedef unordered_multimap<ParenState<A>, StateId,
+ typename ParenState<A>::Hash> CloseParenMap;
+
+ // Maps from open paren state to close source set ID
+ typedef unordered_map<ParenState<A>, ssize_t,
+ typename ParenState<A>::Hash> CloseSourceMap;
+
+ typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
+
+ PdtBalanceData() {}
+
+ void Clear() {
+ open_paren_map_.clear();
+ close_paren_map_.clear();
+ }
+
+ // Adds an open parenthesis with destination state 'open_dest'.
+ void OpenInsert(Label paren_id, StateId open_dest) {
+ ParenState<A> key(paren_id, open_dest);
+ if (!open_paren_set_.count(key)) {
+ open_paren_set_.insert(key);
+ open_paren_map_.insert(make_pair(open_dest, paren_id));
+ }
+ }
+
+ // Adds a matching closing parenthesis with source state
+ // 'close_source' that balances an open_parenthesis with destination
+ // state 'open_dest' if OpenInsert() previously called
+ // (o.w. CloseInsert() does nothing).
+ void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) {
+ ParenState<A> key(paren_id, open_dest);
+ if (open_paren_set_.count(key))
+ close_paren_map_.insert(make_pair(key, close_source));
+ }
+
+ // Find close paren source states matching an open parenthesis.
+ // Methods that follow, iterate through those matching states.
+ // Should be called only after FinishInsert(open_dest).
+ SetIterator Find(Label paren_id, StateId open_dest) {
+ ParenState<A> close_key(paren_id, open_dest);
+ typename CloseSourceMap::const_iterator id_it =
+ close_source_map_.find(close_key);
+ if (id_it == close_source_map_.end()) {
+ return close_source_sets_.FindSet(-1);
+ } else {
+ return close_source_sets_.FindSet(id_it->second);
+ }
+ }
+
+ // Call when all open and close parenthesis insertions wrt open
+ // parentheses entering 'open_dest' are finished. Must be called
+ // before Find(open_dest). Stores close paren source state sets
+ // efficiently.
+ void FinishInsert(StateId open_dest) {
+ vector<StateId> close_sources;
+ for (typename OpenParenMap::iterator oit = open_paren_map_.find(open_dest);
+ oit != open_paren_map_.end() && oit->first == open_dest;) {
+ Label paren_id = oit->second;
+ close_sources.clear();
+ ParenState<A> okey(paren_id, open_dest);
+ open_paren_set_.erase(open_paren_set_.find(okey));
+ for (typename CloseParenMap::iterator cit = close_paren_map_.find(okey);
+ cit != close_paren_map_.end() && cit->first == okey;) {
+ close_sources.push_back(cit->second);
+ close_paren_map_.erase(cit++);
+ }
+ sort(close_sources.begin(), close_sources.end());
+ typename vector<StateId>::iterator unique_end =
+ unique(close_sources.begin(), close_sources.end());
+ close_sources.resize(unique_end - close_sources.begin());
+
+ if (!close_sources.empty())
+ close_source_map_[okey] = close_source_sets_.FindId(close_sources);
+ open_paren_map_.erase(oit++);
+ }
+ }
+
+ // Return a new balance data object representing the reversed balance
+ // information.
+ PdtBalanceData<A> *Reverse(StateId num_states,
+ StateId num_split,
+ StateId state_id_shift) const;
+
+ private:
+ OpenParenSet open_paren_set_; // open par. at dest?
+
+ OpenParenMap open_paren_map_; // open parens per state
+ ParenState<A> open_dest_; // cur open dest. state
+ typename OpenParenMap::const_iterator open_iter_; // cur open parens/state
+
+ CloseParenMap close_paren_map_; // close states/open
+ // paren and state
+
+ CloseSourceMap close_source_map_; // paren, state to set ID
+ mutable Collection<ssize_t, StateId> close_source_sets_;
+};
+
+// Return a new balance data object representing the reversed balance
+// information.
+template <class A>
+PdtBalanceData<A> *PdtBalanceData<A>::Reverse(
+ StateId num_states,
+ StateId num_split,
+ StateId state_id_shift) const {
+ PdtBalanceData<A> *bd = new PdtBalanceData<A>;
+ unordered_set<StateId> close_sources;
+ StateId split_size = num_states / num_split;
+
+ for (StateId i = 0; i < num_states; i+= split_size) {
+ close_sources.clear();
+
+ for (typename CloseSourceMap::const_iterator
+ sit = close_source_map_.begin();
+ sit != close_source_map_.end();
+ ++sit) {
+ ParenState<A> okey = sit->first;
+ StateId open_dest = okey.state_id;
+ Label paren_id = okey.paren_id;
+ for (SetIterator set_iter = close_source_sets_.FindSet(sit->second);
+ !set_iter.Done(); set_iter.Next()) {
+ StateId close_source = set_iter.Element();
+ if ((close_source < i) || (close_source >= i + split_size))
+ continue;
+ close_sources.insert(close_source + state_id_shift);
+ bd->OpenInsert(paren_id, close_source + state_id_shift);
+ bd->CloseInsert(paren_id, close_source + state_id_shift,
+ open_dest + state_id_shift);
+ }
+ }
+
+ for (typename unordered_set<StateId>::const_iterator it
+ = close_sources.begin();
+ it != close_sources.end();
+ ++it) {
+ bd->FinishInsert(*it);
+ }
+
+ }
+ return bd;
+}
+
+
+} // namespace fst
+
+#endif // FST_EXTENSIONS_PDT_PAREN_H_
diff --git a/src/include/fst/extensions/pdt/pdt.h b/src/include/fst/extensions/pdt/pdt.h
new file mode 100644
index 0000000..171541f
--- /dev/null
+++ b/src/include/fst/extensions/pdt/pdt.h
@@ -0,0 +1,212 @@
+// pdt.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Common classes for PDT expansion/traversal.
+
+#ifndef FST_EXTENSIONS_PDT_PDT_H__
+#define FST_EXTENSIONS_PDT_PDT_H__
+
+#include <unordered_map>
+using std::tr1::unordered_map;
+using std::tr1::unordered_multimap;
+#include <map>
+#include <set>
+
+#include <fst/state-table.h>
+#include <fst/fst.h>
+
+namespace fst {
+
+// Provides bijection between parenthesis stacks and signed integral
+// stack IDs. Each stack ID is unique to each distinct stack. The
+// open-close parenthesis label pairs are passed in 'parens'.
+template <typename K, typename L>
+class PdtStack {
+ public:
+ typedef K StackId;
+ typedef L Label;
+
+ // The stacks are stored in a tree. The nodes are stored in vector
+ // 'nodes_'. Each node represents the top of some stack and is
+ // ID'ed by its position in the vector. Its parent node represents
+ // the stack with the top 'popped' and its children are stored in
+ // 'child_map_' accessed by stack_id and label. The paren_id is
+ // the position in 'parens' of the parenthesis for that node.
+ struct StackNode {
+ StackId parent_id;
+ size_t paren_id;
+
+ StackNode(StackId p, size_t i) : parent_id(p), paren_id(i) {}
+ };
+
+ PdtStack(const vector<pair<Label, Label> > &parens)
+ : parens_(parens), min_paren_(kNoLabel), max_paren_(kNoLabel) {
+ for (size_t i = 0; i < parens.size(); ++i) {
+ const pair<Label, Label> &p = parens[i];
+ paren_map_[p.first] = i;
+ paren_map_[p.second] = i;
+
+ if (min_paren_ == kNoLabel || p.first < min_paren_)
+ min_paren_ = p.first;
+ if (p.second < min_paren_)
+ min_paren_ = p.second;
+
+ if (max_paren_ == kNoLabel || p.first > max_paren_)
+ max_paren_ = p.first;
+ if (p.second > max_paren_)
+ max_paren_ = p.second;
+ }
+ nodes_.push_back(StackNode(-1, -1)); // Tree root.
+ }
+
+ // Returns stack ID given the current stack ID (0 if empty) and
+ // label read. 'Pushes' onto a stack if the label is an open
+ // parenthesis, returning the new stack ID. 'Pops' the stack if the
+ // label is a close parenthesis that matches the top of the stack,
+ // returning the parent stack ID. Returns -1 if label is an
+ // unmatched close parenthesis. Otherwise, returns the current stack
+ // ID.
+ StackId Find(StackId stack_id, Label label) {
+ if (min_paren_ == kNoLabel || label < min_paren_ || label > max_paren_)
+ return stack_id; // Non-paren.
+
+ typename unordered_map<Label, size_t>::const_iterator pit
+ = paren_map_.find(label);
+ if (pit == paren_map_.end()) // Non-paren.
+ return stack_id;
+ ssize_t paren_id = pit->second;
+
+ if (label == parens_[paren_id].first) { // Open paren.
+ StackId &child_id = child_map_[make_pair(stack_id, label)];
+ if (child_id == 0) { // Child not found, push label.
+ child_id = nodes_.size();
+ nodes_.push_back(StackNode(stack_id, paren_id));
+ }
+ return child_id;
+ }
+
+ const StackNode &node = nodes_[stack_id];
+ if (paren_id == node.paren_id) // Matching close paren.
+ return node.parent_id;
+
+ return -1; // Non-matching close paren.
+ }
+
+ // Returns the stack ID obtained by "popping" the label at the top
+ // of the current stack ID.
+ StackId Pop(StackId stack_id) const {
+ return nodes_[stack_id].parent_id;
+ }
+
+ // Returns the paren ID at the top of the stack for 'stack_id'
+ ssize_t Top(StackId stack_id) const {
+ return nodes_[stack_id].paren_id;
+ }
+
+ ssize_t ParenId(Label label) const {
+ typename unordered_map<Label, size_t>::const_iterator pit
+ = paren_map_.find(label);
+ if (pit == paren_map_.end()) // Non-paren.
+ return -1;
+ return pit->second;
+ }
+
+ private:
+ struct ChildHash {
+ size_t operator()(const pair<StackId, Label> &p) const {
+ return p.first + p.second * kPrime;
+ }
+ };
+
+ static const size_t kPrime;
+
+ vector<pair<Label, Label> > parens_;
+ vector<StackNode> nodes_;
+ unordered_map<Label, size_t> paren_map_;
+ unordered_map<pair<StackId, Label>,
+ StackId, ChildHash> child_map_; // Child of stack node wrt label
+ Label min_paren_; // For faster paren. check
+ Label max_paren_; // For faster paren. check
+};
+
+template <typename T, typename L>
+const size_t PdtStack<T, L>::kPrime = 7853;
+
+
+// State tuple for PDT expansion
+template <typename S, typename K>
+struct PdtStateTuple {
+ typedef S StateId;
+ typedef K StackId;
+
+ StateId state_id;
+ StackId stack_id;
+
+ PdtStateTuple()
+ : state_id(kNoStateId), stack_id(-1) {}
+
+ PdtStateTuple(StateId fs, StackId ss)
+ : state_id(fs), stack_id(ss) {}
+};
+
+// Equality of PDT state tuples.
+template <typename S, typename K>
+inline bool operator==(const PdtStateTuple<S, K>& x,
+ const PdtStateTuple<S, K>& y) {
+ if (&x == &y)
+ return true;
+ return x.state_id == y.state_id && x.stack_id == y.stack_id;
+}
+
+
+// Hash function object for PDT state tuples
+template <class T>
+class PdtStateHash {
+ public:
+ size_t operator()(const T &tuple) const {
+ return tuple.state_id + tuple.stack_id * kPrime;
+ }
+
+ private:
+ static const size_t kPrime;
+};
+
+template <typename T>
+const size_t PdtStateHash<T>::kPrime = 7853;
+
+
+// Tuple to PDT state bijection.
+template <class S, class K>
+class PdtStateTable
+ : public CompactHashStateTable<PdtStateTuple<S, K>,
+ PdtStateHash<PdtStateTuple<S, K> > > {
+ public:
+ typedef S StateId;
+ typedef K StackId;
+
+ PdtStateTable() {}
+
+ PdtStateTable(const PdtStateTable<S, K> &table) {}
+
+ private:
+ void operator=(const PdtStateTable<S, K> &table); // disallow
+};
+
+} // namespace fst
+
+#endif // FST_EXTENSIONS_PDT_PDT_H__
diff --git a/src/include/fst/extensions/pdt/pdtlib.h b/src/include/fst/extensions/pdt/pdtlib.h
new file mode 100644
index 0000000..71c8123
--- /dev/null
+++ b/src/include/fst/extensions/pdt/pdtlib.h
@@ -0,0 +1,30 @@
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: jpr@google.com (Jake Ratkiewicz)
+
+// This is an experimental push-down transducer(PDT) library. A PDT is
+// encoded as an FST, where some transitions are labeled with open or close
+// parentheses. To be interpreted as a PDT, the parentheses must balance on a
+// path.
+
+#ifndef FST_EXTENSIONS_PDT_PDTLIB_H_
+#define FST_EXTENSIONS_PDT_PDTLIB_H_
+
+#include <fst/extensions/pdt/pdt.h>
+#include <fst/extensions/pdt/compose.h>
+#include <fst/extensions/pdt/expand.h>
+#include <fst/extensions/pdt/replace.h>
+
+#endif // FST_EXTENSIONS_PDT_PDTLIB_H_
diff --git a/src/include/fst/extensions/pdt/pdtscript.h b/src/include/fst/extensions/pdt/pdtscript.h
new file mode 100644
index 0000000..c2a1cf4
--- /dev/null
+++ b/src/include/fst/extensions/pdt/pdtscript.h
@@ -0,0 +1,284 @@
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: jpr@google.com (Jake Ratkiewicz)
+// Convenience file for including all PDT operations at once, and/or
+// registering them for new arc types.
+
+#ifndef FST_EXTENSIONS_PDT_PDTSCRIPT_H_
+#define FST_EXTENSIONS_PDT_PDTSCRIPT_H_
+
+#include <utility>
+using std::pair; using std::make_pair;
+#include <vector>
+using std::vector;
+
+#include <fst/compose.h> // for ComposeOptions
+#include <fst/util.h>
+
+#include <fst/script/fst-class.h>
+#include <fst/script/arg-packs.h>
+#include <fst/script/shortest-path.h>
+
+#include <fst/extensions/pdt/compose.h>
+#include <fst/extensions/pdt/expand.h>
+#include <fst/extensions/pdt/info.h>
+#include <fst/extensions/pdt/replace.h>
+#include <fst/extensions/pdt/reverse.h>
+#include <fst/extensions/pdt/shortest-path.h>
+
+
+namespace fst {
+namespace script {
+
+// PDT COMPOSE
+
+typedef args::Package<const FstClass &,
+ const FstClass &,
+ const vector<pair<int64, int64> >&,
+ MutableFstClass *,
+ const ComposeOptions &,
+ bool> PdtComposeArgs;
+
+template<class Arc>
+void PdtCompose(PdtComposeArgs *args) {
+ const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>());
+ const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>());
+ MutableFst<Arc> *ofst = args->arg4->GetMutableFst<Arc>();
+
+ vector<pair<typename Arc::Label, typename Arc::Label> > parens(
+ args->arg3.size());
+
+ for (size_t i = 0; i < parens.size(); ++i) {
+ parens[i].first = args->arg3[i].first;
+ parens[i].second = args->arg3[i].second;
+ }
+
+ if (args->arg6) {
+ Compose(ifst1, parens, ifst2, ofst, args->arg5);
+ } else {
+ Compose(ifst1, ifst2, parens, ofst, args->arg5);
+ }
+}
+
+void PdtCompose(const FstClass & ifst1,
+ const FstClass & ifst2,
+ const vector<pair<int64, int64> > &parens,
+ MutableFstClass *ofst,
+ const ComposeOptions &copts,
+ bool left_pdt);
+
+// PDT EXPAND
+
+struct PdtExpandOptions {
+ bool connect;
+ bool keep_parentheses;
+ WeightClass weight_threshold;
+
+ PdtExpandOptions(bool c = true, bool k = false,
+ WeightClass w = WeightClass::Zero())
+ : connect(c), keep_parentheses(k), weight_threshold(w) {}
+};
+
+typedef args::Package<const FstClass &,
+ const vector<pair<int64, int64> >&,
+ MutableFstClass *, PdtExpandOptions> PdtExpandArgs;
+
+template<class Arc>
+void PdtExpand(PdtExpandArgs *args) {
+ const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
+ MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
+
+ vector<pair<typename Arc::Label, typename Arc::Label> > parens(
+ args->arg2.size());
+ for (size_t i = 0; i < parens.size(); ++i) {
+ parens[i].first = args->arg2[i].first;
+ parens[i].second = args->arg2[i].second;
+ }
+ Expand(fst, parens, ofst,
+ ExpandOptions<Arc>(
+ args->arg4.connect, args->arg4.keep_parentheses,
+ *(args->arg4.weight_threshold.GetWeight<typename Arc::Weight>())));
+}
+
+void PdtExpand(const FstClass &ifst,
+ const vector<pair<int64, int64> > &parens,
+ MutableFstClass *ofst, const PdtExpandOptions &opts);
+
+void PdtExpand(const FstClass &ifst,
+ const vector<pair<int64, int64> > &parens,
+ MutableFstClass *ofst, bool connect);
+
+// PDT REPLACE
+
+typedef args::Package<const vector<pair<int64, const FstClass*> > &,
+ MutableFstClass *,
+ vector<pair<int64, int64> > *,
+ const int64 &> PdtReplaceArgs;
+template<class Arc>
+void PdtReplace(PdtReplaceArgs *args) {
+ vector<pair<typename Arc::Label, const Fst<Arc> *> > tuples(
+ args->arg1.size());
+ for (size_t i = 0; i < tuples.size(); ++i) {
+ tuples[i].first = args->arg1[i].first;
+ tuples[i].second = (args->arg1[i].second)->GetFst<Arc>();
+ }
+ MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>();
+ vector<pair<typename Arc::Label, typename Arc::Label> > parens(
+ args->arg3->size());
+
+ for (size_t i = 0; i < parens.size(); ++i) {
+ parens[i].first = args->arg3->at(i).first;
+ parens[i].second = args->arg3->at(i).second;
+ }
+ Replace(tuples, ofst, &parens, args->arg4);
+
+ // now copy parens back
+ args->arg3->resize(parens.size());
+ for (size_t i = 0; i < parens.size(); ++i) {
+ (*args->arg3)[i].first = parens[i].first;
+ (*args->arg3)[i].second = parens[i].second;
+ }
+}
+
+void PdtReplace(const vector<pair<int64, const FstClass*> > &fst_tuples,
+ MutableFstClass *ofst,
+ vector<pair<int64, int64> > *parens,
+ const int64 &root);
+
+// PDT REVERSE
+
+typedef args::Package<const FstClass &,
+ const vector<pair<int64, int64> >&,
+ MutableFstClass *> PdtReverseArgs;
+
+template<class Arc>
+void PdtReverse(PdtReverseArgs *args) {
+ const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
+ MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
+
+ vector<pair<typename Arc::Label, typename Arc::Label> > parens(
+ args->arg2.size());
+ for (size_t i = 0; i < parens.size(); ++i) {
+ parens[i].first = args->arg2[i].first;
+ parens[i].second = args->arg2[i].second;
+ }
+ Reverse(fst, parens, ofst);
+}
+
+void PdtReverse(const FstClass &ifst,
+ const vector<pair<int64, int64> > &parens,
+ MutableFstClass *ofst);
+
+
+// PDT SHORTESTPATH
+
+struct PdtShortestPathOptions {
+ QueueType queue_type;
+ bool keep_parentheses;
+ bool path_gc;
+
+ PdtShortestPathOptions(QueueType qt = FIFO_QUEUE,
+ bool kp = false, bool gc = true)
+ : queue_type(qt), keep_parentheses(kp), path_gc(gc) {}
+};
+
+typedef args::Package<const FstClass &,
+ const vector<pair<int64, int64> >&,
+ MutableFstClass *,
+ const PdtShortestPathOptions &> PdtShortestPathArgs;
+
+template<class Arc>
+void PdtShortestPath(PdtShortestPathArgs *args) {
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Label Label;
+ typedef typename Arc::Weight Weight;
+
+ const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
+ MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
+ const PdtShortestPathOptions &opts = args->arg4;
+
+
+ vector<pair<Label, Label> > parens(args->arg2.size());
+ for (size_t i = 0; i < parens.size(); ++i) {
+ parens[i].first = args->arg2[i].first;
+ parens[i].second = args->arg2[i].second;
+ }
+
+ switch (opts.queue_type) {
+ default:
+ FSTERROR() << "Unknown queue type: " << opts.queue_type;
+ case FIFO_QUEUE: {
+ typedef FifoQueue<StateId> Queue;
+ fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
+ opts.path_gc);
+ ShortestPath(fst, parens, ofst, spopts);
+ return;
+ }
+ case LIFO_QUEUE: {
+ typedef LifoQueue<StateId> Queue;
+ fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
+ opts.path_gc);
+ ShortestPath(fst, parens, ofst, spopts);
+ return;
+ }
+ case STATE_ORDER_QUEUE: {
+ typedef StateOrderQueue<StateId> Queue;
+ fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
+ opts.path_gc);
+ ShortestPath(fst, parens, ofst, spopts);
+ return;
+ }
+ }
+}
+
+void PdtShortestPath(const FstClass &ifst,
+ const vector<pair<int64, int64> > &parens,
+ MutableFstClass *ofst,
+ const PdtShortestPathOptions &opts =
+ PdtShortestPathOptions());
+
+// PRINT INFO
+
+typedef args::Package<const FstClass &,
+ const vector<pair<int64, int64> > &> PrintPdtInfoArgs;
+
+template<class Arc>
+void PrintPdtInfo(PrintPdtInfoArgs *args) {
+ const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
+ vector<pair<typename Arc::Label, typename Arc::Label> > parens(
+ args->arg2.size());
+ for (size_t i = 0; i < parens.size(); ++i) {
+ parens[i].first = args->arg2[i].first;
+ parens[i].second = args->arg2[i].second;
+ }
+ PdtInfo<Arc> pdtinfo(fst, parens);
+ PrintPdtInfo(pdtinfo);
+}
+
+void PrintPdtInfo(const FstClass &ifst,
+ const vector<pair<int64, int64> > &parens);
+
+} // namespace script
+} // namespace fst
+
+
+#define REGISTER_FST_PDT_OPERATIONS(ArcType) \
+ REGISTER_FST_OPERATION(PdtCompose, ArcType, PdtComposeArgs); \
+ REGISTER_FST_OPERATION(PdtExpand, ArcType, PdtExpandArgs); \
+ REGISTER_FST_OPERATION(PdtReplace, ArcType, PdtReplaceArgs); \
+ REGISTER_FST_OPERATION(PdtReverse, ArcType, PdtReverseArgs); \
+ REGISTER_FST_OPERATION(PdtShortestPath, ArcType, PdtShortestPathArgs); \
+ REGISTER_FST_OPERATION(PrintPdtInfo, ArcType, PrintPdtInfoArgs)
+#endif // FST_EXTENSIONS_PDT_PDTSCRIPT_H_
diff --git a/src/include/fst/extensions/pdt/replace.h b/src/include/fst/extensions/pdt/replace.h
new file mode 100644
index 0000000..a85d0fe
--- /dev/null
+++ b/src/include/fst/extensions/pdt/replace.h
@@ -0,0 +1,192 @@
+// replace.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Recursively replace Fst arcs with other Fst(s) returning a PDT.
+
+#ifndef FST_EXTENSIONS_PDT_REPLACE_H__
+#define FST_EXTENSIONS_PDT_REPLACE_H__
+
+#include <fst/replace.h>
+
+namespace fst {
+
+// Hash to paren IDs
+template <typename S>
+struct ReplaceParenHash {
+ size_t operator()(const pair<size_t, S> &p) const {
+ return p.first + p.second * kPrime;
+ }
+ private:
+ static const size_t kPrime = 7853;
+};
+
+template <typename S> const size_t ReplaceParenHash<S>::kPrime;
+
+// Builds a pushdown transducer (PDT) from an RTN specification
+// identical to that in fst/lib/replace.h. The result is a PDT
+// encoded as the FST 'ofst' where some transitions are labeled with
+// open or close parentheses. To be interpreted as a PDT, the parens
+// must balance on a path (see PdtExpand()). The open/close
+// parenthesis label pairs are returned in 'parens'.
+template <class Arc>
+void Replace(const vector<pair<typename Arc::Label,
+ const Fst<Arc>* > >& ifst_array,
+ MutableFst<Arc> *ofst,
+ vector<pair<typename Arc::Label,
+ typename Arc::Label> > *parens,
+ typename Arc::Label root) {
+ typedef typename Arc::Label Label;
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Weight Weight;
+
+ ofst->DeleteStates();
+ parens->clear();
+
+ unordered_map<Label, size_t> label2id;
+ for (size_t i = 0; i < ifst_array.size(); ++i)
+ label2id[ifst_array[i].first] = i;
+
+ Label max_label = kNoLabel;
+
+ deque<size_t> non_term_queue; // Queue of non-terminals to replace
+ unordered_set<Label> non_term_set; // Set of non-terminals to replace
+ non_term_queue.push_back(root);
+ non_term_set.insert(root);
+
+ // PDT state corr. to ith replace FST start state.
+ vector<StateId> fst_start(ifst_array.size(), kNoLabel);
+ // PDT state, weight pairs corr. to ith replace FST final state & weights.
+ vector< vector<pair<StateId, Weight> > > fst_final(ifst_array.size());
+
+ // Builds single Fst combining all referenced input Fsts. Leaves in the
+ // non-termnals for now. Tabulate the PDT states that correspond to
+ // the start and final states of the input Fsts.
+ for (StateId soff = 0; !non_term_queue.empty(); soff = ofst->NumStates()) {
+ Label label = non_term_queue.front();
+ non_term_queue.pop_front();
+ size_t fst_id = label2id[label];
+
+ const Fst<Arc> *ifst = ifst_array[fst_id].second;
+ for (StateIterator< Fst<Arc> > siter(*ifst);
+ !siter.Done(); siter.Next()) {
+ StateId is = siter.Value();
+ StateId os = ofst->AddState();
+ if (is == ifst->Start()) {
+ fst_start[fst_id] = os;
+ if (label == root)
+ ofst->SetStart(os);
+ }
+ if (ifst->Final(is) != Weight::Zero()) {
+ if (label == root)
+ ofst->SetFinal(os, ifst->Final(is));
+ fst_final[fst_id].push_back(make_pair(os, ifst->Final(is)));
+ }
+ for (ArcIterator< Fst<Arc> > aiter(*ifst, is);
+ !aiter.Done(); aiter.Next()) {
+ Arc arc = aiter.Value();
+ if (max_label == kNoLabel || arc.olabel > max_label)
+ max_label = arc.olabel;
+ typename unordered_map<Label, size_t>::const_iterator it =
+ label2id.find(arc.olabel);
+ if (it != label2id.end()) {
+ size_t nfst_id = it->second;
+ if (ifst_array[nfst_id].second->Start() == -1)
+ continue;
+ if (non_term_set.count(arc.olabel) == 0) {
+ non_term_queue.push_back(arc.olabel);
+ non_term_set.insert(arc.olabel);
+ }
+ }
+ arc.nextstate += soff;
+ ofst->AddArc(os, arc);
+ }
+ }
+ }
+
+ // Changes each non-terminal transition to an open parenthesis
+ // transition redirected to the PDT state that corresponds to the
+ // start state of the input FST for the non-terminal. Adds close parenthesis
+ // transitions from the PDT states corr. to the final states of the
+ // input FST for the non-terminal to the former destination state of the
+ // non-terminal transition.
+
+ typedef MutableArcIterator< MutableFst<Arc> > MIter;
+ typedef unordered_map<pair<size_t, StateId >, size_t,
+ ReplaceParenHash<StateId> > ParenMap;
+
+ // Parenthesis pair ID per fst, state pair.
+ ParenMap paren_map;
+ // # of parenthesis pairs per fst.
+ vector<size_t> nparens(ifst_array.size(), 0);
+ // Initial open parenthesis label
+ Label first_paren = max_label + 1;
+
+ for (StateIterator< Fst<Arc> > siter(*ofst);
+ !siter.Done(); siter.Next()) {
+ StateId os = siter.Value();
+ MIter *aiter = new MIter(ofst, os);
+ for (size_t n = 0; !aiter->Done(); aiter->Next(), ++n) {
+ Arc arc = aiter->Value();
+ typename unordered_map<Label, size_t>::const_iterator lit =
+ label2id.find(arc.olabel);
+ if (lit != label2id.end()) {
+ size_t nfst_id = lit->second;
+
+ // Get parentheses. Ensures distinct parenthesis pair per
+ // non-terminal and destination state but otherwise reuses them.
+ Label open_paren = kNoLabel, close_paren = kNoLabel;
+ pair<size_t, StateId> paren_key(nfst_id, arc.nextstate);
+ typename ParenMap::const_iterator pit = paren_map.find(paren_key);
+ if (pit != paren_map.end()) {
+ size_t paren_id = pit->second;
+ open_paren = (*parens)[paren_id].first;
+ close_paren = (*parens)[paren_id].second;
+ } else {
+ size_t paren_id = nparens[nfst_id]++;
+ open_paren = first_paren + 2 * paren_id;
+ close_paren = open_paren + 1;
+ paren_map[paren_key] = paren_id;
+ if (paren_id >= parens->size())
+ parens->push_back(make_pair(open_paren, close_paren));
+ }
+
+ // Sets open parenthesis.
+ Arc sarc(open_paren, open_paren, arc.weight, fst_start[nfst_id]);
+ aiter->SetValue(sarc);
+
+ // Adds close parentheses.
+ for (size_t i = 0; i < fst_final[nfst_id].size(); ++i) {
+ pair<StateId, Weight> &p = fst_final[nfst_id][i];
+ Arc farc(close_paren, close_paren, p.second, arc.nextstate);
+
+ ofst->AddArc(p.first, farc);
+ if (os == p.first) { // Invalidated iterator
+ delete aiter;
+ aiter = new MIter(ofst, os);
+ aiter->Seek(n);
+ }
+ }
+ }
+ }
+ delete aiter;
+ }
+}
+
+} // namespace fst
+
+#endif // FST_EXTENSIONS_PDT_REPLACE_H__
diff --git a/src/include/fst/extensions/pdt/reverse.h b/src/include/fst/extensions/pdt/reverse.h
new file mode 100644
index 0000000..b20e1c5
--- /dev/null
+++ b/src/include/fst/extensions/pdt/reverse.h
@@ -0,0 +1,58 @@
+// reverse.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Expand a PDT to an FST.
+
+#ifndef FST_EXTENSIONS_PDT_REVERSE_H__
+#define FST_EXTENSIONS_PDT_REVERSE_H__
+
+#include <unordered_map>
+using std::tr1::unordered_map;
+using std::tr1::unordered_multimap;
+#include <vector>
+using std::vector;
+
+#include <fst/mutable-fst.h>
+#include <fst/relabel.h>
+#include <fst/reverse.h>
+
+namespace fst {
+
+// Reverses a pushdown transducer (PDT) encoded as an FST.
+template<class Arc, class RevArc>
+void Reverse(const Fst<Arc> &ifst,
+ const vector<pair<typename Arc::Label,
+ typename Arc::Label> > &parens,
+ MutableFst<RevArc> *ofst) {
+ typedef typename Arc::Label Label;
+
+ // Reverses FST
+ Reverse(ifst, ofst);
+
+ // Exchanges open and close parenthesis pairs
+ vector<pair<Label, Label> > relabel_pairs;
+ for (size_t i = 0; i < parens.size(); ++i) {
+ relabel_pairs.push_back(make_pair(parens[i].first, parens[i].second));
+ relabel_pairs.push_back(make_pair(parens[i].second, parens[i].first));
+ }
+ Relabel(ofst, relabel_pairs, relabel_pairs);
+}
+
+} // namespace fst
+
+#endif // FST_EXTENSIONS_PDT_REVERSE_H__
diff --git a/src/include/fst/extensions/pdt/shortest-path.h b/src/include/fst/extensions/pdt/shortest-path.h
new file mode 100644
index 0000000..e90471b
--- /dev/null
+++ b/src/include/fst/extensions/pdt/shortest-path.h
@@ -0,0 +1,790 @@
+// shortest-path.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Functions to find shortest paths in a PDT.
+
+#ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
+#define FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
+
+#include <fst/shortest-path.h>
+#include <fst/extensions/pdt/paren.h>
+#include <fst/extensions/pdt/pdt.h>
+
+#include <unordered_map>
+using std::tr1::unordered_map;
+using std::tr1::unordered_multimap;
+#include <tr1/unordered_set>
+using std::tr1::unordered_set;
+using std::tr1::unordered_multiset;
+#include <stack>
+#include <vector>
+using std::vector;
+
+namespace fst {
+
+template <class Arc, class Queue>
+struct PdtShortestPathOptions {
+ bool keep_parentheses;
+ bool path_gc;
+
+ PdtShortestPathOptions(bool kp = false, bool gc = true)
+ : keep_parentheses(kp), path_gc(gc) {}
+};
+
+
+// Class to store PDT shortest path results. Stores shortest path
+// tree info 'Distance()', Parent(), and ArcParent() information keyed
+// on two types:
+// (1) By SearchState: This is a usual node in a shortest path tree but:
+// (a) is w.r.t a PDT search state - a pair of a PDT state and
+// a 'start' state, which is either the PDT start state or
+// the destination state of an open parenthesis.
+// (b) the Distance() is from this 'start' state to the search state.
+// (c) Parent().state is kNoLabel for the 'start' state.
+//
+// (2) By ParenSpec: This connects shortest path trees depending on the
+// the parenthesis taken. Given the parenthesis spec:
+// (a) the Distance() is from the Parent() 'start' state to the
+// parenthesis destination state.
+// (b) the ArcParent() is the parenthesis arc.
+template <class Arc>
+class PdtShortestPathData {
+ public:
+ static const uint8 kFinal;
+
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Weight Weight;
+ typedef typename Arc::Label Label;
+
+ struct SearchState {
+ SearchState() : state(kNoStateId), start(kNoStateId) {}
+
+ SearchState(StateId s, StateId t) : state(s), start(t) {}
+
+ bool operator==(const SearchState &s) const {
+ if (&s == this)
+ return true;
+ return s.state == this->state && s.start == this->start;
+ }
+
+ StateId state; // PDT state
+ StateId start; // PDT paren 'source' state
+ };
+
+
+ // Specifies paren id, source and dest 'start' states of a paren.
+ // These are the 'start' states of the respective sub-graphs.
+ struct ParenSpec {
+ ParenSpec()
+ : paren_id(kNoLabel), src_start(kNoStateId), dest_start(kNoStateId) {}
+
+ ParenSpec(Label id, StateId s, StateId d)
+ : paren_id(id), src_start(s), dest_start(d) {}
+
+ Label paren_id; // Id of parenthesis
+ StateId src_start; // sub-graph 'start' state for paren source.
+ StateId dest_start; // sub-graph 'start' state for paren dest.
+
+ bool operator==(const ParenSpec &x) const {
+ if (&x == this)
+ return true;
+ return x.paren_id == this->paren_id &&
+ x.src_start == this->src_start &&
+ x.dest_start == this->dest_start;
+ }
+ };
+
+ struct SearchData {
+ SearchData() : distance(Weight::Zero()),
+ parent(kNoStateId, kNoStateId),
+ paren_id(kNoLabel),
+ flags(0) {}
+
+ Weight distance; // Distance to this state from PDT 'start' state
+ SearchState parent; // Parent state in shortest path tree
+ int16 paren_id; // If parent arc has paren, paren ID, o.w. kNoLabel
+ uint8 flags; // First byte reserved for PdtShortestPathData use
+ };
+
+ PdtShortestPathData(bool gc)
+ : state_(kNoStateId, kNoStateId),
+ paren_(kNoLabel, kNoStateId, kNoStateId),
+ gc_(gc),
+ nstates_(0),
+ ngc_(0),
+ finished_(false) {}
+
+ ~PdtShortestPathData() {
+ VLOG(1) << "opm size: " << paren_map_.size();
+ VLOG(1) << "# of search states: " << nstates_;
+ if (gc_)
+ VLOG(1) << "# of GC'd search states: " << ngc_;
+ }
+
+ void Clear() {
+ search_map_.clear();
+ search_multimap_.clear();
+ paren_map_.clear();
+ state_ = SearchState(kNoStateId, kNoStateId);
+ nstates_ = 0;
+ ngc_ = 0;
+ }
+
+ Weight Distance(SearchState s) const {
+ SearchData *data = GetSearchData(s);
+ return data->distance;
+ }
+
+ Weight Distance(const ParenSpec &paren) const {
+ SearchData *data = GetSearchData(paren);
+ return data->distance;
+ }
+
+ SearchState Parent(SearchState s) const {
+ SearchData *data = GetSearchData(s);
+ return data->parent;
+ }
+
+ SearchState Parent(const ParenSpec &paren) const {
+ SearchData *data = GetSearchData(paren);
+ return data->parent;
+ }
+
+ Label ParenId(SearchState s) const {
+ SearchData *data = GetSearchData(s);
+ return data->paren_id;
+ }
+
+ uint8 Flags(SearchState s) const {
+ SearchData *data = GetSearchData(s);
+ return data->flags;
+ }
+
+ void SetDistance(SearchState s, Weight w) {
+ SearchData *data = GetSearchData(s);
+ data->distance = w;
+ }
+
+ void SetDistance(const ParenSpec &paren, Weight w) {
+ SearchData *data = GetSearchData(paren);
+ data->distance = w;
+ }
+
+ void SetParent(SearchState s, SearchState p) {
+ SearchData *data = GetSearchData(s);
+ data->parent = p;
+ }
+
+ void SetParent(const ParenSpec &paren, SearchState p) {
+ SearchData *data = GetSearchData(paren);
+ data->parent = p;
+ }
+
+ void SetParenId(SearchState s, Label p) {
+ if (p >= 32768)
+ FSTERROR() << "PdtShortestPathData: Paren ID does not fits in an int16";
+ SearchData *data = GetSearchData(s);
+ data->paren_id = p;
+ }
+
+ void SetFlags(SearchState s, uint8 f, uint8 mask) {
+ SearchData *data = GetSearchData(s);
+ data->flags &= ~mask;
+ data->flags |= f & mask;
+ }
+
+ void GC(StateId s);
+
+ void Finish() { finished_ = true; }
+
+ private:
+ static const Arc kNoArc;
+ static const size_t kPrime0;
+ static const size_t kPrime1;
+ static const uint8 kInited;
+ static const uint8 kMarked;
+
+ // Hash for search state
+ struct SearchStateHash {
+ size_t operator()(const SearchState &s) const {
+ return s.state + s.start * kPrime0;
+ }
+ };
+
+ // Hash for paren map
+ struct ParenHash {
+ size_t operator()(const ParenSpec &paren) const {
+ return paren.paren_id + paren.src_start * kPrime0 +
+ paren.dest_start * kPrime1;
+ }
+ };
+
+ typedef unordered_map<SearchState, SearchData, SearchStateHash> SearchMap;
+
+ typedef unordered_multimap<StateId, StateId> SearchMultimap;
+
+ // Hash map from paren spec to open paren data
+ typedef unordered_map<ParenSpec, SearchData, ParenHash> ParenMap;
+
+ SearchData *GetSearchData(SearchState s) const {
+ if (s == state_)
+ return state_data_;
+ if (finished_) {
+ typename SearchMap::iterator it = search_map_.find(s);
+ if (it == search_map_.end())
+ return &null_search_data_;
+ state_ = s;
+ return state_data_ = &(it->second);
+ } else {
+ state_ = s;
+ state_data_ = &search_map_[s];
+ if (!(state_data_->flags & kInited)) {
+ ++nstates_;
+ if (gc_)
+ search_multimap_.insert(make_pair(s.start, s.state));
+ state_data_->flags = kInited;
+ }
+ return state_data_;
+ }
+ }
+
+ SearchData *GetSearchData(ParenSpec paren) const {
+ if (paren == paren_)
+ return paren_data_;
+ if (finished_) {
+ typename ParenMap::iterator it = paren_map_.find(paren);
+ if (it == paren_map_.end())
+ return &null_search_data_;
+ paren_ = paren;
+ return state_data_ = &(it->second);
+ } else {
+ paren_ = paren;
+ return paren_data_ = &paren_map_[paren];
+ }
+ }
+
+ mutable SearchMap search_map_; // Maps from search state to data
+ mutable SearchMultimap search_multimap_; // Maps from 'start' to subgraph
+ mutable ParenMap paren_map_; // Maps paren spec to search data
+ mutable SearchState state_; // Last state accessed
+ mutable SearchData *state_data_; // Last state data accessed
+ mutable ParenSpec paren_; // Last paren spec accessed
+ mutable SearchData *paren_data_; // Last paren data accessed
+ bool gc_; // Allow GC?
+ mutable size_t nstates_; // Total number of search states
+ size_t ngc_; // Number of GC'd search states
+ mutable SearchData null_search_data_; // Null search data
+ bool finished_; // Read-only access when true
+
+ DISALLOW_COPY_AND_ASSIGN(PdtShortestPathData);
+};
+
+// Deletes inaccessible search data from a given 'start' (open paren dest)
+// state. Assumes 'final' (close paren source or PDT final) states have
+// been flagged 'kFinal'.
+template<class Arc>
+void PdtShortestPathData<Arc>::GC(StateId start) {
+ if (!gc_)
+ return;
+ vector<StateId> final;
+ for (typename SearchMultimap::iterator mmit = search_multimap_.find(start);
+ mmit != search_multimap_.end() && mmit->first == start;
+ ++mmit) {
+ SearchState s(mmit->second, start);
+ const SearchData &data = search_map_[s];
+ if (data.flags & kFinal)
+ final.push_back(s.state);
+ }
+
+ // Mark phase
+ for (size_t i = 0; i < final.size(); ++i) {
+ SearchState s(final[i], start);
+ while (s.state != kNoLabel) {
+ SearchData *sdata = &search_map_[s];
+ if (sdata->flags & kMarked)
+ break;
+ sdata->flags |= kMarked;
+ SearchState p = sdata->parent;
+ if (p.start != start && p.start != kNoLabel) { // entering sub-subgraph
+ ParenSpec paren(sdata->paren_id, s.start, p.start);
+ SearchData *pdata = &paren_map_[paren];
+ s = pdata->parent;
+ } else {
+ s = p;
+ }
+ }
+ }
+
+ // Sweep phase
+ typename SearchMultimap::iterator mmit = search_multimap_.find(start);
+ while (mmit != search_multimap_.end() && mmit->first == start) {
+ SearchState s(mmit->second, start);
+ typename SearchMap::iterator mit = search_map_.find(s);
+ const SearchData &data = mit->second;
+ if (!(data.flags & kMarked)) {
+ search_map_.erase(mit);
+ ++ngc_;
+ }
+ search_multimap_.erase(mmit++);
+ }
+}
+
+template<class Arc> const Arc PdtShortestPathData<Arc>::kNoArc
+ = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId);
+
+template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime0 = 7853;
+
+template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime1 = 7867;
+
+template<class Arc> const uint8 PdtShortestPathData<Arc>::kInited = 0x01;
+
+template<class Arc> const uint8 PdtShortestPathData<Arc>::kFinal = 0x02;
+
+template<class Arc> const uint8 PdtShortestPathData<Arc>::kMarked = 0x04;
+
+
+// This computes the single source shortest (balanced) path (SSSP)
+// through a weighted PDT that has a bounded stack (i.e. is expandable
+// as an FST). It is a generalization of the classic SSSP graph
+// algorithm that removes a state s from a queue (defined by a
+// user-provided queue type) and relaxes the destination states of
+// transitions leaving s. In this PDT version, states that have
+// entering open parentheses are treated as source states for a
+// sub-graph SSSP problem with the shortest path up to the open
+// parenthesis being first saved. When a close parenthesis is then
+// encountered any balancing open parenthesis is examined for this
+// saved information and multiplied back. In this way, each sub-graph
+// is entered only once rather than repeatedly. If every state in the
+// input PDT has the property that there is a unique 'start' state for
+// it with entering open parentheses, then this algorithm is quite
+// straight-forward. In general, this will not be the case, so the
+// algorithm (implicitly) creates a new graph where each state is a
+// pair of an original state and a possible parenthesis 'start' state
+// for that state.
+template<class Arc, class Queue>
+class PdtShortestPath {
+ public:
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Weight Weight;
+ typedef typename Arc::Label Label;
+
+ typedef PdtShortestPathData<Arc> SpData;
+ typedef typename SpData::SearchState SearchState;
+ typedef typename SpData::ParenSpec ParenSpec;
+
+ typedef typename PdtParenReachable<Arc>::SetIterator StateSetIterator;
+ typedef typename PdtBalanceData<Arc>::SetIterator CloseSourceIterator;
+
+ PdtShortestPath(const Fst<Arc> &ifst,
+ const vector<pair<Label, Label> > &parens,
+ const PdtShortestPathOptions<Arc, Queue> &opts)
+ : kFinal(SpData::kFinal),
+ ifst_(ifst.Copy()),
+ parens_(parens),
+ keep_parens_(opts.keep_parentheses),
+ start_(ifst.Start()),
+ sp_data_(opts.path_gc),
+ error_(false) {
+
+ if ((Weight::Properties() & (kPath | kRightSemiring))
+ != (kPath | kRightSemiring)) {
+ FSTERROR() << "SingleShortestPath: Weight needs to have the path"
+ << " property and be right distributive: " << Weight::Type();
+ error_ = true;
+ }
+
+ for (Label i = 0; i < parens.size(); ++i) {
+ const pair<Label, Label> &p = parens[i];
+ paren_id_map_[p.first] = i;
+ paren_id_map_[p.second] = i;
+ }
+ };
+
+ ~PdtShortestPath() {
+ VLOG(1) << "# of input states: " << CountStates(*ifst_);
+ VLOG(1) << "# of enqueued: " << nenqueued_;
+ VLOG(1) << "cpmm size: " << close_paren_multimap_.size();
+ delete ifst_;
+ }
+
+ void ShortestPath(MutableFst<Arc> *ofst) {
+ Init(ofst);
+ GetDistance(start_);
+ GetPath();
+ sp_data_.Finish();
+ if (error_) ofst->SetProperties(kError, kError);
+ }
+
+ const PdtShortestPathData<Arc> &GetShortestPathData() const {
+ return sp_data_;
+ }
+
+ PdtBalanceData<Arc> *GetBalanceData() { return &balance_data_; }
+
+ private:
+ static const Arc kNoArc;
+ static const uint8 kEnqueued;
+ static const uint8 kExpanded;
+ const uint8 kFinal;
+
+ public:
+ // Hash multimap from close paren label to an paren arc.
+ typedef unordered_multimap<ParenState<Arc>, Arc,
+ typename ParenState<Arc>::Hash> CloseParenMultimap;
+
+ const CloseParenMultimap &GetCloseParenMultimap() const {
+ return close_paren_multimap_;
+ }
+
+ private:
+ void Init(MutableFst<Arc> *ofst);
+ void GetDistance(StateId start);
+ void ProcFinal(SearchState s);
+ void ProcArcs(SearchState s);
+ void ProcOpenParen(Label paren_id, SearchState s, Arc arc, Weight w);
+ void ProcCloseParen(Label paren_id, SearchState s, const Arc &arc, Weight w);
+ void ProcNonParen(SearchState s, const Arc &arc, Weight w);
+ void Relax(SearchState s, SearchState t, Arc arc, Weight w, Label paren_id);
+ void Enqueue(SearchState d);
+ void GetPath();
+ Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open);
+
+ Fst<Arc> *ifst_;
+ MutableFst<Arc> *ofst_;
+ const vector<pair<Label, Label> > &parens_;
+ bool keep_parens_;
+ Queue *state_queue_; // current state queue
+ StateId start_;
+ Weight f_distance_;
+ SearchState f_parent_;
+ SpData sp_data_;
+ unordered_map<Label, Label> paren_id_map_;
+ CloseParenMultimap close_paren_multimap_;
+ PdtBalanceData<Arc> balance_data_;
+ ssize_t nenqueued_;
+ bool error_;
+
+ DISALLOW_COPY_AND_ASSIGN(PdtShortestPath);
+};
+
+template<class Arc, class Queue>
+void PdtShortestPath<Arc, Queue>::Init(MutableFst<Arc> *ofst) {
+ ofst_ = ofst;
+ ofst->DeleteStates();
+ ofst->SetInputSymbols(ifst_->InputSymbols());
+ ofst->SetOutputSymbols(ifst_->OutputSymbols());
+
+ if (ifst_->Start() == kNoStateId)
+ return;
+
+ f_distance_ = Weight::Zero();
+ f_parent_ = SearchState(kNoStateId, kNoStateId);
+
+ sp_data_.Clear();
+ close_paren_multimap_.clear();
+ balance_data_.Clear();
+ nenqueued_ = 0;
+
+ // Find open parens per destination state and close parens per source state.
+ for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) {
+ StateId s = siter.Value();
+ for (ArcIterator<Fst<Arc> > aiter(*ifst_, s);
+ !aiter.Done(); aiter.Next()) {
+ const Arc &arc = aiter.Value();
+ typename unordered_map<Label, Label>::const_iterator pit
+ = paren_id_map_.find(arc.ilabel);
+ if (pit != paren_id_map_.end()) { // Is a paren?
+ Label paren_id = pit->second;
+ if (arc.ilabel == parens_[paren_id].first) { // Open paren
+ balance_data_.OpenInsert(paren_id, arc.nextstate);
+ } else { // Close paren
+ ParenState<Arc> paren_state(paren_id, s);
+ close_paren_multimap_.insert(make_pair(paren_state, arc));
+ }
+ }
+ }
+ }
+}
+
+// Computes the shortest distance stored in a recursive way. Each
+// sub-graph (i.e. different paren 'start' state) begins with weight One().
+template<class Arc, class Queue>
+void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) {
+ if (start == kNoStateId)
+ return;
+
+ Queue state_queue;
+ state_queue_ = &state_queue;
+ SearchState q(start, start);
+ Enqueue(q);
+ sp_data_.SetDistance(q, Weight::One());
+
+ while (!state_queue_->Empty()) {
+ StateId state = state_queue_->Head();
+ state_queue_->Dequeue();
+ SearchState s(state, start);
+ sp_data_.SetFlags(s, 0, kEnqueued);
+ ProcFinal(s);
+ ProcArcs(s);
+ sp_data_.SetFlags(s, kExpanded, kExpanded);
+ }
+ balance_data_.FinishInsert(start);
+ sp_data_.GC(start);
+}
+
+// Updates best complete path.
+template<class Arc, class Queue>
+void PdtShortestPath<Arc, Queue>::ProcFinal(SearchState s) {
+ if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) {
+ Weight w = Times(sp_data_.Distance(s),
+ ifst_->Final(s.state));
+ if (f_distance_ != Plus(f_distance_, w)) {
+ if (f_parent_.state != kNoStateId)
+ sp_data_.SetFlags(f_parent_, 0, kFinal);
+ sp_data_.SetFlags(s, kFinal, kFinal);
+
+ f_distance_ = Plus(f_distance_, w);
+ f_parent_ = s;
+ }
+ }
+}
+
+// Processes all arcs leaving the state s.
+template<class Arc, class Queue>
+void PdtShortestPath<Arc, Queue>::ProcArcs(SearchState s) {
+ for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state);
+ !aiter.Done();
+ aiter.Next()) {
+ Arc arc = aiter.Value();
+ Weight w = Times(sp_data_.Distance(s), arc.weight);
+
+ typename unordered_map<Label, Label>::const_iterator pit
+ = paren_id_map_.find(arc.ilabel);
+ if (pit != paren_id_map_.end()) { // Is a paren?
+ Label paren_id = pit->second;
+ if (arc.ilabel == parens_[paren_id].first)
+ ProcOpenParen(paren_id, s, arc, w);
+ else
+ ProcCloseParen(paren_id, s, arc, w);
+ } else {
+ ProcNonParen(s, arc, w);
+ }
+ }
+}
+
+// Saves the shortest path info for reaching this parenthesis
+// and starts a new SSSP in the sub-graph pointed to by the parenthesis
+// if previously unvisited. Otherwise it finds any previously encountered
+// closing parentheses and relaxes them using the recursively stored
+// shortest distance to them.
+template<class Arc, class Queue> inline
+void PdtShortestPath<Arc, Queue>::ProcOpenParen(
+ Label paren_id, SearchState s, Arc arc, Weight w) {
+
+ SearchState d(arc.nextstate, arc.nextstate);
+ ParenSpec paren(paren_id, s.start, d.start);
+ Weight pdist = sp_data_.Distance(paren);
+ if (pdist != Plus(pdist, w)) {
+ sp_data_.SetDistance(paren, w);
+ sp_data_.SetParent(paren, s);
+ Weight dist = sp_data_.Distance(d);
+ if (dist == Weight::Zero()) {
+ Queue *state_queue = state_queue_;
+ GetDistance(d.start);
+ state_queue_ = state_queue;
+ }
+ for (CloseSourceIterator set_iter =
+ balance_data_.Find(paren_id, arc.nextstate);
+ !set_iter.Done(); set_iter.Next()) {
+ SearchState cpstate(set_iter.Element(), d.start);
+ ParenState<Arc> paren_state(paren_id, cpstate.state);
+ for (typename CloseParenMultimap::const_iterator cpit =
+ close_paren_multimap_.find(paren_state);
+ cpit != close_paren_multimap_.end() && paren_state == cpit->first;
+ ++cpit) {
+ const Arc &cparc = cpit->second;
+ Weight cpw = Times(w, Times(sp_data_.Distance(cpstate),
+ cparc.weight));
+ Relax(cpstate, s, cparc, cpw, paren_id);
+ }
+ }
+ }
+}
+
+// Saves the correspondence between each closing parenthesis and its
+// balancing open parenthesis info. Relaxes any close parenthesis
+// destination state that has a balancing previously encountered open
+// parenthesis.
+template<class Arc, class Queue> inline
+void PdtShortestPath<Arc, Queue>::ProcCloseParen(
+ Label paren_id, SearchState s, const Arc &arc, Weight w) {
+ ParenState<Arc> paren_state(paren_id, s.start);
+ if (!(sp_data_.Flags(s) & kExpanded)) {
+ balance_data_.CloseInsert(paren_id, s.start, s.state);
+ sp_data_.SetFlags(s, kFinal, kFinal);
+ }
+}
+
+// For non-parentheses, classical relaxation.
+template<class Arc, class Queue> inline
+void PdtShortestPath<Arc, Queue>::ProcNonParen(
+ SearchState s, const Arc &arc, Weight w) {
+ Relax(s, s, arc, w, kNoLabel);
+}
+
+// Classical relaxation on the search graph for 'arc' from state 's'.
+// State 't' is in the same sub-graph as the nextstate should be (i.e.
+// has the same paren 'start'.
+template<class Arc, class Queue> inline
+void PdtShortestPath<Arc, Queue>::Relax(
+ SearchState s, SearchState t, Arc arc, Weight w, Label paren_id) {
+ SearchState d(arc.nextstate, t.start);
+ Weight dist = sp_data_.Distance(d);
+ if (dist != Plus(dist, w)) {
+ sp_data_.SetParent(d, s);
+ sp_data_.SetParenId(d, paren_id);
+ sp_data_.SetDistance(d, Plus(dist, w));
+ Enqueue(d);
+ }
+}
+
+template<class Arc, class Queue> inline
+void PdtShortestPath<Arc, Queue>::Enqueue(SearchState s) {
+ if (!(sp_data_.Flags(s) & kEnqueued)) {
+ state_queue_->Enqueue(s.state);
+ sp_data_.SetFlags(s, kEnqueued, kEnqueued);
+ ++nenqueued_;
+ } else {
+ state_queue_->Update(s.state);
+ }
+}
+
+// Follows parent pointers to find the shortest path. Uses a stack
+// since the shortest distance is stored recursively.
+template<class Arc, class Queue>
+void PdtShortestPath<Arc, Queue>::GetPath() {
+ SearchState s = f_parent_, d = SearchState(kNoStateId, kNoStateId);
+ StateId s_p = kNoStateId, d_p = kNoStateId;
+ Arc arc(kNoArc);
+ Label paren_id = kNoLabel;
+ stack<ParenSpec> paren_stack;
+ while (s.state != kNoStateId) {
+ d_p = s_p;
+ s_p = ofst_->AddState();
+ if (d.state == kNoStateId) {
+ ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state));
+ } else {
+ if (paren_id != kNoLabel) { // paren?
+ if (arc.ilabel == parens_[paren_id].first) { // open paren
+ paren_stack.pop();
+ } else { // close paren
+ ParenSpec paren(paren_id, d.start, s.start);
+ paren_stack.push(paren);
+ }
+ if (!keep_parens_)
+ arc.ilabel = arc.olabel = 0;
+ }
+ arc.nextstate = d_p;
+ ofst_->AddArc(s_p, arc);
+ }
+ d = s;
+ s = sp_data_.Parent(d);
+ paren_id = sp_data_.ParenId(d);
+ if (s.state != kNoStateId) {
+ arc = GetPathArc(s, d, paren_id, false);
+ } else if (!paren_stack.empty()) {
+ ParenSpec paren = paren_stack.top();
+ s = sp_data_.Parent(paren);
+ paren_id = paren.paren_id;
+ arc = GetPathArc(s, d, paren_id, true);
+ }
+ }
+ ofst_->SetStart(s_p);
+ ofst_->SetProperties(
+ ShortestPathProperties(ofst_->Properties(kFstProperties, false)),
+ kFstProperties);
+}
+
+
+// Finds transition with least weight between two states with label matching
+// paren_id and open/close paren type or a non-paren if kNoLabel.
+template<class Arc, class Queue>
+Arc PdtShortestPath<Arc, Queue>::GetPathArc(
+ SearchState s, SearchState d, Label paren_id, bool open_paren) {
+ Arc path_arc = kNoArc;
+ for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state);
+ !aiter.Done();
+ aiter.Next()) {
+ const Arc &arc = aiter.Value();
+ if (arc.nextstate != d.state)
+ continue;
+ Label arc_paren_id = kNoLabel;
+ typename unordered_map<Label, Label>::const_iterator pit
+ = paren_id_map_.find(arc.ilabel);
+ if (pit != paren_id_map_.end()) {
+ arc_paren_id = pit->second;
+ bool arc_open_paren = arc.ilabel == parens_[arc_paren_id].first;
+ if (arc_open_paren != open_paren)
+ continue;
+ }
+ if (arc_paren_id != paren_id)
+ continue;
+ if (arc.weight == Plus(arc.weight, path_arc.weight))
+ path_arc = arc;
+ }
+ if (path_arc.nextstate == kNoStateId) {
+ FSTERROR() << "PdtShortestPath::GetPathArc failed to find arc";
+ error_ = true;
+ }
+ return path_arc;
+}
+
+template<class Arc, class Queue>
+const Arc PdtShortestPath<Arc, Queue>::kNoArc
+ = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId);
+
+template<class Arc, class Queue>
+const uint8 PdtShortestPath<Arc, Queue>::kEnqueued = 0x10;
+
+template<class Arc, class Queue>
+const uint8 PdtShortestPath<Arc, Queue>::kExpanded = 0x20;
+
+template<class Arc, class Queue>
+void ShortestPath(const Fst<Arc> &ifst,
+ const vector<pair<typename Arc::Label,
+ typename Arc::Label> > &parens,
+ MutableFst<Arc> *ofst,
+ const PdtShortestPathOptions<Arc, Queue> &opts) {
+ PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
+ psp.ShortestPath(ofst);
+}
+
+template<class Arc>
+void ShortestPath(const Fst<Arc> &ifst,
+ const vector<pair<typename Arc::Label,
+ typename Arc::Label> > &parens,
+ MutableFst<Arc> *ofst) {
+ typedef FifoQueue<typename Arc::StateId> Queue;
+ PdtShortestPathOptions<Arc, Queue> opts;
+ PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
+ psp.ShortestPath(ofst);
+}
+
+} // namespace fst
+
+#endif // FST_EXTENSIONS_PDT_SHORTEST_PATH_H__