aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/extensions/pdt/expand.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/extensions/pdt/expand.h')
-rw-r--r--src/include/fst/extensions/pdt/expand.h975
1 files changed, 975 insertions, 0 deletions
diff --git a/src/include/fst/extensions/pdt/expand.h b/src/include/fst/extensions/pdt/expand.h
new file mode 100644
index 0000000..f464403
--- /dev/null
+++ b/src/include/fst/extensions/pdt/expand.h
@@ -0,0 +1,975 @@
+// expand.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Expand a PDT to an FST.
+
+#ifndef FST_EXTENSIONS_PDT_EXPAND_H__
+#define FST_EXTENSIONS_PDT_EXPAND_H__
+
+#include <vector>
+using std::vector;
+
+#include <fst/extensions/pdt/pdt.h>
+#include <fst/extensions/pdt/paren.h>
+#include <fst/extensions/pdt/shortest-path.h>
+#include <fst/extensions/pdt/reverse.h>
+#include <fst/cache.h>
+#include <fst/mutable-fst.h>
+#include <fst/queue.h>
+#include <fst/state-table.h>
+#include <fst/test-properties.h>
+
+namespace fst {
+
+template <class Arc>
+struct ExpandFstOptions : public CacheOptions {
+ bool keep_parentheses;
+ PdtStack<typename Arc::StateId, typename Arc::Label> *stack;
+ PdtStateTable<typename Arc::StateId, typename Arc::StateId> *state_table;
+
+ ExpandFstOptions(
+ const CacheOptions &opts = CacheOptions(),
+ bool kp = false,
+ PdtStack<typename Arc::StateId, typename Arc::Label> *s = 0,
+ PdtStateTable<typename Arc::StateId, typename Arc::StateId> *st = 0)
+ : CacheOptions(opts), keep_parentheses(kp), stack(s), state_table(st) {}
+};
+
+// Properties for an expanded PDT.
+inline uint64 ExpandProperties(uint64 inprops) {
+ return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted);
+}
+
+
+// Implementation class for ExpandFst
+template <class A>
+class ExpandFstImpl
+ : public CacheImpl<A> {
+ public:
+ using FstImpl<A>::SetType;
+ using FstImpl<A>::SetProperties;
+ using FstImpl<A>::Properties;
+ using FstImpl<A>::SetInputSymbols;
+ using FstImpl<A>::SetOutputSymbols;
+
+ using CacheBaseImpl< CacheState<A> >::PushArc;
+ using CacheBaseImpl< CacheState<A> >::HasArcs;
+ using CacheBaseImpl< CacheState<A> >::HasFinal;
+ using CacheBaseImpl< CacheState<A> >::HasStart;
+ using CacheBaseImpl< CacheState<A> >::SetArcs;
+ using CacheBaseImpl< CacheState<A> >::SetFinal;
+ using CacheBaseImpl< CacheState<A> >::SetStart;
+
+ typedef A Arc;
+ typedef typename A::Label Label;
+ typedef typename A::Weight Weight;
+ typedef typename A::StateId StateId;
+ typedef StateId StackId;
+ typedef PdtStateTuple<StateId, StackId> StateTuple;
+
+ ExpandFstImpl(const Fst<A> &fst,
+ const vector<pair<typename Arc::Label,
+ typename Arc::Label> > &parens,
+ const ExpandFstOptions<A> &opts)
+ : CacheImpl<A>(opts), fst_(fst.Copy()),
+ stack_(opts.stack ? opts.stack: new PdtStack<StateId, Label>(parens)),
+ state_table_(opts.state_table ? opts.state_table :
+ new PdtStateTable<StateId, StackId>()),
+ own_stack_(opts.stack == 0), own_state_table_(opts.state_table == 0),
+ keep_parentheses_(opts.keep_parentheses) {
+ SetType("expand");
+
+ uint64 props = fst.Properties(kFstProperties, false);
+ SetProperties(ExpandProperties(props), kCopyProperties);
+
+ SetInputSymbols(fst.InputSymbols());
+ SetOutputSymbols(fst.OutputSymbols());
+ }
+
+ ExpandFstImpl(const ExpandFstImpl &impl)
+ : CacheImpl<A>(impl),
+ fst_(impl.fst_->Copy(true)),
+ stack_(new PdtStack<StateId, Label>(*impl.stack_)),
+ state_table_(new PdtStateTable<StateId, StackId>()),
+ own_stack_(true), own_state_table_(true),
+ keep_parentheses_(impl.keep_parentheses_) {
+ SetType("expand");
+ SetProperties(impl.Properties(), kCopyProperties);
+ SetInputSymbols(impl.InputSymbols());
+ SetOutputSymbols(impl.OutputSymbols());
+ }
+
+ ~ExpandFstImpl() {
+ delete fst_;
+ if (own_stack_)
+ delete stack_;
+ if (own_state_table_)
+ delete state_table_;
+ }
+
+ StateId Start() {
+ if (!HasStart()) {
+ StateId s = fst_->Start();
+ if (s == kNoStateId)
+ return kNoStateId;
+ StateTuple tuple(s, 0);
+ StateId start = state_table_->FindState(tuple);
+ SetStart(start);
+ }
+ return CacheImpl<A>::Start();
+ }
+
+ Weight Final(StateId s) {
+ if (!HasFinal(s)) {
+ const StateTuple &tuple = state_table_->Tuple(s);
+ Weight w = fst_->Final(tuple.state_id);
+ if (w != Weight::Zero() && tuple.stack_id == 0)
+ SetFinal(s, w);
+ else
+ SetFinal(s, Weight::Zero());
+ }
+ return CacheImpl<A>::Final(s);
+ }
+
+ size_t NumArcs(StateId s) {
+ if (!HasArcs(s)) {
+ ExpandState(s);
+ }
+ return CacheImpl<A>::NumArcs(s);
+ }
+
+ size_t NumInputEpsilons(StateId s) {
+ if (!HasArcs(s))
+ ExpandState(s);
+ return CacheImpl<A>::NumInputEpsilons(s);
+ }
+
+ size_t NumOutputEpsilons(StateId s) {
+ if (!HasArcs(s))
+ ExpandState(s);
+ return CacheImpl<A>::NumOutputEpsilons(s);
+ }
+
+ void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
+ if (!HasArcs(s))
+ ExpandState(s);
+ CacheImpl<A>::InitArcIterator(s, data);
+ }
+
+ // Computes the outgoing transitions from a state, creating new destination
+ // states as needed.
+ void ExpandState(StateId s) {
+ StateTuple tuple = state_table_->Tuple(s);
+ for (ArcIterator< Fst<A> > aiter(*fst_, tuple.state_id);
+ !aiter.Done(); aiter.Next()) {
+ Arc arc = aiter.Value();
+ StackId stack_id = stack_->Find(tuple.stack_id, arc.ilabel);
+ if (stack_id == -1) {
+ // Non-matching close parenthesis
+ continue;
+ } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) {
+ // Stack push/pop
+ arc.ilabel = arc.olabel = 0;
+ }
+
+ StateTuple ntuple(arc.nextstate, stack_id);
+ arc.nextstate = state_table_->FindState(ntuple);
+ PushArc(s, arc);
+ }
+ SetArcs(s);
+ }
+
+ const PdtStack<StackId, Label> &GetStack() const { return *stack_; }
+
+ const PdtStateTable<StateId, StackId> &GetStateTable() const {
+ return *state_table_;
+ }
+
+ private:
+ const Fst<A> *fst_;
+
+ PdtStack<StackId, Label> *stack_;
+ PdtStateTable<StateId, StackId> *state_table_;
+ bool own_stack_;
+ bool own_state_table_;
+ bool keep_parentheses_;
+
+ void operator=(const ExpandFstImpl<A> &); // disallow
+};
+
+// Expands a pushdown transducer (PDT) encoded as an FST into an FST.
+// This version is a delayed Fst. In the PDT, some transitions are
+// labeled with open or close parentheses. To be interpreted as a PDT,
+// the parens must balance on a path. The open-close parenthesis label
+// pairs are passed in 'parens'. The expansion enforces the
+// parenthesis constraints. The PDT must be expandable as an FST.
+//
+// This class attaches interface to implementation and handles
+// reference counting, delegating most methods to ImplToFst.
+template <class A>
+class ExpandFst : public ImplToFst< ExpandFstImpl<A> > {
+ public:
+ friend class ArcIterator< ExpandFst<A> >;
+ friend class StateIterator< ExpandFst<A> >;
+
+ typedef A Arc;
+ typedef typename A::Label Label;
+ typedef typename A::Weight Weight;
+ typedef typename A::StateId StateId;
+ typedef StateId StackId;
+ typedef CacheState<A> State;
+ typedef ExpandFstImpl<A> Impl;
+
+ ExpandFst(const Fst<A> &fst,
+ const vector<pair<typename Arc::Label,
+ typename Arc::Label> > &parens)
+ : ImplToFst<Impl>(new Impl(fst, parens, ExpandFstOptions<A>())) {}
+
+ ExpandFst(const Fst<A> &fst,
+ const vector<pair<typename Arc::Label,
+ typename Arc::Label> > &parens,
+ const ExpandFstOptions<A> &opts)
+ : ImplToFst<Impl>(new Impl(fst, parens, opts)) {}
+
+ // See Fst<>::Copy() for doc.
+ ExpandFst(const ExpandFst<A> &fst, bool safe = false)
+ : ImplToFst<Impl>(fst, safe) {}
+
+ // Get a copy of this ExpandFst. See Fst<>::Copy() for further doc.
+ virtual ExpandFst<A> *Copy(bool safe = false) const {
+ return new ExpandFst<A>(*this, safe);
+ }
+
+ virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
+
+ virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
+ GetImpl()->InitArcIterator(s, data);
+ }
+
+ const PdtStack<StackId, Label> &GetStack() const {
+ return GetImpl()->GetStack();
+ }
+
+ const PdtStateTable<StateId, StackId> &GetStateTable() const {
+ return GetImpl()->GetStateTable();
+ }
+
+ private:
+ // Makes visible to friends.
+ Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
+
+ void operator=(const ExpandFst<A> &fst); // Disallow
+};
+
+
+// Specialization for ExpandFst.
+template<class A>
+class StateIterator< ExpandFst<A> >
+ : public CacheStateIterator< ExpandFst<A> > {
+ public:
+ explicit StateIterator(const ExpandFst<A> &fst)
+ : CacheStateIterator< ExpandFst<A> >(fst, fst.GetImpl()) {}
+};
+
+
+// Specialization for ExpandFst.
+template <class A>
+class ArcIterator< ExpandFst<A> >
+ : public CacheArcIterator< ExpandFst<A> > {
+ public:
+ typedef typename A::StateId StateId;
+
+ ArcIterator(const ExpandFst<A> &fst, StateId s)
+ : CacheArcIterator< ExpandFst<A> >(fst.GetImpl(), s) {
+ if (!fst.GetImpl()->HasArcs(s))
+ fst.GetImpl()->ExpandState(s);
+ }
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(ArcIterator);
+};
+
+
+template <class A> inline
+void ExpandFst<A>::InitStateIterator(StateIteratorData<A> *data) const
+{
+ data->base = new StateIterator< ExpandFst<A> >(*this);
+}
+
+//
+// PrunedExpand Class
+//
+
+// Prunes the delayed expansion of a pushdown transducer (PDT) encoded
+// as an FST into an FST. In the PDT, some transitions are labeled
+// with open or close parentheses. To be interpreted as a PDT, the
+// parens must balance on a path. The open-close parenthesis label
+// pairs are passed in 'parens'. The expansion enforces the
+// parenthesis constraints.
+//
+// The algorithm works by visiting the delayed ExpandFst using a
+// shortest-stack first queue discipline and relies on the
+// shortest-distance information computed using a reverse
+// shortest-path call to perform the pruning.
+//
+// The algorithm maintains the same state ordering between the ExpandFst
+// being visited 'efst_' and the result of pruning written into the
+// MutableFst 'ofst_' to improve readability of the code.
+//
+template <class A>
+class PrunedExpand {
+ public:
+ typedef A Arc;
+ typedef typename A::Label Label;
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+ typedef StateId StackId;
+ typedef PdtStack<StackId, Label> Stack;
+ typedef PdtStateTable<StateId, StackId> StateTable;
+ typedef typename PdtBalanceData<Arc>::SetIterator SetIterator;
+
+ // Constructor taking as input a PDT specified by 'ifst' and 'parens'.
+ // 'keep_parentheses' specifies whether parentheses are replaced by
+ // epsilons or not during the expansion. 'opts' is the cache options
+ // used to instantiate the underlying ExpandFst.
+ PrunedExpand(const Fst<A> &ifst,
+ const vector<pair<Label, Label> > &parens,
+ bool keep_parentheses = false,
+ const CacheOptions &opts = CacheOptions())
+ : ifst_(ifst.Copy()),
+ keep_parentheses_(keep_parentheses),
+ stack_(parens),
+ efst_(ifst, parens,
+ ExpandFstOptions<Arc>(opts, true, &stack_, &state_table_)),
+ queue_(state_table_, stack_, stack_length_, distance_, fdistance_) {
+ Reverse(*ifst_, parens, &rfst_);
+ VectorFst<Arc> path;
+ reverse_shortest_path_ = new SP(
+ rfst_, parens,
+ PdtShortestPathOptions<A, FifoQueue<StateId> >(true, false));
+ reverse_shortest_path_->ShortestPath(&path);
+ balance_data_ = reverse_shortest_path_->GetBalanceData()->Reverse(
+ rfst_.NumStates(), 10, -1);
+
+ InitCloseParenMultimap(parens);
+ }
+
+ ~PrunedExpand() {
+ delete ifst_;
+ delete reverse_shortest_path_;
+ delete balance_data_;
+ }
+
+ // Expands and prunes with weight threshold 'threshold' the input PDT.
+ // Writes the result in 'ofst'.
+ void Expand(MutableFst<A> *ofst, const Weight &threshold);
+
+ private:
+ static const uint8 kEnqueued;
+ static const uint8 kExpanded;
+ static const uint8 kSourceState;
+
+ // Comparison functor used by the queue:
+ // 1. states corresponding to shortest stack first,
+ // 2. among stacks of the same length, reverse lexicographic order is used,
+ // 3. among states with the same stack, shortest-first order is used.
+ class StackCompare {
+ public:
+ StackCompare(const StateTable &st,
+ const Stack &s, const vector<StackId> &sl,
+ const vector<Weight> &d, const vector<Weight> &fd)
+ : state_table_(st), stack_(s), stack_length_(sl),
+ distance_(d), fdistance_(fd) {}
+
+ bool operator()(StateId s1, StateId s2) const {
+ StackId si1 = state_table_.Tuple(s1).stack_id;
+ StackId si2 = state_table_.Tuple(s2).stack_id;
+ if (stack_length_[si1] < stack_length_[si2])
+ return true;
+ if (stack_length_[si1] > stack_length_[si2])
+ return false;
+ // If stack id equal, use A*
+ if (si1 == si2) {
+ Weight w1 = (s1 < distance_.size()) && (s1 < fdistance_.size()) ?
+ Times(distance_[s1], fdistance_[s1]) : Weight::Zero();
+ Weight w2 = (s2 < distance_.size()) && (s2 < fdistance_.size()) ?
+ Times(distance_[s2], fdistance_[s2]) : Weight::Zero();
+ return less_(w1, w2);
+ }
+ // If lenghts are equal, use reverse lexico.
+ for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) {
+ if (stack_.Top(si1) < stack_.Top(si2)) return true;
+ if (stack_.Top(si1) > stack_.Top(si2)) return false;
+ }
+ return false;
+ }
+
+ private:
+ const StateTable &state_table_;
+ const Stack &stack_;
+ const vector<StackId> &stack_length_;
+ const vector<Weight> &distance_;
+ const vector<Weight> &fdistance_;
+ NaturalLess<Weight> less_;
+ };
+
+ class ShortestStackFirstQueue
+ : public ShortestFirstQueue<StateId, StackCompare> {
+ public:
+ ShortestStackFirstQueue(
+ const PdtStateTable<StateId, StackId> &st,
+ const Stack &s,
+ const vector<StackId> &sl,
+ const vector<Weight> &d, const vector<Weight> &fd)
+ : ShortestFirstQueue<StateId, StackCompare>(
+ StackCompare(st, s, sl, d, fd)) {}
+ };
+
+
+ void InitCloseParenMultimap(const vector<pair<Label, Label> > &parens);
+ Weight DistanceToDest(StateId state, StateId source) const;
+ uint8 Flags(StateId s) const;
+ void SetFlags(StateId s, uint8 flags, uint8 mask);
+ Weight Distance(StateId s) const;
+ void SetDistance(StateId s, Weight w);
+ Weight FinalDistance(StateId s) const;
+ void SetFinalDistance(StateId s, Weight w);
+ StateId SourceState(StateId s) const;
+ void SetSourceState(StateId s, StateId p);
+ void AddStateAndEnqueue(StateId s);
+ void Relax(StateId s, const A &arc, Weight w);
+ bool PruneArc(StateId s, const A &arc);
+ void ProcStart();
+ void ProcFinal(StateId s);
+ bool ProcNonParen(StateId s, const A &arc, bool add_arc);
+ bool ProcOpenParen(StateId s, const A &arc, StackId si, StackId nsi);
+ bool ProcCloseParen(StateId s, const A &arc);
+ void ProcDestStates(StateId s, StackId si);
+
+ Fst<A> *ifst_; // Input PDT
+ VectorFst<Arc> rfst_; // Reversed PDT
+ bool keep_parentheses_; // Keep parentheses in ofst?
+ StateTable state_table_; // State table for efst_
+ Stack stack_; // Stack trie
+ ExpandFst<Arc> efst_; // Expanded PDT
+ vector<StackId> stack_length_; // Length of stack for given stack id
+ vector<Weight> distance_; // Distance from initial state in efst_/ofst
+ vector<Weight> fdistance_; // Distance to final states in efst_/ofst
+ ShortestStackFirstQueue queue_; // Queue used to visit efst_
+ vector<uint8> flags_; // Status flags for states in efst_/ofst
+ vector<StateId> sources_; // PDT source state for each expanded state
+
+ typedef PdtShortestPath<Arc, FifoQueue<StateId> > SP;
+ typedef typename SP::CloseParenMultimap ParenMultimap;
+ SP *reverse_shortest_path_; // Shortest path for rfst_
+ PdtBalanceData<Arc> *balance_data_; // Not owned by shortest_path_
+ ParenMultimap close_paren_multimap_; // Maps open paren arcs to
+ // balancing close paren arcs.
+
+ MutableFst<Arc> *ofst_; // Output fst
+ Weight limit_; // Weight limit
+
+ typedef unordered_map<StateId, Weight> DestMap;
+ DestMap dest_map_;
+ StackId current_stack_id_;
+ // 'current_stack_id_' is the stack id of the states currently at the top
+ // of queue, i.e., the states currently being popped and processed.
+ // 'dest_map_' maps a state 's' in 'ifst_' that is the source
+ // of a close parentheses matching the top of 'current_stack_id_; to
+ // the shortest-distance from '(s, current_stack_id_)' to the final
+ // states in 'efst_'.
+ ssize_t current_paren_id_; // Paren id at top of current stack
+ ssize_t cached_stack_id_;
+ StateId cached_source_;
+ slist<pair<StateId, Weight> > cached_dest_list_;
+ // 'cached_dest_list_' contains the set of pair of destination
+ // states and weight to final states for source state
+ // 'cached_source_' and paren id 'cached_paren_id': the set of
+ // source state of a close parenthesis with paren id
+ // 'cached_paren_id' balancing an incoming open parenthesis with
+ // paren id 'cached_paren_id' in state 'cached_source_'.
+
+ NaturalLess<Weight> less_;
+};
+
+template <class A> const uint8 PrunedExpand<A>::kEnqueued = 0x01;
+template <class A> const uint8 PrunedExpand<A>::kExpanded = 0x02;
+template <class A> const uint8 PrunedExpand<A>::kSourceState = 0x04;
+
+
+// Initializes close paren multimap, mapping pairs (s,paren_id) to
+// all the arcs out of s labeled with close parenthese for paren_id.
+template <class A>
+void PrunedExpand<A>::InitCloseParenMultimap(
+ const vector<pair<Label, Label> > &parens) {
+ unordered_map<Label, Label> paren_id_map;
+ for (Label i = 0; i < parens.size(); ++i) {
+ const pair<Label, Label> &p = parens[i];
+ paren_id_map[p.first] = i;
+ paren_id_map[p.second] = i;
+ }
+
+ for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) {
+ StateId s = siter.Value();
+ for (ArcIterator<Fst<Arc> > aiter(*ifst_, s);
+ !aiter.Done(); aiter.Next()) {
+ const Arc &arc = aiter.Value();
+ typename unordered_map<Label, Label>::const_iterator pit
+ = paren_id_map.find(arc.ilabel);
+ if (pit == paren_id_map.end()) continue;
+ if (arc.ilabel == parens[pit->second].second) { // Close paren
+ ParenState<Arc> paren_state(pit->second, s);
+ close_paren_multimap_.insert(make_pair(paren_state, arc));
+ }
+ }
+ }
+}
+
+
+// Returns the weight of the shortest balanced path from 'source' to 'dest'
+// in 'ifst_', 'dest' must be the source state of a close paren arc.
+template <class A>
+typename A::Weight PrunedExpand<A>::DistanceToDest(StateId source,
+ StateId dest) const {
+ typename SP::SearchState s(source + 1, dest + 1);
+ VLOG(2) << "D(" << source << ", " << dest << ") ="
+ << reverse_shortest_path_->GetShortestPathData().Distance(s);
+ return reverse_shortest_path_->GetShortestPathData().Distance(s);
+}
+
+// Returns the flags for state 's' in 'ofst_'.
+template <class A>
+uint8 PrunedExpand<A>::Flags(StateId s) const {
+ return s < flags_.size() ? flags_[s] : 0;
+}
+
+// Modifies the flags for state 's' in 'ofst_'.
+template <class A>
+void PrunedExpand<A>::SetFlags(StateId s, uint8 flags, uint8 mask) {
+ while (flags_.size() <= s) flags_.push_back(0);
+ flags_[s] &= ~mask;
+ flags_[s] |= flags & mask;
+}
+
+
+// Returns the shortest distance from the initial state to 's' in 'ofst_'.
+template <class A>
+typename A::Weight PrunedExpand<A>::Distance(StateId s) const {
+ return s < distance_.size() ? distance_[s] : Weight::Zero();
+}
+
+// Sets the shortest distance from the initial state to 's' in 'ofst_' to 'w'.
+template <class A>
+void PrunedExpand<A>::SetDistance(StateId s, Weight w) {
+ while (distance_.size() <= s ) distance_.push_back(Weight::Zero());
+ distance_[s] = w;
+}
+
+
+// Returns the shortest distance from 's' to the final states in 'ofst_'.
+template <class A>
+typename A::Weight PrunedExpand<A>::FinalDistance(StateId s) const {
+ return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
+}
+
+// Sets the shortest distance from 's' to the final states in 'ofst_' to 'w'.
+template <class A>
+void PrunedExpand<A>::SetFinalDistance(StateId s, Weight w) {
+ while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero());
+ fdistance_[s] = w;
+}
+
+// Returns the PDT "source" state of state 's' in 'ofst_'.
+template <class A>
+typename A::StateId PrunedExpand<A>::SourceState(StateId s) const {
+ return s < sources_.size() ? sources_[s] : kNoStateId;
+}
+
+// Sets the PDT "source" state of state 's' in 'ofst_' to state 'p' in 'ifst_'.
+template <class A>
+void PrunedExpand<A>::SetSourceState(StateId s, StateId p) {
+ while (sources_.size() <= s) sources_.push_back(kNoStateId);
+ sources_[s] = p;
+}
+
+// Adds state 's' of 'efst_' to 'ofst_' and inserts it in the queue,
+// modifying the flags for 's' accordingly.
+template <class A>
+void PrunedExpand<A>::AddStateAndEnqueue(StateId s) {
+ if (!(Flags(s) & (kEnqueued | kExpanded))) {
+ while (ofst_->NumStates() <= s) ofst_->AddState();
+ queue_.Enqueue(s);
+ SetFlags(s, kEnqueued, kEnqueued);
+ } else if (Flags(s) & kEnqueued) {
+ queue_.Update(s);
+ }
+ // TODO(allauzen): Check everything is fine when kExpanded?
+}
+
+// Relaxes arc 'arc' out of state 's' in 'ofst_':
+// * if the distance to 's' times the weight of 'arc' is smaller than
+// the currently stored distance for 'arc.nextstate',
+// updates 'Distance(arc.nextstate)' with new estimate;
+// * if 'fd' is less than the currently stored distance from 'arc.nextstate'
+// to the final state, updates with new estimate.
+template <class A>
+void PrunedExpand<A>::Relax(StateId s, const A &arc, Weight fd) {
+ Weight nd = Times(Distance(s), arc.weight);
+ if (less_(nd, Distance(arc.nextstate))) {
+ SetDistance(arc.nextstate, nd);
+ SetSourceState(arc.nextstate, SourceState(s));
+ }
+ if (less_(fd, FinalDistance(arc.nextstate)))
+ SetFinalDistance(arc.nextstate, fd);
+ VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to "
+ << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate)
+ << ", nd = " << nd;
+}
+
+// Returns 'true' if the arc 'arc' out of state 's' in 'efst_' needs to
+// be pruned.
+template <class A>
+bool PrunedExpand<A>::PruneArc(StateId s, const A &arc) {
+ VLOG(2) << "Prune ?";
+ Weight fd = Weight::Zero();
+
+ if ((cached_source_ != SourceState(s)) ||
+ (cached_stack_id_ != current_stack_id_)) {
+ cached_source_ = SourceState(s);
+ cached_stack_id_ = current_stack_id_;
+ cached_dest_list_.clear();
+ if (cached_source_ != ifst_->Start()) {
+ for (SetIterator set_iter =
+ balance_data_->Find(current_paren_id_, cached_source_);
+ !set_iter.Done(); set_iter.Next()) {
+ StateId dest = set_iter.Element();
+ typename DestMap::const_iterator iter = dest_map_.find(dest);
+ cached_dest_list_.push_front(*iter);
+ }
+ } else {
+ // TODO(allauzen): queue discipline should prevent this never
+ // from happening; replace by a check.
+ cached_dest_list_.push_front(
+ make_pair(rfst_.Start() -1, Weight::One()));
+ }
+ }
+
+ for (typename slist<pair<StateId, Weight> >::const_iterator iter =
+ cached_dest_list_.begin();
+ iter != cached_dest_list_.end();
+ ++iter) {
+ fd = Plus(fd,
+ Times(DistanceToDest(state_table_.Tuple(arc.nextstate).state_id,
+ iter->first),
+ iter->second));
+ }
+ Relax(s, arc, fd);
+ Weight w = Times(Distance(s), Times(arc.weight, fd));
+ return less_(limit_, w);
+}
+
+// Adds start state of 'efst_' to 'ofst_', enqueues it and initializes
+// the distance data structures.
+template <class A>
+void PrunedExpand<A>::ProcStart() {
+ StateId s = efst_.Start();
+ AddStateAndEnqueue(s);
+ ofst_->SetStart(s);
+ SetSourceState(s, ifst_->Start());
+
+ current_stack_id_ = 0;
+ current_paren_id_ = -1;
+ stack_length_.push_back(0);
+ dest_map_[rfst_.Start() - 1] = Weight::One(); // not needed
+
+ cached_source_ = ifst_->Start();
+ cached_stack_id_ = 0;
+ cached_dest_list_.push_front(
+ make_pair(rfst_.Start() -1, Weight::One()));
+
+ PdtStateTuple<StateId, StackId> tuple(rfst_.Start() - 1, 0);
+ SetFinalDistance(state_table_.FindState(tuple), Weight::One());
+ SetDistance(s, Weight::One());
+ SetFinalDistance(s, DistanceToDest(ifst_->Start(), rfst_.Start() - 1));
+ VLOG(2) << DistanceToDest(ifst_->Start(), rfst_.Start() - 1);
+}
+
+// Makes 's' final in 'ofst_' if shortest accepting path ending in 's'
+// is below threshold.
+template <class A>
+void PrunedExpand<A>::ProcFinal(StateId s) {
+ Weight final = efst_.Final(s);
+ if ((final == Weight::Zero()) || less_(limit_, Times(Distance(s), final)))
+ return;
+ ofst_->SetFinal(s, final);
+}
+
+// Returns true when arc (or meta-arc) 'arc' out of 's' in 'efst_' is
+// below the threshold. When 'add_arc' is true, 'arc' is added to 'ofst_'.
+template <class A>
+bool PrunedExpand<A>::ProcNonParen(StateId s, const A &arc, bool add_arc) {
+ VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate
+ << ", " << arc.ilabel << ":" << arc.olabel << " / " << arc.weight
+ << ", add_arc = " << (add_arc ? "true" : "false");
+ if (PruneArc(s, arc)) return false;
+ if(add_arc) ofst_->AddArc(s, arc);
+ AddStateAndEnqueue(arc.nextstate);
+ return true;
+}
+
+// Processes an open paren arc 'arc' out of state 's' in 'ofst_'.
+// When 'arc' is labeled with an open paren,
+// 1. considers each (shortest) balanced path starting in 's' by
+// taking 'arc' and ending by a close paren balancing the open
+// paren of 'arc' as a meta-arc, processes and prunes each meta-arc
+// as a non-paren arc, inserting its destination to the queue;
+// 2. if at least one of these meta-arcs has not been pruned,
+// adds the destination of 'arc' to 'ofst_' as a new source state
+// for the stack id 'nsi' and inserts it in the queue.
+template <class A>
+bool PrunedExpand<A>::ProcOpenParen(StateId s, const A &arc, StackId si,
+ StackId nsi) {
+ // Update the stack lenght when needed: |nsi| = |si| + 1.
+ while (stack_length_.size() <= nsi) stack_length_.push_back(-1);
+ if (stack_length_[nsi] == -1)
+ stack_length_[nsi] = stack_length_[si] + 1;
+
+ StateId ns = arc.nextstate;
+ VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id
+ << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")";
+ bool proc_arc = false;
+ Weight fd = Weight::Zero();
+ ssize_t paren_id = stack_.ParenId(arc.ilabel);
+ slist<StateId> sources;
+ for (SetIterator set_iter =
+ balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id);
+ !set_iter.Done(); set_iter.Next()) {
+ sources.push_front(set_iter.Element());
+ }
+ for (typename slist<StateId>::const_iterator sources_iter = sources.begin();
+ sources_iter != sources.end();
+ ++ sources_iter) {
+ StateId source = *sources_iter;
+ VLOG(2) << "Close paren source: " << source;
+ ParenState<Arc> paren_state(paren_id, source);
+ for (typename ParenMultimap::const_iterator iter =
+ close_paren_multimap_.find(paren_state);
+ iter != close_paren_multimap_.end() && paren_state == iter->first;
+ ++iter) {
+ Arc meta_arc = iter->second;
+ PdtStateTuple<StateId, StackId> tuple(meta_arc.nextstate, si);
+ meta_arc.nextstate = state_table_.FindState(tuple);
+ VLOG(2) << state_table_.Tuple(ns).state_id << ", " << source;
+ VLOG(2) << "Meta arc weight = " << arc.weight << " Times "
+ << DistanceToDest(state_table_.Tuple(ns).state_id, source)
+ << " Times " << meta_arc.weight;
+ meta_arc.weight = Times(
+ arc.weight,
+ Times(DistanceToDest(state_table_.Tuple(ns).state_id, source),
+ meta_arc.weight));
+ proc_arc |= ProcNonParen(s, meta_arc, false);
+ fd = Plus(fd, Times(
+ Times(
+ DistanceToDest(state_table_.Tuple(ns).state_id, source),
+ iter->second.weight),
+ FinalDistance(meta_arc.nextstate)));
+ }
+ }
+ if (proc_arc) {
+ VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate;
+ ofst_->AddArc(
+ s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
+ AddStateAndEnqueue(arc.nextstate);
+ Weight nd = Times(Distance(s), arc.weight);
+ if(less_(nd, Distance(arc.nextstate)))
+ SetDistance(arc.nextstate, nd);
+ // FinalDistance not necessary for source state since pruning
+ // decided using the meta-arcs above. But this is a problem with
+ // A*, hence:
+ if (less_(fd, FinalDistance(arc.nextstate)))
+ SetFinalDistance(arc.nextstate, fd);
+ SetFlags(arc.nextstate, kSourceState, kSourceState);
+ }
+ return proc_arc;
+}
+
+// Checks that shortest path through close paren arc in 'efst_' is
+// below threshold, if so adds it to 'ofst_'.
+template <class A>
+bool PrunedExpand<A>::ProcCloseParen(StateId s, const A &arc) {
+ Weight w = Times(Distance(s),
+ Times(arc.weight, FinalDistance(arc.nextstate)));
+ if (less_(limit_, w))
+ return false;
+ ofst_->AddArc(
+ s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
+ return true;
+}
+
+// When 's' in 'ofst_' is a source state for stack id 'si', identifies
+// all the corresponding possible destination states, that is, all the
+// states in 'ifst_' that have an outgoing close paren arc balancing
+// the incoming open paren taken to get to 's', and for each such
+// state 't', computes the shortest distance from (t, si) to the final
+// states in 'ofst_'. Stores this information in 'dest_map_'.
+template <class A>
+void PrunedExpand<A>::ProcDestStates(StateId s, StackId si) {
+ if (!(Flags(s) & kSourceState)) return;
+ if (si != current_stack_id_) {
+ dest_map_.clear();
+ current_stack_id_ = si;
+ current_paren_id_ = stack_.Top(current_stack_id_);
+ VLOG(2) << "StackID " << si << " dequeued for first time";
+ }
+ // TODO(allauzen): clean up source state business; rename current function to
+ // ProcSourceState.
+ SetSourceState(s, state_table_.Tuple(s).state_id);
+
+ ssize_t paren_id = stack_.Top(si);
+ for (SetIterator set_iter =
+ balance_data_->Find(paren_id, state_table_.Tuple(s).state_id);
+ !set_iter.Done(); set_iter.Next()) {
+ StateId dest_state = set_iter.Element();
+ if (dest_map_.find(dest_state) != dest_map_.end())
+ continue;
+ Weight dest_weight = Weight::Zero();
+ ParenState<Arc> paren_state(paren_id, dest_state);
+ for (typename ParenMultimap::const_iterator iter =
+ close_paren_multimap_.find(paren_state);
+ iter != close_paren_multimap_.end() && paren_state == iter->first;
+ ++iter) {
+ const Arc &arc = iter->second;
+ PdtStateTuple<StateId, StackId> tuple(arc.nextstate, stack_.Pop(si));
+ dest_weight = Plus(dest_weight,
+ Times(arc.weight,
+ FinalDistance(state_table_.FindState(tuple))));
+ }
+ dest_map_[dest_state] = dest_weight;
+ VLOG(2) << "State " << dest_state << " is a dest state for stack id "
+ << si << " with weight " << dest_weight;
+ }
+}
+
+// Expands and prunes with weight threshold 'threshold' the input PDT.
+// Writes the result in 'ofst'.
+template <class A>
+void PrunedExpand<A>::Expand(
+ MutableFst<A> *ofst, const typename A::Weight &threshold) {
+ ofst_ = ofst;
+ ofst_->DeleteStates();
+ ofst_->SetInputSymbols(ifst_->InputSymbols());
+ ofst_->SetOutputSymbols(ifst_->OutputSymbols());
+
+ limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold);
+ flags_.clear();
+
+ ProcStart();
+
+ while (!queue_.Empty()) {
+ StateId s = queue_.Head();
+ queue_.Dequeue();
+ SetFlags(s, kExpanded, kExpanded | kEnqueued);
+ VLOG(2) << s << " dequeued!";
+
+ ProcFinal(s);
+ StackId stack_id = state_table_.Tuple(s).stack_id;
+ ProcDestStates(s, stack_id);
+
+ for (ArcIterator<ExpandFst<Arc> > aiter(efst_, s);
+ !aiter.Done();
+ aiter.Next()) {
+ Arc arc = aiter.Value();
+ StackId nextstack_id = state_table_.Tuple(arc.nextstate).stack_id;
+ if (stack_id == nextstack_id)
+ ProcNonParen(s, arc, true);
+ else if (stack_id == stack_.Pop(nextstack_id))
+ ProcOpenParen(s, arc, stack_id, nextstack_id);
+ else
+ ProcCloseParen(s, arc);
+ }
+ VLOG(2) << "d[" << s << "] = " << Distance(s)
+ << ", fd[" << s << "] = " << FinalDistance(s);
+ }
+}
+
+//
+// Expand() Functions
+//
+
+template <class Arc>
+struct ExpandOptions {
+ bool connect;
+ bool keep_parentheses;
+ typename Arc::Weight weight_threshold;
+
+ ExpandOptions(bool c = true, bool k = false,
+ typename Arc::Weight w = Arc::Weight::Zero())
+ : connect(c), keep_parentheses(k), weight_threshold(w) {}
+};
+
+// Expands a pushdown transducer (PDT) encoded as an FST into an FST.
+// This version writes the expanded PDT result to a MutableFst.
+// In the PDT, some transitions are labeled with open or close
+// parentheses. To be interpreted as a PDT, the parens must balance on
+// a path. The open-close parenthesis label pairs are passed in
+// 'parens'. The expansion enforces the parenthesis constraints. The
+// PDT must be expandable as an FST.
+template <class Arc>
+void Expand(
+ const Fst<Arc> &ifst,
+ const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
+ MutableFst<Arc> *ofst,
+ const ExpandOptions<Arc> &opts) {
+ typedef typename Arc::Label Label;
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Weight Weight;
+ typedef typename ExpandFst<Arc>::StackId StackId;
+
+ ExpandFstOptions<Arc> eopts;
+ eopts.gc_limit = 0;
+ if (opts.weight_threshold == Weight::Zero()) {
+ eopts.keep_parentheses = opts.keep_parentheses;
+ *ofst = ExpandFst<Arc>(ifst, parens, eopts);
+ } else {
+ PrunedExpand<Arc> pruned_expand(ifst, parens, opts.keep_parentheses);
+ pruned_expand.Expand(ofst, opts.weight_threshold);
+ }
+
+ if (opts.connect)
+ Connect(ofst);
+}
+
+// Expands a pushdown transducer (PDT) encoded as an FST into an FST.
+// This version writes the expanded PDT result to a MutableFst.
+// In the PDT, some transitions are labeled with open or close
+// parentheses. To be interpreted as a PDT, the parens must balance on
+// a path. The open-close parenthesis label pairs are passed in
+// 'parens'. The expansion enforces the parenthesis constraints. The
+// PDT must be expandable as an FST.
+template<class Arc>
+void Expand(
+ const Fst<Arc> &ifst,
+ const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
+ MutableFst<Arc> *ofst,
+ bool connect = true, bool keep_parentheses = false) {
+ Expand(ifst, parens, ofst, ExpandOptions<Arc>(connect, keep_parentheses));
+}
+
+} // namespace fst
+
+#endif // FST_EXTENSIONS_PDT_EXPAND_H__