aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/extensions/pdt/paren.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/extensions/pdt/paren.h')
-rw-r--r--src/include/fst/extensions/pdt/paren.h496
1 files changed, 496 insertions, 0 deletions
diff --git a/src/include/fst/extensions/pdt/paren.h b/src/include/fst/extensions/pdt/paren.h
new file mode 100644
index 0000000..7b9887f
--- /dev/null
+++ b/src/include/fst/extensions/pdt/paren.h
@@ -0,0 +1,496 @@
+// paren.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// Common classes for PDT parentheses
+
+// \file
+
+#ifndef FST_EXTENSIONS_PDT_PAREN_H_
+#define FST_EXTENSIONS_PDT_PAREN_H_
+
+#include <algorithm>
+#include <unordered_map>
+using std::tr1::unordered_map;
+using std::tr1::unordered_multimap;
+#include <tr1/unordered_set>
+using std::tr1::unordered_set;
+using std::tr1::unordered_multiset;
+#include <set>
+
+#include <fst/extensions/pdt/pdt.h>
+#include <fst/extensions/pdt/collection.h>
+#include <fst/fst.h>
+#include <fst/dfs-visit.h>
+
+
+namespace fst {
+
+//
+// ParenState: Pair of an open (close) parenthesis and
+// its destination (source) state.
+//
+
+template <class A>
+class ParenState {
+ public:
+ typedef typename A::Label Label;
+ typedef typename A::StateId StateId;
+
+ struct Hash {
+ size_t operator()(const ParenState<A> &p) const {
+ return p.paren_id + p.state_id * kPrime;
+ }
+ };
+
+ Label paren_id; // ID of open (close) paren
+ StateId state_id; // destination (source) state of open (close) paren
+
+ ParenState() : paren_id(kNoLabel), state_id(kNoStateId) {}
+
+ ParenState(Label p, StateId s) : paren_id(p), state_id(s) {}
+
+ bool operator==(const ParenState<A> &p) const {
+ if (&p == this)
+ return true;
+ return p.paren_id == this->paren_id && p.state_id == this->state_id;
+ }
+
+ bool operator!=(const ParenState<A> &p) const { return !(p == *this); }
+
+ bool operator<(const ParenState<A> &p) const {
+ return paren_id < this->paren.id ||
+ (p.paren_id == this->paren.id && p.state_id < this->state_id);
+ }
+
+ private:
+ static const size_t kPrime;
+};
+
+template <class A>
+const size_t ParenState<A>::kPrime = 7853;
+
+
+// Creates an FST-style iterator from STL map and iterator.
+template <class M>
+class MapIterator {
+ public:
+ typedef typename M::const_iterator StlIterator;
+ typedef typename M::value_type PairType;
+ typedef typename PairType::second_type ValueType;
+
+ MapIterator(const M &m, StlIterator iter)
+ : map_(m), begin_(iter), iter_(iter) {}
+
+ bool Done() const {
+ return iter_ == map_.end() || iter_->first != begin_->first;
+ }
+
+ ValueType Value() const { return iter_->second; }
+ void Next() { ++iter_; }
+ void Reset() { iter_ = begin_; }
+
+ private:
+ const M &map_;
+ StlIterator begin_;
+ StlIterator iter_;
+};
+
+//
+// PdtParenReachable: Provides various parenthesis reachability information
+// on a PDT.
+//
+
+template <class A>
+class PdtParenReachable {
+ public:
+ typedef typename A::StateId StateId;
+ typedef typename A::Label Label;
+ public:
+ // Maps from state ID to reachable paren IDs from (to) that state.
+ typedef unordered_multimap<StateId, Label> ParenMultiMap;
+
+ // Maps from paren ID and state ID to reachable state set ID
+ typedef unordered_map<ParenState<A>, ssize_t,
+ typename ParenState<A>::Hash> StateSetMap;
+
+ // Maps from paren ID and state ID to arcs exiting that state with that
+ // Label.
+ typedef unordered_multimap<ParenState<A>, A,
+ typename ParenState<A>::Hash> ParenArcMultiMap;
+
+ typedef MapIterator<ParenMultiMap> ParenIterator;
+
+ typedef MapIterator<ParenArcMultiMap> ParenArcIterator;
+
+ typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
+
+ // Computes close (open) parenthesis reachabilty information for
+ // a PDT with bounded stack.
+ PdtParenReachable(const Fst<A> &fst,
+ const vector<pair<Label, Label> > &parens, bool close)
+ : fst_(fst),
+ parens_(parens),
+ close_(close) {
+ for (Label i = 0; i < parens.size(); ++i) {
+ const pair<Label, Label> &p = parens[i];
+ paren_id_map_[p.first] = i;
+ paren_id_map_[p.second] = i;
+ }
+
+ if (close_) {
+ StateId start = fst.Start();
+ if (start == kNoStateId)
+ return;
+ DFSearch(start, start);
+ } else {
+ FSTERROR() << "PdtParenReachable: open paren info not implemented";
+ }
+ }
+
+ // Given a state ID, returns an iterator over paren IDs
+ // for close (open) parens reachable from that state along balanced
+ // paths.
+ ParenIterator FindParens(StateId s) const {
+ return ParenIterator(paren_multimap_, paren_multimap_.find(s));
+ }
+
+ // Given a paren ID and a state ID s, returns an iterator over
+ // states that can be reached along balanced paths from (to) s that
+ // have have close (open) parentheses matching the paren ID exiting
+ // (entering) those states.
+ SetIterator FindStates(Label paren_id, StateId s) const {
+ ParenState<A> paren_state(paren_id, s);
+ typename StateSetMap::const_iterator id_it = set_map_.find(paren_state);
+ if (id_it == set_map_.end()) {
+ return state_sets_.FindSet(-1);
+ } else {
+ return state_sets_.FindSet(id_it->second);
+ }
+ }
+
+ // Given a paren Id and a state ID s, return an iterator over
+ // arcs that exit (enter) s and are labeled with a close (open)
+ // parenthesis matching the paren ID.
+ ParenArcIterator FindParenArcs(Label paren_id, StateId s) const {
+ ParenState<A> paren_state(paren_id, s);
+ return ParenArcIterator(paren_arc_multimap_,
+ paren_arc_multimap_.find(paren_state));
+ }
+
+ private:
+ // DFS that gathers paren and state set information.
+ // Bool returns false when cycle detected.
+ bool DFSearch(StateId s, StateId start);
+
+ // Unions state sets together gathered by the DFS.
+ void ComputeStateSet(StateId s);
+
+ // Gather state set(s) from state 'nexts'.
+ void UpdateStateSet(StateId nexts, set<Label> *paren_set,
+ vector< set<StateId> > *state_sets) const;
+
+ const Fst<A> &fst_;
+ const vector<pair<Label, Label> > &parens_; // Paren ID -> Labels
+ bool close_; // Close/open paren info?
+ unordered_map<Label, Label> paren_id_map_; // Paren labels -> ID
+ ParenMultiMap paren_multimap_; // Paren reachability
+ ParenArcMultiMap paren_arc_multimap_; // Paren Arcs
+ vector<char> state_color_; // DFS state
+ mutable Collection<ssize_t, StateId> state_sets_; // Reachable states -> ID
+ StateSetMap set_map_; // ID -> Reachable states
+ DISALLOW_COPY_AND_ASSIGN(PdtParenReachable);
+};
+
+// DFS that gathers paren and state set information.
+template <class A>
+bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) {
+ if (s >= state_color_.size())
+ state_color_.resize(s + 1, kDfsWhite);
+
+ if (state_color_[s] == kDfsBlack)
+ return true;
+
+ if (state_color_[s] == kDfsGrey)
+ return false;
+
+ state_color_[s] = kDfsGrey;
+
+ for (ArcIterator<Fst<A> > aiter(fst_, s);
+ !aiter.Done();
+ aiter.Next()) {
+ const A &arc = aiter.Value();
+
+ typename unordered_map<Label, Label>::const_iterator pit
+ = paren_id_map_.find(arc.ilabel);
+ if (pit != paren_id_map_.end()) { // paren?
+ Label paren_id = pit->second;
+ if (arc.ilabel == parens_[paren_id].first) { // open paren
+ DFSearch(arc.nextstate, arc.nextstate);
+ for (SetIterator set_iter = FindStates(paren_id, arc.nextstate);
+ !set_iter.Done(); set_iter.Next()) {
+ for (ParenArcIterator paren_arc_iter =
+ FindParenArcs(paren_id, set_iter.Element());
+ !paren_arc_iter.Done();
+ paren_arc_iter.Next()) {
+ const A &cparc = paren_arc_iter.Value();
+ DFSearch(cparc.nextstate, start);
+ }
+ }
+ }
+ } else { // non-paren
+ if(!DFSearch(arc.nextstate, start)) {
+ FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
+ return true;
+ }
+ }
+ }
+ ComputeStateSet(s);
+ state_color_[s] = kDfsBlack;
+ return true;
+}
+
+// Unions state sets together gathered by the DFS.
+template <class A>
+void PdtParenReachable<A>::ComputeStateSet(StateId s) {
+ set<Label> paren_set;
+ vector< set<StateId> > state_sets(parens_.size());
+ for (ArcIterator< Fst<A> > aiter(fst_, s);
+ !aiter.Done();
+ aiter.Next()) {
+ const A &arc = aiter.Value();
+
+ typename unordered_map<Label, Label>::const_iterator pit
+ = paren_id_map_.find(arc.ilabel);
+ if (pit != paren_id_map_.end()) { // paren?
+ Label paren_id = pit->second;
+ if (arc.ilabel == parens_[paren_id].first) { // open paren
+ for (SetIterator set_iter =
+ FindStates(paren_id, arc.nextstate);
+ !set_iter.Done(); set_iter.Next()) {
+ for (ParenArcIterator paren_arc_iter =
+ FindParenArcs(paren_id, set_iter.Element());
+ !paren_arc_iter.Done();
+ paren_arc_iter.Next()) {
+ const A &cparc = paren_arc_iter.Value();
+ UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
+ }
+ }
+ } else { // close paren
+ paren_set.insert(paren_id);
+ state_sets[paren_id].insert(s);
+ ParenState<A> paren_state(paren_id, s);
+ paren_arc_multimap_.insert(make_pair(paren_state, arc));
+ }
+ } else { // non-paren
+ UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
+ }
+ }
+
+ vector<StateId> state_set;
+ for (typename set<Label>::iterator paren_iter = paren_set.begin();
+ paren_iter != paren_set.end(); ++paren_iter) {
+ state_set.clear();
+ Label paren_id = *paren_iter;
+ paren_multimap_.insert(make_pair(s, paren_id));
+ for (typename set<StateId>::iterator state_iter
+ = state_sets[paren_id].begin();
+ state_iter != state_sets[paren_id].end();
+ ++state_iter) {
+ state_set.push_back(*state_iter);
+ }
+ ParenState<A> paren_state(paren_id, s);
+ set_map_[paren_state] = state_sets_.FindId(state_set);
+ }
+}
+
+// Gather state set(s) from state 'nexts'.
+template <class A>
+void PdtParenReachable<A>::UpdateStateSet(
+ StateId nexts, set<Label> *paren_set,
+ vector< set<StateId> > *state_sets) const {
+ for(ParenIterator paren_iter = FindParens(nexts);
+ !paren_iter.Done(); paren_iter.Next()) {
+ Label paren_id = paren_iter.Value();
+ paren_set->insert(paren_id);
+ for (SetIterator set_iter = FindStates(paren_id, nexts);
+ !set_iter.Done(); set_iter.Next()) {
+ (*state_sets)[paren_id].insert(set_iter.Element());
+ }
+ }
+}
+
+
+// Store balancing parenthesis data for a PDT. Allows on-the-fly
+// construction (e.g. in PdtShortestPath) unlike PdtParenReachable above.
+template <class A>
+class PdtBalanceData {
+ public:
+ typedef typename A::StateId StateId;
+ typedef typename A::Label Label;
+
+ // Hash set for open parens
+ typedef unordered_set<ParenState<A>, typename ParenState<A>::Hash> OpenParenSet;
+
+ // Maps from open paren destination state to parenthesis ID.
+ typedef unordered_multimap<StateId, Label> OpenParenMap;
+
+ // Maps from open paren state to source states of matching close parens
+ typedef unordered_multimap<ParenState<A>, StateId,
+ typename ParenState<A>::Hash> CloseParenMap;
+
+ // Maps from open paren state to close source set ID
+ typedef unordered_map<ParenState<A>, ssize_t,
+ typename ParenState<A>::Hash> CloseSourceMap;
+
+ typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
+
+ PdtBalanceData() {}
+
+ void Clear() {
+ open_paren_map_.clear();
+ close_paren_map_.clear();
+ }
+
+ // Adds an open parenthesis with destination state 'open_dest'.
+ void OpenInsert(Label paren_id, StateId open_dest) {
+ ParenState<A> key(paren_id, open_dest);
+ if (!open_paren_set_.count(key)) {
+ open_paren_set_.insert(key);
+ open_paren_map_.insert(make_pair(open_dest, paren_id));
+ }
+ }
+
+ // Adds a matching closing parenthesis with source state
+ // 'close_source' that balances an open_parenthesis with destination
+ // state 'open_dest' if OpenInsert() previously called
+ // (o.w. CloseInsert() does nothing).
+ void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) {
+ ParenState<A> key(paren_id, open_dest);
+ if (open_paren_set_.count(key))
+ close_paren_map_.insert(make_pair(key, close_source));
+ }
+
+ // Find close paren source states matching an open parenthesis.
+ // Methods that follow, iterate through those matching states.
+ // Should be called only after FinishInsert(open_dest).
+ SetIterator Find(Label paren_id, StateId open_dest) {
+ ParenState<A> close_key(paren_id, open_dest);
+ typename CloseSourceMap::const_iterator id_it =
+ close_source_map_.find(close_key);
+ if (id_it == close_source_map_.end()) {
+ return close_source_sets_.FindSet(-1);
+ } else {
+ return close_source_sets_.FindSet(id_it->second);
+ }
+ }
+
+ // Call when all open and close parenthesis insertions wrt open
+ // parentheses entering 'open_dest' are finished. Must be called
+ // before Find(open_dest). Stores close paren source state sets
+ // efficiently.
+ void FinishInsert(StateId open_dest) {
+ vector<StateId> close_sources;
+ for (typename OpenParenMap::iterator oit = open_paren_map_.find(open_dest);
+ oit != open_paren_map_.end() && oit->first == open_dest;) {
+ Label paren_id = oit->second;
+ close_sources.clear();
+ ParenState<A> okey(paren_id, open_dest);
+ open_paren_set_.erase(open_paren_set_.find(okey));
+ for (typename CloseParenMap::iterator cit = close_paren_map_.find(okey);
+ cit != close_paren_map_.end() && cit->first == okey;) {
+ close_sources.push_back(cit->second);
+ close_paren_map_.erase(cit++);
+ }
+ sort(close_sources.begin(), close_sources.end());
+ typename vector<StateId>::iterator unique_end =
+ unique(close_sources.begin(), close_sources.end());
+ close_sources.resize(unique_end - close_sources.begin());
+
+ if (!close_sources.empty())
+ close_source_map_[okey] = close_source_sets_.FindId(close_sources);
+ open_paren_map_.erase(oit++);
+ }
+ }
+
+ // Return a new balance data object representing the reversed balance
+ // information.
+ PdtBalanceData<A> *Reverse(StateId num_states,
+ StateId num_split,
+ StateId state_id_shift) const;
+
+ private:
+ OpenParenSet open_paren_set_; // open par. at dest?
+
+ OpenParenMap open_paren_map_; // open parens per state
+ ParenState<A> open_dest_; // cur open dest. state
+ typename OpenParenMap::const_iterator open_iter_; // cur open parens/state
+
+ CloseParenMap close_paren_map_; // close states/open
+ // paren and state
+
+ CloseSourceMap close_source_map_; // paren, state to set ID
+ mutable Collection<ssize_t, StateId> close_source_sets_;
+};
+
+// Return a new balance data object representing the reversed balance
+// information.
+template <class A>
+PdtBalanceData<A> *PdtBalanceData<A>::Reverse(
+ StateId num_states,
+ StateId num_split,
+ StateId state_id_shift) const {
+ PdtBalanceData<A> *bd = new PdtBalanceData<A>;
+ unordered_set<StateId> close_sources;
+ StateId split_size = num_states / num_split;
+
+ for (StateId i = 0; i < num_states; i+= split_size) {
+ close_sources.clear();
+
+ for (typename CloseSourceMap::const_iterator
+ sit = close_source_map_.begin();
+ sit != close_source_map_.end();
+ ++sit) {
+ ParenState<A> okey = sit->first;
+ StateId open_dest = okey.state_id;
+ Label paren_id = okey.paren_id;
+ for (SetIterator set_iter = close_source_sets_.FindSet(sit->second);
+ !set_iter.Done(); set_iter.Next()) {
+ StateId close_source = set_iter.Element();
+ if ((close_source < i) || (close_source >= i + split_size))
+ continue;
+ close_sources.insert(close_source + state_id_shift);
+ bd->OpenInsert(paren_id, close_source + state_id_shift);
+ bd->CloseInsert(paren_id, close_source + state_id_shift,
+ open_dest + state_id_shift);
+ }
+ }
+
+ for (typename unordered_set<StateId>::const_iterator it
+ = close_sources.begin();
+ it != close_sources.end();
+ ++it) {
+ bd->FinishInsert(*it);
+ }
+
+ }
+ return bd;
+}
+
+
+} // namespace fst
+
+#endif // FST_EXTENSIONS_PDT_PAREN_H_