// shortest-path.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: riley@google.com (Michael Riley) // // \file // Functions to find shortest paths in a PDT. #ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H__ #define FST_EXTENSIONS_PDT_SHORTEST_PATH_H__ #include #include #include #include using std::tr1::unordered_map; using std::tr1::unordered_multimap; #include using std::tr1::unordered_set; using std::tr1::unordered_multiset; #include #include using std::vector; namespace fst { template struct PdtShortestPathOptions { bool keep_parentheses; bool path_gc; PdtShortestPathOptions(bool kp = false, bool gc = true) : keep_parentheses(kp), path_gc(gc) {} }; // Class to store PDT shortest path results. Stores shortest path // tree info 'Distance()', Parent(), and ArcParent() information keyed // on two types: // (1) By SearchState: This is a usual node in a shortest path tree but: // (a) is w.r.t a PDT search state - a pair of a PDT state and // a 'start' state, which is either the PDT start state or // the destination state of an open parenthesis. // (b) the Distance() is from this 'start' state to the search state. // (c) Parent().state is kNoLabel for the 'start' state. // // (2) By ParenSpec: This connects shortest path trees depending on the // the parenthesis taken. Given the parenthesis spec: // (a) the Distance() is from the Parent() 'start' state to the // parenthesis destination state. // (b) the ArcParent() is the parenthesis arc. template class PdtShortestPathData { public: static const uint8 kFinal; typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; typedef typename Arc::Label Label; struct SearchState { SearchState() : state(kNoStateId), start(kNoStateId) {} SearchState(StateId s, StateId t) : state(s), start(t) {} bool operator==(const SearchState &s) const { if (&s == this) return true; return s.state == this->state && s.start == this->start; } StateId state; // PDT state StateId start; // PDT paren 'source' state }; // Specifies paren id, source and dest 'start' states of a paren. // These are the 'start' states of the respective sub-graphs. struct ParenSpec { ParenSpec() : paren_id(kNoLabel), src_start(kNoStateId), dest_start(kNoStateId) {} ParenSpec(Label id, StateId s, StateId d) : paren_id(id), src_start(s), dest_start(d) {} Label paren_id; // Id of parenthesis StateId src_start; // sub-graph 'start' state for paren source. StateId dest_start; // sub-graph 'start' state for paren dest. bool operator==(const ParenSpec &x) const { if (&x == this) return true; return x.paren_id == this->paren_id && x.src_start == this->src_start && x.dest_start == this->dest_start; } }; struct SearchData { SearchData() : distance(Weight::Zero()), parent(kNoStateId, kNoStateId), paren_id(kNoLabel), flags(0) {} Weight distance; // Distance to this state from PDT 'start' state SearchState parent; // Parent state in shortest path tree int16 paren_id; // If parent arc has paren, paren ID, o.w. kNoLabel uint8 flags; // First byte reserved for PdtShortestPathData use }; PdtShortestPathData(bool gc) : state_(kNoStateId, kNoStateId), paren_(kNoLabel, kNoStateId, kNoStateId), gc_(gc), nstates_(0), ngc_(0), finished_(false) {} ~PdtShortestPathData() { VLOG(1) << "opm size: " << paren_map_.size(); VLOG(1) << "# of search states: " << nstates_; if (gc_) VLOG(1) << "# of GC'd search states: " << ngc_; } void Clear() { search_map_.clear(); search_multimap_.clear(); paren_map_.clear(); state_ = SearchState(kNoStateId, kNoStateId); nstates_ = 0; ngc_ = 0; } Weight Distance(SearchState s) const { SearchData *data = GetSearchData(s); return data->distance; } Weight Distance(const ParenSpec &paren) const { SearchData *data = GetSearchData(paren); return data->distance; } SearchState Parent(SearchState s) const { SearchData *data = GetSearchData(s); return data->parent; } SearchState Parent(const ParenSpec &paren) const { SearchData *data = GetSearchData(paren); return data->parent; } Label ParenId(SearchState s) const { SearchData *data = GetSearchData(s); return data->paren_id; } uint8 Flags(SearchState s) const { SearchData *data = GetSearchData(s); return data->flags; } void SetDistance(SearchState s, Weight w) { SearchData *data = GetSearchData(s); data->distance = w; } void SetDistance(const ParenSpec &paren, Weight w) { SearchData *data = GetSearchData(paren); data->distance = w; } void SetParent(SearchState s, SearchState p) { SearchData *data = GetSearchData(s); data->parent = p; } void SetParent(const ParenSpec &paren, SearchState p) { SearchData *data = GetSearchData(paren); data->parent = p; } void SetParenId(SearchState s, Label p) { if (p >= 32768) FSTERROR() << "PdtShortestPathData: Paren ID does not fits in an int16"; SearchData *data = GetSearchData(s); data->paren_id = p; } void SetFlags(SearchState s, uint8 f, uint8 mask) { SearchData *data = GetSearchData(s); data->flags &= ~mask; data->flags |= f & mask; } void GC(StateId s); void Finish() { finished_ = true; } private: static const Arc kNoArc; static const size_t kPrime0; static const size_t kPrime1; static const uint8 kInited; static const uint8 kMarked; // Hash for search state struct SearchStateHash { size_t operator()(const SearchState &s) const { return s.state + s.start * kPrime0; } }; // Hash for paren map struct ParenHash { size_t operator()(const ParenSpec &paren) const { return paren.paren_id + paren.src_start * kPrime0 + paren.dest_start * kPrime1; } }; typedef unordered_map SearchMap; typedef unordered_multimap SearchMultimap; // Hash map from paren spec to open paren data typedef unordered_map ParenMap; SearchData *GetSearchData(SearchState s) const { if (s == state_) return state_data_; if (finished_) { typename SearchMap::iterator it = search_map_.find(s); if (it == search_map_.end()) return &null_search_data_; state_ = s; return state_data_ = &(it->second); } else { state_ = s; state_data_ = &search_map_[s]; if (!(state_data_->flags & kInited)) { ++nstates_; if (gc_) search_multimap_.insert(make_pair(s.start, s.state)); state_data_->flags = kInited; } return state_data_; } } SearchData *GetSearchData(ParenSpec paren) const { if (paren == paren_) return paren_data_; if (finished_) { typename ParenMap::iterator it = paren_map_.find(paren); if (it == paren_map_.end()) return &null_search_data_; paren_ = paren; return state_data_ = &(it->second); } else { paren_ = paren; return paren_data_ = &paren_map_[paren]; } } mutable SearchMap search_map_; // Maps from search state to data mutable SearchMultimap search_multimap_; // Maps from 'start' to subgraph mutable ParenMap paren_map_; // Maps paren spec to search data mutable SearchState state_; // Last state accessed mutable SearchData *state_data_; // Last state data accessed mutable ParenSpec paren_; // Last paren spec accessed mutable SearchData *paren_data_; // Last paren data accessed bool gc_; // Allow GC? mutable size_t nstates_; // Total number of search states size_t ngc_; // Number of GC'd search states mutable SearchData null_search_data_; // Null search data bool finished_; // Read-only access when true DISALLOW_COPY_AND_ASSIGN(PdtShortestPathData); }; // Deletes inaccessible search data from a given 'start' (open paren dest) // state. Assumes 'final' (close paren source or PDT final) states have // been flagged 'kFinal'. template void PdtShortestPathData::GC(StateId start) { if (!gc_) return; vector final; for (typename SearchMultimap::iterator mmit = search_multimap_.find(start); mmit != search_multimap_.end() && mmit->first == start; ++mmit) { SearchState s(mmit->second, start); const SearchData &data = search_map_[s]; if (data.flags & kFinal) final.push_back(s.state); } // Mark phase for (size_t i = 0; i < final.size(); ++i) { SearchState s(final[i], start); while (s.state != kNoLabel) { SearchData *sdata = &search_map_[s]; if (sdata->flags & kMarked) break; sdata->flags |= kMarked; SearchState p = sdata->parent; if (p.start != start && p.start != kNoLabel) { // entering sub-subgraph ParenSpec paren(sdata->paren_id, s.start, p.start); SearchData *pdata = &paren_map_[paren]; s = pdata->parent; } else { s = p; } } } // Sweep phase typename SearchMultimap::iterator mmit = search_multimap_.find(start); while (mmit != search_multimap_.end() && mmit->first == start) { SearchState s(mmit->second, start); typename SearchMap::iterator mit = search_map_.find(s); const SearchData &data = mit->second; if (!(data.flags & kMarked)) { search_map_.erase(mit); ++ngc_; } search_multimap_.erase(mmit++); } } template const Arc PdtShortestPathData::kNoArc = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId); template const size_t PdtShortestPathData::kPrime0 = 7853; template const size_t PdtShortestPathData::kPrime1 = 7867; template const uint8 PdtShortestPathData::kInited = 0x01; template const uint8 PdtShortestPathData::kFinal = 0x02; template const uint8 PdtShortestPathData::kMarked = 0x04; // This computes the single source shortest (balanced) path (SSSP) // through a weighted PDT that has a bounded stack (i.e. is expandable // as an FST). It is a generalization of the classic SSSP graph // algorithm that removes a state s from a queue (defined by a // user-provided queue type) and relaxes the destination states of // transitions leaving s. In this PDT version, states that have // entering open parentheses are treated as source states for a // sub-graph SSSP problem with the shortest path up to the open // parenthesis being first saved. When a close parenthesis is then // encountered any balancing open parenthesis is examined for this // saved information and multiplied back. In this way, each sub-graph // is entered only once rather than repeatedly. If every state in the // input PDT has the property that there is a unique 'start' state for // it with entering open parentheses, then this algorithm is quite // straight-forward. In general, this will not be the case, so the // algorithm (implicitly) creates a new graph where each state is a // pair of an original state and a possible parenthesis 'start' state // for that state. template class PdtShortestPath { public: typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; typedef typename Arc::Label Label; typedef PdtShortestPathData SpData; typedef typename SpData::SearchState SearchState; typedef typename SpData::ParenSpec ParenSpec; typedef typename PdtBalanceData::SetIterator CloseSourceIterator; PdtShortestPath(const Fst &ifst, const vector > &parens, const PdtShortestPathOptions &opts) : kFinal(SpData::kFinal), ifst_(ifst.Copy()), parens_(parens), keep_parens_(opts.keep_parentheses), start_(ifst.Start()), sp_data_(opts.path_gc), error_(false) { if ((Weight::Properties() & (kPath | kRightSemiring)) != (kPath | kRightSemiring)) { FSTERROR() << "PdtShortestPath: Weight needs to have the path" << " property and be right distributive: " << Weight::Type(); error_ = true; } for (Label i = 0; i < parens.size(); ++i) { const pair &p = parens[i]; paren_id_map_[p.first] = i; paren_id_map_[p.second] = i; } }; ~PdtShortestPath() { VLOG(1) << "# of input states: " << CountStates(*ifst_); VLOG(1) << "# of enqueued: " << nenqueued_; VLOG(1) << "cpmm size: " << close_paren_multimap_.size(); delete ifst_; } void ShortestPath(MutableFst *ofst) { Init(ofst); GetDistance(start_); GetPath(); sp_data_.Finish(); if (error_) ofst->SetProperties(kError, kError); } const PdtShortestPathData &GetShortestPathData() const { return sp_data_; } PdtBalanceData *GetBalanceData() { return &balance_data_; } private: static const Arc kNoArc; static const uint8 kEnqueued; static const uint8 kExpanded; static const uint8 kFinished; const uint8 kFinal; public: // Hash multimap from close paren label to an paren arc. typedef unordered_multimap, Arc, typename ParenState::Hash> CloseParenMultimap; const CloseParenMultimap &GetCloseParenMultimap() const { return close_paren_multimap_; } private: void Init(MutableFst *ofst); void GetDistance(StateId start); void ProcFinal(SearchState s); void ProcArcs(SearchState s); void ProcOpenParen(Label paren_id, SearchState s, Arc arc, Weight w); void ProcCloseParen(Label paren_id, SearchState s, const Arc &arc, Weight w); void ProcNonParen(SearchState s, const Arc &arc, Weight w); void Relax(SearchState s, SearchState t, Arc arc, Weight w, Label paren_id); void Enqueue(SearchState d); void GetPath(); Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open); Fst *ifst_; MutableFst *ofst_; const vector > &parens_; bool keep_parens_; Queue *state_queue_; // current state queue StateId start_; Weight f_distance_; SearchState f_parent_; SpData sp_data_; unordered_map paren_id_map_; CloseParenMultimap close_paren_multimap_; PdtBalanceData balance_data_; ssize_t nenqueued_; bool error_; DISALLOW_COPY_AND_ASSIGN(PdtShortestPath); }; template void PdtShortestPath::Init(MutableFst *ofst) { ofst_ = ofst; ofst->DeleteStates(); ofst->SetInputSymbols(ifst_->InputSymbols()); ofst->SetOutputSymbols(ifst_->OutputSymbols()); if (ifst_->Start() == kNoStateId) return; f_distance_ = Weight::Zero(); f_parent_ = SearchState(kNoStateId, kNoStateId); sp_data_.Clear(); close_paren_multimap_.clear(); balance_data_.Clear(); nenqueued_ = 0; // Find open parens per destination state and close parens per source state. for (StateIterator > siter(*ifst_); !siter.Done(); siter.Next()) { StateId s = siter.Value(); for (ArcIterator > aiter(*ifst_, s); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); typename unordered_map::const_iterator pit = paren_id_map_.find(arc.ilabel); if (pit != paren_id_map_.end()) { // Is a paren? Label paren_id = pit->second; if (arc.ilabel == parens_[paren_id].first) { // Open paren balance_data_.OpenInsert(paren_id, arc.nextstate); } else { // Close paren ParenState paren_state(paren_id, s); close_paren_multimap_.insert(make_pair(paren_state, arc)); } } } } } // Computes the shortest distance stored in a recursive way. Each // sub-graph (i.e. different paren 'start' state) begins with weight One(). template void PdtShortestPath::GetDistance(StateId start) { if (start == kNoStateId) return; Queue state_queue; state_queue_ = &state_queue; SearchState q(start, start); Enqueue(q); sp_data_.SetDistance(q, Weight::One()); while (!state_queue_->Empty()) { StateId state = state_queue_->Head(); state_queue_->Dequeue(); SearchState s(state, start); sp_data_.SetFlags(s, 0, kEnqueued); ProcFinal(s); ProcArcs(s); sp_data_.SetFlags(s, kExpanded, kExpanded); } sp_data_.SetFlags(q, kFinished, kFinished); balance_data_.FinishInsert(start); sp_data_.GC(start); } // Updates best complete path. template void PdtShortestPath::ProcFinal(SearchState s) { if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) { Weight w = Times(sp_data_.Distance(s), ifst_->Final(s.state)); if (f_distance_ != Plus(f_distance_, w)) { if (f_parent_.state != kNoStateId) sp_data_.SetFlags(f_parent_, 0, kFinal); sp_data_.SetFlags(s, kFinal, kFinal); f_distance_ = Plus(f_distance_, w); f_parent_ = s; } } } // Processes all arcs leaving the state s. template void PdtShortestPath::ProcArcs(SearchState s) { for (ArcIterator< Fst > aiter(*ifst_, s.state); !aiter.Done(); aiter.Next()) { Arc arc = aiter.Value(); Weight w = Times(sp_data_.Distance(s), arc.weight); typename unordered_map::const_iterator pit = paren_id_map_.find(arc.ilabel); if (pit != paren_id_map_.end()) { // Is a paren? Label paren_id = pit->second; if (arc.ilabel == parens_[paren_id].first) ProcOpenParen(paren_id, s, arc, w); else ProcCloseParen(paren_id, s, arc, w); } else { ProcNonParen(s, arc, w); } } } // Saves the shortest path info for reaching this parenthesis // and starts a new SSSP in the sub-graph pointed to by the parenthesis // if previously unvisited. Otherwise it finds any previously encountered // closing parentheses and relaxes them using the recursively stored // shortest distance to them. template inline void PdtShortestPath::ProcOpenParen( Label paren_id, SearchState s, Arc arc, Weight w) { SearchState d(arc.nextstate, arc.nextstate); ParenSpec paren(paren_id, s.start, d.start); Weight pdist = sp_data_.Distance(paren); if (pdist != Plus(pdist, w)) { sp_data_.SetDistance(paren, w); sp_data_.SetParent(paren, s); Weight dist = sp_data_.Distance(d); if (dist == Weight::Zero()) { Queue *state_queue = state_queue_; GetDistance(d.start); state_queue_ = state_queue; } else if (!(sp_data_.Flags(d) & kFinished)) { FSTERROR() << "PdtShortestPath: open parenthesis recursion: not bounded stack"; error_ = true; } for (CloseSourceIterator set_iter = balance_data_.Find(paren_id, arc.nextstate); !set_iter.Done(); set_iter.Next()) { SearchState cpstate(set_iter.Element(), d.start); ParenState paren_state(paren_id, cpstate.state); for (typename CloseParenMultimap::const_iterator cpit = close_paren_multimap_.find(paren_state); cpit != close_paren_multimap_.end() && paren_state == cpit->first; ++cpit) { const Arc &cparc = cpit->second; Weight cpw = Times(w, Times(sp_data_.Distance(cpstate), cparc.weight)); Relax(cpstate, s, cparc, cpw, paren_id); } } } } // Saves the correspondence between each closing parenthesis and its // balancing open parenthesis info. Relaxes any close parenthesis // destination state that has a balancing previously encountered open // parenthesis. template inline void PdtShortestPath::ProcCloseParen( Label paren_id, SearchState s, const Arc &arc, Weight w) { ParenState paren_state(paren_id, s.start); if (!(sp_data_.Flags(s) & kExpanded)) { balance_data_.CloseInsert(paren_id, s.start, s.state); sp_data_.SetFlags(s, kFinal, kFinal); } } // For non-parentheses, classical relaxation. template inline void PdtShortestPath::ProcNonParen( SearchState s, const Arc &arc, Weight w) { Relax(s, s, arc, w, kNoLabel); } // Classical relaxation on the search graph for 'arc' from state 's'. // State 't' is in the same sub-graph as the nextstate should be (i.e. // has the same paren 'start'. template inline void PdtShortestPath::Relax( SearchState s, SearchState t, Arc arc, Weight w, Label paren_id) { SearchState d(arc.nextstate, t.start); Weight dist = sp_data_.Distance(d); if (dist != Plus(dist, w)) { sp_data_.SetParent(d, s); sp_data_.SetParenId(d, paren_id); sp_data_.SetDistance(d, Plus(dist, w)); Enqueue(d); } } template inline void PdtShortestPath::Enqueue(SearchState s) { if (!(sp_data_.Flags(s) & kEnqueued)) { state_queue_->Enqueue(s.state); sp_data_.SetFlags(s, kEnqueued, kEnqueued); ++nenqueued_; } else { state_queue_->Update(s.state); } } // Follows parent pointers to find the shortest path. Uses a stack // since the shortest distance is stored recursively. template void PdtShortestPath::GetPath() { SearchState s = f_parent_, d = SearchState(kNoStateId, kNoStateId); StateId s_p = kNoStateId, d_p = kNoStateId; Arc arc(kNoArc); Label paren_id = kNoLabel; stack paren_stack; while (s.state != kNoStateId) { d_p = s_p; s_p = ofst_->AddState(); if (d.state == kNoStateId) { ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state)); } else { if (paren_id != kNoLabel) { // paren? if (arc.ilabel == parens_[paren_id].first) { // open paren paren_stack.pop(); } else { // close paren ParenSpec paren(paren_id, d.start, s.start); paren_stack.push(paren); } if (!keep_parens_) arc.ilabel = arc.olabel = 0; } arc.nextstate = d_p; ofst_->AddArc(s_p, arc); } d = s; s = sp_data_.Parent(d); paren_id = sp_data_.ParenId(d); if (s.state != kNoStateId) { arc = GetPathArc(s, d, paren_id, false); } else if (!paren_stack.empty()) { ParenSpec paren = paren_stack.top(); s = sp_data_.Parent(paren); paren_id = paren.paren_id; arc = GetPathArc(s, d, paren_id, true); } } ofst_->SetStart(s_p); ofst_->SetProperties( ShortestPathProperties(ofst_->Properties(kFstProperties, false)), kFstProperties); } // Finds transition with least weight between two states with label matching // paren_id and open/close paren type or a non-paren if kNoLabel. template Arc PdtShortestPath::GetPathArc( SearchState s, SearchState d, Label paren_id, bool open_paren) { Arc path_arc = kNoArc; for (ArcIterator< Fst > aiter(*ifst_, s.state); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); if (arc.nextstate != d.state) continue; Label arc_paren_id = kNoLabel; typename unordered_map::const_iterator pit = paren_id_map_.find(arc.ilabel); if (pit != paren_id_map_.end()) { arc_paren_id = pit->second; bool arc_open_paren = arc.ilabel == parens_[arc_paren_id].first; if (arc_open_paren != open_paren) continue; } if (arc_paren_id != paren_id) continue; if (arc.weight == Plus(arc.weight, path_arc.weight)) path_arc = arc; } if (path_arc.nextstate == kNoStateId) { FSTERROR() << "PdtShortestPath::GetPathArc failed to find arc"; error_ = true; } return path_arc; } template const Arc PdtShortestPath::kNoArc = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId); template const uint8 PdtShortestPath::kEnqueued = 0x10; template const uint8 PdtShortestPath::kExpanded = 0x20; template const uint8 PdtShortestPath::kFinished = 0x40; template void ShortestPath(const Fst &ifst, const vector > &parens, MutableFst *ofst, const PdtShortestPathOptions &opts) { PdtShortestPath psp(ifst, parens, opts); psp.ShortestPath(ofst); } template void ShortestPath(const Fst &ifst, const vector > &parens, MutableFst *ofst) { typedef FifoQueue Queue; PdtShortestPathOptions opts; PdtShortestPath psp(ifst, parens, opts); psp.ShortestPath(ofst); } } // namespace fst #endif // FST_EXTENSIONS_PDT_SHORTEST_PATH_H__