diff options
Diffstat (limited to 'src/include/fst/extensions/pdt')
-rw-r--r-- | src/include/fst/extensions/pdt/collection.h | 33 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/info.h | 2 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/paren.h | 30 | ||||
-rw-r--r-- | src/include/fst/extensions/pdt/shortest-path.h | 14 |
4 files changed, 49 insertions, 30 deletions
diff --git a/src/include/fst/extensions/pdt/collection.h b/src/include/fst/extensions/pdt/collection.h index 26be504..24a443f 100644 --- a/src/include/fst/extensions/pdt/collection.h +++ b/src/include/fst/extensions/pdt/collection.h @@ -16,7 +16,7 @@ // Author: riley@google.com (Michael Riley) // // \file -// Class to store a collection of sets with elements of type T. +// Class to store a collection of ordered (multi-)sets with elements of type T. #ifndef FST_EXTENSIONS_PDT_COLLECTION_H__ #define FST_EXTENSIONS_PDT_COLLECTION_H__ @@ -29,11 +29,11 @@ using std::vector; namespace fst { -// Stores a collection of non-empty sets with elements of type T. A -// default constructor, equality ==, a total order <, and an STL-style -// hash class must be defined on the elements. Provides signed -// integer ID (of type I) of each unique set. The IDs are allocated -// starting from 0 in order. +// Stores a collection of non-empty, ordered (multi-)sets with elements +// of type T. A default constructor, equality ==, and an STL-style +// hash class must be defined on the elements. Provides signed integer +// ID (of type I) of each unique set. The IDs are allocated starting +// from 0 in order. template <class I, class T> class Collection { public: @@ -80,31 +80,34 @@ class Collection { Collection() {} - // Lookups integer ID from set. If it doesn't exist, then adds it. - // Set elements should be in strict order (and therefore unique). - I FindId(const vector<T> &set) { + // Lookups integer ID from ordered multi-set. If it doesn't exist + // and 'insert' is true, then adds it. Otherwise returns -1. + I FindId(const vector<T> &set, bool insert = true) { I node_id = kNoNodeId; for (ssize_t i = set.size() - 1; i >= 0; --i) { Node node(node_id, set[i]); - node_id = node_table_.FindId(node); + node_id = node_table_.FindId(node, insert); + if (node_id == -1) break; } return node_id; } - // Finds set given integer ID. Returns true if ID corresponds - // to set. Use iterators below to traverse result. + // Finds ordered (multi-)set given integer ID. Returns set iterator + // to traverse result. SetIterator FindSet(I id) { - if (id < 0 && id >= node_table_.Size()) { + if (id < 0 || id >= node_table_.Size()) { return SetIterator(kNoNodeId, Node(kNoNodeId, T()), &node_table_); } else { return SetIterator(id, node_table_.FindEntry(id), &node_table_); } } + I Size() const { return node_table_.Size(); } + private: static const I kNoNodeId; static const size_t kPrime; - static std::tr1::hash<T> hash_; + static std::hash<T> hash_; NodeTable node_table_; @@ -115,7 +118,7 @@ template<class I, class T> const I Collection<I, T>::kNoNodeId = -1; template <class I, class T> const size_t Collection<I, T>::kPrime = 7853; -template <class I, class T> std::tr1::hash<T> Collection<I, T>::hash_; +template <class I, class T> std::hash<T> Collection<I, T>::hash_; } // namespace fst diff --git a/src/include/fst/extensions/pdt/info.h b/src/include/fst/extensions/pdt/info.h index ef9a860..55e76c4 100644 --- a/src/include/fst/extensions/pdt/info.h +++ b/src/include/fst/extensions/pdt/info.h @@ -24,7 +24,7 @@ #include <unordered_map> using std::tr1::unordered_map; using std::tr1::unordered_multimap; -#include <tr1/unordered_set> +#include <unordered_set> using std::tr1::unordered_set; using std::tr1::unordered_multiset; #include <vector> diff --git a/src/include/fst/extensions/pdt/paren.h b/src/include/fst/extensions/pdt/paren.h index 7b9887f..a9d30c5 100644 --- a/src/include/fst/extensions/pdt/paren.h +++ b/src/include/fst/extensions/pdt/paren.h @@ -26,7 +26,7 @@ #include <unordered_map> using std::tr1::unordered_map; using std::tr1::unordered_multimap; -#include <tr1/unordered_set> +#include <unordered_set> using std::tr1::unordered_set; using std::tr1::unordered_multiset; #include <set> @@ -144,7 +144,8 @@ class PdtParenReachable { const vector<pair<Label, Label> > &parens, bool close) : fst_(fst), parens_(parens), - close_(close) { + close_(close), + error_(false) { for (Label i = 0; i < parens.size(); ++i) { const pair<Label, Label> &p = parens[i]; paren_id_map_[p.first] = i; @@ -155,12 +156,18 @@ class PdtParenReachable { StateId start = fst.Start(); if (start == kNoStateId) return; - DFSearch(start, start); + if (!DFSearch(start)) { + FSTERROR() << "PdtReachable: Underlying cyclicity not supported"; + error_ = true; + } } else { FSTERROR() << "PdtParenReachable: open paren info not implemented"; + error_ = true; } } + bool const Error() { return error_; } + // Given a state ID, returns an iterator over paren IDs // for close (open) parens reachable from that state along balanced // paths. @@ -194,7 +201,7 @@ class PdtParenReachable { private: // DFS that gathers paren and state set information. // Bool returns false when cycle detected. - bool DFSearch(StateId s, StateId start); + bool DFSearch(StateId s); // Unions state sets together gathered by the DFS. void ComputeStateSet(StateId s); @@ -212,12 +219,13 @@ class PdtParenReachable { vector<char> state_color_; // DFS state mutable Collection<ssize_t, StateId> state_sets_; // Reachable states -> ID StateSetMap set_map_; // ID -> Reachable states + bool error_; DISALLOW_COPY_AND_ASSIGN(PdtParenReachable); }; // DFS that gathers paren and state set information. template <class A> -bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) { +bool PdtParenReachable<A>::DFSearch(StateId s) { if (s >= state_color_.size()) state_color_.resize(s + 1, kDfsWhite); @@ -239,7 +247,8 @@ bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) { if (pit != paren_id_map_.end()) { // paren? Label paren_id = pit->second; if (arc.ilabel == parens_[paren_id].first) { // open paren - DFSearch(arc.nextstate, arc.nextstate); + if (!DFSearch(arc.nextstate)) + return false; for (SetIterator set_iter = FindStates(paren_id, arc.nextstate); !set_iter.Done(); set_iter.Next()) { for (ParenArcIterator paren_arc_iter = @@ -247,15 +256,14 @@ bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) { !paren_arc_iter.Done(); paren_arc_iter.Next()) { const A &cparc = paren_arc_iter.Value(); - DFSearch(cparc.nextstate, start); + if (!DFSearch(cparc.nextstate)) + return false; } } } } else { // non-paren - if(!DFSearch(arc.nextstate, start)) { - FSTERROR() << "PdtReachable: Underlying cyclicity not supported"; - return true; - } + if(!DFSearch(arc.nextstate)) + return false; } } ComputeStateSet(s); diff --git a/src/include/fst/extensions/pdt/shortest-path.h b/src/include/fst/extensions/pdt/shortest-path.h index e90471b..85f94b8 100644 --- a/src/include/fst/extensions/pdt/shortest-path.h +++ b/src/include/fst/extensions/pdt/shortest-path.h @@ -28,7 +28,7 @@ #include <unordered_map> using std::tr1::unordered_map; using std::tr1::unordered_multimap; -#include <tr1/unordered_set> +#include <unordered_set> using std::tr1::unordered_set; using std::tr1::unordered_multiset; #include <stack> @@ -387,7 +387,6 @@ class PdtShortestPath { typedef typename SpData::SearchState SearchState; typedef typename SpData::ParenSpec ParenSpec; - typedef typename PdtParenReachable<Arc>::SetIterator StateSetIterator; typedef typename PdtBalanceData<Arc>::SetIterator CloseSourceIterator; PdtShortestPath(const Fst<Arc> &ifst, @@ -403,7 +402,7 @@ class PdtShortestPath { if ((Weight::Properties() & (kPath | kRightSemiring)) != (kPath | kRightSemiring)) { - FSTERROR() << "SingleShortestPath: Weight needs to have the path" + FSTERROR() << "PdtShortestPath: Weight needs to have the path" << " property and be right distributive: " << Weight::Type(); error_ = true; } @@ -440,6 +439,7 @@ class PdtShortestPath { static const Arc kNoArc; static const uint8 kEnqueued; static const uint8 kExpanded; + static const uint8 kFinished; const uint8 kFinal; public: @@ -543,6 +543,7 @@ void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) { ProcArcs(s); sp_data_.SetFlags(s, kExpanded, kExpanded); } + sp_data_.SetFlags(q, kFinished, kFinished); balance_data_.FinishInsert(start); sp_data_.GC(start); } @@ -607,7 +608,11 @@ void PdtShortestPath<Arc, Queue>::ProcOpenParen( Queue *state_queue = state_queue_; GetDistance(d.start); state_queue_ = state_queue; + } else if (!(sp_data_.Flags(d) & kFinished)) { + FSTERROR() << "PdtShortestPath: open parenthesis recursion: not bounded stack"; + error_ = true; } + for (CloseSourceIterator set_iter = balance_data_.Find(paren_id, arc.nextstate); !set_iter.Done(); set_iter.Next()) { @@ -765,6 +770,9 @@ template<class Arc, class Queue> const uint8 PdtShortestPath<Arc, Queue>::kExpanded = 0x20; template<class Arc, class Queue> +const uint8 PdtShortestPath<Arc, Queue>::kFinished = 0x40; + +template<class Arc, class Queue> void ShortestPath(const Fst<Arc> &ifst, const vector<pair<typename Arc::Label, typename Arc::Label> > &parens, |