diff options
Diffstat (limited to 'src/include/fst/extensions/pdt/paren.h')
-rw-r--r-- | src/include/fst/extensions/pdt/paren.h | 496 |
1 files changed, 496 insertions, 0 deletions
diff --git a/src/include/fst/extensions/pdt/paren.h b/src/include/fst/extensions/pdt/paren.h new file mode 100644 index 0000000..7b9887f --- /dev/null +++ b/src/include/fst/extensions/pdt/paren.h @@ -0,0 +1,496 @@ +// paren.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// Common classes for PDT parentheses + +// \file + +#ifndef FST_EXTENSIONS_PDT_PAREN_H_ +#define FST_EXTENSIONS_PDT_PAREN_H_ + +#include <algorithm> +#include <unordered_map> +using std::tr1::unordered_map; +using std::tr1::unordered_multimap; +#include <tr1/unordered_set> +using std::tr1::unordered_set; +using std::tr1::unordered_multiset; +#include <set> + +#include <fst/extensions/pdt/pdt.h> +#include <fst/extensions/pdt/collection.h> +#include <fst/fst.h> +#include <fst/dfs-visit.h> + + +namespace fst { + +// +// ParenState: Pair of an open (close) parenthesis and +// its destination (source) state. +// + +template <class A> +class ParenState { + public: + typedef typename A::Label Label; + typedef typename A::StateId StateId; + + struct Hash { + size_t operator()(const ParenState<A> &p) const { + return p.paren_id + p.state_id * kPrime; + } + }; + + Label paren_id; // ID of open (close) paren + StateId state_id; // destination (source) state of open (close) paren + + ParenState() : paren_id(kNoLabel), state_id(kNoStateId) {} + + ParenState(Label p, StateId s) : paren_id(p), state_id(s) {} + + bool operator==(const ParenState<A> &p) const { + if (&p == this) + return true; + return p.paren_id == this->paren_id && p.state_id == this->state_id; + } + + bool operator!=(const ParenState<A> &p) const { return !(p == *this); } + + bool operator<(const ParenState<A> &p) const { + return paren_id < this->paren.id || + (p.paren_id == this->paren.id && p.state_id < this->state_id); + } + + private: + static const size_t kPrime; +}; + +template <class A> +const size_t ParenState<A>::kPrime = 7853; + + +// Creates an FST-style iterator from STL map and iterator. +template <class M> +class MapIterator { + public: + typedef typename M::const_iterator StlIterator; + typedef typename M::value_type PairType; + typedef typename PairType::second_type ValueType; + + MapIterator(const M &m, StlIterator iter) + : map_(m), begin_(iter), iter_(iter) {} + + bool Done() const { + return iter_ == map_.end() || iter_->first != begin_->first; + } + + ValueType Value() const { return iter_->second; } + void Next() { ++iter_; } + void Reset() { iter_ = begin_; } + + private: + const M &map_; + StlIterator begin_; + StlIterator iter_; +}; + +// +// PdtParenReachable: Provides various parenthesis reachability information +// on a PDT. +// + +template <class A> +class PdtParenReachable { + public: + typedef typename A::StateId StateId; + typedef typename A::Label Label; + public: + // Maps from state ID to reachable paren IDs from (to) that state. + typedef unordered_multimap<StateId, Label> ParenMultiMap; + + // Maps from paren ID and state ID to reachable state set ID + typedef unordered_map<ParenState<A>, ssize_t, + typename ParenState<A>::Hash> StateSetMap; + + // Maps from paren ID and state ID to arcs exiting that state with that + // Label. + typedef unordered_multimap<ParenState<A>, A, + typename ParenState<A>::Hash> ParenArcMultiMap; + + typedef MapIterator<ParenMultiMap> ParenIterator; + + typedef MapIterator<ParenArcMultiMap> ParenArcIterator; + + typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator; + + // Computes close (open) parenthesis reachabilty information for + // a PDT with bounded stack. + PdtParenReachable(const Fst<A> &fst, + const vector<pair<Label, Label> > &parens, bool close) + : fst_(fst), + parens_(parens), + close_(close) { + for (Label i = 0; i < parens.size(); ++i) { + const pair<Label, Label> &p = parens[i]; + paren_id_map_[p.first] = i; + paren_id_map_[p.second] = i; + } + + if (close_) { + StateId start = fst.Start(); + if (start == kNoStateId) + return; + DFSearch(start, start); + } else { + FSTERROR() << "PdtParenReachable: open paren info not implemented"; + } + } + + // Given a state ID, returns an iterator over paren IDs + // for close (open) parens reachable from that state along balanced + // paths. + ParenIterator FindParens(StateId s) const { + return ParenIterator(paren_multimap_, paren_multimap_.find(s)); + } + + // Given a paren ID and a state ID s, returns an iterator over + // states that can be reached along balanced paths from (to) s that + // have have close (open) parentheses matching the paren ID exiting + // (entering) those states. + SetIterator FindStates(Label paren_id, StateId s) const { + ParenState<A> paren_state(paren_id, s); + typename StateSetMap::const_iterator id_it = set_map_.find(paren_state); + if (id_it == set_map_.end()) { + return state_sets_.FindSet(-1); + } else { + return state_sets_.FindSet(id_it->second); + } + } + + // Given a paren Id and a state ID s, return an iterator over + // arcs that exit (enter) s and are labeled with a close (open) + // parenthesis matching the paren ID. + ParenArcIterator FindParenArcs(Label paren_id, StateId s) const { + ParenState<A> paren_state(paren_id, s); + return ParenArcIterator(paren_arc_multimap_, + paren_arc_multimap_.find(paren_state)); + } + + private: + // DFS that gathers paren and state set information. + // Bool returns false when cycle detected. + bool DFSearch(StateId s, StateId start); + + // Unions state sets together gathered by the DFS. + void ComputeStateSet(StateId s); + + // Gather state set(s) from state 'nexts'. + void UpdateStateSet(StateId nexts, set<Label> *paren_set, + vector< set<StateId> > *state_sets) const; + + const Fst<A> &fst_; + const vector<pair<Label, Label> > &parens_; // Paren ID -> Labels + bool close_; // Close/open paren info? + unordered_map<Label, Label> paren_id_map_; // Paren labels -> ID + ParenMultiMap paren_multimap_; // Paren reachability + ParenArcMultiMap paren_arc_multimap_; // Paren Arcs + vector<char> state_color_; // DFS state + mutable Collection<ssize_t, StateId> state_sets_; // Reachable states -> ID + StateSetMap set_map_; // ID -> Reachable states + DISALLOW_COPY_AND_ASSIGN(PdtParenReachable); +}; + +// DFS that gathers paren and state set information. +template <class A> +bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) { + if (s >= state_color_.size()) + state_color_.resize(s + 1, kDfsWhite); + + if (state_color_[s] == kDfsBlack) + return true; + + if (state_color_[s] == kDfsGrey) + return false; + + state_color_[s] = kDfsGrey; + + for (ArcIterator<Fst<A> > aiter(fst_, s); + !aiter.Done(); + aiter.Next()) { + const A &arc = aiter.Value(); + + typename unordered_map<Label, Label>::const_iterator pit + = paren_id_map_.find(arc.ilabel); + if (pit != paren_id_map_.end()) { // paren? + Label paren_id = pit->second; + if (arc.ilabel == parens_[paren_id].first) { // open paren + DFSearch(arc.nextstate, arc.nextstate); + for (SetIterator set_iter = FindStates(paren_id, arc.nextstate); + !set_iter.Done(); set_iter.Next()) { + for (ParenArcIterator paren_arc_iter = + FindParenArcs(paren_id, set_iter.Element()); + !paren_arc_iter.Done(); + paren_arc_iter.Next()) { + const A &cparc = paren_arc_iter.Value(); + DFSearch(cparc.nextstate, start); + } + } + } + } else { // non-paren + if(!DFSearch(arc.nextstate, start)) { + FSTERROR() << "PdtReachable: Underlying cyclicity not supported"; + return true; + } + } + } + ComputeStateSet(s); + state_color_[s] = kDfsBlack; + return true; +} + +// Unions state sets together gathered by the DFS. +template <class A> +void PdtParenReachable<A>::ComputeStateSet(StateId s) { + set<Label> paren_set; + vector< set<StateId> > state_sets(parens_.size()); + for (ArcIterator< Fst<A> > aiter(fst_, s); + !aiter.Done(); + aiter.Next()) { + const A &arc = aiter.Value(); + + typename unordered_map<Label, Label>::const_iterator pit + = paren_id_map_.find(arc.ilabel); + if (pit != paren_id_map_.end()) { // paren? + Label paren_id = pit->second; + if (arc.ilabel == parens_[paren_id].first) { // open paren + for (SetIterator set_iter = + FindStates(paren_id, arc.nextstate); + !set_iter.Done(); set_iter.Next()) { + for (ParenArcIterator paren_arc_iter = + FindParenArcs(paren_id, set_iter.Element()); + !paren_arc_iter.Done(); + paren_arc_iter.Next()) { + const A &cparc = paren_arc_iter.Value(); + UpdateStateSet(cparc.nextstate, &paren_set, &state_sets); + } + } + } else { // close paren + paren_set.insert(paren_id); + state_sets[paren_id].insert(s); + ParenState<A> paren_state(paren_id, s); + paren_arc_multimap_.insert(make_pair(paren_state, arc)); + } + } else { // non-paren + UpdateStateSet(arc.nextstate, &paren_set, &state_sets); + } + } + + vector<StateId> state_set; + for (typename set<Label>::iterator paren_iter = paren_set.begin(); + paren_iter != paren_set.end(); ++paren_iter) { + state_set.clear(); + Label paren_id = *paren_iter; + paren_multimap_.insert(make_pair(s, paren_id)); + for (typename set<StateId>::iterator state_iter + = state_sets[paren_id].begin(); + state_iter != state_sets[paren_id].end(); + ++state_iter) { + state_set.push_back(*state_iter); + } + ParenState<A> paren_state(paren_id, s); + set_map_[paren_state] = state_sets_.FindId(state_set); + } +} + +// Gather state set(s) from state 'nexts'. +template <class A> +void PdtParenReachable<A>::UpdateStateSet( + StateId nexts, set<Label> *paren_set, + vector< set<StateId> > *state_sets) const { + for(ParenIterator paren_iter = FindParens(nexts); + !paren_iter.Done(); paren_iter.Next()) { + Label paren_id = paren_iter.Value(); + paren_set->insert(paren_id); + for (SetIterator set_iter = FindStates(paren_id, nexts); + !set_iter.Done(); set_iter.Next()) { + (*state_sets)[paren_id].insert(set_iter.Element()); + } + } +} + + +// Store balancing parenthesis data for a PDT. Allows on-the-fly +// construction (e.g. in PdtShortestPath) unlike PdtParenReachable above. +template <class A> +class PdtBalanceData { + public: + typedef typename A::StateId StateId; + typedef typename A::Label Label; + + // Hash set for open parens + typedef unordered_set<ParenState<A>, typename ParenState<A>::Hash> OpenParenSet; + + // Maps from open paren destination state to parenthesis ID. + typedef unordered_multimap<StateId, Label> OpenParenMap; + + // Maps from open paren state to source states of matching close parens + typedef unordered_multimap<ParenState<A>, StateId, + typename ParenState<A>::Hash> CloseParenMap; + + // Maps from open paren state to close source set ID + typedef unordered_map<ParenState<A>, ssize_t, + typename ParenState<A>::Hash> CloseSourceMap; + + typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator; + + PdtBalanceData() {} + + void Clear() { + open_paren_map_.clear(); + close_paren_map_.clear(); + } + + // Adds an open parenthesis with destination state 'open_dest'. + void OpenInsert(Label paren_id, StateId open_dest) { + ParenState<A> key(paren_id, open_dest); + if (!open_paren_set_.count(key)) { + open_paren_set_.insert(key); + open_paren_map_.insert(make_pair(open_dest, paren_id)); + } + } + + // Adds a matching closing parenthesis with source state + // 'close_source' that balances an open_parenthesis with destination + // state 'open_dest' if OpenInsert() previously called + // (o.w. CloseInsert() does nothing). + void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) { + ParenState<A> key(paren_id, open_dest); + if (open_paren_set_.count(key)) + close_paren_map_.insert(make_pair(key, close_source)); + } + + // Find close paren source states matching an open parenthesis. + // Methods that follow, iterate through those matching states. + // Should be called only after FinishInsert(open_dest). + SetIterator Find(Label paren_id, StateId open_dest) { + ParenState<A> close_key(paren_id, open_dest); + typename CloseSourceMap::const_iterator id_it = + close_source_map_.find(close_key); + if (id_it == close_source_map_.end()) { + return close_source_sets_.FindSet(-1); + } else { + return close_source_sets_.FindSet(id_it->second); + } + } + + // Call when all open and close parenthesis insertions wrt open + // parentheses entering 'open_dest' are finished. Must be called + // before Find(open_dest). Stores close paren source state sets + // efficiently. + void FinishInsert(StateId open_dest) { + vector<StateId> close_sources; + for (typename OpenParenMap::iterator oit = open_paren_map_.find(open_dest); + oit != open_paren_map_.end() && oit->first == open_dest;) { + Label paren_id = oit->second; + close_sources.clear(); + ParenState<A> okey(paren_id, open_dest); + open_paren_set_.erase(open_paren_set_.find(okey)); + for (typename CloseParenMap::iterator cit = close_paren_map_.find(okey); + cit != close_paren_map_.end() && cit->first == okey;) { + close_sources.push_back(cit->second); + close_paren_map_.erase(cit++); + } + sort(close_sources.begin(), close_sources.end()); + typename vector<StateId>::iterator unique_end = + unique(close_sources.begin(), close_sources.end()); + close_sources.resize(unique_end - close_sources.begin()); + + if (!close_sources.empty()) + close_source_map_[okey] = close_source_sets_.FindId(close_sources); + open_paren_map_.erase(oit++); + } + } + + // Return a new balance data object representing the reversed balance + // information. + PdtBalanceData<A> *Reverse(StateId num_states, + StateId num_split, + StateId state_id_shift) const; + + private: + OpenParenSet open_paren_set_; // open par. at dest? + + OpenParenMap open_paren_map_; // open parens per state + ParenState<A> open_dest_; // cur open dest. state + typename OpenParenMap::const_iterator open_iter_; // cur open parens/state + + CloseParenMap close_paren_map_; // close states/open + // paren and state + + CloseSourceMap close_source_map_; // paren, state to set ID + mutable Collection<ssize_t, StateId> close_source_sets_; +}; + +// Return a new balance data object representing the reversed balance +// information. +template <class A> +PdtBalanceData<A> *PdtBalanceData<A>::Reverse( + StateId num_states, + StateId num_split, + StateId state_id_shift) const { + PdtBalanceData<A> *bd = new PdtBalanceData<A>; + unordered_set<StateId> close_sources; + StateId split_size = num_states / num_split; + + for (StateId i = 0; i < num_states; i+= split_size) { + close_sources.clear(); + + for (typename CloseSourceMap::const_iterator + sit = close_source_map_.begin(); + sit != close_source_map_.end(); + ++sit) { + ParenState<A> okey = sit->first; + StateId open_dest = okey.state_id; + Label paren_id = okey.paren_id; + for (SetIterator set_iter = close_source_sets_.FindSet(sit->second); + !set_iter.Done(); set_iter.Next()) { + StateId close_source = set_iter.Element(); + if ((close_source < i) || (close_source >= i + split_size)) + continue; + close_sources.insert(close_source + state_id_shift); + bd->OpenInsert(paren_id, close_source + state_id_shift); + bd->CloseInsert(paren_id, close_source + state_id_shift, + open_dest + state_id_shift); + } + } + + for (typename unordered_set<StateId>::const_iterator it + = close_sources.begin(); + it != close_sources.end(); + ++it) { + bd->FinishInsert(*it); + } + + } + return bd; +} + + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_PAREN_H_ |