aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/extensions/pdt
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/extensions/pdt')
-rw-r--r--src/include/fst/extensions/pdt/collection.h33
-rw-r--r--src/include/fst/extensions/pdt/info.h2
-rw-r--r--src/include/fst/extensions/pdt/paren.h30
-rw-r--r--src/include/fst/extensions/pdt/shortest-path.h14
4 files changed, 49 insertions, 30 deletions
diff --git a/src/include/fst/extensions/pdt/collection.h b/src/include/fst/extensions/pdt/collection.h
index 26be504..24a443f 100644
--- a/src/include/fst/extensions/pdt/collection.h
+++ b/src/include/fst/extensions/pdt/collection.h
@@ -16,7 +16,7 @@
// Author: riley@google.com (Michael Riley)
//
// \file
-// Class to store a collection of sets with elements of type T.
+// Class to store a collection of ordered (multi-)sets with elements of type T.
#ifndef FST_EXTENSIONS_PDT_COLLECTION_H__
#define FST_EXTENSIONS_PDT_COLLECTION_H__
@@ -29,11 +29,11 @@ using std::vector;
namespace fst {
-// Stores a collection of non-empty sets with elements of type T. A
-// default constructor, equality ==, a total order <, and an STL-style
-// hash class must be defined on the elements. Provides signed
-// integer ID (of type I) of each unique set. The IDs are allocated
-// starting from 0 in order.
+// Stores a collection of non-empty, ordered (multi-)sets with elements
+// of type T. A default constructor, equality ==, and an STL-style
+// hash class must be defined on the elements. Provides signed integer
+// ID (of type I) of each unique set. The IDs are allocated starting
+// from 0 in order.
template <class I, class T>
class Collection {
public:
@@ -80,31 +80,34 @@ class Collection {
Collection() {}
- // Lookups integer ID from set. If it doesn't exist, then adds it.
- // Set elements should be in strict order (and therefore unique).
- I FindId(const vector<T> &set) {
+ // Lookups integer ID from ordered multi-set. If it doesn't exist
+ // and 'insert' is true, then adds it. Otherwise returns -1.
+ I FindId(const vector<T> &set, bool insert = true) {
I node_id = kNoNodeId;
for (ssize_t i = set.size() - 1; i >= 0; --i) {
Node node(node_id, set[i]);
- node_id = node_table_.FindId(node);
+ node_id = node_table_.FindId(node, insert);
+ if (node_id == -1) break;
}
return node_id;
}
- // Finds set given integer ID. Returns true if ID corresponds
- // to set. Use iterators below to traverse result.
+ // Finds ordered (multi-)set given integer ID. Returns set iterator
+ // to traverse result.
SetIterator FindSet(I id) {
- if (id < 0 && id >= node_table_.Size()) {
+ if (id < 0 || id >= node_table_.Size()) {
return SetIterator(kNoNodeId, Node(kNoNodeId, T()), &node_table_);
} else {
return SetIterator(id, node_table_.FindEntry(id), &node_table_);
}
}
+ I Size() const { return node_table_.Size(); }
+
private:
static const I kNoNodeId;
static const size_t kPrime;
- static std::tr1::hash<T> hash_;
+ static std::hash<T> hash_;
NodeTable node_table_;
@@ -115,7 +118,7 @@ template<class I, class T> const I Collection<I, T>::kNoNodeId = -1;
template <class I, class T> const size_t Collection<I, T>::kPrime = 7853;
-template <class I, class T> std::tr1::hash<T> Collection<I, T>::hash_;
+template <class I, class T> std::hash<T> Collection<I, T>::hash_;
} // namespace fst
diff --git a/src/include/fst/extensions/pdt/info.h b/src/include/fst/extensions/pdt/info.h
index ef9a860..55e76c4 100644
--- a/src/include/fst/extensions/pdt/info.h
+++ b/src/include/fst/extensions/pdt/info.h
@@ -24,7 +24,7 @@
#include <unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
-#include <tr1/unordered_set>
+#include <unordered_set>
using std::tr1::unordered_set;
using std::tr1::unordered_multiset;
#include <vector>
diff --git a/src/include/fst/extensions/pdt/paren.h b/src/include/fst/extensions/pdt/paren.h
index 7b9887f..a9d30c5 100644
--- a/src/include/fst/extensions/pdt/paren.h
+++ b/src/include/fst/extensions/pdt/paren.h
@@ -26,7 +26,7 @@
#include <unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
-#include <tr1/unordered_set>
+#include <unordered_set>
using std::tr1::unordered_set;
using std::tr1::unordered_multiset;
#include <set>
@@ -144,7 +144,8 @@ class PdtParenReachable {
const vector<pair<Label, Label> > &parens, bool close)
: fst_(fst),
parens_(parens),
- close_(close) {
+ close_(close),
+ error_(false) {
for (Label i = 0; i < parens.size(); ++i) {
const pair<Label, Label> &p = parens[i];
paren_id_map_[p.first] = i;
@@ -155,12 +156,18 @@ class PdtParenReachable {
StateId start = fst.Start();
if (start == kNoStateId)
return;
- DFSearch(start, start);
+ if (!DFSearch(start)) {
+ FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
+ error_ = true;
+ }
} else {
FSTERROR() << "PdtParenReachable: open paren info not implemented";
+ error_ = true;
}
}
+ bool const Error() { return error_; }
+
// Given a state ID, returns an iterator over paren IDs
// for close (open) parens reachable from that state along balanced
// paths.
@@ -194,7 +201,7 @@ class PdtParenReachable {
private:
// DFS that gathers paren and state set information.
// Bool returns false when cycle detected.
- bool DFSearch(StateId s, StateId start);
+ bool DFSearch(StateId s);
// Unions state sets together gathered by the DFS.
void ComputeStateSet(StateId s);
@@ -212,12 +219,13 @@ class PdtParenReachable {
vector<char> state_color_; // DFS state
mutable Collection<ssize_t, StateId> state_sets_; // Reachable states -> ID
StateSetMap set_map_; // ID -> Reachable states
+ bool error_;
DISALLOW_COPY_AND_ASSIGN(PdtParenReachable);
};
// DFS that gathers paren and state set information.
template <class A>
-bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) {
+bool PdtParenReachable<A>::DFSearch(StateId s) {
if (s >= state_color_.size())
state_color_.resize(s + 1, kDfsWhite);
@@ -239,7 +247,8 @@ bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) {
if (pit != paren_id_map_.end()) { // paren?
Label paren_id = pit->second;
if (arc.ilabel == parens_[paren_id].first) { // open paren
- DFSearch(arc.nextstate, arc.nextstate);
+ if (!DFSearch(arc.nextstate))
+ return false;
for (SetIterator set_iter = FindStates(paren_id, arc.nextstate);
!set_iter.Done(); set_iter.Next()) {
for (ParenArcIterator paren_arc_iter =
@@ -247,15 +256,14 @@ bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) {
!paren_arc_iter.Done();
paren_arc_iter.Next()) {
const A &cparc = paren_arc_iter.Value();
- DFSearch(cparc.nextstate, start);
+ if (!DFSearch(cparc.nextstate))
+ return false;
}
}
}
} else { // non-paren
- if(!DFSearch(arc.nextstate, start)) {
- FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
- return true;
- }
+ if(!DFSearch(arc.nextstate))
+ return false;
}
}
ComputeStateSet(s);
diff --git a/src/include/fst/extensions/pdt/shortest-path.h b/src/include/fst/extensions/pdt/shortest-path.h
index e90471b..85f94b8 100644
--- a/src/include/fst/extensions/pdt/shortest-path.h
+++ b/src/include/fst/extensions/pdt/shortest-path.h
@@ -28,7 +28,7 @@
#include <unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
-#include <tr1/unordered_set>
+#include <unordered_set>
using std::tr1::unordered_set;
using std::tr1::unordered_multiset;
#include <stack>
@@ -387,7 +387,6 @@ class PdtShortestPath {
typedef typename SpData::SearchState SearchState;
typedef typename SpData::ParenSpec ParenSpec;
- typedef typename PdtParenReachable<Arc>::SetIterator StateSetIterator;
typedef typename PdtBalanceData<Arc>::SetIterator CloseSourceIterator;
PdtShortestPath(const Fst<Arc> &ifst,
@@ -403,7 +402,7 @@ class PdtShortestPath {
if ((Weight::Properties() & (kPath | kRightSemiring))
!= (kPath | kRightSemiring)) {
- FSTERROR() << "SingleShortestPath: Weight needs to have the path"
+ FSTERROR() << "PdtShortestPath: Weight needs to have the path"
<< " property and be right distributive: " << Weight::Type();
error_ = true;
}
@@ -440,6 +439,7 @@ class PdtShortestPath {
static const Arc kNoArc;
static const uint8 kEnqueued;
static const uint8 kExpanded;
+ static const uint8 kFinished;
const uint8 kFinal;
public:
@@ -543,6 +543,7 @@ void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) {
ProcArcs(s);
sp_data_.SetFlags(s, kExpanded, kExpanded);
}
+ sp_data_.SetFlags(q, kFinished, kFinished);
balance_data_.FinishInsert(start);
sp_data_.GC(start);
}
@@ -607,7 +608,11 @@ void PdtShortestPath<Arc, Queue>::ProcOpenParen(
Queue *state_queue = state_queue_;
GetDistance(d.start);
state_queue_ = state_queue;
+ } else if (!(sp_data_.Flags(d) & kFinished)) {
+ FSTERROR() << "PdtShortestPath: open parenthesis recursion: not bounded stack";
+ error_ = true;
}
+
for (CloseSourceIterator set_iter =
balance_data_.Find(paren_id, arc.nextstate);
!set_iter.Done(); set_iter.Next()) {
@@ -765,6 +770,9 @@ template<class Arc, class Queue>
const uint8 PdtShortestPath<Arc, Queue>::kExpanded = 0x20;
template<class Arc, class Queue>
+const uint8 PdtShortestPath<Arc, Queue>::kFinished = 0x40;
+
+template<class Arc, class Queue>
void ShortestPath(const Fst<Arc> &ifst,
const vector<pair<typename Arc::Label,
typename Arc::Label> > &parens,