diff options
author | Ian Hodson <idh@google.com> | 2012-05-30 21:27:06 +0100 |
---|---|---|
committer | Ian Hodson <idh@google.com> | 2012-05-30 22:47:36 +0100 |
commit | f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2 (patch) | |
tree | b131ed907f9b2d5af09c0983b651e9e69bc6aab9 /src/include/fst/extensions/pdt/expand.h | |
parent | a92766f0a6ba4fac46cd6fd3856ef20c3b204f0d (diff) | |
download | openfst-f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2.tar.gz |
Add openfst to external, as used by GoogleTTSandroid-sdk-support_r11android-cts-4.2_r2android-cts-4.2_r1android-cts-4.1_r4android-cts-4.1_r2android-cts-4.1_r1android-4.2_r1android-4.2.2_r1.2android-4.2.2_r1.1android-4.2.2_r1android-4.2.1_r1.2android-4.2.1_r1.1android-4.2.1_r1android-4.1.2_r2.1android-4.1.2_r2android-4.1.2_r1android-4.1.1_r6.1android-4.1.1_r6android-4.1.1_r5android-4.1.1_r4android-4.1.1_r3android-4.1.1_r2android-4.1.1_r1.1android-4.1.1_r1tools_r22tools_r21jb-releasejb-mr1.1-releasejb-mr1.1-dev-plus-aospjb-mr1.1-devjb-mr1-releasejb-mr1-dev-plus-aospjb-mr1-devjb-mr0-releasejb-dev
Moved from GoogleTTS
Change-Id: I6bc6bdadaa53bd0f810b88443339f6d899502cc8
Diffstat (limited to 'src/include/fst/extensions/pdt/expand.h')
-rw-r--r-- | src/include/fst/extensions/pdt/expand.h | 975 |
1 files changed, 975 insertions, 0 deletions
diff --git a/src/include/fst/extensions/pdt/expand.h b/src/include/fst/extensions/pdt/expand.h new file mode 100644 index 0000000..f464403 --- /dev/null +++ b/src/include/fst/extensions/pdt/expand.h @@ -0,0 +1,975 @@ +// expand.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Expand a PDT to an FST. + +#ifndef FST_EXTENSIONS_PDT_EXPAND_H__ +#define FST_EXTENSIONS_PDT_EXPAND_H__ + +#include <vector> +using std::vector; + +#include <fst/extensions/pdt/pdt.h> +#include <fst/extensions/pdt/paren.h> +#include <fst/extensions/pdt/shortest-path.h> +#include <fst/extensions/pdt/reverse.h> +#include <fst/cache.h> +#include <fst/mutable-fst.h> +#include <fst/queue.h> +#include <fst/state-table.h> +#include <fst/test-properties.h> + +namespace fst { + +template <class Arc> +struct ExpandFstOptions : public CacheOptions { + bool keep_parentheses; + PdtStack<typename Arc::StateId, typename Arc::Label> *stack; + PdtStateTable<typename Arc::StateId, typename Arc::StateId> *state_table; + + ExpandFstOptions( + const CacheOptions &opts = CacheOptions(), + bool kp = false, + PdtStack<typename Arc::StateId, typename Arc::Label> *s = 0, + PdtStateTable<typename Arc::StateId, typename Arc::StateId> *st = 0) + : CacheOptions(opts), keep_parentheses(kp), stack(s), state_table(st) {} +}; + +// Properties for an expanded PDT. +inline uint64 ExpandProperties(uint64 inprops) { + return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted); +} + + +// Implementation class for ExpandFst +template <class A> +class ExpandFstImpl + : public CacheImpl<A> { + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::Properties; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + + using CacheBaseImpl< CacheState<A> >::PushArc; + using CacheBaseImpl< CacheState<A> >::HasArcs; + using CacheBaseImpl< CacheState<A> >::HasFinal; + using CacheBaseImpl< CacheState<A> >::HasStart; + using CacheBaseImpl< CacheState<A> >::SetArcs; + using CacheBaseImpl< CacheState<A> >::SetFinal; + using CacheBaseImpl< CacheState<A> >::SetStart; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef StateId StackId; + typedef PdtStateTuple<StateId, StackId> StateTuple; + + ExpandFstImpl(const Fst<A> &fst, + const vector<pair<typename Arc::Label, + typename Arc::Label> > &parens, + const ExpandFstOptions<A> &opts) + : CacheImpl<A>(opts), fst_(fst.Copy()), + stack_(opts.stack ? opts.stack: new PdtStack<StateId, Label>(parens)), + state_table_(opts.state_table ? opts.state_table : + new PdtStateTable<StateId, StackId>()), + own_stack_(opts.stack == 0), own_state_table_(opts.state_table == 0), + keep_parentheses_(opts.keep_parentheses) { + SetType("expand"); + + uint64 props = fst.Properties(kFstProperties, false); + SetProperties(ExpandProperties(props), kCopyProperties); + + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + ExpandFstImpl(const ExpandFstImpl &impl) + : CacheImpl<A>(impl), + fst_(impl.fst_->Copy(true)), + stack_(new PdtStack<StateId, Label>(*impl.stack_)), + state_table_(new PdtStateTable<StateId, StackId>()), + own_stack_(true), own_state_table_(true), + keep_parentheses_(impl.keep_parentheses_) { + SetType("expand"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~ExpandFstImpl() { + delete fst_; + if (own_stack_) + delete stack_; + if (own_state_table_) + delete state_table_; + } + + StateId Start() { + if (!HasStart()) { + StateId s = fst_->Start(); + if (s == kNoStateId) + return kNoStateId; + StateTuple tuple(s, 0); + StateId start = state_table_->FindState(tuple); + SetStart(start); + } + return CacheImpl<A>::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + const StateTuple &tuple = state_table_->Tuple(s); + Weight w = fst_->Final(tuple.state_id); + if (w != Weight::Zero() && tuple.stack_id == 0) + SetFinal(s, w); + else + SetFinal(s, Weight::Zero()); + } + return CacheImpl<A>::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) { + ExpandState(s); + } + return CacheImpl<A>::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) + ExpandState(s); + return CacheImpl<A>::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) + ExpandState(s); + return CacheImpl<A>::NumOutputEpsilons(s); + } + + void InitArcIterator(StateId s, ArcIteratorData<A> *data) { + if (!HasArcs(s)) + ExpandState(s); + CacheImpl<A>::InitArcIterator(s, data); + } + + // Computes the outgoing transitions from a state, creating new destination + // states as needed. + void ExpandState(StateId s) { + StateTuple tuple = state_table_->Tuple(s); + for (ArcIterator< Fst<A> > aiter(*fst_, tuple.state_id); + !aiter.Done(); aiter.Next()) { + Arc arc = aiter.Value(); + StackId stack_id = stack_->Find(tuple.stack_id, arc.ilabel); + if (stack_id == -1) { + // Non-matching close parenthesis + continue; + } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) { + // Stack push/pop + arc.ilabel = arc.olabel = 0; + } + + StateTuple ntuple(arc.nextstate, stack_id); + arc.nextstate = state_table_->FindState(ntuple); + PushArc(s, arc); + } + SetArcs(s); + } + + const PdtStack<StackId, Label> &GetStack() const { return *stack_; } + + const PdtStateTable<StateId, StackId> &GetStateTable() const { + return *state_table_; + } + + private: + const Fst<A> *fst_; + + PdtStack<StackId, Label> *stack_; + PdtStateTable<StateId, StackId> *state_table_; + bool own_stack_; + bool own_state_table_; + bool keep_parentheses_; + + void operator=(const ExpandFstImpl<A> &); // disallow +}; + +// Expands a pushdown transducer (PDT) encoded as an FST into an FST. +// This version is a delayed Fst. In the PDT, some transitions are +// labeled with open or close parentheses. To be interpreted as a PDT, +// the parens must balance on a path. The open-close parenthesis label +// pairs are passed in 'parens'. The expansion enforces the +// parenthesis constraints. The PDT must be expandable as an FST. +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template <class A> +class ExpandFst : public ImplToFst< ExpandFstImpl<A> > { + public: + friend class ArcIterator< ExpandFst<A> >; + friend class StateIterator< ExpandFst<A> >; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef StateId StackId; + typedef CacheState<A> State; + typedef ExpandFstImpl<A> Impl; + + ExpandFst(const Fst<A> &fst, + const vector<pair<typename Arc::Label, + typename Arc::Label> > &parens) + : ImplToFst<Impl>(new Impl(fst, parens, ExpandFstOptions<A>())) {} + + ExpandFst(const Fst<A> &fst, + const vector<pair<typename Arc::Label, + typename Arc::Label> > &parens, + const ExpandFstOptions<A> &opts) + : ImplToFst<Impl>(new Impl(fst, parens, opts)) {} + + // See Fst<>::Copy() for doc. + ExpandFst(const ExpandFst<A> &fst, bool safe = false) + : ImplToFst<Impl>(fst, safe) {} + + // Get a copy of this ExpandFst. See Fst<>::Copy() for further doc. + virtual ExpandFst<A> *Copy(bool safe = false) const { + return new ExpandFst<A>(*this, safe); + } + + virtual inline void InitStateIterator(StateIteratorData<A> *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + const PdtStack<StackId, Label> &GetStack() const { + return GetImpl()->GetStack(); + } + + const PdtStateTable<StateId, StackId> &GetStateTable() const { + return GetImpl()->GetStateTable(); + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const ExpandFst<A> &fst); // Disallow +}; + + +// Specialization for ExpandFst. +template<class A> +class StateIterator< ExpandFst<A> > + : public CacheStateIterator< ExpandFst<A> > { + public: + explicit StateIterator(const ExpandFst<A> &fst) + : CacheStateIterator< ExpandFst<A> >(fst, fst.GetImpl()) {} +}; + + +// Specialization for ExpandFst. +template <class A> +class ArcIterator< ExpandFst<A> > + : public CacheArcIterator< ExpandFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const ExpandFst<A> &fst, StateId s) + : CacheArcIterator< ExpandFst<A> >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->ExpandState(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + + +template <class A> inline +void ExpandFst<A>::InitStateIterator(StateIteratorData<A> *data) const +{ + data->base = new StateIterator< ExpandFst<A> >(*this); +} + +// +// PrunedExpand Class +// + +// Prunes the delayed expansion of a pushdown transducer (PDT) encoded +// as an FST into an FST. In the PDT, some transitions are labeled +// with open or close parentheses. To be interpreted as a PDT, the +// parens must balance on a path. The open-close parenthesis label +// pairs are passed in 'parens'. The expansion enforces the +// parenthesis constraints. +// +// The algorithm works by visiting the delayed ExpandFst using a +// shortest-stack first queue discipline and relies on the +// shortest-distance information computed using a reverse +// shortest-path call to perform the pruning. +// +// The algorithm maintains the same state ordering between the ExpandFst +// being visited 'efst_' and the result of pruning written into the +// MutableFst 'ofst_' to improve readability of the code. +// +template <class A> +class PrunedExpand { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + typedef StateId StackId; + typedef PdtStack<StackId, Label> Stack; + typedef PdtStateTable<StateId, StackId> StateTable; + typedef typename PdtBalanceData<Arc>::SetIterator SetIterator; + + // Constructor taking as input a PDT specified by 'ifst' and 'parens'. + // 'keep_parentheses' specifies whether parentheses are replaced by + // epsilons or not during the expansion. 'opts' is the cache options + // used to instantiate the underlying ExpandFst. + PrunedExpand(const Fst<A> &ifst, + const vector<pair<Label, Label> > &parens, + bool keep_parentheses = false, + const CacheOptions &opts = CacheOptions()) + : ifst_(ifst.Copy()), + keep_parentheses_(keep_parentheses), + stack_(parens), + efst_(ifst, parens, + ExpandFstOptions<Arc>(opts, true, &stack_, &state_table_)), + queue_(state_table_, stack_, stack_length_, distance_, fdistance_) { + Reverse(*ifst_, parens, &rfst_); + VectorFst<Arc> path; + reverse_shortest_path_ = new SP( + rfst_, parens, + PdtShortestPathOptions<A, FifoQueue<StateId> >(true, false)); + reverse_shortest_path_->ShortestPath(&path); + balance_data_ = reverse_shortest_path_->GetBalanceData()->Reverse( + rfst_.NumStates(), 10, -1); + + InitCloseParenMultimap(parens); + } + + ~PrunedExpand() { + delete ifst_; + delete reverse_shortest_path_; + delete balance_data_; + } + + // Expands and prunes with weight threshold 'threshold' the input PDT. + // Writes the result in 'ofst'. + void Expand(MutableFst<A> *ofst, const Weight &threshold); + + private: + static const uint8 kEnqueued; + static const uint8 kExpanded; + static const uint8 kSourceState; + + // Comparison functor used by the queue: + // 1. states corresponding to shortest stack first, + // 2. among stacks of the same length, reverse lexicographic order is used, + // 3. among states with the same stack, shortest-first order is used. + class StackCompare { + public: + StackCompare(const StateTable &st, + const Stack &s, const vector<StackId> &sl, + const vector<Weight> &d, const vector<Weight> &fd) + : state_table_(st), stack_(s), stack_length_(sl), + distance_(d), fdistance_(fd) {} + + bool operator()(StateId s1, StateId s2) const { + StackId si1 = state_table_.Tuple(s1).stack_id; + StackId si2 = state_table_.Tuple(s2).stack_id; + if (stack_length_[si1] < stack_length_[si2]) + return true; + if (stack_length_[si1] > stack_length_[si2]) + return false; + // If stack id equal, use A* + if (si1 == si2) { + Weight w1 = (s1 < distance_.size()) && (s1 < fdistance_.size()) ? + Times(distance_[s1], fdistance_[s1]) : Weight::Zero(); + Weight w2 = (s2 < distance_.size()) && (s2 < fdistance_.size()) ? + Times(distance_[s2], fdistance_[s2]) : Weight::Zero(); + return less_(w1, w2); + } + // If lenghts are equal, use reverse lexico. + for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) { + if (stack_.Top(si1) < stack_.Top(si2)) return true; + if (stack_.Top(si1) > stack_.Top(si2)) return false; + } + return false; + } + + private: + const StateTable &state_table_; + const Stack &stack_; + const vector<StackId> &stack_length_; + const vector<Weight> &distance_; + const vector<Weight> &fdistance_; + NaturalLess<Weight> less_; + }; + + class ShortestStackFirstQueue + : public ShortestFirstQueue<StateId, StackCompare> { + public: + ShortestStackFirstQueue( + const PdtStateTable<StateId, StackId> &st, + const Stack &s, + const vector<StackId> &sl, + const vector<Weight> &d, const vector<Weight> &fd) + : ShortestFirstQueue<StateId, StackCompare>( + StackCompare(st, s, sl, d, fd)) {} + }; + + + void InitCloseParenMultimap(const vector<pair<Label, Label> > &parens); + Weight DistanceToDest(StateId state, StateId source) const; + uint8 Flags(StateId s) const; + void SetFlags(StateId s, uint8 flags, uint8 mask); + Weight Distance(StateId s) const; + void SetDistance(StateId s, Weight w); + Weight FinalDistance(StateId s) const; + void SetFinalDistance(StateId s, Weight w); + StateId SourceState(StateId s) const; + void SetSourceState(StateId s, StateId p); + void AddStateAndEnqueue(StateId s); + void Relax(StateId s, const A &arc, Weight w); + bool PruneArc(StateId s, const A &arc); + void ProcStart(); + void ProcFinal(StateId s); + bool ProcNonParen(StateId s, const A &arc, bool add_arc); + bool ProcOpenParen(StateId s, const A &arc, StackId si, StackId nsi); + bool ProcCloseParen(StateId s, const A &arc); + void ProcDestStates(StateId s, StackId si); + + Fst<A> *ifst_; // Input PDT + VectorFst<Arc> rfst_; // Reversed PDT + bool keep_parentheses_; // Keep parentheses in ofst? + StateTable state_table_; // State table for efst_ + Stack stack_; // Stack trie + ExpandFst<Arc> efst_; // Expanded PDT + vector<StackId> stack_length_; // Length of stack for given stack id + vector<Weight> distance_; // Distance from initial state in efst_/ofst + vector<Weight> fdistance_; // Distance to final states in efst_/ofst + ShortestStackFirstQueue queue_; // Queue used to visit efst_ + vector<uint8> flags_; // Status flags for states in efst_/ofst + vector<StateId> sources_; // PDT source state for each expanded state + + typedef PdtShortestPath<Arc, FifoQueue<StateId> > SP; + typedef typename SP::CloseParenMultimap ParenMultimap; + SP *reverse_shortest_path_; // Shortest path for rfst_ + PdtBalanceData<Arc> *balance_data_; // Not owned by shortest_path_ + ParenMultimap close_paren_multimap_; // Maps open paren arcs to + // balancing close paren arcs. + + MutableFst<Arc> *ofst_; // Output fst + Weight limit_; // Weight limit + + typedef unordered_map<StateId, Weight> DestMap; + DestMap dest_map_; + StackId current_stack_id_; + // 'current_stack_id_' is the stack id of the states currently at the top + // of queue, i.e., the states currently being popped and processed. + // 'dest_map_' maps a state 's' in 'ifst_' that is the source + // of a close parentheses matching the top of 'current_stack_id_; to + // the shortest-distance from '(s, current_stack_id_)' to the final + // states in 'efst_'. + ssize_t current_paren_id_; // Paren id at top of current stack + ssize_t cached_stack_id_; + StateId cached_source_; + slist<pair<StateId, Weight> > cached_dest_list_; + // 'cached_dest_list_' contains the set of pair of destination + // states and weight to final states for source state + // 'cached_source_' and paren id 'cached_paren_id': the set of + // source state of a close parenthesis with paren id + // 'cached_paren_id' balancing an incoming open parenthesis with + // paren id 'cached_paren_id' in state 'cached_source_'. + + NaturalLess<Weight> less_; +}; + +template <class A> const uint8 PrunedExpand<A>::kEnqueued = 0x01; +template <class A> const uint8 PrunedExpand<A>::kExpanded = 0x02; +template <class A> const uint8 PrunedExpand<A>::kSourceState = 0x04; + + +// Initializes close paren multimap, mapping pairs (s,paren_id) to +// all the arcs out of s labeled with close parenthese for paren_id. +template <class A> +void PrunedExpand<A>::InitCloseParenMultimap( + const vector<pair<Label, Label> > &parens) { + unordered_map<Label, Label> paren_id_map; + for (Label i = 0; i < parens.size(); ++i) { + const pair<Label, Label> &p = parens[i]; + paren_id_map[p.first] = i; + paren_id_map[p.second] = i; + } + + for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + for (ArcIterator<Fst<Arc> > aiter(*ifst_, s); + !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + typename unordered_map<Label, Label>::const_iterator pit + = paren_id_map.find(arc.ilabel); + if (pit == paren_id_map.end()) continue; + if (arc.ilabel == parens[pit->second].second) { // Close paren + ParenState<Arc> paren_state(pit->second, s); + close_paren_multimap_.insert(make_pair(paren_state, arc)); + } + } + } +} + + +// Returns the weight of the shortest balanced path from 'source' to 'dest' +// in 'ifst_', 'dest' must be the source state of a close paren arc. +template <class A> +typename A::Weight PrunedExpand<A>::DistanceToDest(StateId source, + StateId dest) const { + typename SP::SearchState s(source + 1, dest + 1); + VLOG(2) << "D(" << source << ", " << dest << ") =" + << reverse_shortest_path_->GetShortestPathData().Distance(s); + return reverse_shortest_path_->GetShortestPathData().Distance(s); +} + +// Returns the flags for state 's' in 'ofst_'. +template <class A> +uint8 PrunedExpand<A>::Flags(StateId s) const { + return s < flags_.size() ? flags_[s] : 0; +} + +// Modifies the flags for state 's' in 'ofst_'. +template <class A> +void PrunedExpand<A>::SetFlags(StateId s, uint8 flags, uint8 mask) { + while (flags_.size() <= s) flags_.push_back(0); + flags_[s] &= ~mask; + flags_[s] |= flags & mask; +} + + +// Returns the shortest distance from the initial state to 's' in 'ofst_'. +template <class A> +typename A::Weight PrunedExpand<A>::Distance(StateId s) const { + return s < distance_.size() ? distance_[s] : Weight::Zero(); +} + +// Sets the shortest distance from the initial state to 's' in 'ofst_' to 'w'. +template <class A> +void PrunedExpand<A>::SetDistance(StateId s, Weight w) { + while (distance_.size() <= s ) distance_.push_back(Weight::Zero()); + distance_[s] = w; +} + + +// Returns the shortest distance from 's' to the final states in 'ofst_'. +template <class A> +typename A::Weight PrunedExpand<A>::FinalDistance(StateId s) const { + return s < fdistance_.size() ? fdistance_[s] : Weight::Zero(); +} + +// Sets the shortest distance from 's' to the final states in 'ofst_' to 'w'. +template <class A> +void PrunedExpand<A>::SetFinalDistance(StateId s, Weight w) { + while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero()); + fdistance_[s] = w; +} + +// Returns the PDT "source" state of state 's' in 'ofst_'. +template <class A> +typename A::StateId PrunedExpand<A>::SourceState(StateId s) const { + return s < sources_.size() ? sources_[s] : kNoStateId; +} + +// Sets the PDT "source" state of state 's' in 'ofst_' to state 'p' in 'ifst_'. +template <class A> +void PrunedExpand<A>::SetSourceState(StateId s, StateId p) { + while (sources_.size() <= s) sources_.push_back(kNoStateId); + sources_[s] = p; +} + +// Adds state 's' of 'efst_' to 'ofst_' and inserts it in the queue, +// modifying the flags for 's' accordingly. +template <class A> +void PrunedExpand<A>::AddStateAndEnqueue(StateId s) { + if (!(Flags(s) & (kEnqueued | kExpanded))) { + while (ofst_->NumStates() <= s) ofst_->AddState(); + queue_.Enqueue(s); + SetFlags(s, kEnqueued, kEnqueued); + } else if (Flags(s) & kEnqueued) { + queue_.Update(s); + } + // TODO(allauzen): Check everything is fine when kExpanded? +} + +// Relaxes arc 'arc' out of state 's' in 'ofst_': +// * if the distance to 's' times the weight of 'arc' is smaller than +// the currently stored distance for 'arc.nextstate', +// updates 'Distance(arc.nextstate)' with new estimate; +// * if 'fd' is less than the currently stored distance from 'arc.nextstate' +// to the final state, updates with new estimate. +template <class A> +void PrunedExpand<A>::Relax(StateId s, const A &arc, Weight fd) { + Weight nd = Times(Distance(s), arc.weight); + if (less_(nd, Distance(arc.nextstate))) { + SetDistance(arc.nextstate, nd); + SetSourceState(arc.nextstate, SourceState(s)); + } + if (less_(fd, FinalDistance(arc.nextstate))) + SetFinalDistance(arc.nextstate, fd); + VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to " + << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate) + << ", nd = " << nd; +} + +// Returns 'true' if the arc 'arc' out of state 's' in 'efst_' needs to +// be pruned. +template <class A> +bool PrunedExpand<A>::PruneArc(StateId s, const A &arc) { + VLOG(2) << "Prune ?"; + Weight fd = Weight::Zero(); + + if ((cached_source_ != SourceState(s)) || + (cached_stack_id_ != current_stack_id_)) { + cached_source_ = SourceState(s); + cached_stack_id_ = current_stack_id_; + cached_dest_list_.clear(); + if (cached_source_ != ifst_->Start()) { + for (SetIterator set_iter = + balance_data_->Find(current_paren_id_, cached_source_); + !set_iter.Done(); set_iter.Next()) { + StateId dest = set_iter.Element(); + typename DestMap::const_iterator iter = dest_map_.find(dest); + cached_dest_list_.push_front(*iter); + } + } else { + // TODO(allauzen): queue discipline should prevent this never + // from happening; replace by a check. + cached_dest_list_.push_front( + make_pair(rfst_.Start() -1, Weight::One())); + } + } + + for (typename slist<pair<StateId, Weight> >::const_iterator iter = + cached_dest_list_.begin(); + iter != cached_dest_list_.end(); + ++iter) { + fd = Plus(fd, + Times(DistanceToDest(state_table_.Tuple(arc.nextstate).state_id, + iter->first), + iter->second)); + } + Relax(s, arc, fd); + Weight w = Times(Distance(s), Times(arc.weight, fd)); + return less_(limit_, w); +} + +// Adds start state of 'efst_' to 'ofst_', enqueues it and initializes +// the distance data structures. +template <class A> +void PrunedExpand<A>::ProcStart() { + StateId s = efst_.Start(); + AddStateAndEnqueue(s); + ofst_->SetStart(s); + SetSourceState(s, ifst_->Start()); + + current_stack_id_ = 0; + current_paren_id_ = -1; + stack_length_.push_back(0); + dest_map_[rfst_.Start() - 1] = Weight::One(); // not needed + + cached_source_ = ifst_->Start(); + cached_stack_id_ = 0; + cached_dest_list_.push_front( + make_pair(rfst_.Start() -1, Weight::One())); + + PdtStateTuple<StateId, StackId> tuple(rfst_.Start() - 1, 0); + SetFinalDistance(state_table_.FindState(tuple), Weight::One()); + SetDistance(s, Weight::One()); + SetFinalDistance(s, DistanceToDest(ifst_->Start(), rfst_.Start() - 1)); + VLOG(2) << DistanceToDest(ifst_->Start(), rfst_.Start() - 1); +} + +// Makes 's' final in 'ofst_' if shortest accepting path ending in 's' +// is below threshold. +template <class A> +void PrunedExpand<A>::ProcFinal(StateId s) { + Weight final = efst_.Final(s); + if ((final == Weight::Zero()) || less_(limit_, Times(Distance(s), final))) + return; + ofst_->SetFinal(s, final); +} + +// Returns true when arc (or meta-arc) 'arc' out of 's' in 'efst_' is +// below the threshold. When 'add_arc' is true, 'arc' is added to 'ofst_'. +template <class A> +bool PrunedExpand<A>::ProcNonParen(StateId s, const A &arc, bool add_arc) { + VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate + << ", " << arc.ilabel << ":" << arc.olabel << " / " << arc.weight + << ", add_arc = " << (add_arc ? "true" : "false"); + if (PruneArc(s, arc)) return false; + if(add_arc) ofst_->AddArc(s, arc); + AddStateAndEnqueue(arc.nextstate); + return true; +} + +// Processes an open paren arc 'arc' out of state 's' in 'ofst_'. +// When 'arc' is labeled with an open paren, +// 1. considers each (shortest) balanced path starting in 's' by +// taking 'arc' and ending by a close paren balancing the open +// paren of 'arc' as a meta-arc, processes and prunes each meta-arc +// as a non-paren arc, inserting its destination to the queue; +// 2. if at least one of these meta-arcs has not been pruned, +// adds the destination of 'arc' to 'ofst_' as a new source state +// for the stack id 'nsi' and inserts it in the queue. +template <class A> +bool PrunedExpand<A>::ProcOpenParen(StateId s, const A &arc, StackId si, + StackId nsi) { + // Update the stack lenght when needed: |nsi| = |si| + 1. + while (stack_length_.size() <= nsi) stack_length_.push_back(-1); + if (stack_length_[nsi] == -1) + stack_length_[nsi] = stack_length_[si] + 1; + + StateId ns = arc.nextstate; + VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id + << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")"; + bool proc_arc = false; + Weight fd = Weight::Zero(); + ssize_t paren_id = stack_.ParenId(arc.ilabel); + slist<StateId> sources; + for (SetIterator set_iter = + balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id); + !set_iter.Done(); set_iter.Next()) { + sources.push_front(set_iter.Element()); + } + for (typename slist<StateId>::const_iterator sources_iter = sources.begin(); + sources_iter != sources.end(); + ++ sources_iter) { + StateId source = *sources_iter; + VLOG(2) << "Close paren source: " << source; + ParenState<Arc> paren_state(paren_id, source); + for (typename ParenMultimap::const_iterator iter = + close_paren_multimap_.find(paren_state); + iter != close_paren_multimap_.end() && paren_state == iter->first; + ++iter) { + Arc meta_arc = iter->second; + PdtStateTuple<StateId, StackId> tuple(meta_arc.nextstate, si); + meta_arc.nextstate = state_table_.FindState(tuple); + VLOG(2) << state_table_.Tuple(ns).state_id << ", " << source; + VLOG(2) << "Meta arc weight = " << arc.weight << " Times " + << DistanceToDest(state_table_.Tuple(ns).state_id, source) + << " Times " << meta_arc.weight; + meta_arc.weight = Times( + arc.weight, + Times(DistanceToDest(state_table_.Tuple(ns).state_id, source), + meta_arc.weight)); + proc_arc |= ProcNonParen(s, meta_arc, false); + fd = Plus(fd, Times( + Times( + DistanceToDest(state_table_.Tuple(ns).state_id, source), + iter->second.weight), + FinalDistance(meta_arc.nextstate))); + } + } + if (proc_arc) { + VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate; + ofst_->AddArc( + s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate)); + AddStateAndEnqueue(arc.nextstate); + Weight nd = Times(Distance(s), arc.weight); + if(less_(nd, Distance(arc.nextstate))) + SetDistance(arc.nextstate, nd); + // FinalDistance not necessary for source state since pruning + // decided using the meta-arcs above. But this is a problem with + // A*, hence: + if (less_(fd, FinalDistance(arc.nextstate))) + SetFinalDistance(arc.nextstate, fd); + SetFlags(arc.nextstate, kSourceState, kSourceState); + } + return proc_arc; +} + +// Checks that shortest path through close paren arc in 'efst_' is +// below threshold, if so adds it to 'ofst_'. +template <class A> +bool PrunedExpand<A>::ProcCloseParen(StateId s, const A &arc) { + Weight w = Times(Distance(s), + Times(arc.weight, FinalDistance(arc.nextstate))); + if (less_(limit_, w)) + return false; + ofst_->AddArc( + s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate)); + return true; +} + +// When 's' in 'ofst_' is a source state for stack id 'si', identifies +// all the corresponding possible destination states, that is, all the +// states in 'ifst_' that have an outgoing close paren arc balancing +// the incoming open paren taken to get to 's', and for each such +// state 't', computes the shortest distance from (t, si) to the final +// states in 'ofst_'. Stores this information in 'dest_map_'. +template <class A> +void PrunedExpand<A>::ProcDestStates(StateId s, StackId si) { + if (!(Flags(s) & kSourceState)) return; + if (si != current_stack_id_) { + dest_map_.clear(); + current_stack_id_ = si; + current_paren_id_ = stack_.Top(current_stack_id_); + VLOG(2) << "StackID " << si << " dequeued for first time"; + } + // TODO(allauzen): clean up source state business; rename current function to + // ProcSourceState. + SetSourceState(s, state_table_.Tuple(s).state_id); + + ssize_t paren_id = stack_.Top(si); + for (SetIterator set_iter = + balance_data_->Find(paren_id, state_table_.Tuple(s).state_id); + !set_iter.Done(); set_iter.Next()) { + StateId dest_state = set_iter.Element(); + if (dest_map_.find(dest_state) != dest_map_.end()) + continue; + Weight dest_weight = Weight::Zero(); + ParenState<Arc> paren_state(paren_id, dest_state); + for (typename ParenMultimap::const_iterator iter = + close_paren_multimap_.find(paren_state); + iter != close_paren_multimap_.end() && paren_state == iter->first; + ++iter) { + const Arc &arc = iter->second; + PdtStateTuple<StateId, StackId> tuple(arc.nextstate, stack_.Pop(si)); + dest_weight = Plus(dest_weight, + Times(arc.weight, + FinalDistance(state_table_.FindState(tuple)))); + } + dest_map_[dest_state] = dest_weight; + VLOG(2) << "State " << dest_state << " is a dest state for stack id " + << si << " with weight " << dest_weight; + } +} + +// Expands and prunes with weight threshold 'threshold' the input PDT. +// Writes the result in 'ofst'. +template <class A> +void PrunedExpand<A>::Expand( + MutableFst<A> *ofst, const typename A::Weight &threshold) { + ofst_ = ofst; + ofst_->DeleteStates(); + ofst_->SetInputSymbols(ifst_->InputSymbols()); + ofst_->SetOutputSymbols(ifst_->OutputSymbols()); + + limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold); + flags_.clear(); + + ProcStart(); + + while (!queue_.Empty()) { + StateId s = queue_.Head(); + queue_.Dequeue(); + SetFlags(s, kExpanded, kExpanded | kEnqueued); + VLOG(2) << s << " dequeued!"; + + ProcFinal(s); + StackId stack_id = state_table_.Tuple(s).stack_id; + ProcDestStates(s, stack_id); + + for (ArcIterator<ExpandFst<Arc> > aiter(efst_, s); + !aiter.Done(); + aiter.Next()) { + Arc arc = aiter.Value(); + StackId nextstack_id = state_table_.Tuple(arc.nextstate).stack_id; + if (stack_id == nextstack_id) + ProcNonParen(s, arc, true); + else if (stack_id == stack_.Pop(nextstack_id)) + ProcOpenParen(s, arc, stack_id, nextstack_id); + else + ProcCloseParen(s, arc); + } + VLOG(2) << "d[" << s << "] = " << Distance(s) + << ", fd[" << s << "] = " << FinalDistance(s); + } +} + +// +// Expand() Functions +// + +template <class Arc> +struct ExpandOptions { + bool connect; + bool keep_parentheses; + typename Arc::Weight weight_threshold; + + ExpandOptions(bool c = true, bool k = false, + typename Arc::Weight w = Arc::Weight::Zero()) + : connect(c), keep_parentheses(k), weight_threshold(w) {} +}; + +// Expands a pushdown transducer (PDT) encoded as an FST into an FST. +// This version writes the expanded PDT result to a MutableFst. +// In the PDT, some transitions are labeled with open or close +// parentheses. To be interpreted as a PDT, the parens must balance on +// a path. The open-close parenthesis label pairs are passed in +// 'parens'. The expansion enforces the parenthesis constraints. The +// PDT must be expandable as an FST. +template <class Arc> +void Expand( + const Fst<Arc> &ifst, + const vector<pair<typename Arc::Label, typename Arc::Label> > &parens, + MutableFst<Arc> *ofst, + const ExpandOptions<Arc> &opts) { + typedef typename Arc::Label Label; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef typename ExpandFst<Arc>::StackId StackId; + + ExpandFstOptions<Arc> eopts; + eopts.gc_limit = 0; + if (opts.weight_threshold == Weight::Zero()) { + eopts.keep_parentheses = opts.keep_parentheses; + *ofst = ExpandFst<Arc>(ifst, parens, eopts); + } else { + PrunedExpand<Arc> pruned_expand(ifst, parens, opts.keep_parentheses); + pruned_expand.Expand(ofst, opts.weight_threshold); + } + + if (opts.connect) + Connect(ofst); +} + +// Expands a pushdown transducer (PDT) encoded as an FST into an FST. +// This version writes the expanded PDT result to a MutableFst. +// In the PDT, some transitions are labeled with open or close +// parentheses. To be interpreted as a PDT, the parens must balance on +// a path. The open-close parenthesis label pairs are passed in +// 'parens'. The expansion enforces the parenthesis constraints. The +// PDT must be expandable as an FST. +template<class Arc> +void Expand( + const Fst<Arc> &ifst, + const vector<pair<typename Arc::Label, typename Arc::Label> > &parens, + MutableFst<Arc> *ofst, + bool connect = true, bool keep_parentheses = false) { + Expand(ifst, parens, ofst, ExpandOptions<Arc>(connect, keep_parentheses)); +} + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_EXPAND_H__ |