diff options
author | Ian Hodson <idh@google.com> | 2012-05-30 21:27:06 +0100 |
---|---|---|
committer | Ian Hodson <idh@google.com> | 2012-05-30 22:47:36 +0100 |
commit | f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2 (patch) | |
tree | b131ed907f9b2d5af09c0983b651e9e69bc6aab9 /src/include/fst/extensions/pdt | |
parent | a92766f0a6ba4fac46cd6fd3856ef20c3b204f0d (diff) | |
download | openfst-f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2.tar.gz |
Add openfst to external, as used by GoogleTTSandroid-sdk-support_r11android-cts-4.2_r2android-cts-4.2_r1android-cts-4.1_r4android-cts-4.1_r2android-cts-4.1_r1android-4.2_r1android-4.2.2_r1.2android-4.2.2_r1.1android-4.2.2_r1android-4.2.1_r1.2android-4.2.1_r1.1android-4.2.1_r1android-4.1.2_r2.1android-4.1.2_r2android-4.1.2_r1android-4.1.1_r6.1android-4.1.1_r6android-4.1.1_r5android-4.1.1_r4android-4.1.1_r3android-4.1.1_r2android-4.1.1_r1.1android-4.1.1_r1tools_r22tools_r21jb-releasejb-mr1.1-releasejb-mr1.1-dev-plus-aospjb-mr1.1-devjb-mr1-releasejb-mr1-dev-plus-aospjb-mr1-devjb-mr0-releasejb-dev
Moved from GoogleTTS
Change-Id: I6bc6bdadaa53bd0f810b88443339f6d899502cc8
Diffstat (limited to 'src/include/fst/extensions/pdt')
-rw-r--r-- | src/include/fst/extensions/pdt/collection.h | 122 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/compose.h | 146 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/expand.h | 975 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/info.h | 175 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/paren.h | 496 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/pdt.h | 212 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/pdtlib.h | 30 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/pdtscript.h | 284 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/replace.h | 192 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/reverse.h | 58 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/shortest-path.h | 790 |
11 files changed, 3480 insertions, 0 deletions
diff --git a/src/include/fst/extensions/pdt/collection.h b/src/include/fst/extensions/pdt/collection.h new file mode 100644 index 0000000..26be504 --- /dev/null +++ b/src/include/fst/extensions/pdt/collection.h @@ -0,0 +1,122 @@ +// collection.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Class to store a collection of sets with elements of type T. + +#ifndef FST_EXTENSIONS_PDT_COLLECTION_H__ +#define FST_EXTENSIONS_PDT_COLLECTION_H__ + +#include <algorithm> +#include <vector> +using std::vector; + +#include <fst/bi-table.h> + +namespace fst { + +// Stores a collection of non-empty sets with elements of type T. A +// default constructor, equality ==, a total order <, and an STL-style +// hash class must be defined on the elements. Provides signed +// integer ID (of type I) of each unique set. The IDs are allocated +// starting from 0 in order. +template <class I, class T> +class Collection { + public: + struct Node { // Trie node + I node_id; // Root is kNoNodeId; + T element; + + Node() : node_id(kNoNodeId), element(T()) {} + Node(I i, const T &t) : node_id(i), element(t) {} + + bool operator==(const Node& n) const { + return n.node_id == node_id && n.element == element; + } + }; + + struct NodeHash { + size_t operator()(const Node &n) const { + return n.node_id + hash_(n.element) * kPrime; + } + }; + + typedef CompactHashBiTable<I, Node, NodeHash> NodeTable; + + class SetIterator { + public: + SetIterator(I id, Node node, NodeTable *node_table) + :id_(id), node_(node), node_table_(node_table) {} + + bool Done() const { return id_ == kNoNodeId; } + + const T &Element() const { return node_.element; } + + void Next() { + id_ = node_.node_id; + if (id_ != kNoNodeId) + node_ = node_table_->FindEntry(id_); + } + + private: + I id_; // Iterator set node id + Node node_; // Iterator set node + NodeTable *node_table_; + }; + + Collection() {} + + // Lookups integer ID from set. If it doesn't exist, then adds it. + // Set elements should be in strict order (and therefore unique). + I FindId(const vector<T> &set) { + I node_id = kNoNodeId; + for (ssize_t i = set.size() - 1; i >= 0; --i) { + Node node(node_id, set[i]); + node_id = node_table_.FindId(node); + } + return node_id; + } + + // Finds set given integer ID. Returns true if ID corresponds + // to set. Use iterators below to traverse result. + SetIterator FindSet(I id) { + if (id < 0 && id >= node_table_.Size()) { + return SetIterator(kNoNodeId, Node(kNoNodeId, T()), &node_table_); + } else { + return SetIterator(id, node_table_.FindEntry(id), &node_table_); + } + } + + private: + static const I kNoNodeId; + static const size_t kPrime; + static std::tr1::hash<T> hash_; + + NodeTable node_table_; + + DISALLOW_COPY_AND_ASSIGN(Collection); +}; + +template<class I, class T> const I Collection<I, T>::kNoNodeId = -1; + +template <class I, class T> const size_t Collection<I, T>::kPrime = 7853; + +template <class I, class T> std::tr1::hash<T> Collection<I, T>::hash_; + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_COLLECTION_H__ diff --git a/src/include/fst/extensions/pdt/compose.h b/src/include/fst/extensions/pdt/compose.h new file mode 100644 index 0000000..364d76f --- /dev/null +++ b/src/include/fst/extensions/pdt/compose.h @@ -0,0 +1,146 @@ +// compose.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Compose a PDT and an FST. + +#ifndef FST_EXTENSIONS_PDT_COMPOSE_H__ +#define FST_EXTENSIONS_PDT_COMPOSE_H__ + +#include <fst/compose.h> + +namespace fst { + +// Class to setup composition options for PDT composition. +// Default is for the PDT as the first composition argument. +template <class Arc, bool left_pdt = true> +class PdtComposeOptions : public +ComposeFstOptions<Arc, + MultiEpsMatcher< Matcher<Fst<Arc> > >, + MultiEpsFilter<AltSequenceComposeFilter< + MultiEpsMatcher< + Matcher<Fst<Arc> > > > > > { + public: + typedef typename Arc::Label Label; + typedef MultiEpsMatcher< Matcher<Fst<Arc> > > PdtMatcher; + typedef MultiEpsFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter; + typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions; + using COptions::matcher1; + using COptions::matcher2; + using COptions::filter; + + PdtComposeOptions(const Fst<Arc> &ifst1, + const vector<pair<Label, Label> > &parens, + const Fst<Arc> &ifst2) { + matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsList); + matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsLoop); + + // Treat parens as multi-epsilons when composing. + for (size_t i = 0; i < parens.size(); ++i) { + matcher1->AddMultiEpsLabel(parens[i].first); + matcher1->AddMultiEpsLabel(parens[i].second); + matcher2->AddMultiEpsLabel(parens[i].first); + matcher2->AddMultiEpsLabel(parens[i].second); + } + + filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true); + } +}; + +// Class to setup composition options for PDT with FST composition. +// Specialization is for the FST as the first composition argument. +template <class Arc> +class PdtComposeOptions<Arc, false> : public +ComposeFstOptions<Arc, + MultiEpsMatcher< Matcher<Fst<Arc> > >, + MultiEpsFilter<SequenceComposeFilter< + MultiEpsMatcher< + Matcher<Fst<Arc> > > > > > { + public: + typedef typename Arc::Label Label; + typedef MultiEpsMatcher< Matcher<Fst<Arc> > > PdtMatcher; + typedef MultiEpsFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter; + typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions; + using COptions::matcher1; + using COptions::matcher2; + using COptions::filter; + + PdtComposeOptions(const Fst<Arc> &ifst1, + const Fst<Arc> &ifst2, + const vector<pair<Label, Label> > &parens) { + matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kMultiEpsLoop); + matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kMultiEpsList); + + // Treat parens as multi-epsilons when composing. + for (size_t i = 0; i < parens.size(); ++i) { + matcher1->AddMultiEpsLabel(parens[i].first); + matcher1->AddMultiEpsLabel(parens[i].second); + matcher2->AddMultiEpsLabel(parens[i].first); + matcher2->AddMultiEpsLabel(parens[i].second); + } + + filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, true); + } +}; + + +// Composes pushdown transducer (PDT) encoded as an FST (1st arg) and +// an FST (2nd arg) with the result also a PDT encoded as an Fst. (3rd arg). +// In the PDTs, some transitions are labeled with open or close +// parentheses. To be interpreted as a PDT, the parens must balance on +// a path (see PdtExpand()). The open-close parenthesis label pairs +// are passed in 'parens'. +template <class Arc> +void Compose(const Fst<Arc> &ifst1, + const vector<pair<typename Arc::Label, + typename Arc::Label> > &parens, + const Fst<Arc> &ifst2, + MutableFst<Arc> *ofst, + const ComposeOptions &opts = ComposeOptions()) { + + PdtComposeOptions<Arc, true> copts(ifst1, parens, ifst2); + copts.gc_limit = 0; + *ofst = ComposeFst<Arc>(ifst1, ifst2, copts); + if (opts.connect) + Connect(ofst); +} + + +// Composes an FST (1st arg) and pushdown transducer (PDT) encoded as +// an FST (2nd arg) with the result also a PDT encoded as an Fst (3rd arg). +// In the PDTs, some transitions are labeled with open or close +// parentheses. To be interpreted as a PDT, the parens must balance on +// a path (see ExpandFst()). The open-close parenthesis label pairs +// are passed in 'parens'. +template <class Arc> +void Compose(const Fst<Arc> &ifst1, + const Fst<Arc> &ifst2, + const vector<pair<typename Arc::Label, + typename Arc::Label> > &parens, + MutableFst<Arc> *ofst, + const ComposeOptions &opts = ComposeOptions()) { + + PdtComposeOptions<Arc, false> copts(ifst1, ifst2, parens); + copts.gc_limit = 0; + *ofst = ComposeFst<Arc>(ifst1, ifst2, copts); + if (opts.connect) + Connect(ofst); +} + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_COMPOSE_H__ diff --git a/src/include/fst/extensions/pdt/expand.h b/src/include/fst/extensions/pdt/expand.h new file mode 100644 index 0000000..f464403 --- /dev/null +++ b/src/include/fst/extensions/pdt/expand.h @@ -0,0 +1,975 @@ +// expand.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Expand a PDT to an FST. + +#ifndef FST_EXTENSIONS_PDT_EXPAND_H__ +#define FST_EXTENSIONS_PDT_EXPAND_H__ + +#include <vector> +using std::vector; + +#include <fst/extensions/pdt/pdt.h> +#include <fst/extensions/pdt/paren.h> +#include <fst/extensions/pdt/shortest-path.h> +#include <fst/extensions/pdt/reverse.h> +#include <fst/cache.h> +#include <fst/mutable-fst.h> +#include <fst/queue.h> +#include <fst/state-table.h> +#include <fst/test-properties.h> + +namespace fst { + +template <class Arc> +struct ExpandFstOptions : public CacheOptions { + bool keep_parentheses; + PdtStack<typename Arc::StateId, typename Arc::Label> *stack; + PdtStateTable<typename Arc::StateId, typename Arc::StateId> *state_table; + + ExpandFstOptions( + const CacheOptions &opts = CacheOptions(), + bool kp = false, + PdtStack<typename Arc::StateId, typename Arc::Label> *s = 0, + PdtStateTable<typename Arc::StateId, typename Arc::StateId> *st = 0) + : CacheOptions(opts), keep_parentheses(kp), stack(s), state_table(st) {} +}; + +// Properties for an expanded PDT. +inline uint64 ExpandProperties(uint64 inprops) { + return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted); +} + + +// Implementation class for ExpandFst +template <class A> +class ExpandFstImpl + : public CacheImpl<A> { + public: + using FstImpl<A>::SetType; + using FstImpl<A>::SetProperties; + using FstImpl<A>::Properties; + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + + using CacheBaseImpl< CacheState<A> >::PushArc; + using CacheBaseImpl< CacheState<A> >::HasArcs; + using CacheBaseImpl< CacheState<A> >::HasFinal; + using CacheBaseImpl< CacheState<A> >::HasStart; + using CacheBaseImpl< CacheState<A> >::SetArcs; + using CacheBaseImpl< CacheState<A> >::SetFinal; + using CacheBaseImpl< CacheState<A> >::SetStart; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef StateId StackId; + typedef PdtStateTuple<StateId, StackId> StateTuple; + + ExpandFstImpl(const Fst<A> &fst, + const vector<pair<typename Arc::Label, + typename Arc::Label> > &parens, + const ExpandFstOptions<A> &opts) + : CacheImpl<A>(opts), fst_(fst.Copy()), + stack_(opts.stack ? opts.stack: new PdtStack<StateId, Label>(parens)), + state_table_(opts.state_table ? opts.state_table : + new PdtStateTable<StateId, StackId>()), + own_stack_(opts.stack == 0), own_state_table_(opts.state_table == 0), + keep_parentheses_(opts.keep_parentheses) { + SetType("expand"); + + uint64 props = fst.Properties(kFstProperties, false); + SetProperties(ExpandProperties(props), kCopyProperties); + + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + ExpandFstImpl(const ExpandFstImpl &impl) + : CacheImpl<A>(impl), + fst_(impl.fst_->Copy(true)), + stack_(new PdtStack<StateId, Label>(*impl.stack_)), + state_table_(new PdtStateTable<StateId, StackId>()), + own_stack_(true), own_state_table_(true), + keep_parentheses_(impl.keep_parentheses_) { + SetType("expand"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~ExpandFstImpl() { + delete fst_; + if (own_stack_) + delete stack_; + if (own_state_table_) + delete state_table_; + } + + StateId Start() { + if (!HasStart()) { + StateId s = fst_->Start(); + if (s == kNoStateId) + return kNoStateId; + StateTuple tuple(s, 0); + StateId start = state_table_->FindState(tuple); + SetStart(start); + } + return CacheImpl<A>::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + const StateTuple &tuple = state_table_->Tuple(s); + Weight w = fst_->Final(tuple.state_id); + if (w != Weight::Zero() && tuple.stack_id == 0) + SetFinal(s, w); + else + SetFinal(s, Weight::Zero()); + } + return CacheImpl<A>::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) { + ExpandState(s); + } + return CacheImpl<A>::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) + ExpandState(s); + return CacheImpl<A>::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) + ExpandState(s); + return CacheImpl<A>::NumOutputEpsilons(s); + } + + void InitArcIterator(StateId s, ArcIteratorData<A> *data) { + if (!HasArcs(s)) + ExpandState(s); + CacheImpl<A>::InitArcIterator(s, data); + } + + // Computes the outgoing transitions from a state, creating new destination + // states as needed. + void ExpandState(StateId s) { + StateTuple tuple = state_table_->Tuple(s); + for (ArcIterator< Fst<A> > aiter(*fst_, tuple.state_id); + !aiter.Done(); aiter.Next()) { + Arc arc = aiter.Value(); + StackId stack_id = stack_->Find(tuple.stack_id, arc.ilabel); + if (stack_id == -1) { + // Non-matching close parenthesis + continue; + } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) { + // Stack push/pop + arc.ilabel = arc.olabel = 0; + } + + StateTuple ntuple(arc.nextstate, stack_id); + arc.nextstate = state_table_->FindState(ntuple); + PushArc(s, arc); + } + SetArcs(s); + } + + const PdtStack<StackId, Label> &GetStack() const { return *stack_; } + + const PdtStateTable<StateId, StackId> &GetStateTable() const { + return *state_table_; + } + + private: + const Fst<A> *fst_; + + PdtStack<StackId, Label> *stack_; + PdtStateTable<StateId, StackId> *state_table_; + bool own_stack_; + bool own_state_table_; + bool keep_parentheses_; + + void operator=(const ExpandFstImpl<A> &); // disallow +}; + +// Expands a pushdown transducer (PDT) encoded as an FST into an FST. +// This version is a delayed Fst. In the PDT, some transitions are +// labeled with open or close parentheses. To be interpreted as a PDT, +// the parens must balance on a path. The open-close parenthesis label +// pairs are passed in 'parens'. The expansion enforces the +// parenthesis constraints. The PDT must be expandable as an FST. +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template <class A> +class ExpandFst : public ImplToFst< ExpandFstImpl<A> > { + public: + friend class ArcIterator< ExpandFst<A> >; + friend class StateIterator< ExpandFst<A> >; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef StateId StackId; + typedef CacheState<A> State; + typedef ExpandFstImpl<A> Impl; + + ExpandFst(const Fst<A> &fst, + const vector<pair<typename Arc::Label, + typename Arc::Label> > &parens) + : ImplToFst<Impl>(new Impl(fst, parens, ExpandFstOptions<A>())) {} + + ExpandFst(const Fst<A> &fst, + const vector<pair<typename Arc::Label, + typename Arc::Label> > &parens, + const ExpandFstOptions<A> &opts) + : ImplToFst<Impl>(new Impl(fst, parens, opts)) {} + + // See Fst<>::Copy() for doc. + ExpandFst(const ExpandFst<A> &fst, bool safe = false) + : ImplToFst<Impl>(fst, safe) {} + + // Get a copy of this ExpandFst. See Fst<>::Copy() for further doc. + virtual ExpandFst<A> *Copy(bool safe = false) const { + return new ExpandFst<A>(*this, safe); + } + + virtual inline void InitStateIterator(StateIteratorData<A> *data) const; + + virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { + GetImpl()->InitArcIterator(s, data); + } + + const PdtStack<StackId, Label> &GetStack() const { + return GetImpl()->GetStack(); + } + + const PdtStateTable<StateId, StackId> &GetStateTable() const { + return GetImpl()->GetStateTable(); + } + + private: + // Makes visible to friends. + Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + + void operator=(const ExpandFst<A> &fst); // Disallow +}; + + +// Specialization for ExpandFst. +template<class A> +class StateIterator< ExpandFst<A> > + : public CacheStateIterator< ExpandFst<A> > { + public: + explicit StateIterator(const ExpandFst<A> &fst) + : CacheStateIterator< ExpandFst<A> >(fst, fst.GetImpl()) {} +}; + + +// Specialization for ExpandFst. +template <class A> +class ArcIterator< ExpandFst<A> > + : public CacheArcIterator< ExpandFst<A> > { + public: + typedef typename A::StateId StateId; + + ArcIterator(const ExpandFst<A> &fst, StateId s) + : CacheArcIterator< ExpandFst<A> >(fst.GetImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) + fst.GetImpl()->ExpandState(s); + } + + private: + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + + +template <class A> inline +void ExpandFst<A>::InitStateIterator(StateIteratorData<A> *data) const +{ + data->base = new StateIterator< ExpandFst<A> >(*this); +} + +// +// PrunedExpand Class +// + +// Prunes the delayed expansion of a pushdown transducer (PDT) encoded +// as an FST into an FST. In the PDT, some transitions are labeled +// with open or close parentheses. To be interpreted as a PDT, the +// parens must balance on a path. The open-close parenthesis label +// pairs are passed in 'parens'. The expansion enforces the +// parenthesis constraints. +// +// The algorithm works by visiting the delayed ExpandFst using a +// shortest-stack first queue discipline and relies on the +// shortest-distance information computed using a reverse +// shortest-path call to perform the pruning. +// +// The algorithm maintains the same state ordering between the ExpandFst +// being visited 'efst_' and the result of pruning written into the +// MutableFst 'ofst_' to improve readability of the code. +// +template <class A> +class PrunedExpand { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + typedef StateId StackId; + typedef PdtStack<StackId, Label> Stack; + typedef PdtStateTable<StateId, StackId> StateTable; + typedef typename PdtBalanceData<Arc>::SetIterator SetIterator; + + // Constructor taking as input a PDT specified by 'ifst' and 'parens'. + // 'keep_parentheses' specifies whether parentheses are replaced by + // epsilons or not during the expansion. 'opts' is the cache options + // used to instantiate the underlying ExpandFst. + PrunedExpand(const Fst<A> &ifst, + const vector<pair<Label, Label> > &parens, + bool keep_parentheses = false, + const CacheOptions &opts = CacheOptions()) + : ifst_(ifst.Copy()), + keep_parentheses_(keep_parentheses), + stack_(parens), + efst_(ifst, parens, + ExpandFstOptions<Arc>(opts, true, &stack_, &state_table_)), + queue_(state_table_, stack_, stack_length_, distance_, fdistance_) { + Reverse(*ifst_, parens, &rfst_); + VectorFst<Arc> path; + reverse_shortest_path_ = new SP( + rfst_, parens, + PdtShortestPathOptions<A, FifoQueue<StateId> >(true, false)); + reverse_shortest_path_->ShortestPath(&path); + balance_data_ = reverse_shortest_path_->GetBalanceData()->Reverse( + rfst_.NumStates(), 10, -1); + + InitCloseParenMultimap(parens); + } + + ~PrunedExpand() { + delete ifst_; + delete reverse_shortest_path_; + delete balance_data_; + } + + // Expands and prunes with weight threshold 'threshold' the input PDT. + // Writes the result in 'ofst'. + void Expand(MutableFst<A> *ofst, const Weight &threshold); + + private: + static const uint8 kEnqueued; + static const uint8 kExpanded; + static const uint8 kSourceState; + + // Comparison functor used by the queue: + // 1. states corresponding to shortest stack first, + // 2. among stacks of the same length, reverse lexicographic order is used, + // 3. among states with the same stack, shortest-first order is used. + class StackCompare { + public: + StackCompare(const StateTable &st, + const Stack &s, const vector<StackId> &sl, + const vector<Weight> &d, const vector<Weight> &fd) + : state_table_(st), stack_(s), stack_length_(sl), + distance_(d), fdistance_(fd) {} + + bool operator()(StateId s1, StateId s2) const { + StackId si1 = state_table_.Tuple(s1).stack_id; + StackId si2 = state_table_.Tuple(s2).stack_id; + if (stack_length_[si1] < stack_length_[si2]) + return true; + if (stack_length_[si1] > stack_length_[si2]) + return false; + // If stack id equal, use A* + if (si1 == si2) { + Weight w1 = (s1 < distance_.size()) && (s1 < fdistance_.size()) ? + Times(distance_[s1], fdistance_[s1]) : Weight::Zero(); + Weight w2 = (s2 < distance_.size()) && (s2 < fdistance_.size()) ? + Times(distance_[s2], fdistance_[s2]) : Weight::Zero(); + return less_(w1, w2); + } + // If lenghts are equal, use reverse lexico. + for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) { + if (stack_.Top(si1) < stack_.Top(si2)) return true; + if (stack_.Top(si1) > stack_.Top(si2)) return false; + } + return false; + } + + private: + const StateTable &state_table_; + const Stack &stack_; + const vector<StackId> &stack_length_; + const vector<Weight> &distance_; + const vector<Weight> &fdistance_; + NaturalLess<Weight> less_; + }; + + class ShortestStackFirstQueue + : public ShortestFirstQueue<StateId, StackCompare> { + public: + ShortestStackFirstQueue( + const PdtStateTable<StateId, StackId> &st, + const Stack &s, + const vector<StackId> &sl, + const vector<Weight> &d, const vector<Weight> &fd) + : ShortestFirstQueue<StateId, StackCompare>( + StackCompare(st, s, sl, d, fd)) {} + }; + + + void InitCloseParenMultimap(const vector<pair<Label, Label> > &parens); + Weight DistanceToDest(StateId state, StateId source) const; + uint8 Flags(StateId s) const; + void SetFlags(StateId s, uint8 flags, uint8 mask); + Weight Distance(StateId s) const; + void SetDistance(StateId s, Weight w); + Weight FinalDistance(StateId s) const; + void SetFinalDistance(StateId s, Weight w); + StateId SourceState(StateId s) const; + void SetSourceState(StateId s, StateId p); + void AddStateAndEnqueue(StateId s); + void Relax(StateId s, const A &arc, Weight w); + bool PruneArc(StateId s, const A &arc); + void ProcStart(); + void ProcFinal(StateId s); + bool ProcNonParen(StateId s, const A &arc, bool add_arc); + bool ProcOpenParen(StateId s, const A &arc, StackId si, StackId nsi); + bool ProcCloseParen(StateId s, const A &arc); + void ProcDestStates(StateId s, StackId si); + + Fst<A> *ifst_; // Input PDT + VectorFst<Arc> rfst_; // Reversed PDT + bool keep_parentheses_; // Keep parentheses in ofst? + StateTable state_table_; // State table for efst_ + Stack stack_; // Stack trie + ExpandFst<Arc> efst_; // Expanded PDT + vector<StackId> stack_length_; // Length of stack for given stack id + vector<Weight> distance_; // Distance from initial state in efst_/ofst + vector<Weight> fdistance_; // Distance to final states in efst_/ofst + ShortestStackFirstQueue queue_; // Queue used to visit efst_ + vector<uint8> flags_; // Status flags for states in efst_/ofst + vector<StateId> sources_; // PDT source state for each expanded state + + typedef PdtShortestPath<Arc, FifoQueue<StateId> > SP; + typedef typename SP::CloseParenMultimap ParenMultimap; + SP *reverse_shortest_path_; // Shortest path for rfst_ + PdtBalanceData<Arc> *balance_data_; // Not owned by shortest_path_ + ParenMultimap close_paren_multimap_; // Maps open paren arcs to + // balancing close paren arcs. + + MutableFst<Arc> *ofst_; // Output fst + Weight limit_; // Weight limit + + typedef unordered_map<StateId, Weight> DestMap; + DestMap dest_map_; + StackId current_stack_id_; + // 'current_stack_id_' is the stack id of the states currently at the top + // of queue, i.e., the states currently being popped and processed. + // 'dest_map_' maps a state 's' in 'ifst_' that is the source + // of a close parentheses matching the top of 'current_stack_id_; to + // the shortest-distance from '(s, current_stack_id_)' to the final + // states in 'efst_'. + ssize_t current_paren_id_; // Paren id at top of current stack + ssize_t cached_stack_id_; + StateId cached_source_; + slist<pair<StateId, Weight> > cached_dest_list_; + // 'cached_dest_list_' contains the set of pair of destination + // states and weight to final states for source state + // 'cached_source_' and paren id 'cached_paren_id': the set of + // source state of a close parenthesis with paren id + // 'cached_paren_id' balancing an incoming open parenthesis with + // paren id 'cached_paren_id' in state 'cached_source_'. + + NaturalLess<Weight> less_; +}; + +template <class A> const uint8 PrunedExpand<A>::kEnqueued = 0x01; +template <class A> const uint8 PrunedExpand<A>::kExpanded = 0x02; +template <class A> const uint8 PrunedExpand<A>::kSourceState = 0x04; + + +// Initializes close paren multimap, mapping pairs (s,paren_id) to +// all the arcs out of s labeled with close parenthese for paren_id. +template <class A> +void PrunedExpand<A>::InitCloseParenMultimap( + const vector<pair<Label, Label> > &parens) { + unordered_map<Label, Label> paren_id_map; + for (Label i = 0; i < parens.size(); ++i) { + const pair<Label, Label> &p = parens[i]; + paren_id_map[p.first] = i; + paren_id_map[p.second] = i; + } + + for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + for (ArcIterator<Fst<Arc> > aiter(*ifst_, s); + !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + typename unordered_map<Label, Label>::const_iterator pit + = paren_id_map.find(arc.ilabel); + if (pit == paren_id_map.end()) continue; + if (arc.ilabel == parens[pit->second].second) { // Close paren + ParenState<Arc> paren_state(pit->second, s); + close_paren_multimap_.insert(make_pair(paren_state, arc)); + } + } + } +} + + +// Returns the weight of the shortest balanced path from 'source' to 'dest' +// in 'ifst_', 'dest' must be the source state of a close paren arc. +template <class A> +typename A::Weight PrunedExpand<A>::DistanceToDest(StateId source, + StateId dest) const { + typename SP::SearchState s(source + 1, dest + 1); + VLOG(2) << "D(" << source << ", " << dest << ") =" + << reverse_shortest_path_->GetShortestPathData().Distance(s); + return reverse_shortest_path_->GetShortestPathData().Distance(s); +} + +// Returns the flags for state 's' in 'ofst_'. +template <class A> +uint8 PrunedExpand<A>::Flags(StateId s) const { + return s < flags_.size() ? flags_[s] : 0; +} + +// Modifies the flags for state 's' in 'ofst_'. +template <class A> +void PrunedExpand<A>::SetFlags(StateId s, uint8 flags, uint8 mask) { + while (flags_.size() <= s) flags_.push_back(0); + flags_[s] &= ~mask; + flags_[s] |= flags & mask; +} + + +// Returns the shortest distance from the initial state to 's' in 'ofst_'. +template <class A> +typename A::Weight PrunedExpand<A>::Distance(StateId s) const { + return s < distance_.size() ? distance_[s] : Weight::Zero(); +} + +// Sets the shortest distance from the initial state to 's' in 'ofst_' to 'w'. +template <class A> +void PrunedExpand<A>::SetDistance(StateId s, Weight w) { + while (distance_.size() <= s ) distance_.push_back(Weight::Zero()); + distance_[s] = w; +} + + +// Returns the shortest distance from 's' to the final states in 'ofst_'. +template <class A> +typename A::Weight PrunedExpand<A>::FinalDistance(StateId s) const { + return s < fdistance_.size() ? fdistance_[s] : Weight::Zero(); +} + +// Sets the shortest distance from 's' to the final states in 'ofst_' to 'w'. +template <class A> +void PrunedExpand<A>::SetFinalDistance(StateId s, Weight w) { + while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero()); + fdistance_[s] = w; +} + +// Returns the PDT "source" state of state 's' in 'ofst_'. +template <class A> +typename A::StateId PrunedExpand<A>::SourceState(StateId s) const { + return s < sources_.size() ? sources_[s] : kNoStateId; +} + +// Sets the PDT "source" state of state 's' in 'ofst_' to state 'p' in 'ifst_'. +template <class A> +void PrunedExpand<A>::SetSourceState(StateId s, StateId p) { + while (sources_.size() <= s) sources_.push_back(kNoStateId); + sources_[s] = p; +} + +// Adds state 's' of 'efst_' to 'ofst_' and inserts it in the queue, +// modifying the flags for 's' accordingly. +template <class A> +void PrunedExpand<A>::AddStateAndEnqueue(StateId s) { + if (!(Flags(s) & (kEnqueued | kExpanded))) { + while (ofst_->NumStates() <= s) ofst_->AddState(); + queue_.Enqueue(s); + SetFlags(s, kEnqueued, kEnqueued); + } else if (Flags(s) & kEnqueued) { + queue_.Update(s); + } + // TODO(allauzen): Check everything is fine when kExpanded? +} + +// Relaxes arc 'arc' out of state 's' in 'ofst_': +// * if the distance to 's' times the weight of 'arc' is smaller than +// the currently stored distance for 'arc.nextstate', +// updates 'Distance(arc.nextstate)' with new estimate; +// * if 'fd' is less than the currently stored distance from 'arc.nextstate' +// to the final state, updates with new estimate. +template <class A> +void PrunedExpand<A>::Relax(StateId s, const A &arc, Weight fd) { + Weight nd = Times(Distance(s), arc.weight); + if (less_(nd, Distance(arc.nextstate))) { + SetDistance(arc.nextstate, nd); + SetSourceState(arc.nextstate, SourceState(s)); + } + if (less_(fd, FinalDistance(arc.nextstate))) + SetFinalDistance(arc.nextstate, fd); + VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to " + << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate) + << ", nd = " << nd; +} + +// Returns 'true' if the arc 'arc' out of state 's' in 'efst_' needs to +// be pruned. +template <class A> +bool PrunedExpand<A>::PruneArc(StateId s, const A &arc) { + VLOG(2) << "Prune ?"; + Weight fd = Weight::Zero(); + + if ((cached_source_ != SourceState(s)) || + (cached_stack_id_ != current_stack_id_)) { + cached_source_ = SourceState(s); + cached_stack_id_ = current_stack_id_; + cached_dest_list_.clear(); + if (cached_source_ != ifst_->Start()) { + for (SetIterator set_iter = + balance_data_->Find(current_paren_id_, cached_source_); + !set_iter.Done(); set_iter.Next()) { + StateId dest = set_iter.Element(); + typename DestMap::const_iterator iter = dest_map_.find(dest); + cached_dest_list_.push_front(*iter); + } + } else { + // TODO(allauzen): queue discipline should prevent this never + // from happening; replace by a check. + cached_dest_list_.push_front( + make_pair(rfst_.Start() -1, Weight::One())); + } + } + + for (typename slist<pair<StateId, Weight> >::const_iterator iter = + cached_dest_list_.begin(); + iter != cached_dest_list_.end(); + ++iter) { + fd = Plus(fd, + Times(DistanceToDest(state_table_.Tuple(arc.nextstate).state_id, + iter->first), + iter->second)); + } + Relax(s, arc, fd); + Weight w = Times(Distance(s), Times(arc.weight, fd)); + return less_(limit_, w); +} + +// Adds start state of 'efst_' to 'ofst_', enqueues it and initializes +// the distance data structures. +template <class A> +void PrunedExpand<A>::ProcStart() { + StateId s = efst_.Start(); + AddStateAndEnqueue(s); + ofst_->SetStart(s); + SetSourceState(s, ifst_->Start()); + + current_stack_id_ = 0; + current_paren_id_ = -1; + stack_length_.push_back(0); + dest_map_[rfst_.Start() - 1] = Weight::One(); // not needed + + cached_source_ = ifst_->Start(); + cached_stack_id_ = 0; + cached_dest_list_.push_front( + make_pair(rfst_.Start() -1, Weight::One())); + + PdtStateTuple<StateId, StackId> tuple(rfst_.Start() - 1, 0); + SetFinalDistance(state_table_.FindState(tuple), Weight::One()); + SetDistance(s, Weight::One()); + SetFinalDistance(s, DistanceToDest(ifst_->Start(), rfst_.Start() - 1)); + VLOG(2) << DistanceToDest(ifst_->Start(), rfst_.Start() - 1); +} + +// Makes 's' final in 'ofst_' if shortest accepting path ending in 's' +// is below threshold. +template <class A> +void PrunedExpand<A>::ProcFinal(StateId s) { + Weight final = efst_.Final(s); + if ((final == Weight::Zero()) || less_(limit_, Times(Distance(s), final))) + return; + ofst_->SetFinal(s, final); +} + +// Returns true when arc (or meta-arc) 'arc' out of 's' in 'efst_' is +// below the threshold. When 'add_arc' is true, 'arc' is added to 'ofst_'. +template <class A> +bool PrunedExpand<A>::ProcNonParen(StateId s, const A &arc, bool add_arc) { + VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate + << ", " << arc.ilabel << ":" << arc.olabel << " / " << arc.weight + << ", add_arc = " << (add_arc ? "true" : "false"); + if (PruneArc(s, arc)) return false; + if(add_arc) ofst_->AddArc(s, arc); + AddStateAndEnqueue(arc.nextstate); + return true; +} + +// Processes an open paren arc 'arc' out of state 's' in 'ofst_'. +// When 'arc' is labeled with an open paren, +// 1. considers each (shortest) balanced path starting in 's' by +// taking 'arc' and ending by a close paren balancing the open +// paren of 'arc' as a meta-arc, processes and prunes each meta-arc +// as a non-paren arc, inserting its destination to the queue; +// 2. if at least one of these meta-arcs has not been pruned, +// adds the destination of 'arc' to 'ofst_' as a new source state +// for the stack id 'nsi' and inserts it in the queue. +template <class A> +bool PrunedExpand<A>::ProcOpenParen(StateId s, const A &arc, StackId si, + StackId nsi) { + // Update the stack lenght when needed: |nsi| = |si| + 1. + while (stack_length_.size() <= nsi) stack_length_.push_back(-1); + if (stack_length_[nsi] == -1) + stack_length_[nsi] = stack_length_[si] + 1; + + StateId ns = arc.nextstate; + VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id + << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")"; + bool proc_arc = false; + Weight fd = Weight::Zero(); + ssize_t paren_id = stack_.ParenId(arc.ilabel); + slist<StateId> sources; + for (SetIterator set_iter = + balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id); + !set_iter.Done(); set_iter.Next()) { + sources.push_front(set_iter.Element()); + } + for (typename slist<StateId>::const_iterator sources_iter = sources.begin(); + sources_iter != sources.end(); + ++ sources_iter) { + StateId source = *sources_iter; + VLOG(2) << "Close paren source: " << source; + ParenState<Arc> paren_state(paren_id, source); + for (typename ParenMultimap::const_iterator iter = + close_paren_multimap_.find(paren_state); + iter != close_paren_multimap_.end() && paren_state == iter->first; + ++iter) { + Arc meta_arc = iter->second; + PdtStateTuple<StateId, StackId> tuple(meta_arc.nextstate, si); + meta_arc.nextstate = state_table_.FindState(tuple); + VLOG(2) << state_table_.Tuple(ns).state_id << ", " << source; + VLOG(2) << "Meta arc weight = " << arc.weight << " Times " + << DistanceToDest(state_table_.Tuple(ns).state_id, source) + << " Times " << meta_arc.weight; + meta_arc.weight = Times( + arc.weight, + Times(DistanceToDest(state_table_.Tuple(ns).state_id, source), + meta_arc.weight)); + proc_arc |= ProcNonParen(s, meta_arc, false); + fd = Plus(fd, Times( + Times( + DistanceToDest(state_table_.Tuple(ns).state_id, source), + iter->second.weight), + FinalDistance(meta_arc.nextstate))); + } + } + if (proc_arc) { + VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate; + ofst_->AddArc( + s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate)); + AddStateAndEnqueue(arc.nextstate); + Weight nd = Times(Distance(s), arc.weight); + if(less_(nd, Distance(arc.nextstate))) + SetDistance(arc.nextstate, nd); + // FinalDistance not necessary for source state since pruning + // decided using the meta-arcs above. But this is a problem with + // A*, hence: + if (less_(fd, FinalDistance(arc.nextstate))) + SetFinalDistance(arc.nextstate, fd); + SetFlags(arc.nextstate, kSourceState, kSourceState); + } + return proc_arc; +} + +// Checks that shortest path through close paren arc in 'efst_' is +// below threshold, if so adds it to 'ofst_'. +template <class A> +bool PrunedExpand<A>::ProcCloseParen(StateId s, const A &arc) { + Weight w = Times(Distance(s), + Times(arc.weight, FinalDistance(arc.nextstate))); + if (less_(limit_, w)) + return false; + ofst_->AddArc( + s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate)); + return true; +} + +// When 's' in 'ofst_' is a source state for stack id 'si', identifies +// all the corresponding possible destination states, that is, all the +// states in 'ifst_' that have an outgoing close paren arc balancing +// the incoming open paren taken to get to 's', and for each such +// state 't', computes the shortest distance from (t, si) to the final +// states in 'ofst_'. Stores this information in 'dest_map_'. +template <class A> +void PrunedExpand<A>::ProcDestStates(StateId s, StackId si) { + if (!(Flags(s) & kSourceState)) return; + if (si != current_stack_id_) { + dest_map_.clear(); + current_stack_id_ = si; + current_paren_id_ = stack_.Top(current_stack_id_); + VLOG(2) << "StackID " << si << " dequeued for first time"; + } + // TODO(allauzen): clean up source state business; rename current function to + // ProcSourceState. + SetSourceState(s, state_table_.Tuple(s).state_id); + + ssize_t paren_id = stack_.Top(si); + for (SetIterator set_iter = + balance_data_->Find(paren_id, state_table_.Tuple(s).state_id); + !set_iter.Done(); set_iter.Next()) { + StateId dest_state = set_iter.Element(); + if (dest_map_.find(dest_state) != dest_map_.end()) + continue; + Weight dest_weight = Weight::Zero(); + ParenState<Arc> paren_state(paren_id, dest_state); + for (typename ParenMultimap::const_iterator iter = + close_paren_multimap_.find(paren_state); + iter != close_paren_multimap_.end() && paren_state == iter->first; + ++iter) { + const Arc &arc = iter->second; + PdtStateTuple<StateId, StackId> tuple(arc.nextstate, stack_.Pop(si)); + dest_weight = Plus(dest_weight, + Times(arc.weight, + FinalDistance(state_table_.FindState(tuple)))); + } + dest_map_[dest_state] = dest_weight; + VLOG(2) << "State " << dest_state << " is a dest state for stack id " + << si << " with weight " << dest_weight; + } +} + +// Expands and prunes with weight threshold 'threshold' the input PDT. +// Writes the result in 'ofst'. +template <class A> +void PrunedExpand<A>::Expand( + MutableFst<A> *ofst, const typename A::Weight &threshold) { + ofst_ = ofst; + ofst_->DeleteStates(); + ofst_->SetInputSymbols(ifst_->InputSymbols()); + ofst_->SetOutputSymbols(ifst_->OutputSymbols()); + + limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold); + flags_.clear(); + + ProcStart(); + + while (!queue_.Empty()) { + StateId s = queue_.Head(); + queue_.Dequeue(); + SetFlags(s, kExpanded, kExpanded | kEnqueued); + VLOG(2) << s << " dequeued!"; + + ProcFinal(s); + StackId stack_id = state_table_.Tuple(s).stack_id; + ProcDestStates(s, stack_id); + + for (ArcIterator<ExpandFst<Arc> > aiter(efst_, s); + !aiter.Done(); + aiter.Next()) { + Arc arc = aiter.Value(); + StackId nextstack_id = state_table_.Tuple(arc.nextstate).stack_id; + if (stack_id == nextstack_id) + ProcNonParen(s, arc, true); + else if (stack_id == stack_.Pop(nextstack_id)) + ProcOpenParen(s, arc, stack_id, nextstack_id); + else + ProcCloseParen(s, arc); + } + VLOG(2) << "d[" << s << "] = " << Distance(s) + << ", fd[" << s << "] = " << FinalDistance(s); + } +} + +// +// Expand() Functions +// + +template <class Arc> +struct ExpandOptions { + bool connect; + bool keep_parentheses; + typename Arc::Weight weight_threshold; + + ExpandOptions(bool c = true, bool k = false, + typename Arc::Weight w = Arc::Weight::Zero()) + : connect(c), keep_parentheses(k), weight_threshold(w) {} +}; + +// Expands a pushdown transducer (PDT) encoded as an FST into an FST. +// This version writes the expanded PDT result to a MutableFst. +// In the PDT, some transitions are labeled with open or close +// parentheses. To be interpreted as a PDT, the parens must balance on +// a path. The open-close parenthesis label pairs are passed in +// 'parens'. The expansion enforces the parenthesis constraints. The +// PDT must be expandable as an FST. +template <class Arc> +void Expand( + const Fst<Arc> &ifst, + const vector<pair<typename Arc::Label, typename Arc::Label> > &parens, + MutableFst<Arc> *ofst, + const ExpandOptions<Arc> &opts) { + typedef typename Arc::Label Label; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef typename ExpandFst<Arc>::StackId StackId; + + ExpandFstOptions<Arc> eopts; + eopts.gc_limit = 0; + if (opts.weight_threshold == Weight::Zero()) { + eopts.keep_parentheses = opts.keep_parentheses; + *ofst = ExpandFst<Arc>(ifst, parens, eopts); + } else { + PrunedExpand<Arc> pruned_expand(ifst, parens, opts.keep_parentheses); + pruned_expand.Expand(ofst, opts.weight_threshold); + } + + if (opts.connect) + Connect(ofst); +} + +// Expands a pushdown transducer (PDT) encoded as an FST into an FST. +// This version writes the expanded PDT result to a MutableFst. +// In the PDT, some transitions are labeled with open or close +// parentheses. To be interpreted as a PDT, the parens must balance on +// a path. The open-close parenthesis label pairs are passed in +// 'parens'. The expansion enforces the parenthesis constraints. The +// PDT must be expandable as an FST. +template<class Arc> +void Expand( + const Fst<Arc> &ifst, + const vector<pair<typename Arc::Label, typename Arc::Label> > &parens, + MutableFst<Arc> *ofst, + bool connect = true, bool keep_parentheses = false) { + Expand(ifst, parens, ofst, ExpandOptions<Arc>(connect, keep_parentheses)); +} + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_EXPAND_H__ diff --git a/src/include/fst/extensions/pdt/info.h b/src/include/fst/extensions/pdt/info.h new file mode 100644 index 0000000..ef9a860 --- /dev/null +++ b/src/include/fst/extensions/pdt/info.h @@ -0,0 +1,175 @@ +// info.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Prints information about a PDT. + +#ifndef FST_EXTENSIONS_PDT_INFO_H__ +#define FST_EXTENSIONS_PDT_INFO_H__ + +#include <unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <tr1/unordered_set> +using std::tr1::unordered_set; +using std::tr1::unordered_multiset; +#include <vector> +using std::vector; + +#include <fst/fst.h> +#include <fst/extensions/pdt/pdt.h> + +namespace fst { + +// Compute various information about PDTs, helper class for pdtinfo.cc. +template <class A> class PdtInfo { +public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + PdtInfo(const Fst<A> &fst, + const vector<pair<typename A::Label, + typename A::Label> > &parens); + + const string& FstType() const { return fst_type_; } + const string& ArcType() const { return A::Type(); } + + int64 NumStates() const { return nstates_; } + int64 NumArcs() const { return narcs_; } + int64 NumOpenParens() const { return nopen_parens_; } + int64 NumCloseParens() const { return nclose_parens_; } + int64 NumUniqueOpenParens() const { return nuniq_open_parens_; } + int64 NumUniqueCloseParens() const { return nuniq_close_parens_; } + int64 NumOpenParenStates() const { return nopen_paren_states_; } + int64 NumCloseParenStates() const { return nclose_paren_states_; } + + private: + string fst_type_; + int64 nstates_; + int64 narcs_; + int64 nopen_parens_; + int64 nclose_parens_; + int64 nuniq_open_parens_; + int64 nuniq_close_parens_; + int64 nopen_paren_states_; + int64 nclose_paren_states_; + + DISALLOW_COPY_AND_ASSIGN(PdtInfo); +}; + +template <class A> +PdtInfo<A>::PdtInfo(const Fst<A> &fst, + const vector<pair<typename A::Label, + typename A::Label> > &parens) + : fst_type_(fst.Type()), + nstates_(0), + narcs_(0), + nopen_parens_(0), + nclose_parens_(0), + nuniq_open_parens_(0), + nuniq_close_parens_(0), + nopen_paren_states_(0), + nclose_paren_states_(0) { + unordered_map<Label, size_t> paren_map; + unordered_set<Label> paren_set; + unordered_set<StateId> open_paren_state_set; + unordered_set<StateId> close_paren_state_set; + + for (size_t i = 0; i < parens.size(); ++i) { + const pair<Label, Label> &p = parens[i]; + paren_map[p.first] = i; + paren_map[p.second] = i; + } + + for (StateIterator< Fst<A> > siter(fst); + !siter.Done(); + siter.Next()) { + ++nstates_; + StateId s = siter.Value(); + for (ArcIterator< Fst<A> > aiter(fst, s); + !aiter.Done(); + aiter.Next()) { + const A &arc = aiter.Value(); + ++narcs_; + typename unordered_map<Label, size_t>::const_iterator pit + = paren_map.find(arc.ilabel); + if (pit != paren_map.end()) { + Label open_paren = parens[pit->second].first; + Label close_paren = parens[pit->second].second; + if (arc.ilabel == open_paren) { + ++nopen_parens_; + if (!paren_set.count(open_paren)) { + ++nuniq_open_parens_; + paren_set.insert(open_paren); + } + if (!open_paren_state_set.count(arc.nextstate)) { + ++nopen_paren_states_; + open_paren_state_set.insert(arc.nextstate); + } + } else { + ++nclose_parens_; + if (!paren_set.count(close_paren)) { + ++nuniq_close_parens_; + paren_set.insert(close_paren); + } + if (!close_paren_state_set.count(s)) { + ++nclose_paren_states_; + close_paren_state_set.insert(s); + } + + } + } + } + } +} + + +template <class A> +void PrintPdtInfo(const PdtInfo<A> &pdtinfo) { + ios_base::fmtflags old = cout.setf(ios::left); + cout.width(50); + cout << "fst type" << pdtinfo.FstType().c_str() << endl; + cout.width(50); + cout << "arc type" << pdtinfo.ArcType().c_str() << endl; + cout.width(50); + cout << "# of states" << pdtinfo.NumStates() << endl; + cout.width(50); + cout << "# of arcs" << pdtinfo.NumArcs() << endl; + cout.width(50); + cout << "# of open parentheses" << pdtinfo.NumOpenParens() << endl; + cout.width(50); + cout << "# of close parentheses" << pdtinfo.NumCloseParens() << endl; + cout.width(50); + cout << "# of unique open parentheses" + << pdtinfo.NumUniqueOpenParens() << endl; + cout.width(50); + cout << "# of unique close parentheses" + << pdtinfo.NumUniqueCloseParens() << endl; + cout.width(50); + cout << "# of open parenthesis dest. states" + << pdtinfo.NumOpenParenStates() << endl; + cout.width(50); + cout << "# of close parenthesis source states" + << pdtinfo.NumCloseParenStates() << endl; + cout.setf(old); +} + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_INFO_H__ diff --git a/src/include/fst/extensions/pdt/paren.h b/src/include/fst/extensions/pdt/paren.h new file mode 100644 index 0000000..7b9887f --- /dev/null +++ b/src/include/fst/extensions/pdt/paren.h @@ -0,0 +1,496 @@ +// paren.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// Common classes for PDT parentheses + +// \file + +#ifndef FST_EXTENSIONS_PDT_PAREN_H_ +#define FST_EXTENSIONS_PDT_PAREN_H_ + +#include <algorithm> +#include <unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <tr1/unordered_set> +using std::tr1::unordered_set; +using std::tr1::unordered_multiset; +#include <set> + +#include <fst/extensions/pdt/pdt.h> +#include <fst/extensions/pdt/collection.h> +#include <fst/fst.h> +#include <fst/dfs-visit.h> + + +namespace fst { + +// +// ParenState: Pair of an open (close) parenthesis and +// its destination (source) state. +// + +template <class A> +class ParenState { + public: + typedef typename A::Label Label; + typedef typename A::StateId StateId; + + struct Hash { + size_t operator()(const ParenState<A> &p) const { + return p.paren_id + p.state_id * kPrime; + } + }; + + Label paren_id; // ID of open (close) paren + StateId state_id; // destination (source) state of open (close) paren + + ParenState() : paren_id(kNoLabel), state_id(kNoStateId) {} + + ParenState(Label p, StateId s) : paren_id(p), state_id(s) {} + + bool operator==(const ParenState<A> &p) const { + if (&p == this) + return true; + return p.paren_id == this->paren_id && p.state_id == this->state_id; + } + + bool operator!=(const ParenState<A> &p) const { return !(p == *this); } + + bool operator<(const ParenState<A> &p) const { + return paren_id < this->paren.id || + (p.paren_id == this->paren.id && p.state_id < this->state_id); + } + + private: + static const size_t kPrime; +}; + +template <class A> +const size_t ParenState<A>::kPrime = 7853; + + +// Creates an FST-style iterator from STL map and iterator. +template <class M> +class MapIterator { + public: + typedef typename M::const_iterator StlIterator; + typedef typename M::value_type PairType; + typedef typename PairType::second_type ValueType; + + MapIterator(const M &m, StlIterator iter) + : map_(m), begin_(iter), iter_(iter) {} + + bool Done() const { + return iter_ == map_.end() || iter_->first != begin_->first; + } + + ValueType Value() const { return iter_->second; } + void Next() { ++iter_; } + void Reset() { iter_ = begin_; } + + private: + const M &map_; + StlIterator begin_; + StlIterator iter_; +}; + +// +// PdtParenReachable: Provides various parenthesis reachability information +// on a PDT. +// + +template <class A> +class PdtParenReachable { + public: + typedef typename A::StateId StateId; + typedef typename A::Label Label; + public: + // Maps from state ID to reachable paren IDs from (to) that state. + typedef unordered_multimap<StateId, Label> ParenMultiMap; + + // Maps from paren ID and state ID to reachable state set ID + typedef unordered_map<ParenState<A>, ssize_t, + typename ParenState<A>::Hash> StateSetMap; + + // Maps from paren ID and state ID to arcs exiting that state with that + // Label. + typedef unordered_multimap<ParenState<A>, A, + typename ParenState<A>::Hash> ParenArcMultiMap; + + typedef MapIterator<ParenMultiMap> ParenIterator; + + typedef MapIterator<ParenArcMultiMap> ParenArcIterator; + + typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator; + + // Computes close (open) parenthesis reachabilty information for + // a PDT with bounded stack. + PdtParenReachable(const Fst<A> &fst, + const vector<pair<Label, Label> > &parens, bool close) + : fst_(fst), + parens_(parens), + close_(close) { + for (Label i = 0; i < parens.size(); ++i) { + const pair<Label, Label> &p = parens[i]; + paren_id_map_[p.first] = i; + paren_id_map_[p.second] = i; + } + + if (close_) { + StateId start = fst.Start(); + if (start == kNoStateId) + return; + DFSearch(start, start); + } else { + FSTERROR() << "PdtParenReachable: open paren info not implemented"; + } + } + + // Given a state ID, returns an iterator over paren IDs + // for close (open) parens reachable from that state along balanced + // paths. + ParenIterator FindParens(StateId s) const { + return ParenIterator(paren_multimap_, paren_multimap_.find(s)); + } + + // Given a paren ID and a state ID s, returns an iterator over + // states that can be reached along balanced paths from (to) s that + // have have close (open) parentheses matching the paren ID exiting + // (entering) those states. + SetIterator FindStates(Label paren_id, StateId s) const { + ParenState<A> paren_state(paren_id, s); + typename StateSetMap::const_iterator id_it = set_map_.find(paren_state); + if (id_it == set_map_.end()) { + return state_sets_.FindSet(-1); + } else { + return state_sets_.FindSet(id_it->second); + } + } + + // Given a paren Id and a state ID s, return an iterator over + // arcs that exit (enter) s and are labeled with a close (open) + // parenthesis matching the paren ID. + ParenArcIterator FindParenArcs(Label paren_id, StateId s) const { + ParenState<A> paren_state(paren_id, s); + return ParenArcIterator(paren_arc_multimap_, + paren_arc_multimap_.find(paren_state)); + } + + private: + // DFS that gathers paren and state set information. + // Bool returns false when cycle detected. + bool DFSearch(StateId s, StateId start); + + // Unions state sets together gathered by the DFS. + void ComputeStateSet(StateId s); + + // Gather state set(s) from state 'nexts'. + void UpdateStateSet(StateId nexts, set<Label> *paren_set, + vector< set<StateId> > *state_sets) const; + + const Fst<A> &fst_; + const vector<pair<Label, Label> > &parens_; // Paren ID -> Labels + bool close_; // Close/open paren info? + unordered_map<Label, Label> paren_id_map_; // Paren labels -> ID + ParenMultiMap paren_multimap_; // Paren reachability + ParenArcMultiMap paren_arc_multimap_; // Paren Arcs + vector<char> state_color_; // DFS state + mutable Collection<ssize_t, StateId> state_sets_; // Reachable states -> ID + StateSetMap set_map_; // ID -> Reachable states + DISALLOW_COPY_AND_ASSIGN(PdtParenReachable); +}; + +// DFS that gathers paren and state set information. +template <class A> +bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) { + if (s >= state_color_.size()) + state_color_.resize(s + 1, kDfsWhite); + + if (state_color_[s] == kDfsBlack) + return true; + + if (state_color_[s] == kDfsGrey) + return false; + + state_color_[s] = kDfsGrey; + + for (ArcIterator<Fst<A> > aiter(fst_, s); + !aiter.Done(); + aiter.Next()) { + const A &arc = aiter.Value(); + + typename unordered_map<Label, Label>::const_iterator pit + = paren_id_map_.find(arc.ilabel); + if (pit != paren_id_map_.end()) { // paren? + Label paren_id = pit->second; + if (arc.ilabel == parens_[paren_id].first) { // open paren + DFSearch(arc.nextstate, arc.nextstate); + for (SetIterator set_iter = FindStates(paren_id, arc.nextstate); + !set_iter.Done(); set_iter.Next()) { + for (ParenArcIterator paren_arc_iter = + FindParenArcs(paren_id, set_iter.Element()); + !paren_arc_iter.Done(); + paren_arc_iter.Next()) { + const A &cparc = paren_arc_iter.Value(); + DFSearch(cparc.nextstate, start); + } + } + } + } else { // non-paren + if(!DFSearch(arc.nextstate, start)) { + FSTERROR() << "PdtReachable: Underlying cyclicity not supported"; + return true; + } + } + } + ComputeStateSet(s); + state_color_[s] = kDfsBlack; + return true; +} + +// Unions state sets together gathered by the DFS. +template <class A> +void PdtParenReachable<A>::ComputeStateSet(StateId s) { + set<Label> paren_set; + vector< set<StateId> > state_sets(parens_.size()); + for (ArcIterator< Fst<A> > aiter(fst_, s); + !aiter.Done(); + aiter.Next()) { + const A &arc = aiter.Value(); + + typename unordered_map<Label, Label>::const_iterator pit + = paren_id_map_.find(arc.ilabel); + if (pit != paren_id_map_.end()) { // paren? + Label paren_id = pit->second; + if (arc.ilabel == parens_[paren_id].first) { // open paren + for (SetIterator set_iter = + FindStates(paren_id, arc.nextstate); + !set_iter.Done(); set_iter.Next()) { + for (ParenArcIterator paren_arc_iter = + FindParenArcs(paren_id, set_iter.Element()); + !paren_arc_iter.Done(); + paren_arc_iter.Next()) { + const A &cparc = paren_arc_iter.Value(); + UpdateStateSet(cparc.nextstate, &paren_set, &state_sets); + } + } + } else { // close paren + paren_set.insert(paren_id); + state_sets[paren_id].insert(s); + ParenState<A> paren_state(paren_id, s); + paren_arc_multimap_.insert(make_pair(paren_state, arc)); + } + } else { // non-paren + UpdateStateSet(arc.nextstate, &paren_set, &state_sets); + } + } + + vector<StateId> state_set; + for (typename set<Label>::iterator paren_iter = paren_set.begin(); + paren_iter != paren_set.end(); ++paren_iter) { + state_set.clear(); + Label paren_id = *paren_iter; + paren_multimap_.insert(make_pair(s, paren_id)); + for (typename set<StateId>::iterator state_iter + = state_sets[paren_id].begin(); + state_iter != state_sets[paren_id].end(); + ++state_iter) { + state_set.push_back(*state_iter); + } + ParenState<A> paren_state(paren_id, s); + set_map_[paren_state] = state_sets_.FindId(state_set); + } +} + +// Gather state set(s) from state 'nexts'. +template <class A> +void PdtParenReachable<A>::UpdateStateSet( + StateId nexts, set<Label> *paren_set, + vector< set<StateId> > *state_sets) const { + for(ParenIterator paren_iter = FindParens(nexts); + !paren_iter.Done(); paren_iter.Next()) { + Label paren_id = paren_iter.Value(); + paren_set->insert(paren_id); + for (SetIterator set_iter = FindStates(paren_id, nexts); + !set_iter.Done(); set_iter.Next()) { + (*state_sets)[paren_id].insert(set_iter.Element()); + } + } +} + + +// Store balancing parenthesis data for a PDT. Allows on-the-fly +// construction (e.g. in PdtShortestPath) unlike PdtParenReachable above. +template <class A> +class PdtBalanceData { + public: + typedef typename A::StateId StateId; + typedef typename A::Label Label; + + // Hash set for open parens + typedef unordered_set<ParenState<A>, typename ParenState<A>::Hash> OpenParenSet; + + // Maps from open paren destination state to parenthesis ID. + typedef unordered_multimap<StateId, Label> OpenParenMap; + + // Maps from open paren state to source states of matching close parens + typedef unordered_multimap<ParenState<A>, StateId, + typename ParenState<A>::Hash> CloseParenMap; + + // Maps from open paren state to close source set ID + typedef unordered_map<ParenState<A>, ssize_t, + typename ParenState<A>::Hash> CloseSourceMap; + + typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator; + + PdtBalanceData() {} + + void Clear() { + open_paren_map_.clear(); + close_paren_map_.clear(); + } + + // Adds an open parenthesis with destination state 'open_dest'. + void OpenInsert(Label paren_id, StateId open_dest) { + ParenState<A> key(paren_id, open_dest); + if (!open_paren_set_.count(key)) { + open_paren_set_.insert(key); + open_paren_map_.insert(make_pair(open_dest, paren_id)); + } + } + + // Adds a matching closing parenthesis with source state + // 'close_source' that balances an open_parenthesis with destination + // state 'open_dest' if OpenInsert() previously called + // (o.w. CloseInsert() does nothing). + void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) { + ParenState<A> key(paren_id, open_dest); + if (open_paren_set_.count(key)) + close_paren_map_.insert(make_pair(key, close_source)); + } + + // Find close paren source states matching an open parenthesis. + // Methods that follow, iterate through those matching states. + // Should be called only after FinishInsert(open_dest). + SetIterator Find(Label paren_id, StateId open_dest) { + ParenState<A> close_key(paren_id, open_dest); + typename CloseSourceMap::const_iterator id_it = + close_source_map_.find(close_key); + if (id_it == close_source_map_.end()) { + return close_source_sets_.FindSet(-1); + } else { + return close_source_sets_.FindSet(id_it->second); + } + } + + // Call when all open and close parenthesis insertions wrt open + // parentheses entering 'open_dest' are finished. Must be called + // before Find(open_dest). Stores close paren source state sets + // efficiently. + void FinishInsert(StateId open_dest) { + vector<StateId> close_sources; + for (typename OpenParenMap::iterator oit = open_paren_map_.find(open_dest); + oit != open_paren_map_.end() && oit->first == open_dest;) { + Label paren_id = oit->second; + close_sources.clear(); + ParenState<A> okey(paren_id, open_dest); + open_paren_set_.erase(open_paren_set_.find(okey)); + for (typename CloseParenMap::iterator cit = close_paren_map_.find(okey); + cit != close_paren_map_.end() && cit->first == okey;) { + close_sources.push_back(cit->second); + close_paren_map_.erase(cit++); + } + sort(close_sources.begin(), close_sources.end()); + typename vector<StateId>::iterator unique_end = + unique(close_sources.begin(), close_sources.end()); + close_sources.resize(unique_end - close_sources.begin()); + + if (!close_sources.empty()) + close_source_map_[okey] = close_source_sets_.FindId(close_sources); + open_paren_map_.erase(oit++); + } + } + + // Return a new balance data object representing the reversed balance + // information. + PdtBalanceData<A> *Reverse(StateId num_states, + StateId num_split, + StateId state_id_shift) const; + + private: + OpenParenSet open_paren_set_; // open par. at dest? + + OpenParenMap open_paren_map_; // open parens per state + ParenState<A> open_dest_; // cur open dest. state + typename OpenParenMap::const_iterator open_iter_; // cur open parens/state + + CloseParenMap close_paren_map_; // close states/open + // paren and state + + CloseSourceMap close_source_map_; // paren, state to set ID + mutable Collection<ssize_t, StateId> close_source_sets_; +}; + +// Return a new balance data object representing the reversed balance +// information. +template <class A> +PdtBalanceData<A> *PdtBalanceData<A>::Reverse( + StateId num_states, + StateId num_split, + StateId state_id_shift) const { + PdtBalanceData<A> *bd = new PdtBalanceData<A>; + unordered_set<StateId> close_sources; + StateId split_size = num_states / num_split; + + for (StateId i = 0; i < num_states; i+= split_size) { + close_sources.clear(); + + for (typename CloseSourceMap::const_iterator + sit = close_source_map_.begin(); + sit != close_source_map_.end(); + ++sit) { + ParenState<A> okey = sit->first; + StateId open_dest = okey.state_id; + Label paren_id = okey.paren_id; + for (SetIterator set_iter = close_source_sets_.FindSet(sit->second); + !set_iter.Done(); set_iter.Next()) { + StateId close_source = set_iter.Element(); + if ((close_source < i) || (close_source >= i + split_size)) + continue; + close_sources.insert(close_source + state_id_shift); + bd->OpenInsert(paren_id, close_source + state_id_shift); + bd->CloseInsert(paren_id, close_source + state_id_shift, + open_dest + state_id_shift); + } + } + + for (typename unordered_set<StateId>::const_iterator it + = close_sources.begin(); + it != close_sources.end(); + ++it) { + bd->FinishInsert(*it); + } + + } + return bd; +} + + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_PAREN_H_ diff --git a/src/include/fst/extensions/pdt/pdt.h b/src/include/fst/extensions/pdt/pdt.h new file mode 100644 index 0000000..171541f --- /dev/null +++ b/src/include/fst/extensions/pdt/pdt.h @@ -0,0 +1,212 @@ +// pdt.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Common classes for PDT expansion/traversal. + +#ifndef FST_EXTENSIONS_PDT_PDT_H__ +#define FST_EXTENSIONS_PDT_PDT_H__ + +#include <unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <map> +#include <set> + +#include <fst/state-table.h> +#include <fst/fst.h> + +namespace fst { + +// Provides bijection between parenthesis stacks and signed integral +// stack IDs. Each stack ID is unique to each distinct stack. The +// open-close parenthesis label pairs are passed in 'parens'. +template <typename K, typename L> +class PdtStack { + public: + typedef K StackId; + typedef L Label; + + // The stacks are stored in a tree. The nodes are stored in vector + // 'nodes_'. Each node represents the top of some stack and is + // ID'ed by its position in the vector. Its parent node represents + // the stack with the top 'popped' and its children are stored in + // 'child_map_' accessed by stack_id and label. The paren_id is + // the position in 'parens' of the parenthesis for that node. + struct StackNode { + StackId parent_id; + size_t paren_id; + + StackNode(StackId p, size_t i) : parent_id(p), paren_id(i) {} + }; + + PdtStack(const vector<pair<Label, Label> > &parens) + : parens_(parens), min_paren_(kNoLabel), max_paren_(kNoLabel) { + for (size_t i = 0; i < parens.size(); ++i) { + const pair<Label, Label> &p = parens[i]; + paren_map_[p.first] = i; + paren_map_[p.second] = i; + + if (min_paren_ == kNoLabel || p.first < min_paren_) + min_paren_ = p.first; + if (p.second < min_paren_) + min_paren_ = p.second; + + if (max_paren_ == kNoLabel || p.first > max_paren_) + max_paren_ = p.first; + if (p.second > max_paren_) + max_paren_ = p.second; + } + nodes_.push_back(StackNode(-1, -1)); // Tree root. + } + + // Returns stack ID given the current stack ID (0 if empty) and + // label read. 'Pushes' onto a stack if the label is an open + // parenthesis, returning the new stack ID. 'Pops' the stack if the + // label is a close parenthesis that matches the top of the stack, + // returning the parent stack ID. Returns -1 if label is an + // unmatched close parenthesis. Otherwise, returns the current stack + // ID. + StackId Find(StackId stack_id, Label label) { + if (min_paren_ == kNoLabel || label < min_paren_ || label > max_paren_) + return stack_id; // Non-paren. + + typename unordered_map<Label, size_t>::const_iterator pit + = paren_map_.find(label); + if (pit == paren_map_.end()) // Non-paren. + return stack_id; + ssize_t paren_id = pit->second; + + if (label == parens_[paren_id].first) { // Open paren. + StackId &child_id = child_map_[make_pair(stack_id, label)]; + if (child_id == 0) { // Child not found, push label. + child_id = nodes_.size(); + nodes_.push_back(StackNode(stack_id, paren_id)); + } + return child_id; + } + + const StackNode &node = nodes_[stack_id]; + if (paren_id == node.paren_id) // Matching close paren. + return node.parent_id; + + return -1; // Non-matching close paren. + } + + // Returns the stack ID obtained by "popping" the label at the top + // of the current stack ID. + StackId Pop(StackId stack_id) const { + return nodes_[stack_id].parent_id; + } + + // Returns the paren ID at the top of the stack for 'stack_id' + ssize_t Top(StackId stack_id) const { + return nodes_[stack_id].paren_id; + } + + ssize_t ParenId(Label label) const { + typename unordered_map<Label, size_t>::const_iterator pit + = paren_map_.find(label); + if (pit == paren_map_.end()) // Non-paren. + return -1; + return pit->second; + } + + private: + struct ChildHash { + size_t operator()(const pair<StackId, Label> &p) const { + return p.first + p.second * kPrime; + } + }; + + static const size_t kPrime; + + vector<pair<Label, Label> > parens_; + vector<StackNode> nodes_; + unordered_map<Label, size_t> paren_map_; + unordered_map<pair<StackId, Label>, + StackId, ChildHash> child_map_; // Child of stack node wrt label + Label min_paren_; // For faster paren. check + Label max_paren_; // For faster paren. check +}; + +template <typename T, typename L> +const size_t PdtStack<T, L>::kPrime = 7853; + + +// State tuple for PDT expansion +template <typename S, typename K> +struct PdtStateTuple { + typedef S StateId; + typedef K StackId; + + StateId state_id; + StackId stack_id; + + PdtStateTuple() + : state_id(kNoStateId), stack_id(-1) {} + + PdtStateTuple(StateId fs, StackId ss) + : state_id(fs), stack_id(ss) {} +}; + +// Equality of PDT state tuples. +template <typename S, typename K> +inline bool operator==(const PdtStateTuple<S, K>& x, + const PdtStateTuple<S, K>& y) { + if (&x == &y) + return true; + return x.state_id == y.state_id && x.stack_id == y.stack_id; +} + + +// Hash function object for PDT state tuples +template <class T> +class PdtStateHash { + public: + size_t operator()(const T &tuple) const { + return tuple.state_id + tuple.stack_id * kPrime; + } + + private: + static const size_t kPrime; +}; + +template <typename T> +const size_t PdtStateHash<T>::kPrime = 7853; + + +// Tuple to PDT state bijection. +template <class S, class K> +class PdtStateTable + : public CompactHashStateTable<PdtStateTuple<S, K>, + PdtStateHash<PdtStateTuple<S, K> > > { + public: + typedef S StateId; + typedef K StackId; + + PdtStateTable() {} + + PdtStateTable(const PdtStateTable<S, K> &table) {} + + private: + void operator=(const PdtStateTable<S, K> &table); // disallow +}; + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_PDT_H__ diff --git a/src/include/fst/extensions/pdt/pdtlib.h b/src/include/fst/extensions/pdt/pdtlib.h new file mode 100644 index 0000000..71c8123 --- /dev/null +++ b/src/include/fst/extensions/pdt/pdtlib.h @@ -0,0 +1,30 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: jpr@google.com (Jake Ratkiewicz) + +// This is an experimental push-down transducer(PDT) library. A PDT is +// encoded as an FST, where some transitions are labeled with open or close +// parentheses. To be interpreted as a PDT, the parentheses must balance on a +// path. + +#ifndef FST_EXTENSIONS_PDT_PDTLIB_H_ +#define FST_EXTENSIONS_PDT_PDTLIB_H_ + +#include <fst/extensions/pdt/pdt.h> +#include <fst/extensions/pdt/compose.h> +#include <fst/extensions/pdt/expand.h> +#include <fst/extensions/pdt/replace.h> + +#endif // FST_EXTENSIONS_PDT_PDTLIB_H_ diff --git a/src/include/fst/extensions/pdt/pdtscript.h b/src/include/fst/extensions/pdt/pdtscript.h new file mode 100644 index 0000000..c2a1cf4 --- /dev/null +++ b/src/include/fst/extensions/pdt/pdtscript.h @@ -0,0 +1,284 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: jpr@google.com (Jake Ratkiewicz) +// Convenience file for including all PDT operations at once, and/or +// registering them for new arc types. + +#ifndef FST_EXTENSIONS_PDT_PDTSCRIPT_H_ +#define FST_EXTENSIONS_PDT_PDTSCRIPT_H_ + +#include <utility> +using std::pair; using std::make_pair; +#include <vector> +using std::vector; + +#include <fst/compose.h> // for ComposeOptions +#include <fst/util.h> + +#include <fst/script/fst-class.h> +#include <fst/script/arg-packs.h> +#include <fst/script/shortest-path.h> + +#include <fst/extensions/pdt/compose.h> +#include <fst/extensions/pdt/expand.h> +#include <fst/extensions/pdt/info.h> +#include <fst/extensions/pdt/replace.h> +#include <fst/extensions/pdt/reverse.h> +#include <fst/extensions/pdt/shortest-path.h> + + +namespace fst { +namespace script { + +// PDT COMPOSE + +typedef args::Package<const FstClass &, + const FstClass &, + const vector<pair<int64, int64> >&, + MutableFstClass *, + const ComposeOptions &, + bool> PdtComposeArgs; + +template<class Arc> +void PdtCompose(PdtComposeArgs *args) { + const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>()); + const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg4->GetMutableFst<Arc>(); + + vector<pair<typename Arc::Label, typename Arc::Label> > parens( + args->arg3.size()); + + for (size_t i = 0; i < parens.size(); ++i) { + parens[i].first = args->arg3[i].first; + parens[i].second = args->arg3[i].second; + } + + if (args->arg6) { + Compose(ifst1, parens, ifst2, ofst, args->arg5); + } else { + Compose(ifst1, ifst2, parens, ofst, args->arg5); + } +} + +void PdtCompose(const FstClass & ifst1, + const FstClass & ifst2, + const vector<pair<int64, int64> > &parens, + MutableFstClass *ofst, + const ComposeOptions &copts, + bool left_pdt); + +// PDT EXPAND + +struct PdtExpandOptions { + bool connect; + bool keep_parentheses; + WeightClass weight_threshold; + + PdtExpandOptions(bool c = true, bool k = false, + WeightClass w = WeightClass::Zero()) + : connect(c), keep_parentheses(k), weight_threshold(w) {} +}; + +typedef args::Package<const FstClass &, + const vector<pair<int64, int64> >&, + MutableFstClass *, PdtExpandOptions> PdtExpandArgs; + +template<class Arc> +void PdtExpand(PdtExpandArgs *args) { + const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>(); + + vector<pair<typename Arc::Label, typename Arc::Label> > parens( + args->arg2.size()); + for (size_t i = 0; i < parens.size(); ++i) { + parens[i].first = args->arg2[i].first; + parens[i].second = args->arg2[i].second; + } + Expand(fst, parens, ofst, + ExpandOptions<Arc>( + args->arg4.connect, args->arg4.keep_parentheses, + *(args->arg4.weight_threshold.GetWeight<typename Arc::Weight>()))); +} + +void PdtExpand(const FstClass &ifst, + const vector<pair<int64, int64> > &parens, + MutableFstClass *ofst, const PdtExpandOptions &opts); + +void PdtExpand(const FstClass &ifst, + const vector<pair<int64, int64> > &parens, + MutableFstClass *ofst, bool connect); + +// PDT REPLACE + +typedef args::Package<const vector<pair<int64, const FstClass*> > &, + MutableFstClass *, + vector<pair<int64, int64> > *, + const int64 &> PdtReplaceArgs; +template<class Arc> +void PdtReplace(PdtReplaceArgs *args) { + vector<pair<typename Arc::Label, const Fst<Arc> *> > tuples( + args->arg1.size()); + for (size_t i = 0; i < tuples.size(); ++i) { + tuples[i].first = args->arg1[i].first; + tuples[i].second = (args->arg1[i].second)->GetFst<Arc>(); + } + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + vector<pair<typename Arc::Label, typename Arc::Label> > parens( + args->arg3->size()); + + for (size_t i = 0; i < parens.size(); ++i) { + parens[i].first = args->arg3->at(i).first; + parens[i].second = args->arg3->at(i).second; + } + Replace(tuples, ofst, &parens, args->arg4); + + // now copy parens back + args->arg3->resize(parens.size()); + for (size_t i = 0; i < parens.size(); ++i) { + (*args->arg3)[i].first = parens[i].first; + (*args->arg3)[i].second = parens[i].second; + } +} + +void PdtReplace(const vector<pair<int64, const FstClass*> > &fst_tuples, + MutableFstClass *ofst, + vector<pair<int64, int64> > *parens, + const int64 &root); + +// PDT REVERSE + +typedef args::Package<const FstClass &, + const vector<pair<int64, int64> >&, + MutableFstClass *> PdtReverseArgs; + +template<class Arc> +void PdtReverse(PdtReverseArgs *args) { + const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>(); + + vector<pair<typename Arc::Label, typename Arc::Label> > parens( + args->arg2.size()); + for (size_t i = 0; i < parens.size(); ++i) { + parens[i].first = args->arg2[i].first; + parens[i].second = args->arg2[i].second; + } + Reverse(fst, parens, ofst); +} + +void PdtReverse(const FstClass &ifst, + const vector<pair<int64, int64> > &parens, + MutableFstClass *ofst); + + +// PDT SHORTESTPATH + +struct PdtShortestPathOptions { + QueueType queue_type; + bool keep_parentheses; + bool path_gc; + + PdtShortestPathOptions(QueueType qt = FIFO_QUEUE, + bool kp = false, bool gc = true) + : queue_type(qt), keep_parentheses(kp), path_gc(gc) {} +}; + +typedef args::Package<const FstClass &, + const vector<pair<int64, int64> >&, + MutableFstClass *, + const PdtShortestPathOptions &> PdtShortestPathArgs; + +template<class Arc> +void PdtShortestPath(PdtShortestPathArgs *args) { + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>(); + const PdtShortestPathOptions &opts = args->arg4; + + + vector<pair<Label, Label> > parens(args->arg2.size()); + for (size_t i = 0; i < parens.size(); ++i) { + parens[i].first = args->arg2[i].first; + parens[i].second = args->arg2[i].second; + } + + switch (opts.queue_type) { + default: + FSTERROR() << "Unknown queue type: " << opts.queue_type; + case FIFO_QUEUE: { + typedef FifoQueue<StateId> Queue; + fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses, + opts.path_gc); + ShortestPath(fst, parens, ofst, spopts); + return; + } + case LIFO_QUEUE: { + typedef LifoQueue<StateId> Queue; + fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses, + opts.path_gc); + ShortestPath(fst, parens, ofst, spopts); + return; + } + case STATE_ORDER_QUEUE: { + typedef StateOrderQueue<StateId> Queue; + fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses, + opts.path_gc); + ShortestPath(fst, parens, ofst, spopts); + return; + } + } +} + +void PdtShortestPath(const FstClass &ifst, + const vector<pair<int64, int64> > &parens, + MutableFstClass *ofst, + const PdtShortestPathOptions &opts = + PdtShortestPathOptions()); + +// PRINT INFO + +typedef args::Package<const FstClass &, + const vector<pair<int64, int64> > &> PrintPdtInfoArgs; + +template<class Arc> +void PrintPdtInfo(PrintPdtInfoArgs *args) { + const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); + vector<pair<typename Arc::Label, typename Arc::Label> > parens( + args->arg2.size()); + for (size_t i = 0; i < parens.size(); ++i) { + parens[i].first = args->arg2[i].first; + parens[i].second = args->arg2[i].second; + } + PdtInfo<Arc> pdtinfo(fst, parens); + PrintPdtInfo(pdtinfo); +} + +void PrintPdtInfo(const FstClass &ifst, + const vector<pair<int64, int64> > &parens); + +} // namespace script +} // namespace fst + + +#define REGISTER_FST_PDT_OPERATIONS(ArcType) \ + REGISTER_FST_OPERATION(PdtCompose, ArcType, PdtComposeArgs); \ + REGISTER_FST_OPERATION(PdtExpand, ArcType, PdtExpandArgs); \ + REGISTER_FST_OPERATION(PdtReplace, ArcType, PdtReplaceArgs); \ + REGISTER_FST_OPERATION(PdtReverse, ArcType, PdtReverseArgs); \ + REGISTER_FST_OPERATION(PdtShortestPath, ArcType, PdtShortestPathArgs); \ + REGISTER_FST_OPERATION(PrintPdtInfo, ArcType, PrintPdtInfoArgs) +#endif // FST_EXTENSIONS_PDT_PDTSCRIPT_H_ diff --git a/src/include/fst/extensions/pdt/replace.h b/src/include/fst/extensions/pdt/replace.h new file mode 100644 index 0000000..a85d0fe --- /dev/null +++ b/src/include/fst/extensions/pdt/replace.h @@ -0,0 +1,192 @@ +// replace.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Recursively replace Fst arcs with other Fst(s) returning a PDT. + +#ifndef FST_EXTENSIONS_PDT_REPLACE_H__ +#define FST_EXTENSIONS_PDT_REPLACE_H__ + +#include <fst/replace.h> + +namespace fst { + +// Hash to paren IDs +template <typename S> +struct ReplaceParenHash { + size_t operator()(const pair<size_t, S> &p) const { + return p.first + p.second * kPrime; + } + private: + static const size_t kPrime = 7853; +}; + +template <typename S> const size_t ReplaceParenHash<S>::kPrime; + +// Builds a pushdown transducer (PDT) from an RTN specification +// identical to that in fst/lib/replace.h. The result is a PDT +// encoded as the FST 'ofst' where some transitions are labeled with +// open or close parentheses. To be interpreted as a PDT, the parens +// must balance on a path (see PdtExpand()). The open/close +// parenthesis label pairs are returned in 'parens'. +template <class Arc> +void Replace(const vector<pair<typename Arc::Label, + const Fst<Arc>* > >& ifst_array, + MutableFst<Arc> *ofst, + vector<pair<typename Arc::Label, + typename Arc::Label> > *parens, + typename Arc::Label root) { + typedef typename Arc::Label Label; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + ofst->DeleteStates(); + parens->clear(); + + unordered_map<Label, size_t> label2id; + for (size_t i = 0; i < ifst_array.size(); ++i) + label2id[ifst_array[i].first] = i; + + Label max_label = kNoLabel; + + deque<size_t> non_term_queue; // Queue of non-terminals to replace + unordered_set<Label> non_term_set; // Set of non-terminals to replace + non_term_queue.push_back(root); + non_term_set.insert(root); + + // PDT state corr. to ith replace FST start state. + vector<StateId> fst_start(ifst_array.size(), kNoLabel); + // PDT state, weight pairs corr. to ith replace FST final state & weights. + vector< vector<pair<StateId, Weight> > > fst_final(ifst_array.size()); + + // Builds single Fst combining all referenced input Fsts. Leaves in the + // non-termnals for now. Tabulate the PDT states that correspond to + // the start and final states of the input Fsts. + for (StateId soff = 0; !non_term_queue.empty(); soff = ofst->NumStates()) { + Label label = non_term_queue.front(); + non_term_queue.pop_front(); + size_t fst_id = label2id[label]; + + const Fst<Arc> *ifst = ifst_array[fst_id].second; + for (StateIterator< Fst<Arc> > siter(*ifst); + !siter.Done(); siter.Next()) { + StateId is = siter.Value(); + StateId os = ofst->AddState(); + if (is == ifst->Start()) { + fst_start[fst_id] = os; + if (label == root) + ofst->SetStart(os); + } + if (ifst->Final(is) != Weight::Zero()) { + if (label == root) + ofst->SetFinal(os, ifst->Final(is)); + fst_final[fst_id].push_back(make_pair(os, ifst->Final(is))); + } + for (ArcIterator< Fst<Arc> > aiter(*ifst, is); + !aiter.Done(); aiter.Next()) { + Arc arc = aiter.Value(); + if (max_label == kNoLabel || arc.olabel > max_label) + max_label = arc.olabel; + typename unordered_map<Label, size_t>::const_iterator it = + label2id.find(arc.olabel); + if (it != label2id.end()) { + size_t nfst_id = it->second; + if (ifst_array[nfst_id].second->Start() == -1) + continue; + if (non_term_set.count(arc.olabel) == 0) { + non_term_queue.push_back(arc.olabel); + non_term_set.insert(arc.olabel); + } + } + arc.nextstate += soff; + ofst->AddArc(os, arc); + } + } + } + + // Changes each non-terminal transition to an open parenthesis + // transition redirected to the PDT state that corresponds to the + // start state of the input FST for the non-terminal. Adds close parenthesis + // transitions from the PDT states corr. to the final states of the + // input FST for the non-terminal to the former destination state of the + // non-terminal transition. + + typedef MutableArcIterator< MutableFst<Arc> > MIter; + typedef unordered_map<pair<size_t, StateId >, size_t, + ReplaceParenHash<StateId> > ParenMap; + + // Parenthesis pair ID per fst, state pair. + ParenMap paren_map; + // # of parenthesis pairs per fst. + vector<size_t> nparens(ifst_array.size(), 0); + // Initial open parenthesis label + Label first_paren = max_label + 1; + + for (StateIterator< Fst<Arc> > siter(*ofst); + !siter.Done(); siter.Next()) { + StateId os = siter.Value(); + MIter *aiter = new MIter(ofst, os); + for (size_t n = 0; !aiter->Done(); aiter->Next(), ++n) { + Arc arc = aiter->Value(); + typename unordered_map<Label, size_t>::const_iterator lit = + label2id.find(arc.olabel); + if (lit != label2id.end()) { + size_t nfst_id = lit->second; + + // Get parentheses. Ensures distinct parenthesis pair per + // non-terminal and destination state but otherwise reuses them. + Label open_paren = kNoLabel, close_paren = kNoLabel; + pair<size_t, StateId> paren_key(nfst_id, arc.nextstate); + typename ParenMap::const_iterator pit = paren_map.find(paren_key); + if (pit != paren_map.end()) { + size_t paren_id = pit->second; + open_paren = (*parens)[paren_id].first; + close_paren = (*parens)[paren_id].second; + } else { + size_t paren_id = nparens[nfst_id]++; + open_paren = first_paren + 2 * paren_id; + close_paren = open_paren + 1; + paren_map[paren_key] = paren_id; + if (paren_id >= parens->size()) + parens->push_back(make_pair(open_paren, close_paren)); + } + + // Sets open parenthesis. + Arc sarc(open_paren, open_paren, arc.weight, fst_start[nfst_id]); + aiter->SetValue(sarc); + + // Adds close parentheses. + for (size_t i = 0; i < fst_final[nfst_id].size(); ++i) { + pair<StateId, Weight> &p = fst_final[nfst_id][i]; + Arc farc(close_paren, close_paren, p.second, arc.nextstate); + + ofst->AddArc(p.first, farc); + if (os == p.first) { // Invalidated iterator + delete aiter; + aiter = new MIter(ofst, os); + aiter->Seek(n); + } + } + } + } + delete aiter; + } +} + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_REPLACE_H__ diff --git a/src/include/fst/extensions/pdt/reverse.h b/src/include/fst/extensions/pdt/reverse.h new file mode 100644 index 0000000..b20e1c5 --- /dev/null +++ b/src/include/fst/extensions/pdt/reverse.h @@ -0,0 +1,58 @@ +// reverse.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Expand a PDT to an FST. + +#ifndef FST_EXTENSIONS_PDT_REVERSE_H__ +#define FST_EXTENSIONS_PDT_REVERSE_H__ + +#include <unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <vector> +using std::vector; + +#include <fst/mutable-fst.h> +#include <fst/relabel.h> +#include <fst/reverse.h> + +namespace fst { + +// Reverses a pushdown transducer (PDT) encoded as an FST. +template<class Arc, class RevArc> +void Reverse(const Fst<Arc> &ifst, + const vector<pair<typename Arc::Label, + typename Arc::Label> > &parens, + MutableFst<RevArc> *ofst) { + typedef typename Arc::Label Label; + + // Reverses FST + Reverse(ifst, ofst); + + // Exchanges open and close parenthesis pairs + vector<pair<Label, Label> > relabel_pairs; + for (size_t i = 0; i < parens.size(); ++i) { + relabel_pairs.push_back(make_pair(parens[i].first, parens[i].second)); + relabel_pairs.push_back(make_pair(parens[i].second, parens[i].first)); + } + Relabel(ofst, relabel_pairs, relabel_pairs); +} + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_REVERSE_H__ diff --git a/src/include/fst/extensions/pdt/shortest-path.h b/src/include/fst/extensions/pdt/shortest-path.h new file mode 100644 index 0000000..e90471b --- /dev/null +++ b/src/include/fst/extensions/pdt/shortest-path.h @@ -0,0 +1,790 @@ +// shortest-path.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Functions to find shortest paths in a PDT. + +#ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H__ +#define FST_EXTENSIONS_PDT_SHORTEST_PATH_H__ + +#include <fst/shortest-path.h> +#include <fst/extensions/pdt/paren.h> +#include <fst/extensions/pdt/pdt.h> + +#include <unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <tr1/unordered_set> +using std::tr1::unordered_set; +using std::tr1::unordered_multiset; +#include <stack> +#include <vector> +using std::vector; + +namespace fst { + +template <class Arc, class Queue> +struct PdtShortestPathOptions { + bool keep_parentheses; + bool path_gc; + + PdtShortestPathOptions(bool kp = false, bool gc = true) + : keep_parentheses(kp), path_gc(gc) {} +}; + + +// Class to store PDT shortest path results. Stores shortest path +// tree info 'Distance()', Parent(), and ArcParent() information keyed +// on two types: +// (1) By SearchState: This is a usual node in a shortest path tree but: +// (a) is w.r.t a PDT search state - a pair of a PDT state and +// a 'start' state, which is either the PDT start state or +// the destination state of an open parenthesis. +// (b) the Distance() is from this 'start' state to the search state. +// (c) Parent().state is kNoLabel for the 'start' state. +// +// (2) By ParenSpec: This connects shortest path trees depending on the +// the parenthesis taken. Given the parenthesis spec: +// (a) the Distance() is from the Parent() 'start' state to the +// parenthesis destination state. +// (b) the ArcParent() is the parenthesis arc. +template <class Arc> +class PdtShortestPathData { + public: + static const uint8 kFinal; + + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef typename Arc::Label Label; + + struct SearchState { + SearchState() : state(kNoStateId), start(kNoStateId) {} + + SearchState(StateId s, StateId t) : state(s), start(t) {} + + bool operator==(const SearchState &s) const { + if (&s == this) + return true; + return s.state == this->state && s.start == this->start; + } + + StateId state; // PDT state + StateId start; // PDT paren 'source' state + }; + + + // Specifies paren id, source and dest 'start' states of a paren. + // These are the 'start' states of the respective sub-graphs. + struct ParenSpec { + ParenSpec() + : paren_id(kNoLabel), src_start(kNoStateId), dest_start(kNoStateId) {} + + ParenSpec(Label id, StateId s, StateId d) + : paren_id(id), src_start(s), dest_start(d) {} + + Label paren_id; // Id of parenthesis + StateId src_start; // sub-graph 'start' state for paren source. + StateId dest_start; // sub-graph 'start' state for paren dest. + + bool operator==(const ParenSpec &x) const { + if (&x == this) + return true; + return x.paren_id == this->paren_id && + x.src_start == this->src_start && + x.dest_start == this->dest_start; + } + }; + + struct SearchData { + SearchData() : distance(Weight::Zero()), + parent(kNoStateId, kNoStateId), + paren_id(kNoLabel), + flags(0) {} + + Weight distance; // Distance to this state from PDT 'start' state + SearchState parent; // Parent state in shortest path tree + int16 paren_id; // If parent arc has paren, paren ID, o.w. kNoLabel + uint8 flags; // First byte reserved for PdtShortestPathData use + }; + + PdtShortestPathData(bool gc) + : state_(kNoStateId, kNoStateId), + paren_(kNoLabel, kNoStateId, kNoStateId), + gc_(gc), + nstates_(0), + ngc_(0), + finished_(false) {} + + ~PdtShortestPathData() { + VLOG(1) << "opm size: " << paren_map_.size(); + VLOG(1) << "# of search states: " << nstates_; + if (gc_) + VLOG(1) << "# of GC'd search states: " << ngc_; + } + + void Clear() { + search_map_.clear(); + search_multimap_.clear(); + paren_map_.clear(); + state_ = SearchState(kNoStateId, kNoStateId); + nstates_ = 0; + ngc_ = 0; + } + + Weight Distance(SearchState s) const { + SearchData *data = GetSearchData(s); + return data->distance; + } + + Weight Distance(const ParenSpec &paren) const { + SearchData *data = GetSearchData(paren); + return data->distance; + } + + SearchState Parent(SearchState s) const { + SearchData *data = GetSearchData(s); + return data->parent; + } + + SearchState Parent(const ParenSpec &paren) const { + SearchData *data = GetSearchData(paren); + return data->parent; + } + + Label ParenId(SearchState s) const { + SearchData *data = GetSearchData(s); + return data->paren_id; + } + + uint8 Flags(SearchState s) const { + SearchData *data = GetSearchData(s); + return data->flags; + } + + void SetDistance(SearchState s, Weight w) { + SearchData *data = GetSearchData(s); + data->distance = w; + } + + void SetDistance(const ParenSpec &paren, Weight w) { + SearchData *data = GetSearchData(paren); + data->distance = w; + } + + void SetParent(SearchState s, SearchState p) { + SearchData *data = GetSearchData(s); + data->parent = p; + } + + void SetParent(const ParenSpec &paren, SearchState p) { + SearchData *data = GetSearchData(paren); + data->parent = p; + } + + void SetParenId(SearchState s, Label p) { + if (p >= 32768) + FSTERROR() << "PdtShortestPathData: Paren ID does not fits in an int16"; + SearchData *data = GetSearchData(s); + data->paren_id = p; + } + + void SetFlags(SearchState s, uint8 f, uint8 mask) { + SearchData *data = GetSearchData(s); + data->flags &= ~mask; + data->flags |= f & mask; + } + + void GC(StateId s); + + void Finish() { finished_ = true; } + + private: + static const Arc kNoArc; + static const size_t kPrime0; + static const size_t kPrime1; + static const uint8 kInited; + static const uint8 kMarked; + + // Hash for search state + struct SearchStateHash { + size_t operator()(const SearchState &s) const { + return s.state + s.start * kPrime0; + } + }; + + // Hash for paren map + struct ParenHash { + size_t operator()(const ParenSpec &paren) const { + return paren.paren_id + paren.src_start * kPrime0 + + paren.dest_start * kPrime1; + } + }; + + typedef unordered_map<SearchState, SearchData, SearchStateHash> SearchMap; + + typedef unordered_multimap<StateId, StateId> SearchMultimap; + + // Hash map from paren spec to open paren data + typedef unordered_map<ParenSpec, SearchData, ParenHash> ParenMap; + + SearchData *GetSearchData(SearchState s) const { + if (s == state_) + return state_data_; + if (finished_) { + typename SearchMap::iterator it = search_map_.find(s); + if (it == search_map_.end()) + return &null_search_data_; + state_ = s; + return state_data_ = &(it->second); + } else { + state_ = s; + state_data_ = &search_map_[s]; + if (!(state_data_->flags & kInited)) { + ++nstates_; + if (gc_) + search_multimap_.insert(make_pair(s.start, s.state)); + state_data_->flags = kInited; + } + return state_data_; + } + } + + SearchData *GetSearchData(ParenSpec paren) const { + if (paren == paren_) + return paren_data_; + if (finished_) { + typename ParenMap::iterator it = paren_map_.find(paren); + if (it == paren_map_.end()) + return &null_search_data_; + paren_ = paren; + return state_data_ = &(it->second); + } else { + paren_ = paren; + return paren_data_ = &paren_map_[paren]; + } + } + + mutable SearchMap search_map_; // Maps from search state to data + mutable SearchMultimap search_multimap_; // Maps from 'start' to subgraph + mutable ParenMap paren_map_; // Maps paren spec to search data + mutable SearchState state_; // Last state accessed + mutable SearchData *state_data_; // Last state data accessed + mutable ParenSpec paren_; // Last paren spec accessed + mutable SearchData *paren_data_; // Last paren data accessed + bool gc_; // Allow GC? + mutable size_t nstates_; // Total number of search states + size_t ngc_; // Number of GC'd search states + mutable SearchData null_search_data_; // Null search data + bool finished_; // Read-only access when true + + DISALLOW_COPY_AND_ASSIGN(PdtShortestPathData); +}; + +// Deletes inaccessible search data from a given 'start' (open paren dest) +// state. Assumes 'final' (close paren source or PDT final) states have +// been flagged 'kFinal'. +template<class Arc> +void PdtShortestPathData<Arc>::GC(StateId start) { + if (!gc_) + return; + vector<StateId> final; + for (typename SearchMultimap::iterator mmit = search_multimap_.find(start); + mmit != search_multimap_.end() && mmit->first == start; + ++mmit) { + SearchState s(mmit->second, start); + const SearchData &data = search_map_[s]; + if (data.flags & kFinal) + final.push_back(s.state); + } + + // Mark phase + for (size_t i = 0; i < final.size(); ++i) { + SearchState s(final[i], start); + while (s.state != kNoLabel) { + SearchData *sdata = &search_map_[s]; + if (sdata->flags & kMarked) + break; + sdata->flags |= kMarked; + SearchState p = sdata->parent; + if (p.start != start && p.start != kNoLabel) { // entering sub-subgraph + ParenSpec paren(sdata->paren_id, s.start, p.start); + SearchData *pdata = &paren_map_[paren]; + s = pdata->parent; + } else { + s = p; + } + } + } + + // Sweep phase + typename SearchMultimap::iterator mmit = search_multimap_.find(start); + while (mmit != search_multimap_.end() && mmit->first == start) { + SearchState s(mmit->second, start); + typename SearchMap::iterator mit = search_map_.find(s); + const SearchData &data = mit->second; + if (!(data.flags & kMarked)) { + search_map_.erase(mit); + ++ngc_; + } + search_multimap_.erase(mmit++); + } +} + +template<class Arc> const Arc PdtShortestPathData<Arc>::kNoArc + = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId); + +template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime0 = 7853; + +template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime1 = 7867; + +template<class Arc> const uint8 PdtShortestPathData<Arc>::kInited = 0x01; + +template<class Arc> const uint8 PdtShortestPathData<Arc>::kFinal = 0x02; + +template<class Arc> const uint8 PdtShortestPathData<Arc>::kMarked = 0x04; + + +// This computes the single source shortest (balanced) path (SSSP) +// through a weighted PDT that has a bounded stack (i.e. is expandable +// as an FST). It is a generalization of the classic SSSP graph +// algorithm that removes a state s from a queue (defined by a +// user-provided queue type) and relaxes the destination states of +// transitions leaving s. In this PDT version, states that have +// entering open parentheses are treated as source states for a +// sub-graph SSSP problem with the shortest path up to the open +// parenthesis being first saved. When a close parenthesis is then +// encountered any balancing open parenthesis is examined for this +// saved information and multiplied back. In this way, each sub-graph +// is entered only once rather than repeatedly. If every state in the +// input PDT has the property that there is a unique 'start' state for +// it with entering open parentheses, then this algorithm is quite +// straight-forward. In general, this will not be the case, so the +// algorithm (implicitly) creates a new graph where each state is a +// pair of an original state and a possible parenthesis 'start' state +// for that state. +template<class Arc, class Queue> +class PdtShortestPath { + public: + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef typename Arc::Label Label; + + typedef PdtShortestPathData<Arc> SpData; + typedef typename SpData::SearchState SearchState; + typedef typename SpData::ParenSpec ParenSpec; + + typedef typename PdtParenReachable<Arc>::SetIterator StateSetIterator; + typedef typename PdtBalanceData<Arc>::SetIterator CloseSourceIterator; + + PdtShortestPath(const Fst<Arc> &ifst, + const vector<pair<Label, Label> > &parens, + const PdtShortestPathOptions<Arc, Queue> &opts) + : kFinal(SpData::kFinal), + ifst_(ifst.Copy()), + parens_(parens), + keep_parens_(opts.keep_parentheses), + start_(ifst.Start()), + sp_data_(opts.path_gc), + error_(false) { + + if ((Weight::Properties() & (kPath | kRightSemiring)) + != (kPath | kRightSemiring)) { + FSTERROR() << "SingleShortestPath: Weight needs to have the path" + << " property and be right distributive: " << Weight::Type(); + error_ = true; + } + + for (Label i = 0; i < parens.size(); ++i) { + const pair<Label, Label> &p = parens[i]; + paren_id_map_[p.first] = i; + paren_id_map_[p.second] = i; + } + }; + + ~PdtShortestPath() { + VLOG(1) << "# of input states: " << CountStates(*ifst_); + VLOG(1) << "# of enqueued: " << nenqueued_; + VLOG(1) << "cpmm size: " << close_paren_multimap_.size(); + delete ifst_; + } + + void ShortestPath(MutableFst<Arc> *ofst) { + Init(ofst); + GetDistance(start_); + GetPath(); + sp_data_.Finish(); + if (error_) ofst->SetProperties(kError, kError); + } + + const PdtShortestPathData<Arc> &GetShortestPathData() const { + return sp_data_; + } + + PdtBalanceData<Arc> *GetBalanceData() { return &balance_data_; } + + private: + static const Arc kNoArc; + static const uint8 kEnqueued; + static const uint8 kExpanded; + const uint8 kFinal; + + public: + // Hash multimap from close paren label to an paren arc. + typedef unordered_multimap<ParenState<Arc>, Arc, + typename ParenState<Arc>::Hash> CloseParenMultimap; + + const CloseParenMultimap &GetCloseParenMultimap() const { + return close_paren_multimap_; + } + + private: + void Init(MutableFst<Arc> *ofst); + void GetDistance(StateId start); + void ProcFinal(SearchState s); + void ProcArcs(SearchState s); + void ProcOpenParen(Label paren_id, SearchState s, Arc arc, Weight w); + void ProcCloseParen(Label paren_id, SearchState s, const Arc &arc, Weight w); + void ProcNonParen(SearchState s, const Arc &arc, Weight w); + void Relax(SearchState s, SearchState t, Arc arc, Weight w, Label paren_id); + void Enqueue(SearchState d); + void GetPath(); + Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open); + + Fst<Arc> *ifst_; + MutableFst<Arc> *ofst_; + const vector<pair<Label, Label> > &parens_; + bool keep_parens_; + Queue *state_queue_; // current state queue + StateId start_; + Weight f_distance_; + SearchState f_parent_; + SpData sp_data_; + unordered_map<Label, Label> paren_id_map_; + CloseParenMultimap close_paren_multimap_; + PdtBalanceData<Arc> balance_data_; + ssize_t nenqueued_; + bool error_; + + DISALLOW_COPY_AND_ASSIGN(PdtShortestPath); +}; + +template<class Arc, class Queue> +void PdtShortestPath<Arc, Queue>::Init(MutableFst<Arc> *ofst) { + ofst_ = ofst; + ofst->DeleteStates(); + ofst->SetInputSymbols(ifst_->InputSymbols()); + ofst->SetOutputSymbols(ifst_->OutputSymbols()); + + if (ifst_->Start() == kNoStateId) + return; + + f_distance_ = Weight::Zero(); + f_parent_ = SearchState(kNoStateId, kNoStateId); + + sp_data_.Clear(); + close_paren_multimap_.clear(); + balance_data_.Clear(); + nenqueued_ = 0; + + // Find open parens per destination state and close parens per source state. + for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + for (ArcIterator<Fst<Arc> > aiter(*ifst_, s); + !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + typename unordered_map<Label, Label>::const_iterator pit + = paren_id_map_.find(arc.ilabel); + if (pit != paren_id_map_.end()) { // Is a paren? + Label paren_id = pit->second; + if (arc.ilabel == parens_[paren_id].first) { // Open paren + balance_data_.OpenInsert(paren_id, arc.nextstate); + } else { // Close paren + ParenState<Arc> paren_state(paren_id, s); + close_paren_multimap_.insert(make_pair(paren_state, arc)); + } + } + } + } +} + +// Computes the shortest distance stored in a recursive way. Each +// sub-graph (i.e. different paren 'start' state) begins with weight One(). +template<class Arc, class Queue> +void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) { + if (start == kNoStateId) + return; + + Queue state_queue; + state_queue_ = &state_queue; + SearchState q(start, start); + Enqueue(q); + sp_data_.SetDistance(q, Weight::One()); + + while (!state_queue_->Empty()) { + StateId state = state_queue_->Head(); + state_queue_->Dequeue(); + SearchState s(state, start); + sp_data_.SetFlags(s, 0, kEnqueued); + ProcFinal(s); + ProcArcs(s); + sp_data_.SetFlags(s, kExpanded, kExpanded); + } + balance_data_.FinishInsert(start); + sp_data_.GC(start); +} + +// Updates best complete path. +template<class Arc, class Queue> +void PdtShortestPath<Arc, Queue>::ProcFinal(SearchState s) { + if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) { + Weight w = Times(sp_data_.Distance(s), + ifst_->Final(s.state)); + if (f_distance_ != Plus(f_distance_, w)) { + if (f_parent_.state != kNoStateId) + sp_data_.SetFlags(f_parent_, 0, kFinal); + sp_data_.SetFlags(s, kFinal, kFinal); + + f_distance_ = Plus(f_distance_, w); + f_parent_ = s; + } + } +} + +// Processes all arcs leaving the state s. +template<class Arc, class Queue> +void PdtShortestPath<Arc, Queue>::ProcArcs(SearchState s) { + for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state); + !aiter.Done(); + aiter.Next()) { + Arc arc = aiter.Value(); + Weight w = Times(sp_data_.Distance(s), arc.weight); + + typename unordered_map<Label, Label>::const_iterator pit + = paren_id_map_.find(arc.ilabel); + if (pit != paren_id_map_.end()) { // Is a paren? + Label paren_id = pit->second; + if (arc.ilabel == parens_[paren_id].first) + ProcOpenParen(paren_id, s, arc, w); + else + ProcCloseParen(paren_id, s, arc, w); + } else { + ProcNonParen(s, arc, w); + } + } +} + +// Saves the shortest path info for reaching this parenthesis +// and starts a new SSSP in the sub-graph pointed to by the parenthesis +// if previously unvisited. Otherwise it finds any previously encountered +// closing parentheses and relaxes them using the recursively stored +// shortest distance to them. +template<class Arc, class Queue> inline +void PdtShortestPath<Arc, Queue>::ProcOpenParen( + Label paren_id, SearchState s, Arc arc, Weight w) { + + SearchState d(arc.nextstate, arc.nextstate); + ParenSpec paren(paren_id, s.start, d.start); + Weight pdist = sp_data_.Distance(paren); + if (pdist != Plus(pdist, w)) { + sp_data_.SetDistance(paren, w); + sp_data_.SetParent(paren, s); + Weight dist = sp_data_.Distance(d); + if (dist == Weight::Zero()) { + Queue *state_queue = state_queue_; + GetDistance(d.start); + state_queue_ = state_queue; + } + for (CloseSourceIterator set_iter = + balance_data_.Find(paren_id, arc.nextstate); + !set_iter.Done(); set_iter.Next()) { + SearchState cpstate(set_iter.Element(), d.start); + ParenState<Arc> paren_state(paren_id, cpstate.state); + for (typename CloseParenMultimap::const_iterator cpit = + close_paren_multimap_.find(paren_state); + cpit != close_paren_multimap_.end() && paren_state == cpit->first; + ++cpit) { + const Arc &cparc = cpit->second; + Weight cpw = Times(w, Times(sp_data_.Distance(cpstate), + cparc.weight)); + Relax(cpstate, s, cparc, cpw, paren_id); + } + } + } +} + +// Saves the correspondence between each closing parenthesis and its +// balancing open parenthesis info. Relaxes any close parenthesis +// destination state that has a balancing previously encountered open +// parenthesis. +template<class Arc, class Queue> inline +void PdtShortestPath<Arc, Queue>::ProcCloseParen( + Label paren_id, SearchState s, const Arc &arc, Weight w) { + ParenState<Arc> paren_state(paren_id, s.start); + if (!(sp_data_.Flags(s) & kExpanded)) { + balance_data_.CloseInsert(paren_id, s.start, s.state); + sp_data_.SetFlags(s, kFinal, kFinal); + } +} + +// For non-parentheses, classical relaxation. +template<class Arc, class Queue> inline +void PdtShortestPath<Arc, Queue>::ProcNonParen( + SearchState s, const Arc &arc, Weight w) { + Relax(s, s, arc, w, kNoLabel); +} + +// Classical relaxation on the search graph for 'arc' from state 's'. +// State 't' is in the same sub-graph as the nextstate should be (i.e. +// has the same paren 'start'. +template<class Arc, class Queue> inline +void PdtShortestPath<Arc, Queue>::Relax( + SearchState s, SearchState t, Arc arc, Weight w, Label paren_id) { + SearchState d(arc.nextstate, t.start); + Weight dist = sp_data_.Distance(d); + if (dist != Plus(dist, w)) { + sp_data_.SetParent(d, s); + sp_data_.SetParenId(d, paren_id); + sp_data_.SetDistance(d, Plus(dist, w)); + Enqueue(d); + } +} + +template<class Arc, class Queue> inline +void PdtShortestPath<Arc, Queue>::Enqueue(SearchState s) { + if (!(sp_data_.Flags(s) & kEnqueued)) { + state_queue_->Enqueue(s.state); + sp_data_.SetFlags(s, kEnqueued, kEnqueued); + ++nenqueued_; + } else { + state_queue_->Update(s.state); + } +} + +// Follows parent pointers to find the shortest path. Uses a stack +// since the shortest distance is stored recursively. +template<class Arc, class Queue> +void PdtShortestPath<Arc, Queue>::GetPath() { + SearchState s = f_parent_, d = SearchState(kNoStateId, kNoStateId); + StateId s_p = kNoStateId, d_p = kNoStateId; + Arc arc(kNoArc); + Label paren_id = kNoLabel; + stack<ParenSpec> paren_stack; + while (s.state != kNoStateId) { + d_p = s_p; + s_p = ofst_->AddState(); + if (d.state == kNoStateId) { + ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state)); + } else { + if (paren_id != kNoLabel) { // paren? + if (arc.ilabel == parens_[paren_id].first) { // open paren + paren_stack.pop(); + } else { // close paren + ParenSpec paren(paren_id, d.start, s.start); + paren_stack.push(paren); + } + if (!keep_parens_) + arc.ilabel = arc.olabel = 0; + } + arc.nextstate = d_p; + ofst_->AddArc(s_p, arc); + } + d = s; + s = sp_data_.Parent(d); + paren_id = sp_data_.ParenId(d); + if (s.state != kNoStateId) { + arc = GetPathArc(s, d, paren_id, false); + } else if (!paren_stack.empty()) { + ParenSpec paren = paren_stack.top(); + s = sp_data_.Parent(paren); + paren_id = paren.paren_id; + arc = GetPathArc(s, d, paren_id, true); + } + } + ofst_->SetStart(s_p); + ofst_->SetProperties( + ShortestPathProperties(ofst_->Properties(kFstProperties, false)), + kFstProperties); +} + + +// Finds transition with least weight between two states with label matching +// paren_id and open/close paren type or a non-paren if kNoLabel. +template<class Arc, class Queue> +Arc PdtShortestPath<Arc, Queue>::GetPathArc( + SearchState s, SearchState d, Label paren_id, bool open_paren) { + Arc path_arc = kNoArc; + for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state); + !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.nextstate != d.state) + continue; + Label arc_paren_id = kNoLabel; + typename unordered_map<Label, Label>::const_iterator pit + = paren_id_map_.find(arc.ilabel); + if (pit != paren_id_map_.end()) { + arc_paren_id = pit->second; + bool arc_open_paren = arc.ilabel == parens_[arc_paren_id].first; + if (arc_open_paren != open_paren) + continue; + } + if (arc_paren_id != paren_id) + continue; + if (arc.weight == Plus(arc.weight, path_arc.weight)) + path_arc = arc; + } + if (path_arc.nextstate == kNoStateId) { + FSTERROR() << "PdtShortestPath::GetPathArc failed to find arc"; + error_ = true; + } + return path_arc; +} + +template<class Arc, class Queue> +const Arc PdtShortestPath<Arc, Queue>::kNoArc + = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId); + +template<class Arc, class Queue> +const uint8 PdtShortestPath<Arc, Queue>::kEnqueued = 0x10; + +template<class Arc, class Queue> +const uint8 PdtShortestPath<Arc, Queue>::kExpanded = 0x20; + +template<class Arc, class Queue> +void ShortestPath(const Fst<Arc> &ifst, + const vector<pair<typename Arc::Label, + typename Arc::Label> > &parens, + MutableFst<Arc> *ofst, + const PdtShortestPathOptions<Arc, Queue> &opts) { + PdtShortestPath<Arc, Queue> psp(ifst, parens, opts); + psp.ShortestPath(ofst); +} + +template<class Arc> +void ShortestPath(const Fst<Arc> &ifst, + const vector<pair<typename Arc::Label, + typename Arc::Label> > &parens, + MutableFst<Arc> *ofst) { + typedef FifoQueue<typename Arc::StateId> Queue; + PdtShortestPathOptions<Arc, Queue> opts; + PdtShortestPath<Arc, Queue> psp(ifst, parens, opts); + psp.ShortestPath(ofst); +} + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_SHORTEST_PATH_H__ |