aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/extensions/ngram/ngram-fst.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/extensions/ngram/ngram-fst.h')
-rw-r--r--src/include/fst/extensions/ngram/ngram-fst.h111
1 files changed, 92 insertions, 19 deletions
diff --git a/src/include/fst/extensions/ngram/ngram-fst.h b/src/include/fst/extensions/ngram/ngram-fst.h
index eee664a..873ae6a 100644
--- a/src/include/fst/extensions/ngram/ngram-fst.h
+++ b/src/include/fst/extensions/ngram/ngram-fst.h
@@ -26,6 +26,7 @@ using std::vector;
#include <fst/compat.h>
#include <fst/fstlib.h>
+#include <fst/mapped-file.h>
#include <fst/extensions/ngram/bitmap-index.h>
// NgramFst implements a n-gram language model based upon the LOUDS data
@@ -76,7 +77,7 @@ class NGramFstImpl : public FstImpl<A> {
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
- NGramFstImpl() : data_(0), owned_(false) {
+ NGramFstImpl() : data_region_(0), data_(0), owned_(false) {
SetType("ngram");
SetInputSymbols(NULL);
SetOutputSymbols(NULL);
@@ -89,6 +90,7 @@ class NGramFstImpl : public FstImpl<A> {
if (owned_) {
delete [] data_;
}
+ delete data_region_;
}
static NGramFstImpl<A>* Read(istream &strm, // NOLINT
@@ -104,7 +106,8 @@ class NGramFstImpl : public FstImpl<A> {
strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures));
strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final));
size_t size = Storage(num_states, num_futures, num_final);
- char* data = new char[size];
+ MappedFile *data_region = MappedFile::Allocate(size);
+ char *data = reinterpret_cast<char *>(data_region->mutable_data());
// Copy num_states, num_futures and num_final back into data.
memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states));
memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures),
@@ -116,7 +119,7 @@ class NGramFstImpl : public FstImpl<A> {
delete impl;
return NULL;
}
- impl->Init(data, true /* owned */);
+ impl->Init(data, false, data_region);
return impl;
}
@@ -126,7 +129,7 @@ class NGramFstImpl : public FstImpl<A> {
hdr.SetStart(Start());
hdr.SetNumStates(num_states_);
WriteHeader(strm, opts, kFileVersion, &hdr);
- strm.write(data_, Storage(num_states_, num_futures_, num_final_));
+ strm.write(data_, StorageSize());
return strm;
}
@@ -223,11 +226,23 @@ class NGramFstImpl : public FstImpl<A> {
// Access to the underlying representation
const char* GetData(size_t* data_size) const {
- *data_size = Storage(num_states_, num_futures_, num_final_);
+ *data_size = StorageSize();
return data_;
}
- void Init(const char* data, bool owned);
+ void Init(const char* data, bool owned, MappedFile *file = 0);
+
+ const vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const {
+ SetInstFuture(s, inst);
+ SetInstContext(inst);
+ return inst->context_;
+ }
+
+ size_t StorageSize() const {
+ return Storage(num_states_, num_futures_, num_final_);
+ }
+
+ void GetStates(const vector<Label>& context, vector<StateId> *states) const;
private:
StateId Transition(const vector<Label> &context, Label future) const;
@@ -242,6 +257,7 @@ class NGramFstImpl : public FstImpl<A> {
// Minimum file format version supported.
static const int kMinFileVersion = 4;
+ MappedFile *data_region_;
const char* data_;
bool owned_; // True if we own data_
uint64 num_states_, num_futures_, num_final_;
@@ -261,7 +277,7 @@ class NGramFstImpl : public FstImpl<A> {
template<typename A>
NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
- : data_(0), owned_(false) {
+ : data_region_(0), data_(0), owned_(false) {
typedef A Arc;
typedef typename Arc::Label Label;
typedef typename Arc::Weight Weight;
@@ -286,12 +302,16 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
// epsilons.
StateId unigram = fst.Start();
while (1) {
- ArcIterator<Fst<A> > aiter(fst, unigram);
- if (aiter.Done()) {
- FSTERROR() << "Start state has no arcs";
+ if (unigram == kNoStateId) {
+ FSTERROR() << "Could not identify unigram state.";
SetProperties(kError, kError);
return;
}
+ ArcIterator<Fst<A> > aiter(fst, unigram);
+ if (aiter.Done()) {
+ LOG(WARNING) << "Unigram state " << unigram << " has no arcs.";
+ break;
+ }
if (aiter.Value().ilabel != 0) break;
unigram = aiter.Value().nextstate;
}
@@ -385,7 +405,8 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
Weight weight;
Label label = kNoLabel;
const size_t storage = Storage(num_states, num_futures, num_final);
- char* data = new char[storage];
+ MappedFile *data_region = MappedFile::Allocate(storage);
+ char *data = reinterpret_cast<char *>(data_region->mutable_data());
memset(data, 0, storage);
size_t offset = 0;
memcpy(data + offset, reinterpret_cast<char *>(&num_states),
@@ -482,14 +503,17 @@ NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
return;
}
- Init(data, true /* owned */);
+ Init(data, false, data_region);
}
template<typename A>
-inline void NGramFstImpl<A>::Init(const char* data, bool owned) {
+inline void NGramFstImpl<A>::Init(const char* data, bool owned,
+ MappedFile *data_region) {
if (owned_) {
delete [] data_;
}
+ delete data_region_;
+ data_region_ = data_region;
owned_ = owned;
data_ = data;
size_t offset = 0;
@@ -507,7 +531,7 @@ inline void NGramFstImpl<A>::Init(const char* data, bool owned) {
future_ = reinterpret_cast<const uint64*>(data_ + offset);
offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits);
final_ = reinterpret_cast<const uint64*>(data_ + offset);
- offset += BitmapIndex::StorageSize(num_states_ + 1) * sizeof(bits);
+ offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits);
context_words_ = reinterpret_cast<const Label*>(data_ + offset);
offset += (num_states_ + 1) * sizeof(*context_words_);
future_words_ = reinterpret_cast<const Label*>(data_ + offset);
@@ -538,10 +562,10 @@ inline void NGramFstImpl<A>::Init(const char* data, bool owned) {
template<typename A>
inline typename A::StateId NGramFstImpl<A>::Transition(
const vector<Label> &context, Label future) const {
- size_t num_children = root_num_children_;
const Label *children = root_children_;
- const Label *loc = lower_bound(children, children + num_children, future);
- if (loc == children + num_children || *loc != future) {
+ const Label *loc = lower_bound(children, children + root_num_children_,
+ future);
+ if (loc == children + root_num_children_ || *loc != future) {
return context_index_.Rank1(0);
}
size_t node = root_first_child_ + loc - children;
@@ -551,7 +575,6 @@ inline typename A::StateId NGramFstImpl<A>::Transition(
return context_index_.Rank1(node);
}
size_t last_child = context_index_.Select0(node_rank + 1) - 1;
- num_children = last_child - first_child + 1;
for (int word = context.size() - 1; word >= 0; --word) {
children = context_words_ + context_index_.Rank1(first_child);
loc = lower_bound(children, children + last_child - first_child + 1,
@@ -569,6 +592,42 @@ inline typename A::StateId NGramFstImpl<A>::Transition(
return context_index_.Rank1(node);
}
+template<typename A>
+inline void NGramFstImpl<A>::GetStates(
+ const vector<Label> &context,
+ vector<typename A::StateId>* states) const {
+ states->clear();
+ states->push_back(0);
+ typename vector<Label>::const_reverse_iterator cit = context.rbegin();
+ const Label *children = root_children_;
+ const Label *loc = lower_bound(children, children + root_num_children_, *cit);
+ if (loc == children + root_num_children_ || *loc != *cit) return;
+ size_t node = root_first_child_ + loc - children;
+ states->push_back(context_index_.Rank1(node));
+ if (context.size() == 1) return;
+ size_t node_rank = context_index_.Rank1(node);
+ size_t first_child = context_index_.Select0(node_rank) + 1;
+ ++cit;
+ if (context_index_.Get(first_child) != false) {
+ size_t last_child = context_index_.Select0(node_rank + 1) - 1;
+ while (cit != context.rend()) {
+ children = context_words_ + context_index_.Rank1(first_child);
+ loc = lower_bound(children, children + last_child - first_child + 1,
+ *cit);
+ if (loc == children + last_child - first_child + 1 || *loc != *cit) {
+ break;
+ }
+ ++cit;
+ node = first_child + loc - children;
+ states->push_back(context_index_.Rank1(node));
+ node_rank = context_index_.Rank1(node);
+ first_child = context_index_.Select0(node_rank) + 1;
+ if (context_index_.Get(first_child) == false) break;
+ last_child = context_index_.Select0(node_rank + 1) - 1;
+ }
+ }
+}
+
/*****************************************************************************/
template<class A>
class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
@@ -597,7 +656,7 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
// Non-standard constructor to initialize NGramFst directly from data.
NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) {
- GetImpl()->Init(data, owned);
+ GetImpl()->Init(data, owned, NULL);
}
// Get method that gets the data associated with Init().
@@ -605,6 +664,16 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
return GetImpl()->GetData(data_size);
}
+ const vector<Label> GetContext(StateId s) const {
+ return GetImpl()->GetContext(s, &inst_);
+ }
+
+ // Consumes as much as possible of context from right to left, returns the
+ // the states corresponding to the increasingly conditioned input sequence.
+ void GetStates(const vector<Label>& context, vector<StateId> *state) const {
+ return GetImpl()->GetStates(context, state);
+ }
+
virtual size_t NumArcs(StateId s) const {
return GetImpl()->NumArcs(s, &inst_);
}
@@ -650,6 +719,10 @@ class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
return new NGramFstMatcher<A>(*this, match_type);
}
+ size_t StorageSize() const {
+ return GetImpl()->StorageSize();
+ }
+
private:
explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {}