aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/extensions
diff options
context:
space:
mode:
authorAlexander Gutkin <agutkin@google.com>2012-09-12 18:11:43 +0100
committerAlexander Gutkin <agutkin@google.com>2012-09-12 18:11:43 +0100
commitdfd8b8327b93660601d016cdc6f29f433b45a8d8 (patch)
tree968ec84b8e32ad73ec18d74334930f36b7471906 /src/include/fst/extensions
parentf4c12fce1ee58e670f9c3fce46c40296ba9ee8a2 (diff)
downloadopenfst-dfd8b8327b93660601d016cdc6f29f433b45a8d8.tar.gz
Updated OpenFST version to openfst-1.3.2-CL32004048 from Greco3.
Change-Id: I19b0db718256b35c0e3e5a7315f1ed6335e6dcac
Diffstat (limited to 'src/include/fst/extensions')
-rw-r--r--src/include/fst/extensions/far/compile-strings.h61
-rw-r--r--src/include/fst/extensions/far/equal.h99
-rw-r--r--src/include/fst/extensions/far/extract.h2
-rw-r--r--src/include/fst/extensions/far/far.h184
-rw-r--r--src/include/fst/extensions/far/farscript.h51
-rw-r--r--src/include/fst/extensions/far/info.h2
-rw-r--r--src/include/fst/extensions/far/print-strings.h28
-rw-r--r--src/include/fst/extensions/far/stlist.h9
-rw-r--r--src/include/fst/extensions/far/sttable.h1
-rw-r--r--src/include/fst/extensions/ngram/bitmap-index.h183
-rw-r--r--src/include/fst/extensions/ngram/ngram-fst.h912
-rw-r--r--src/include/fst/extensions/ngram/nthbit.h46
-rw-r--r--src/include/fst/extensions/pdt/collection.h33
-rw-r--r--src/include/fst/extensions/pdt/info.h2
-rw-r--r--src/include/fst/extensions/pdt/paren.h30
-rw-r--r--src/include/fst/extensions/pdt/shortest-path.h14
16 files changed, 1585 insertions, 72 deletions
diff --git a/src/include/fst/extensions/far/compile-strings.h b/src/include/fst/extensions/far/compile-strings.h
index d7f4d6b..ca247db 100644
--- a/src/include/fst/extensions/far/compile-strings.h
+++ b/src/include/fst/extensions/far/compile-strings.h
@@ -56,7 +56,7 @@ class StringReader {
const SymbolTable *syms = 0,
Label unknown_label = kNoStateId)
: nline_(0), strm_(istrm), source_(source), entry_type_(entry_type),
- token_type_(token_type), done_(false),
+ token_type_(token_type), symbols_(syms), done_(false),
compiler_(token_type, syms, unknown_label, allow_negative_labels) {
Next(); // Initialize the reader to the first input.
}
@@ -87,8 +87,12 @@ class StringReader {
done_ = true; // whitespace at the end of a file.
}
- VectorFst<A> *GetVectorFst() {
+ VectorFst<A> *GetVectorFst(bool keep_symbols = false) {
VectorFst<A> *fst = new VectorFst<A>;
+ if (keep_symbols) {
+ fst->SetInputSymbols(symbols_);
+ fst->SetOutputSymbols(symbols_);
+ }
if (compiler_(content_, fst)) {
return fst;
} else {
@@ -97,9 +101,16 @@ class StringReader {
}
}
- CompactFst<A, StringCompactor<A> > *GetCompactFst() {
- CompactFst<A, StringCompactor<A> > *fst =
- new CompactFst<A, StringCompactor<A> >;
+ CompactFst<A, StringCompactor<A> > *GetCompactFst(bool keep_symbols = false) {
+ CompactFst<A, StringCompactor<A> > *fst;
+ if (keep_symbols) {
+ VectorFst<A> tmp;
+ tmp.SetInputSymbols(symbols_);
+ tmp.SetOutputSymbols(symbols_);
+ fst = new CompactFst<A, StringCompactor<A> >(tmp);
+ } else {
+ fst = new CompactFst<A, StringCompactor<A> >;
+ }
if (compiler_(content_, fst)) {
return fst;
} else {
@@ -114,6 +125,7 @@ class StringReader {
string source_;
EntryType entry_type_;
TokenType token_type_;
+ const SymbolTable *symbols_;
bool done_;
StringCompiler<A> compiler_;
string content_; // The actual content of the input stream's next FST.
@@ -135,6 +147,8 @@ void FarCompileStrings(const vector<string> &in_fnames,
FarTokenType tt,
const string &symbols_fname,
const string &unknown_symbol,
+ bool keep_symbols,
+ bool initial_symbols,
bool allow_negative_labels,
bool file_list_input,
const string &key_prefix,
@@ -175,8 +189,9 @@ void FarCompileStrings(const vector<string> &in_fnames,
const SymbolTable *syms = 0;
typename Arc::Label unknown_label = kNoLabel;
if (!symbols_fname.empty()) {
- syms = SymbolTable::ReadText(symbols_fname,
- allow_negative_labels);
+ SymbolTableTextOptions opts;
+ opts.allow_negative = allow_negative_labels;
+ syms = SymbolTable::ReadText(symbols_fname, opts);
if (!syms) {
FSTERROR() << "FarCompileStrings: error reading symbol table: "
<< symbols_fname;
@@ -199,32 +214,47 @@ void FarCompileStrings(const vector<string> &in_fnames,
vector<string> inputs;
if (file_list_input) {
for (int i = 1; i < in_fnames.size(); ++i) {
- ifstream istrm(in_fnames[i].c_str());
+ istream *istrm = in_fnames.empty() ? &cin :
+ new ifstream(in_fnames[i].c_str());
string str;
- while (getline(istrm, str))
+ while (getline(*istrm, str))
inputs.push_back(str);
+ if (!in_fnames.empty())
+ delete istrm;
}
} else {
inputs = in_fnames;
}
for (int i = 0, n = 0; i < inputs.size(); ++i) {
+ if (generate_keys == 0 && inputs[i].empty()) {
+ FSTERROR() << "FarCompileStrings: read from a file instead of stdin or"
+ << " set the --generate_keys flags.";
+ delete far_writer;
+ delete syms;
+ return;
+ }
int key_size = generate_keys ? generate_keys :
(entry_type == StringReader<Arc>::FILE ? 1 :
KeySize(inputs[i].c_str()));
- ifstream istrm(inputs[i].c_str());
+ istream *istrm = inputs[i].empty() ? &cin :
+ new ifstream(inputs[i].c_str());
+ bool keep_syms = keep_symbols;
for (StringReader<Arc> reader(
- istrm, inputs[i], entry_type, token_type,
- allow_negative_labels, syms, unknown_label);
+ *istrm, inputs[i].empty() ? "stdin" : inputs[i],
+ entry_type, token_type, allow_negative_labels,
+ syms, unknown_label);
!reader.Done();
reader.Next()) {
++n;
const Fst<Arc> *fst;
if (compact)
- fst = reader.GetCompactFst();
+ fst = reader.GetCompactFst(keep_syms);
else
- fst = reader.GetVectorFst();
+ fst = reader.GetVectorFst(keep_syms);
+ if (initial_symbols)
+ keep_syms = false;
if (!fst) {
FSTERROR() << "FarCompileStrings: compiling string number " << n
<< " in file " << inputs[i] << " failed with token_type = "
@@ -236,6 +266,7 @@ void FarCompileStrings(const vector<string> &in_fnames,
(fet == FET_FILE ? "file" : "unknown"));
delete far_writer;
delete syms;
+ if (!inputs[i].empty()) delete istrm;
return;
}
ostringstream keybuf;
@@ -260,6 +291,8 @@ void FarCompileStrings(const vector<string> &in_fnames,
}
if (generate_keys == 0)
n = 0;
+ if (!inputs[i].empty())
+ delete istrm;
}
delete far_writer;
diff --git a/src/include/fst/extensions/far/equal.h b/src/include/fst/extensions/far/equal.h
new file mode 100644
index 0000000..be82e2d
--- /dev/null
+++ b/src/include/fst/extensions/far/equal.h
@@ -0,0 +1,99 @@
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: allauzen@google.com (Cyril Allauzen)
+
+#ifndef FST_EXTENSIONS_FAR_EQUAL_H_
+#define FST_EXTENSIONS_FAR_EQUAL_H_
+
+#include <string>
+
+#include <fst/extensions/far/far.h>
+#include <fst/equal.h>
+
+namespace fst {
+
+template <class Arc>
+bool FarEqual(const string &filename1,
+ const string &filename2,
+ float delta = kDelta,
+ const string &begin_key = string(),
+ const string &end_key = string()) {
+
+ FarReader<Arc> *reader1 = FarReader<Arc>::Open(filename1);
+ FarReader<Arc> *reader2 = FarReader<Arc>::Open(filename2);
+ if (!reader1 || !reader2) {
+ delete reader1;
+ delete reader2;
+ VLOG(1) << "FarEqual: cannot open input Far file(s)";
+ return false;
+ }
+
+ if (!begin_key.empty()) {
+ bool find_begin1 = reader1->Find(begin_key);
+ bool find_begin2 = reader2->Find(begin_key);
+ if (!find_begin1 || !find_begin2) {
+ bool ret = !find_begin1 && !find_begin2;
+ if (!ret) {
+ VLOG(1) << "FarEqual: key \"" << begin_key << "\" missing from "
+ << (find_begin1 ? "second" : "first") << " archive.";
+ }
+ delete reader1;
+ delete reader2;
+ return ret;
+ }
+ }
+
+ for(; !reader1->Done() && !reader2->Done();
+ reader1->Next(), reader2->Next()) {
+ const string key1 = reader1->GetKey();
+ const string key2 = reader2->GetKey();
+ if (!end_key.empty() && end_key < key1 && end_key < key2) {
+ delete reader1;
+ delete reader2;
+ return true;
+ }
+ if (key1 != key2) {
+ VLOG(1) << "FarEqual: mismatched keys \""
+ << key1 << "\" <> \"" << key2 << "\".";
+ delete reader1;
+ delete reader2;
+ return false;
+ }
+ if (!Equal(reader1->GetFst(), reader2->GetFst(), delta)) {
+ VLOG(1) << "FarEqual: Fsts for key \"" << key1 << "\" are not equal.";
+ delete reader1;
+ delete reader2;
+ return false;
+ }
+ }
+
+ if (!reader1->Done() || !reader2->Done()) {
+ VLOG(1) << "FarEqual: key \""
+ << (reader1->Done() ? reader2->GetKey() : reader1->GetKey())
+ << "\" missing form " << (reader2->Done() ? "first" : "second")
+ << " archive.";
+ delete reader1;
+ delete reader2;
+ return false;
+ }
+
+ delete reader1;
+ delete reader2;
+ return true;
+}
+
+} // namespace fst
+
+#endif // FST_EXTENSIONS_FAR_EQUAL_H_
diff --git a/src/include/fst/extensions/far/extract.h b/src/include/fst/extensions/far/extract.h
index 022ca60..d6f92ff 100644
--- a/src/include/fst/extensions/far/extract.h
+++ b/src/include/fst/extensions/far/extract.h
@@ -70,7 +70,7 @@ void FarExtract(const vector<string> &ifilenames,
if (nrep > 0) {
ostringstream tmp;
tmp << '.' << nrep;
- key += tmp.str();
+ key.append(tmp.str().data(), tmp.str().size());
}
ofilename = key;
}
diff --git a/src/include/fst/extensions/far/far.h b/src/include/fst/extensions/far/far.h
index 82b9e5c..acce76e 100644
--- a/src/include/fst/extensions/far/far.h
+++ b/src/include/fst/extensions/far/far.h
@@ -32,6 +32,13 @@ namespace fst {
enum FarEntryType { FET_LINE, FET_FILE };
enum FarTokenType { FTT_SYMBOL, FTT_BYTE, FTT_UTF8 };
+inline bool IsFst(const string &filename) {
+ ifstream strm(filename.c_str());
+ if (!strm)
+ return false;
+ return IsFstHeader(strm, filename);
+}
+
// FST archive header class
class FarHeader {
public:
@@ -40,8 +47,11 @@ class FarHeader {
bool Read(const string &filename) {
FstHeader fsthdr;
- if (filename.empty()) { // Header reading unsupported on stdin.
- return false;
+ if (filename.empty()) {
+ // Header reading unsupported on stdin. Assumes STList and StdArc.
+ fartype_ = "stlist";
+ arctype_ = "standard";
+ return true;
} else if (IsSTTable(filename)) { // Check if STTable
ReadSTTableHeader(filename, &fsthdr);
fartype_ = "sttable";
@@ -52,6 +62,12 @@ class FarHeader {
fartype_ = "sttable";
arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType();
return true;
+ } else if (IsFst(filename)) { // Check if Fst
+ ifstream istrm(filename.c_str());
+ fsthdr.Read(istrm, filename);
+ fartype_ = "fst";
+ arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType();
+ return true;
}
return false;
}
@@ -61,8 +77,12 @@ class FarHeader {
string arctype_;
};
-enum FarType { FAR_DEFAULT = 0, FAR_STTABLE = 1, FAR_STLIST = 2,
- FAR_SSTABLE = 3 };
+enum FarType {
+ FAR_DEFAULT = 0,
+ FAR_STTABLE = 1,
+ FAR_STLIST = 2,
+ FAR_FST = 3,
+};
// This class creates an archive of FSTs.
template <class A>
@@ -153,7 +173,7 @@ class STTableFarWriter : public FarWriter<A> {
public:
typedef A Arc;
- static STTableFarWriter *Create(const string filename) {
+ static STTableFarWriter *Create(const string &filename) {
STTableWriter<Fst<A>, FstWriter<A> > *writer =
STTableWriter<Fst<A>, FstWriter<A> >::Create(filename);
return new STTableFarWriter(writer);
@@ -183,7 +203,7 @@ class STListFarWriter : public FarWriter<A> {
public:
typedef A Arc;
- static STListFarWriter *Create(const string filename) {
+ static STListFarWriter *Create(const string &filename) {
STListWriter<Fst<A>, FstWriter<A> > *writer =
STListWriter<Fst<A>, FstWriter<A> >::Create(filename);
return new STListFarWriter(writer);
@@ -209,6 +229,43 @@ class STListFarWriter : public FarWriter<A> {
template <class A>
+class FstFarWriter : public FarWriter<A> {
+ public:
+ typedef A Arc;
+
+ explicit FstFarWriter(const string &filename)
+ : filename_(filename), error_(false), written_(false) {}
+
+ static FstFarWriter *Create(const string &filename) {
+ return new FstFarWriter(filename);
+ }
+
+ void Add(const string &key, const Fst<A> &fst) {
+ if (written_) {
+ LOG(WARNING) << "FstFarWriter::Add: only one Fst supported,"
+ << " subsequent entries discarded.";
+ } else {
+ error_ = !fst.Write(filename_);
+ written_ = true;
+ }
+ }
+
+ FarType Type() const { return FAR_FST; }
+
+ bool Error() const { return error_; }
+
+ ~FstFarWriter() {}
+
+ private:
+ string filename_;
+ bool error_;
+ bool written_;
+
+ DISALLOW_COPY_AND_ASSIGN(FstFarWriter);
+};
+
+
+template <class A>
FarWriter<A> *FarWriter<A>::Create(const string &filename, FarType type) {
switch(type) {
case FAR_DEFAULT:
@@ -220,6 +277,9 @@ FarWriter<A> *FarWriter<A>::Create(const string &filename, FarType type) {
case FAR_STLIST:
return STListFarWriter<A>::Create(filename);
break;
+ case FAR_FST:
+ return FstFarWriter<A>::Create(filename);
+ break;
default:
LOG(ERROR) << "FarWriter::Create: unknown far type";
return 0;
@@ -331,6 +391,114 @@ class STListFarReader : public FarReader<A> {
DISALLOW_COPY_AND_ASSIGN(STListFarReader);
};
+template <class A>
+class FstFarReader : public FarReader<A> {
+ public:
+ typedef A Arc;
+
+ static FstFarReader *Open(const string &filename) {
+ vector<string> filenames;
+ filenames.push_back(filename);
+ return new FstFarReader<A>(filenames);
+ }
+
+ static FstFarReader *Open(const vector<string> &filenames) {
+ return new FstFarReader<A>(filenames);
+ }
+
+ FstFarReader(const vector<string> &filenames)
+ : keys_(filenames), has_stdin_(false), pos_(0), fst_(0), error_(false) {
+ sort(keys_.begin(), keys_.end());
+ streams_.resize(keys_.size(), 0);
+ for (size_t i = 0; i < keys_.size(); ++i) {
+ if (keys_[i].empty()) {
+ if (!has_stdin_) {
+ streams_[i] = &cin;
+ //sources_[i] = "stdin";
+ has_stdin_ = true;
+ } else {
+ FSTERROR() << "FstFarReader::FstFarReader: stdin should only "
+ << "appear once in the input file list.";
+ error_ = true;
+ return;
+ }
+ } else {
+ streams_[i] = new ifstream(
+ keys_[i].c_str(), ifstream::in | ifstream::binary);
+ }
+ }
+ if (pos_ >= keys_.size()) return;
+ ReadFst();
+ }
+
+ void Reset() {
+ if (has_stdin_) {
+ FSTERROR() << "FstFarReader::Reset: operation not supported on stdin";
+ error_ = true;
+ return;
+ }
+ pos_ = 0;
+ ReadFst();
+ }
+
+ bool Find(const string &key) {
+ if (has_stdin_) {
+ FSTERROR() << "FstFarReader::Find: operation not supported on stdin";
+ error_ = true;
+ return false;
+ }
+ pos_ = 0;//TODO
+ ReadFst();
+ return true;
+ }
+
+ bool Done() const { return error_ || pos_ >= keys_.size(); }
+
+ void Next() {
+ ++pos_;
+ ReadFst();
+ }
+
+ const string &GetKey() const {
+ return keys_[pos_];
+ }
+
+ const Fst<A> &GetFst() const {
+ return *fst_;
+ }
+
+ FarType Type() const { return FAR_FST; }
+
+ bool Error() const { return error_; }
+
+ ~FstFarReader() {
+ if (fst_) delete fst_;
+ for (size_t i = 0; i < keys_.size(); ++i)
+ delete streams_[i];
+ }
+
+ private:
+ void ReadFst() {
+ if (fst_) delete fst_;
+ if (pos_ >= keys_.size()) return;
+ streams_[pos_]->seekg(0);
+ fst_ = Fst<A>::Read(*streams_[pos_], FstReadOptions());
+ if (!fst_) {
+ FSTERROR() << "FstFarReader: error reading Fst from: " << keys_[pos_];
+ error_ = true;
+ }
+ }
+
+ private:
+ vector<string> keys_;
+ vector<istream*> streams_;
+ bool has_stdin_;
+ size_t pos_;
+ mutable Fst<A> *fst_;
+ mutable bool error_;
+
+ DISALLOW_COPY_AND_ASSIGN(FstFarReader);
+};
template <class A>
FarReader<A> *FarReader<A>::Open(const string &filename) {
@@ -340,6 +508,8 @@ FarReader<A> *FarReader<A>::Open(const string &filename) {
return STTableFarReader<A>::Open(filename);
else if (IsSTList(filename))
return STListFarReader<A>::Open(filename);
+ else if (IsFst(filename))
+ return FstFarReader<A>::Open(filename);
return 0;
}
@@ -352,6 +522,8 @@ FarReader<A> *FarReader<A>::Open(const vector<string> &filenames) {
return STTableFarReader<A>::Open(filenames);
else if (!filenames.empty() && IsSTList(filenames[0]))
return STListFarReader<A>::Open(filenames);
+ else if (!filenames.empty() && IsFst(filenames[0]))
+ return FstFarReader<A>::Open(filenames);
return 0;
}
diff --git a/src/include/fst/extensions/far/farscript.h b/src/include/fst/extensions/far/farscript.h
index 9c3b1ca..3a9c145 100644
--- a/src/include/fst/extensions/far/farscript.h
+++ b/src/include/fst/extensions/far/farscript.h
@@ -27,6 +27,7 @@ using std::vector;
#include <fst/script/arg-packs.h>
#include <fst/extensions/far/compile-strings.h>
#include <fst/extensions/far/create.h>
+#include <fst/extensions/far/equal.h>
#include <fst/extensions/far/extract.h>
#include <fst/extensions/far/info.h>
#include <fst/extensions/far/print-strings.h>
@@ -51,6 +52,8 @@ struct FarCompileStringsArgs {
const FarTokenType tt;
const string &symbols_fname;
const string &unknown_symbol;
+ const bool keep_symbols;
+ const bool initial_symbols;
const bool allow_negative_labels;
const bool file_list_input;
const string &key_prefix;
@@ -65,6 +68,8 @@ struct FarCompileStringsArgs {
FarTokenType tt,
const string &symbols_fname,
const string &unknown_symbol,
+ bool keep_symbols,
+ bool initial_symbols,
bool allow_negative_labels,
bool file_list_input,
const string &key_prefix,
@@ -72,6 +77,7 @@ struct FarCompileStringsArgs {
in_fnames(in_fnames), out_fname(out_fname), fst_type(fst_type),
far_type(far_type), generate_keys(generate_keys), fet(fet),
tt(tt), symbols_fname(symbols_fname), unknown_symbol(unknown_symbol),
+ keep_symbols(keep_symbols), initial_symbols(initial_symbols),
allow_negative_labels(allow_negative_labels),
file_list_input(file_list_input), key_prefix(key_prefix),
key_suffix(key_suffix) { }
@@ -82,7 +88,8 @@ void FarCompileStrings(FarCompileStringsArgs *args) {
fst::FarCompileStrings<Arc>(
args->in_fnames, args->out_fname, args->fst_type, args->far_type,
args->generate_keys, args->fet, args->tt, args->symbols_fname,
- args->unknown_symbol, args->allow_negative_labels, args->file_list_input,
+ args->unknown_symbol, args->keep_symbols, args->initial_symbols,
+ args->allow_negative_labels, args->file_list_input,
args->key_prefix, args->key_suffix);
}
@@ -97,6 +104,8 @@ void FarCompileStrings(
FarTokenType tt,
const string &symbols_fname,
const string &unknown_symbol,
+ bool keep_symbols,
+ bool initial_symbols,
bool allow_negative_labels,
bool file_list_input,
const string &key_prefix,
@@ -143,6 +152,25 @@ void FarCreate(const vector<string> &in_fnames,
const string &key_suffix);
+typedef args::Package<const string &, const string &, float,
+ const string &, const string &> FarEqualInnerArgs;
+typedef args::WithReturnValue<bool, FarEqualInnerArgs> FarEqualArgs;
+
+template <class Arc>
+void FarEqual(FarEqualArgs *args) {
+ args->retval = fst::FarEqual<Arc>(
+ args->args.arg1, args->args.arg2, args->args.arg3,
+ args->args.arg4, args->args.arg5);
+}
+
+bool FarEqual(const string &filename1,
+ const string &filename2,
+ const string &arc_type,
+ float delta = kDelta,
+ const string &begin_key = string(),
+ const string &end_key = string());
+
+
typedef args::Package<const vector<string> &, int32,
const string&, const string&, const string&,
const string&> FarExtractArgs;
@@ -180,7 +208,9 @@ struct FarPrintStringsArgs {
const string &begin_key;
const string &end_key;
const bool print_key;
+ const bool print_weight;
const string &symbols_fname;
+ const bool initial_symbols;
const int32 generate_filenames;
const string &filename_prefix;
const string &filename_suffix;
@@ -188,12 +218,14 @@ struct FarPrintStringsArgs {
FarPrintStringsArgs(
const vector<string> &ifilenames, const FarEntryType entry_type,
const FarTokenType token_type, const string &begin_key,
- const string &end_key, const bool print_key,
- const string &symbols_fname, const int32 generate_filenames,
+ const string &end_key, const bool print_key, const bool print_weight,
+ const string &symbols_fname, const bool initial_symbols,
+ const int32 generate_filenames,
const string &filename_prefix, const string &filename_suffix) :
ifilenames(ifilenames), entry_type(entry_type), token_type(token_type),
- begin_key(begin_key), end_key(end_key), print_key(print_key),
- symbols_fname(symbols_fname),
+ begin_key(begin_key), end_key(end_key),
+ print_key(print_key), print_weight(print_weight),
+ symbols_fname(symbols_fname), initial_symbols(initial_symbols),
generate_filenames(generate_filenames), filename_prefix(filename_prefix),
filename_suffix(filename_suffix) { }
};
@@ -202,9 +234,9 @@ template <class Arc>
void FarPrintStrings(FarPrintStringsArgs *args) {
fst::FarPrintStrings<Arc>(
args->ifilenames, args->entry_type, args->token_type,
- args->begin_key, args->end_key, args->print_key,
- args->symbols_fname, args->generate_filenames, args->filename_prefix,
- args->filename_suffix);
+ args->begin_key, args->end_key, args->print_key, args->print_weight,
+ args->symbols_fname, args->initial_symbols, args->generate_filenames,
+ args->filename_prefix, args->filename_suffix);
}
@@ -215,7 +247,9 @@ void FarPrintStrings(const vector<string> &ifilenames,
const string &begin_key,
const string &end_key,
const bool print_key,
+ const bool print_weight,
const string &symbols_fname,
+ const bool initial_symbols,
const int32 generate_filenames,
const string &filename_prefix,
const string &filename_suffix);
@@ -227,6 +261,7 @@ void FarPrintStrings(const vector<string> &ifilenames,
#define REGISTER_FST_FAR_OPERATIONS(ArcType) \
REGISTER_FST_OPERATION(FarCompileStrings, ArcType, FarCompileStringsArgs); \
REGISTER_FST_OPERATION(FarCreate, ArcType, FarCreateArgs); \
+ REGISTER_FST_OPERATION(FarEqual, ArcType, FarEqualArgs); \
REGISTER_FST_OPERATION(FarExtract, ArcType, FarExtractArgs); \
REGISTER_FST_OPERATION(FarInfo, ArcType, FarInfoArgs); \
REGISTER_FST_OPERATION(FarPrintStrings, ArcType, FarPrintStringsArgs)
diff --git a/src/include/fst/extensions/far/info.h b/src/include/fst/extensions/far/info.h
index f010546..100fe68 100644
--- a/src/include/fst/extensions/far/info.h
+++ b/src/include/fst/extensions/far/info.h
@@ -34,7 +34,7 @@ void CountStatesAndArcs(const Fst<Arc> &fst, size_t *nstate, size_t *narc) {
StateIterator<Fst<Arc> > siter(fst);
for (; !siter.Done(); siter.Next(), ++(*nstate)) {
ArcIterator<Fst<Arc> > aiter(fst, siter.Value());
- for (; !aiter.Done(); aiter.Next(), ++(*narc));
+ for (; !aiter.Done(); aiter.Next(), ++(*narc)) {}
}
}
diff --git a/src/include/fst/extensions/far/print-strings.h b/src/include/fst/extensions/far/print-strings.h
index aff1e51..dcc7351 100644
--- a/src/include/fst/extensions/far/print-strings.h
+++ b/src/include/fst/extensions/far/print-strings.h
@@ -27,17 +27,21 @@
using std::vector;
#include <fst/extensions/far/far.h>
+#include <fst/shortest-distance.h>
#include <fst/string.h>
+DECLARE_string(far_field_separator);
+
namespace fst {
template <class Arc>
void FarPrintStrings(
const vector<string> &ifilenames, const FarEntryType entry_type,
const FarTokenType far_token_type, const string &begin_key,
- const string &end_key, const bool print_key, const string &symbols_fname,
- const int32 generate_filenames, const string &filename_prefix,
- const string &filename_suffix) {
+ const string &end_key, const bool print_key, const bool print_weight,
+ const string &symbols_fname, const bool initial_symbols,
+ const int32 generate_filenames,
+ const string &filename_prefix, const string &filename_suffix) {
typename StringPrinter<Arc>::TokenType token_type;
if (far_token_type == FTT_SYMBOL) {
@@ -54,7 +58,9 @@ void FarPrintStrings(
const SymbolTable *syms = 0;
if (!symbols_fname.empty()) {
// allow negative flag?
- syms = SymbolTable::ReadText(symbols_fname, true);
+ SymbolTableTextOptions opts;
+ opts.allow_negative = true;
+ syms = SymbolTable::ReadText(symbols_fname, opts);
if (!syms) {
FSTERROR() << "FarPrintStrings: error reading symbol table: "
<< symbols_fname;
@@ -62,8 +68,6 @@ void FarPrintStrings(
}
}
- StringPrinter<Arc> string_printer(token_type, syms);
-
FarReader<Arc> *far_reader = FarReader<Arc>::Open(ifilenames);
if (!far_reader) return;
@@ -83,14 +87,21 @@ void FarPrintStrings(
okey = key;
const Fst<Arc> &fst = far_reader->GetFst();
+ if (i == 1 && initial_symbols && syms == 0 && fst.InputSymbols() != 0)
+ syms = fst.InputSymbols()->Copy();
string str;
VLOG(2) << "Handling key: " << key;
+ StringPrinter<Arc> string_printer(
+ token_type, syms ? syms : fst.InputSymbols());
string_printer(fst, &str);
if (entry_type == FET_LINE) {
if (print_key)
- cout << key << "\t";
- cout << str << endl;
+ cout << key << FLAGS_far_field_separator[0];
+ cout << str;
+ if (print_weight)
+ cout << FLAGS_far_field_separator[0] << ShortestDistance(fst);
+ cout << endl;
} else if (entry_type == FET_FILE) {
stringstream sstrm;
if (generate_filenames) {
@@ -117,6 +128,7 @@ void FarPrintStrings(
ostrm << "\n";
}
}
+ delete syms;
}
diff --git a/src/include/fst/extensions/far/stlist.h b/src/include/fst/extensions/far/stlist.h
index 4738181..1cdc80c 100644
--- a/src/include/fst/extensions/far/stlist.h
+++ b/src/include/fst/extensions/far/stlist.h
@@ -26,6 +26,7 @@
#include <iostream>
#include <fstream>
+#include <sstream>
#include <fst/util.h>
#include <algorithm>
@@ -58,7 +59,7 @@ class STListWriter {
explicit STListWriter(const string filename)
: stream_(
- filename.empty() ? &std::cout :
+ filename.empty() ? &cout :
new ofstream(filename.c_str(), ofstream::out | ofstream::binary)),
error_(false) {
WriteType(*stream_, kSTListMagicNumber);
@@ -92,7 +93,7 @@ class STListWriter {
~STListWriter() {
WriteType(*stream_, string());
- if (stream_ != &std::cout)
+ if (stream_ != &cout)
delete stream_;
}
@@ -127,7 +128,7 @@ class STListReader {
for (size_t i = 0; i < filenames.size(); ++i) {
if (filenames[i].empty()) {
if (!has_stdin) {
- streams_[i] = &std::cin;
+ streams_[i] = &cin;
sources_[i] = "stdin";
has_stdin = true;
} else {
@@ -177,7 +178,7 @@ class STListReader {
~STListReader() {
for (size_t i = 0; i < streams_.size(); ++i) {
- if (streams_[i] != &std::cin)
+ if (streams_[i] != &cin)
delete streams_[i];
}
if (entry_)
diff --git a/src/include/fst/extensions/far/sttable.h b/src/include/fst/extensions/far/sttable.h
index 3a03133..3ce0a4b 100644
--- a/src/include/fst/extensions/far/sttable.h
+++ b/src/include/fst/extensions/far/sttable.h
@@ -29,6 +29,7 @@
#include <algorithm>
#include <iostream>
#include <fstream>
+#include <sstream>
#include <fst/util.h>
namespace fst {
diff --git a/src/include/fst/extensions/ngram/bitmap-index.h b/src/include/fst/extensions/ngram/bitmap-index.h
new file mode 100644
index 0000000..f5a5ba7
--- /dev/null
+++ b/src/include/fst/extensions/ngram/bitmap-index.h
@@ -0,0 +1,183 @@
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: sorenj@google.com (Jeffrey Sorensen)
+
+#ifndef FST_EXTENSIONS_NGRAM_BITMAP_INDEX_H_
+#define FST_EXTENSIONS_NGRAM_BITMAP_INDEX_H_
+
+#include <vector>
+using std::vector;
+
+#include <fst/compat.h>
+
+// This class is a bitstring storage class with an index that allows
+// seeking to the Nth set or clear bit in time O(Log(N)) where N is
+// the length of the bit vector. In addition, it allows counting set or
+// clear bits over ranges in constant time.
+//
+// This is accomplished by maintaining an "secondary" index of limited
+// size in bits that maintains a running count of the number of bits set
+// in each block of bitmap data. A block is defined as the number of
+// uint64 values that can fit in the secondary index before an overflow
+// occurs.
+//
+// To handle overflows, a "primary" index containing a running count of
+// bits set in each block is created using the type uint64.
+
+namespace fst {
+
+class BitmapIndex {
+ public:
+ static size_t StorageSize(size_t size) {
+ return ((size + kStorageBlockMask) >> kStorageLogBitSize);
+ }
+
+ BitmapIndex() : bits_(NULL), size_(0) { }
+
+ bool Get(size_t index) const {
+ return (bits_[index >> kStorageLogBitSize] &
+ (kOne << (index & kStorageBlockMask))) != 0;
+ }
+
+ static void Set(uint64* bits, size_t index) {
+ bits[index >> kStorageLogBitSize] |= (kOne << (index & kStorageBlockMask));
+ }
+
+ static void Clear(uint64* bits, size_t index) {
+ bits[index >> kStorageLogBitSize] &= ~(kOne << (index & kStorageBlockMask));
+ }
+
+ size_t Bits() const {
+ return size_;
+ }
+
+ size_t ArraySize() const {
+ return StorageSize(size_);
+ }
+
+ // Returns the number of one bits in the bitmap
+ size_t GetOnesCount() const {
+ return primary_index_[primary_index_size() - 1];
+ }
+
+ // Returns the number of one bits in positions 0 to limit - 1.
+ // REQUIRES: limit <= Bits()
+ size_t Rank1(size_t end) const;
+
+ // Returns the number of one bits in the range start to end - 1.
+ // REQUIRES: limit <= Bits()
+ size_t GetOnesCountInRange(size_t start, size_t end) const {
+ return Rank1(end) - Rank1(start);
+ }
+
+ // Returns the number of zero bits in positions 0 to limit - 1.
+ // REQUIRES: limit <= Bits()
+ size_t Rank0(size_t end) const {
+ return end - Rank1(end);
+ }
+
+ // Returns the number of zero bits in the range start to end - 1.
+ // REQUIRES: limit <= Bits()
+ size_t GetZeroesCountInRange(size_t start, size_t end) const {
+ return end - start - GetOnesCountInRange(start, end);
+ }
+
+ // Return true if any bit between begin inclusive and end exclusive
+ // is set. 0 <= begin <= end <= Bits() is required.
+ //
+ bool TestRange(size_t start, size_t end) const {
+ return Rank1(end) > Rank1(start);
+ }
+
+ // Returns the offset to the nth set bit (zero based)
+ // or Bits() if index >= number of ones
+ size_t Select1(size_t bit_index) const;
+
+ // Returns the offset to the nth clear bit (zero based)
+ // or Bits() if index > number of
+ size_t Select0(size_t bit_index) const;
+
+ // Rebuilds from index for the associated Bitmap, should be called
+ // whenever changes have been made to the Bitmap or else behavior
+ // of the indexed bitmap methods will be undefined.
+ void BuildIndex(const uint64 *bits, size_t size);
+
+ // the secondary index accumulates counts until it can possibly overflow
+ // this constant computes the number of uint64 units that can fit into
+ // units the size of uint16.
+ static const uint64 kOne = 1;
+ static const uint32 kStorageBitSize = 64;
+ static const uint32 kStorageLogBitSize = 6;
+ static const uint32 kSecondaryBlockSize = ((1 << 16) - 1)
+ >> kStorageLogBitSize;
+
+ private:
+ static const uint32 kStorageBlockMask = kStorageBitSize - 1;
+
+ // returns, from the index, the count of ones up to array_index
+ size_t get_index_ones_count(size_t array_index) const;
+
+ // because the indexes, both primary and secondary, contain a running
+ // count of the population of one bits contained in [0,i), there is
+ // no reason to have an element in the zeroth position as this value would
+ // necessarily be zero. (The bits are indexed in a zero based way.) Thus
+ // we don't store the 0th element in either index. Both of the following
+ // functions, if greater than 0, must be decremented by one before retreiving
+ // the value from the corresponding array.
+ // returns the 1 + the block that contains the bitindex in question
+ // the inverted version works the same but looks for zeros using an inverted
+ // view of the index
+ size_t find_primary_block(size_t bit_index) const;
+
+ size_t find_inverted_primary_block(size_t bit_index) const;
+
+ // similarly, the secondary index (which resets its count to zero at
+ // the end of every kSecondaryBlockSize entries) does not store the element
+ // at 0. Note that the rem_bit_index parameter is the number of bits
+ // within the secondary block, after the bits accounted for by the primary
+ // block have been removed (i.e. the remaining bits) And, because we
+ // reset to zero with each new block, there is no need to store those
+ // actual zeros.
+ // returns 1 + the secondary block that contains the bitindex in question
+ size_t find_secondary_block(size_t block, size_t rem_bit_index) const;
+
+ size_t find_inverted_secondary_block(size_t block, size_t rem_bit_index)
+ const;
+
+ // We create a primary index based upon the number of secondary index
+ // blocks. The primary index uses fields wide enough to accomodate any
+ // index of the bitarray so cannot overflow
+ // The primary index is the actual running
+ // count of one bits set for all blocks (and, thus, all uint64s).
+ size_t primary_index_size() const {
+ return (ArraySize() + kSecondaryBlockSize - 1) / kSecondaryBlockSize;
+ }
+
+ const uint64* bits_;
+ size_t size_;
+
+ // The primary index contains the running popcount of all blocks
+ // which means the nth value contains the popcounts of
+ // [0,n*kSecondaryBlockSize], however, the 0th element is omitted.
+ vector<uint32> primary_index_;
+ // The secondary index contains the running popcount of the associated
+ // bitmap. It is the same length (in units of uint16) as the
+ // bitmap's map is in units of uint64s.
+ vector<uint16> secondary_index_;
+};
+
+} // end namespace fst
+
+#endif // FST_EXTENSIONS_NGRAM_BITMAP_INDEX_H_
diff --git a/src/include/fst/extensions/ngram/ngram-fst.h b/src/include/fst/extensions/ngram/ngram-fst.h
new file mode 100644
index 0000000..eee664a
--- /dev/null
+++ b/src/include/fst/extensions/ngram/ngram-fst.h
@@ -0,0 +1,912 @@
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: sorenj@google.com (Jeffrey Sorensen)
+//
+#ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
+#define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
+
+#include <stddef.h>
+#include <string.h>
+#include <algorithm>
+#include <string>
+#include <vector>
+using std::vector;
+
+#include <fst/compat.h>
+#include <fst/fstlib.h>
+#include <fst/extensions/ngram/bitmap-index.h>
+
+// NgramFst implements a n-gram language model based upon the LOUDS data
+// structure. Please refer to "Unary Data Strucutres for Language Models"
+// http://research.google.com/pubs/archive/37218.pdf
+
+namespace fst {
+template <class A> class NGramFst;
+template <class A> class NGramFstMatcher;
+
+// Instance data containing mutable state for bookkeeping repeated access to
+// the same state.
+template <class A>
+struct NGramFstInst {
+ typedef typename A::Label Label;
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+ StateId state_;
+ size_t num_futures_;
+ size_t offset_;
+ size_t node_;
+ StateId node_state_;
+ vector<Label> context_;
+ StateId context_state_;
+ NGramFstInst()
+ : state_(kNoStateId), node_state_(kNoStateId),
+ context_state_(kNoStateId) { }
+};
+
+// Implementation class for LOUDS based NgramFst interface
+template <class A>
+class NGramFstImpl : public FstImpl<A> {
+ using FstImpl<A>::SetInputSymbols;
+ using FstImpl<A>::SetOutputSymbols;
+ using FstImpl<A>::SetType;
+ using FstImpl<A>::WriteHeader;
+
+ friend class ArcIterator<NGramFst<A> >;
+ friend class NGramFstMatcher<A>;
+
+ public:
+ using FstImpl<A>::InputSymbols;
+ using FstImpl<A>::SetProperties;
+ using FstImpl<A>::Properties;
+
+ typedef A Arc;
+ typedef typename A::Label Label;
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ NGramFstImpl() : data_(0), owned_(false) {
+ SetType("ngram");
+ SetInputSymbols(NULL);
+ SetOutputSymbols(NULL);
+ SetProperties(kStaticProperties);
+ }
+
+ NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out);
+
+ ~NGramFstImpl() {
+ if (owned_) {
+ delete [] data_;
+ }
+ }
+
+ static NGramFstImpl<A>* Read(istream &strm, // NOLINT
+ const FstReadOptions &opts) {
+ NGramFstImpl<A>* impl = new NGramFstImpl();
+ FstHeader hdr;
+ if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0;
+ uint64 num_states, num_futures, num_final;
+ const size_t offset = sizeof(num_states) + sizeof(num_futures) +
+ sizeof(num_final);
+ // Peek at num_states and num_futures to see how much more needs to be read.
+ strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states));
+ strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures));
+ strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final));
+ size_t size = Storage(num_states, num_futures, num_final);
+ char* data = new char[size];
+ // Copy num_states, num_futures and num_final back into data.
+ memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states));
+ memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures),
+ sizeof(num_futures));
+ memcpy(data + sizeof(num_states) + sizeof(num_futures),
+ reinterpret_cast<char *>(&num_final), sizeof(num_final));
+ strm.read(data + offset, size - offset);
+ if (!strm) {
+ delete impl;
+ return NULL;
+ }
+ impl->Init(data, true /* owned */);
+ return impl;
+ }
+
+ bool Write(ostream &strm, // NOLINT
+ const FstWriteOptions &opts) const {
+ FstHeader hdr;
+ hdr.SetStart(Start());
+ hdr.SetNumStates(num_states_);
+ WriteHeader(strm, opts, kFileVersion, &hdr);
+ strm.write(data_, Storage(num_states_, num_futures_, num_final_));
+ return strm;
+ }
+
+ StateId Start() const {
+ return 1;
+ }
+
+ Weight Final(StateId state) const {
+ if (final_index_.Get(state)) {
+ return final_probs_[final_index_.Rank1(state)];
+ } else {
+ return Weight::Zero();
+ }
+ }
+
+ size_t NumArcs(StateId state, NGramFstInst<A> *inst = NULL) const {
+ if (inst == NULL) {
+ const size_t next_zero = future_index_.Select0(state + 1);
+ const size_t this_zero = future_index_.Select0(state);
+ return next_zero - this_zero - 1;
+ }
+ SetInstFuture(state, inst);
+ return inst->num_futures_ + ((state == 0) ? 0 : 1);
+ }
+
+ size_t NumInputEpsilons(StateId state) const {
+ // State 0 has no parent, thus no backoff.
+ if (state == 0) return 0;
+ return 1;
+ }
+
+ size_t NumOutputEpsilons(StateId state) const {
+ return NumInputEpsilons(state);
+ }
+
+ StateId NumStates() const {
+ return num_states_;
+ }
+
+ void InitStateIterator(StateIteratorData<A>* data) const {
+ data->base = 0;
+ data->nstates = num_states_;
+ }
+
+ static size_t Storage(uint64 num_states, uint64 num_futures,
+ uint64 num_final) {
+ uint64 b64;
+ Weight weight;
+ Label label;
+ size_t offset = sizeof(num_states) + sizeof(num_futures) +
+ sizeof(num_final);
+ offset += sizeof(b64) * (
+ BitmapIndex::StorageSize(num_states * 2 + 1) +
+ BitmapIndex::StorageSize(num_futures + num_states + 1) +
+ BitmapIndex::StorageSize(num_states));
+ offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label);
+ // Pad for alignemnt, see
+ // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
+ offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
+ offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) +
+ (num_futures + 1) * sizeof(weight);
+ return offset;
+ }
+
+ void SetInstFuture(StateId state, NGramFstInst<A> *inst) const {
+ if (inst->state_ != state) {
+ inst->state_ = state;
+ const size_t next_zero = future_index_.Select0(state + 1);
+ const size_t this_zero = future_index_.Select0(state);
+ inst->num_futures_ = next_zero - this_zero - 1;
+ inst->offset_ = future_index_.Rank1(future_index_.Select0(state) + 1);
+ }
+ }
+
+ void SetInstNode(NGramFstInst<A> *inst) const {
+ if (inst->node_state_ != inst->state_) {
+ inst->node_state_ = inst->state_;
+ inst->node_ = context_index_.Select1(inst->state_);
+ }
+ }
+
+ void SetInstContext(NGramFstInst<A> *inst) const {
+ SetInstNode(inst);
+ if (inst->context_state_ != inst->state_) {
+ inst->context_state_ = inst->state_;
+ inst->context_.clear();
+ size_t node = inst->node_;
+ while (node != 0) {
+ inst->context_.push_back(context_words_[context_index_.Rank1(node)]);
+ node = context_index_.Select1(context_index_.Rank0(node) - 1);
+ }
+ }
+ }
+
+ // Access to the underlying representation
+ const char* GetData(size_t* data_size) const {
+ *data_size = Storage(num_states_, num_futures_, num_final_);
+ return data_;
+ }
+
+ void Init(const char* data, bool owned);
+
+ private:
+ StateId Transition(const vector<Label> &context, Label future) const;
+
+ // Properties always true for this Fst class.
+ static const uint64 kStaticProperties = kAcceptor | kIDeterministic |
+ kODeterministic | kEpsilons | kIEpsilons | kOEpsilons | kILabelSorted |
+ kOLabelSorted | kWeighted | kCyclic | kInitialAcyclic | kNotTopSorted |
+ kAccessible | kCoAccessible | kNotString | kExpanded;
+ // Current file format version.
+ static const int kFileVersion = 4;
+ // Minimum file format version supported.
+ static const int kMinFileVersion = 4;
+
+ const char* data_;
+ bool owned_; // True if we own data_
+ uint64 num_states_, num_futures_, num_final_;
+ size_t root_num_children_;
+ const Label *root_children_;
+ size_t root_first_child_;
+ // borrowed references
+ const uint64 *context_, *future_, *final_;
+ const Label *context_words_, *future_words_;
+ const Weight *backoff_, *final_probs_, *future_probs_;
+ BitmapIndex context_index_;
+ BitmapIndex future_index_;
+ BitmapIndex final_index_;
+
+ void operator=(const NGramFstImpl<A> &); // Disallow
+};
+
+template<typename A>
+NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
+ : data_(0), owned_(false) {
+ typedef A Arc;
+ typedef typename Arc::Label Label;
+ typedef typename Arc::Weight Weight;
+ typedef typename Arc::StateId StateId;
+ SetType("ngram");
+ SetInputSymbols(fst.InputSymbols());
+ SetOutputSymbols(fst.OutputSymbols());
+ SetProperties(kStaticProperties);
+
+ // Check basic requirements for an OpenGRM language model Fst.
+ int64 props = kAcceptor | kIDeterministic | kIEpsilons | kILabelSorted;
+ if (fst.Properties(props, true) != props) {
+ FSTERROR() << "NGramFst only accepts OpenGRM langauge models as input";
+ SetProperties(kError, kError);
+ return;
+ }
+
+ int64 num_states = CountStates(fst);
+ Label* context = new Label[num_states];
+
+ // Find the unigram state by starting from the start state, following
+ // epsilons.
+ StateId unigram = fst.Start();
+ while (1) {
+ ArcIterator<Fst<A> > aiter(fst, unigram);
+ if (aiter.Done()) {
+ FSTERROR() << "Start state has no arcs";
+ SetProperties(kError, kError);
+ return;
+ }
+ if (aiter.Value().ilabel != 0) break;
+ unigram = aiter.Value().nextstate;
+ }
+
+ // Each state's context is determined by the subtree it is under from the
+ // unigram state.
+ queue<pair<StateId, Label> > label_queue;
+ vector<bool> visited(num_states);
+ // Force an epsilon link to the start state.
+ label_queue.push(make_pair(fst.Start(), 0));
+ for (ArcIterator<Fst<A> > aiter(fst, unigram);
+ !aiter.Done(); aiter.Next()) {
+ label_queue.push(make_pair(aiter.Value().nextstate, aiter.Value().ilabel));
+ }
+ // investigate states in breadth first fashion to assign context words.
+ while (!label_queue.empty()) {
+ pair<StateId, Label> &now = label_queue.front();
+ if (!visited[now.first]) {
+ context[now.first] = now.second;
+ visited[now.first] = true;
+ for (ArcIterator<Fst<A> > aiter(fst, now.first);
+ !aiter.Done(); aiter.Next()) {
+ const Arc &arc = aiter.Value();
+ if (arc.ilabel != 0) {
+ label_queue.push(make_pair(arc.nextstate, now.second));
+ }
+ }
+ }
+ label_queue.pop();
+ }
+ visited.clear();
+
+ // The arc from the start state should be assigned an epsilon to put it
+ // in front of the all other labels (which makes Start state 1 after
+ // unigram which is state 0).
+ context[fst.Start()] = 0;
+
+ // Build the tree of contexts fst by reversing the epsilon arcs from fst.
+ VectorFst<Arc> context_fst;
+ uint64 num_final = 0;
+ for (int i = 0; i < num_states; ++i) {
+ if (fst.Final(i) != Weight::Zero()) {
+ ++num_final;
+ }
+ context_fst.SetFinal(context_fst.AddState(), fst.Final(i));
+ }
+ context_fst.SetStart(unigram);
+ context_fst.SetInputSymbols(fst.InputSymbols());
+ context_fst.SetOutputSymbols(fst.OutputSymbols());
+ int64 num_context_arcs = 0;
+ int64 num_futures = 0;
+ for (StateIterator<Fst<A> > siter(fst); !siter.Done(); siter.Next()) {
+ const StateId &state = siter.Value();
+ num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state);
+ ArcIterator<Fst<A> > aiter(fst, state);
+ if (!aiter.Done()) {
+ const Arc &arc = aiter.Value();
+ // this arc goes from state to arc.nextstate, so create an arc from
+ // arc.nextstate to state to reverse it.
+ if (arc.ilabel == 0) {
+ context_fst.AddArc(arc.nextstate, Arc(context[state], context[state],
+ arc.weight, state));
+ num_context_arcs++;
+ }
+ }
+ }
+ if (num_context_arcs != context_fst.NumStates() - 1) {
+ FSTERROR() << "Number of contexts arcs != number of states - 1";
+ SetProperties(kError, kError);
+ return;
+ }
+ if (context_fst.NumStates() != num_states) {
+ FSTERROR() << "Number of contexts != number of states";
+ SetProperties(kError, kError);
+ return;
+ }
+ int64 context_props = context_fst.Properties(kIDeterministic |
+ kILabelSorted, true);
+ if (!(context_props & kIDeterministic)) {
+ FSTERROR() << "Input fst is not structured properly";
+ SetProperties(kError, kError);
+ return;
+ }
+ if (!(context_props & kILabelSorted)) {
+ ArcSort(&context_fst, ILabelCompare<Arc>());
+ }
+
+ delete [] context;
+
+ uint64 b64;
+ Weight weight;
+ Label label = kNoLabel;
+ const size_t storage = Storage(num_states, num_futures, num_final);
+ char* data = new char[storage];
+ memset(data, 0, storage);
+ size_t offset = 0;
+ memcpy(data + offset, reinterpret_cast<char *>(&num_states),
+ sizeof(num_states));
+ offset += sizeof(num_states);
+ memcpy(data + offset, reinterpret_cast<char *>(&num_futures),
+ sizeof(num_futures));
+ offset += sizeof(num_futures);
+ memcpy(data + offset, reinterpret_cast<char *>(&num_final),
+ sizeof(num_final));
+ offset += sizeof(num_final);
+ uint64* context_bits = reinterpret_cast<uint64*>(data + offset);
+ offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64);
+ uint64* future_bits = reinterpret_cast<uint64*>(data + offset);
+ offset +=
+ BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64);
+ uint64* final_bits = reinterpret_cast<uint64*>(data + offset);
+ offset += BitmapIndex::StorageSize(num_states) * sizeof(b64);
+ Label* context_words = reinterpret_cast<Label*>(data + offset);
+ offset += (num_states + 1) * sizeof(label);
+ Label* future_words = reinterpret_cast<Label*>(data + offset);
+ offset += num_futures * sizeof(label);
+ offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
+ Weight* backoff = reinterpret_cast<Weight*>(data + offset);
+ offset += (num_states + 1) * sizeof(weight);
+ Weight* final_probs = reinterpret_cast<Weight*>(data + offset);
+ offset += num_final * sizeof(weight);
+ Weight* future_probs = reinterpret_cast<Weight*>(data + offset);
+ int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0,
+ final_bit = 0;
+
+ // pseudo-root bits
+ BitmapIndex::Set(context_bits, context_bit++);
+ ++context_bit;
+ context_words[context_arc] = label;
+ backoff[context_arc] = Weight::Zero();
+ context_arc++;
+
+ ++future_bit;
+ if (order_out) {
+ order_out->clear();
+ order_out->resize(num_states);
+ }
+
+ queue<StateId> context_q;
+ context_q.push(context_fst.Start());
+ StateId state_number = 0;
+ while (!context_q.empty()) {
+ const StateId &state = context_q.front();
+ if (order_out) {
+ (*order_out)[state] = state_number;
+ }
+
+ const Weight &final = context_fst.Final(state);
+ if (final != Weight::Zero()) {
+ BitmapIndex::Set(final_bits, state_number);
+ final_probs[final_bit] = final;
+ ++final_bit;
+ }
+
+ for (ArcIterator<VectorFst<A> > aiter(context_fst, state);
+ !aiter.Done(); aiter.Next()) {
+ const Arc &arc = aiter.Value();
+ context_words[context_arc] = arc.ilabel;
+ backoff[context_arc] = arc.weight;
+ ++context_arc;
+ BitmapIndex::Set(context_bits, context_bit++);
+ context_q.push(arc.nextstate);
+ }
+ ++context_bit;
+
+ for (ArcIterator<Fst<A> > aiter(fst, state); !aiter.Done(); aiter.Next()) {
+ const Arc &arc = aiter.Value();
+ if (arc.ilabel != 0) {
+ future_words[future_arc] = arc.ilabel;
+ future_probs[future_arc] = arc.weight;
+ ++future_arc;
+ BitmapIndex::Set(future_bits, future_bit++);
+ }
+ }
+ ++future_bit;
+ ++state_number;
+ context_q.pop();
+ }
+
+ if ((state_number != num_states) ||
+ (context_bit != num_states * 2 + 1) ||
+ (context_arc != num_states) ||
+ (future_arc != num_futures) ||
+ (future_bit != num_futures + num_states + 1) ||
+ (final_bit != num_final)) {
+ FSTERROR() << "Structure problems detected during construction";
+ SetProperties(kError, kError);
+ return;
+ }
+
+ Init(data, true /* owned */);
+}
+
+template<typename A>
+inline void NGramFstImpl<A>::Init(const char* data, bool owned) {
+ if (owned_) {
+ delete [] data_;
+ }
+ owned_ = owned;
+ data_ = data;
+ size_t offset = 0;
+ num_states_ = *(reinterpret_cast<const uint64*>(data_ + offset));
+ offset += sizeof(num_states_);
+ num_futures_ = *(reinterpret_cast<const uint64*>(data_ + offset));
+ offset += sizeof(num_futures_);
+ num_final_ = *(reinterpret_cast<const uint64*>(data_ + offset));
+ offset += sizeof(num_final_);
+ uint64 bits;
+ size_t context_bits = num_states_ * 2 + 1;
+ size_t future_bits = num_futures_ + num_states_ + 1;
+ context_ = reinterpret_cast<const uint64*>(data_ + offset);
+ offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits);
+ future_ = reinterpret_cast<const uint64*>(data_ + offset);
+ offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits);
+ final_ = reinterpret_cast<const uint64*>(data_ + offset);
+ offset += BitmapIndex::StorageSize(num_states_ + 1) * sizeof(bits);
+ context_words_ = reinterpret_cast<const Label*>(data_ + offset);
+ offset += (num_states_ + 1) * sizeof(*context_words_);
+ future_words_ = reinterpret_cast<const Label*>(data_ + offset);
+ offset += num_futures_ * sizeof(*future_words_);
+ offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1);
+ backoff_ = reinterpret_cast<const Weight*>(data_ + offset);
+ offset += (num_states_ + 1) * sizeof(*backoff_);
+ final_probs_ = reinterpret_cast<const Weight*>(data_ + offset);
+ offset += num_final_ * sizeof(*final_probs_);
+ future_probs_ = reinterpret_cast<const Weight*>(data_ + offset);
+
+ context_index_.BuildIndex(context_, context_bits);
+ future_index_.BuildIndex(future_, future_bits);
+ final_index_.BuildIndex(final_, num_states_);
+
+ const size_t node_rank = context_index_.Rank1(0);
+ root_first_child_ = context_index_.Select0(node_rank) + 1;
+ if (context_index_.Get(root_first_child_) == false) {
+ FSTERROR() << "Missing unigrams";
+ SetProperties(kError, kError);
+ return;
+ }
+ const size_t last_child = context_index_.Select0(node_rank + 1) - 1;
+ root_num_children_ = last_child - root_first_child_ + 1;
+ root_children_ = context_words_ + context_index_.Rank1(root_first_child_);
+}
+
+template<typename A>
+inline typename A::StateId NGramFstImpl<A>::Transition(
+ const vector<Label> &context, Label future) const {
+ size_t num_children = root_num_children_;
+ const Label *children = root_children_;
+ const Label *loc = lower_bound(children, children + num_children, future);
+ if (loc == children + num_children || *loc != future) {
+ return context_index_.Rank1(0);
+ }
+ size_t node = root_first_child_ + loc - children;
+ size_t node_rank = context_index_.Rank1(node);
+ size_t first_child = context_index_.Select0(node_rank) + 1;
+ if (context_index_.Get(first_child) == false) {
+ return context_index_.Rank1(node);
+ }
+ size_t last_child = context_index_.Select0(node_rank + 1) - 1;
+ num_children = last_child - first_child + 1;
+ for (int word = context.size() - 1; word >= 0; --word) {
+ children = context_words_ + context_index_.Rank1(first_child);
+ loc = lower_bound(children, children + last_child - first_child + 1,
+ context[word]);
+ if (loc == children + last_child - first_child + 1 ||
+ *loc != context[word]) {
+ break;
+ }
+ node = first_child + loc - children;
+ node_rank = context_index_.Rank1(node);
+ first_child = context_index_.Select0(node_rank) + 1;
+ if (context_index_.Get(first_child) == false) break;
+ last_child = context_index_.Select0(node_rank + 1) - 1;
+ }
+ return context_index_.Rank1(node);
+}
+
+/*****************************************************************************/
+template<class A>
+class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
+ friend class ArcIterator<NGramFst<A> >;
+ friend class NGramFstMatcher<A>;
+
+ public:
+ typedef A Arc;
+ typedef typename A::StateId StateId;
+ typedef typename A::Label Label;
+ typedef typename A::Weight Weight;
+ typedef NGramFstImpl<A> Impl;
+
+ explicit NGramFst(const Fst<A> &dst)
+ : ImplToExpandedFst<Impl>(new Impl(dst, NULL)) {}
+
+ NGramFst(const Fst<A> &fst, vector<StateId>* order_out)
+ : ImplToExpandedFst<Impl>(new Impl(fst, order_out)) {}
+
+ // Because the NGramFstImpl is a const stateless data structure, there
+ // is never a need to do anything beside copy the reference.
+ NGramFst(const NGramFst<A> &fst, bool safe = false)
+ : ImplToExpandedFst<Impl>(fst, false) {}
+
+ NGramFst() : ImplToExpandedFst<Impl>(new Impl()) {}
+
+ // Non-standard constructor to initialize NGramFst directly from data.
+ NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) {
+ GetImpl()->Init(data, owned);
+ }
+
+ // Get method that gets the data associated with Init().
+ const char* GetData(size_t* data_size) const {
+ return GetImpl()->GetData(data_size);
+ }
+
+ virtual size_t NumArcs(StateId s) const {
+ return GetImpl()->NumArcs(s, &inst_);
+ }
+
+ virtual NGramFst<A>* Copy(bool safe = false) const {
+ return new NGramFst(*this, safe);
+ }
+
+ static NGramFst<A>* Read(istream &strm, const FstReadOptions &opts) {
+ Impl* impl = Impl::Read(strm, opts);
+ return impl ? new NGramFst<A>(impl) : 0;
+ }
+
+ static NGramFst<A>* Read(const string &filename) {
+ if (!filename.empty()) {
+ ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
+ if (!strm) {
+ LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename;
+ return 0;
+ }
+ return Read(strm, FstReadOptions(filename));
+ } else {
+ return Read(cin, FstReadOptions("standard input"));
+ }
+ }
+
+ virtual bool Write(ostream &strm, const FstWriteOptions &opts) const {
+ return GetImpl()->Write(strm, opts);
+ }
+
+ virtual bool Write(const string &filename) const {
+ return Fst<A>::WriteFile(filename);
+ }
+
+ virtual inline void InitStateIterator(StateIteratorData<A>* data) const {
+ GetImpl()->InitStateIterator(data);
+ }
+
+ virtual inline void InitArcIterator(
+ StateId s, ArcIteratorData<A>* data) const;
+
+ virtual MatcherBase<A>* InitMatcher(MatchType match_type) const {
+ return new NGramFstMatcher<A>(*this, match_type);
+ }
+
+ private:
+ explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {}
+
+ Impl* GetImpl() const {
+ return
+ ImplToExpandedFst<Impl, ExpandedFst<A> >::GetImpl();
+ }
+
+ void SetImpl(Impl* impl, bool own_impl = true) {
+ ImplToExpandedFst<Impl, Fst<A> >::SetImpl(impl, own_impl);
+ }
+
+ mutable NGramFstInst<A> inst_;
+};
+
+template <class A> inline void
+NGramFst<A>::InitArcIterator(StateId s, ArcIteratorData<A>* data) const {
+ GetImpl()->SetInstFuture(s, &inst_);
+ GetImpl()->SetInstNode(&inst_);
+ data->base = new ArcIterator<NGramFst<A> >(*this, s);
+}
+
+/*****************************************************************************/
+template <class A>
+class NGramFstMatcher : public MatcherBase<A> {
+ public:
+ typedef A Arc;
+ typedef typename A::Label Label;
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ NGramFstMatcher(const NGramFst<A> &fst, MatchType match_type)
+ : fst_(fst), inst_(fst.inst_), match_type_(match_type),
+ current_loop_(false),
+ loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
+ if (match_type_ == MATCH_OUTPUT) {
+ swap(loop_.ilabel, loop_.olabel);
+ }
+ }
+
+ NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false)
+ : fst_(matcher.fst_), inst_(matcher.inst_),
+ match_type_(matcher.match_type_), current_loop_(false),
+ loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
+ if (match_type_ == MATCH_OUTPUT) {
+ swap(loop_.ilabel, loop_.olabel);
+ }
+ }
+
+ virtual NGramFstMatcher<A>* Copy(bool safe = false) const {
+ return new NGramFstMatcher<A>(*this, safe);
+ }
+
+ virtual MatchType Type(bool test) const {
+ return match_type_;
+ }
+
+ virtual const Fst<A> &GetFst() const {
+ return fst_;
+ }
+
+ virtual uint64 Properties(uint64 props) const {
+ return props;
+ }
+
+ private:
+ virtual void SetState_(StateId s) {
+ fst_.GetImpl()->SetInstFuture(s, &inst_);
+ current_loop_ = false;
+ }
+
+ virtual bool Find_(Label label) {
+ const Label nolabel = kNoLabel;
+ done_ = true;
+ if (label == 0 || label == nolabel) {
+ if (label == 0) {
+ current_loop_ = true;
+ loop_.nextstate = inst_.state_;
+ }
+ // The unigram state has no epsilon arc.
+ if (inst_.state_ != 0) {
+ arc_.ilabel = arc_.olabel = 0;
+ fst_.GetImpl()->SetInstNode(&inst_);
+ arc_.nextstate = fst_.GetImpl()->context_index_.Rank1(
+ fst_.GetImpl()->context_index_.Select1(
+ fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1));
+ arc_.weight = fst_.GetImpl()->backoff_[inst_.state_];
+ done_ = false;
+ }
+ } else {
+ const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_;
+ const Label *end = start + inst_.num_futures_;
+ const Label* search = lower_bound(start, end, label);
+ if (search != end && *search == label) {
+ size_t state = search - start;
+ arc_.ilabel = arc_.olabel = label;
+ arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state];
+ fst_.GetImpl()->SetInstContext(&inst_);
+ arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label);
+ done_ = false;
+ }
+ }
+ return !Done_();
+ }
+
+ virtual bool Done_() const {
+ return !current_loop_ && done_;
+ }
+
+ virtual const Arc& Value_() const {
+ return (current_loop_) ? loop_ : arc_;
+ }
+
+ virtual void Next_() {
+ if (current_loop_) {
+ current_loop_ = false;
+ } else {
+ done_ = true;
+ }
+ }
+
+ const NGramFst<A>& fst_;
+ NGramFstInst<A> inst_;
+ MatchType match_type_; // Supplied by caller
+ bool done_;
+ Arc arc_;
+ bool current_loop_; // Current arc is the implicit loop
+ Arc loop_;
+};
+
+/*****************************************************************************/
+template<class A>
+class ArcIterator<NGramFst<A> > : public ArcIteratorBase<A> {
+ public:
+ typedef A Arc;
+ typedef typename A::Label Label;
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ ArcIterator(const NGramFst<A> &fst, StateId state)
+ : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) {
+ inst_ = fst.inst_;
+ impl_->SetInstFuture(state, &inst_);
+ impl_->SetInstNode(&inst_);
+ }
+
+ bool Done() const {
+ return i_ >= ((inst_.node_ == 0) ? inst_.num_futures_ :
+ inst_.num_futures_ + 1);
+ }
+
+ const Arc &Value() const {
+ bool eps = (inst_.node_ != 0 && i_ == 0);
+ StateId state = (inst_.node_ == 0) ? i_ : i_ - 1;
+ if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) {
+ arc_.ilabel =
+ arc_.olabel = eps ? 0 : impl_->future_words_[inst_.offset_ + state];
+ lazy_ &= ~(kArcILabelValue | kArcOLabelValue);
+ }
+ if (flags_ & lazy_ & kArcNextStateValue) {
+ if (eps) {
+ arc_.nextstate = impl_->context_index_.Rank1(
+ impl_->context_index_.Select1(
+ impl_->context_index_.Rank0(inst_.node_) - 1));
+ } else {
+ if (lazy_ & kArcNextStateValue) {
+ impl_->SetInstContext(&inst_); // first time only.
+ }
+ arc_.nextstate =
+ impl_->Transition(inst_.context_,
+ impl_->future_words_[inst_.offset_ + state]);
+ }
+ lazy_ &= ~kArcNextStateValue;
+ }
+ if (flags_ & lazy_ & kArcWeightValue) {
+ arc_.weight = eps ? impl_->backoff_[inst_.state_] :
+ impl_->future_probs_[inst_.offset_ + state];
+ lazy_ &= ~kArcWeightValue;
+ }
+ return arc_;
+ }
+
+ void Next() {
+ ++i_;
+ lazy_ = ~0;
+ }
+
+ size_t Position() const { return i_; }
+
+ void Reset() {
+ i_ = 0;
+ lazy_ = ~0;
+ }
+
+ void Seek(size_t a) {
+ if (i_ != a) {
+ i_ = a;
+ lazy_ = ~0;
+ }
+ }
+
+ uint32 Flags() const {
+ return flags_;
+ }
+
+ void SetFlags(uint32 f, uint32 m) {
+ flags_ &= ~m;
+ flags_ |= (f & kArcValueFlags);
+ }
+
+ private:
+ virtual bool Done_() const { return Done(); }
+ virtual const Arc& Value_() const { return Value(); }
+ virtual void Next_() { Next(); }
+ virtual size_t Position_() const { return Position(); }
+ virtual void Reset_() { Reset(); }
+ virtual void Seek_(size_t a) { Seek(a); }
+ uint32 Flags_() const { return Flags(); }
+ void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); }
+
+ mutable Arc arc_;
+ mutable uint32 lazy_;
+ const NGramFstImpl<A> *impl_;
+ mutable NGramFstInst<A> inst_;
+
+ size_t i_;
+ uint32 flags_;
+
+ DISALLOW_COPY_AND_ASSIGN(ArcIterator);
+};
+
+/*****************************************************************************/
+// Specialization for NGramFst; see generic version in fst.h
+// for sample usage (but use the ProdLmFst type!). This version
+// should inline.
+template <class A>
+class StateIterator<NGramFst<A> > : public StateIteratorBase<A> {
+ public:
+ typedef typename A::StateId StateId;
+
+ explicit StateIterator(const NGramFst<A> &fst)
+ : s_(0), num_states_(fst.NumStates()) { }
+
+ bool Done() const { return s_ >= num_states_; }
+ StateId Value() const { return s_; }
+ void Next() { ++s_; }
+ void Reset() { s_ = 0; }
+
+ private:
+ virtual bool Done_() const { return Done(); }
+ virtual StateId Value_() const { return Value(); }
+ virtual void Next_() { Next(); }
+ virtual void Reset_() { Reset(); }
+
+ StateId s_, num_states_;
+
+ DISALLOW_COPY_AND_ASSIGN(StateIterator);
+};
+} // namespace fst
+#endif // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
diff --git a/src/include/fst/extensions/ngram/nthbit.h b/src/include/fst/extensions/ngram/nthbit.h
new file mode 100644
index 0000000..d4a9a5a
--- /dev/null
+++ b/src/include/fst/extensions/ngram/nthbit.h
@@ -0,0 +1,46 @@
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: sorenj@google.com (Jeffrey Sorensen)
+// dr@google.com (Doug Rohde)
+
+#ifndef FST_EXTENSIONS_NGRAM_NTHBIT_H_
+#define FST_EXTENSIONS_NGRAM_NTHBIT_H_
+
+#include <fst/types.h>
+
+extern uint32 nth_bit_bit_offset[];
+
+inline uint32 nth_bit(uint64 v, uint32 r) {
+ uint32 shift = 0;
+ uint32 c = __builtin_popcount(v & 0xffffffff);
+ uint32 mask = -(r > c);
+ r -= c & mask;
+ shift += (32 & mask);
+
+ c = __builtin_popcount((v >> shift) & 0xffff);
+ mask = -(r > c);
+ r -= c & mask;
+ shift += (16 & mask);
+
+ c = __builtin_popcount((v >> shift) & 0xff);
+ mask = -(r > c);
+ r -= c & mask;
+ shift += (8 & mask);
+
+ return shift + ((nth_bit_bit_offset[(v >> shift) & 0xff] >>
+ ((r - 1) << 2)) & 0xf);
+}
+
+#endif // FST_EXTENSIONS_NGRAM_NTHBIT_H_
diff --git a/src/include/fst/extensions/pdt/collection.h b/src/include/fst/extensions/pdt/collection.h
index 26be504..24a443f 100644
--- a/src/include/fst/extensions/pdt/collection.h
+++ b/src/include/fst/extensions/pdt/collection.h
@@ -16,7 +16,7 @@
// Author: riley@google.com (Michael Riley)
//
// \file
-// Class to store a collection of sets with elements of type T.
+// Class to store a collection of ordered (multi-)sets with elements of type T.
#ifndef FST_EXTENSIONS_PDT_COLLECTION_H__
#define FST_EXTENSIONS_PDT_COLLECTION_H__
@@ -29,11 +29,11 @@ using std::vector;
namespace fst {
-// Stores a collection of non-empty sets with elements of type T. A
-// default constructor, equality ==, a total order <, and an STL-style
-// hash class must be defined on the elements. Provides signed
-// integer ID (of type I) of each unique set. The IDs are allocated
-// starting from 0 in order.
+// Stores a collection of non-empty, ordered (multi-)sets with elements
+// of type T. A default constructor, equality ==, and an STL-style
+// hash class must be defined on the elements. Provides signed integer
+// ID (of type I) of each unique set. The IDs are allocated starting
+// from 0 in order.
template <class I, class T>
class Collection {
public:
@@ -80,31 +80,34 @@ class Collection {
Collection() {}
- // Lookups integer ID from set. If it doesn't exist, then adds it.
- // Set elements should be in strict order (and therefore unique).
- I FindId(const vector<T> &set) {
+ // Lookups integer ID from ordered multi-set. If it doesn't exist
+ // and 'insert' is true, then adds it. Otherwise returns -1.
+ I FindId(const vector<T> &set, bool insert = true) {
I node_id = kNoNodeId;
for (ssize_t i = set.size() - 1; i >= 0; --i) {
Node node(node_id, set[i]);
- node_id = node_table_.FindId(node);
+ node_id = node_table_.FindId(node, insert);
+ if (node_id == -1) break;
}
return node_id;
}
- // Finds set given integer ID. Returns true if ID corresponds
- // to set. Use iterators below to traverse result.
+ // Finds ordered (multi-)set given integer ID. Returns set iterator
+ // to traverse result.
SetIterator FindSet(I id) {
- if (id < 0 && id >= node_table_.Size()) {
+ if (id < 0 || id >= node_table_.Size()) {
return SetIterator(kNoNodeId, Node(kNoNodeId, T()), &node_table_);
} else {
return SetIterator(id, node_table_.FindEntry(id), &node_table_);
}
}
+ I Size() const { return node_table_.Size(); }
+
private:
static const I kNoNodeId;
static const size_t kPrime;
- static std::tr1::hash<T> hash_;
+ static std::hash<T> hash_;
NodeTable node_table_;
@@ -115,7 +118,7 @@ template<class I, class T> const I Collection<I, T>::kNoNodeId = -1;
template <class I, class T> const size_t Collection<I, T>::kPrime = 7853;
-template <class I, class T> std::tr1::hash<T> Collection<I, T>::hash_;
+template <class I, class T> std::hash<T> Collection<I, T>::hash_;
} // namespace fst
diff --git a/src/include/fst/extensions/pdt/info.h b/src/include/fst/extensions/pdt/info.h
index ef9a860..55e76c4 100644
--- a/src/include/fst/extensions/pdt/info.h
+++ b/src/include/fst/extensions/pdt/info.h
@@ -24,7 +24,7 @@
#include <unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
-#include <tr1/unordered_set>
+#include <unordered_set>
using std::tr1::unordered_set;
using std::tr1::unordered_multiset;
#include <vector>
diff --git a/src/include/fst/extensions/pdt/paren.h b/src/include/fst/extensions/pdt/paren.h
index 7b9887f..a9d30c5 100644
--- a/src/include/fst/extensions/pdt/paren.h
+++ b/src/include/fst/extensions/pdt/paren.h
@@ -26,7 +26,7 @@
#include <unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
-#include <tr1/unordered_set>
+#include <unordered_set>
using std::tr1::unordered_set;
using std::tr1::unordered_multiset;
#include <set>
@@ -144,7 +144,8 @@ class PdtParenReachable {
const vector<pair<Label, Label> > &parens, bool close)
: fst_(fst),
parens_(parens),
- close_(close) {
+ close_(close),
+ error_(false) {
for (Label i = 0; i < parens.size(); ++i) {
const pair<Label, Label> &p = parens[i];
paren_id_map_[p.first] = i;
@@ -155,12 +156,18 @@ class PdtParenReachable {
StateId start = fst.Start();
if (start == kNoStateId)
return;
- DFSearch(start, start);
+ if (!DFSearch(start)) {
+ FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
+ error_ = true;
+ }
} else {
FSTERROR() << "PdtParenReachable: open paren info not implemented";
+ error_ = true;
}
}
+ bool const Error() { return error_; }
+
// Given a state ID, returns an iterator over paren IDs
// for close (open) parens reachable from that state along balanced
// paths.
@@ -194,7 +201,7 @@ class PdtParenReachable {
private:
// DFS that gathers paren and state set information.
// Bool returns false when cycle detected.
- bool DFSearch(StateId s, StateId start);
+ bool DFSearch(StateId s);
// Unions state sets together gathered by the DFS.
void ComputeStateSet(StateId s);
@@ -212,12 +219,13 @@ class PdtParenReachable {
vector<char> state_color_; // DFS state
mutable Collection<ssize_t, StateId> state_sets_; // Reachable states -> ID
StateSetMap set_map_; // ID -> Reachable states
+ bool error_;
DISALLOW_COPY_AND_ASSIGN(PdtParenReachable);
};
// DFS that gathers paren and state set information.
template <class A>
-bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) {
+bool PdtParenReachable<A>::DFSearch(StateId s) {
if (s >= state_color_.size())
state_color_.resize(s + 1, kDfsWhite);
@@ -239,7 +247,8 @@ bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) {
if (pit != paren_id_map_.end()) { // paren?
Label paren_id = pit->second;
if (arc.ilabel == parens_[paren_id].first) { // open paren
- DFSearch(arc.nextstate, arc.nextstate);
+ if (!DFSearch(arc.nextstate))
+ return false;
for (SetIterator set_iter = FindStates(paren_id, arc.nextstate);
!set_iter.Done(); set_iter.Next()) {
for (ParenArcIterator paren_arc_iter =
@@ -247,15 +256,14 @@ bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) {
!paren_arc_iter.Done();
paren_arc_iter.Next()) {
const A &cparc = paren_arc_iter.Value();
- DFSearch(cparc.nextstate, start);
+ if (!DFSearch(cparc.nextstate))
+ return false;
}
}
}
} else { // non-paren
- if(!DFSearch(arc.nextstate, start)) {
- FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
- return true;
- }
+ if(!DFSearch(arc.nextstate))
+ return false;
}
}
ComputeStateSet(s);
diff --git a/src/include/fst/extensions/pdt/shortest-path.h b/src/include/fst/extensions/pdt/shortest-path.h
index e90471b..85f94b8 100644
--- a/src/include/fst/extensions/pdt/shortest-path.h
+++ b/src/include/fst/extensions/pdt/shortest-path.h
@@ -28,7 +28,7 @@
#include <unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
-#include <tr1/unordered_set>
+#include <unordered_set>
using std::tr1::unordered_set;
using std::tr1::unordered_multiset;
#include <stack>
@@ -387,7 +387,6 @@ class PdtShortestPath {
typedef typename SpData::SearchState SearchState;
typedef typename SpData::ParenSpec ParenSpec;
- typedef typename PdtParenReachable<Arc>::SetIterator StateSetIterator;
typedef typename PdtBalanceData<Arc>::SetIterator CloseSourceIterator;
PdtShortestPath(const Fst<Arc> &ifst,
@@ -403,7 +402,7 @@ class PdtShortestPath {
if ((Weight::Properties() & (kPath | kRightSemiring))
!= (kPath | kRightSemiring)) {
- FSTERROR() << "SingleShortestPath: Weight needs to have the path"
+ FSTERROR() << "PdtShortestPath: Weight needs to have the path"
<< " property and be right distributive: " << Weight::Type();
error_ = true;
}
@@ -440,6 +439,7 @@ class PdtShortestPath {
static const Arc kNoArc;
static const uint8 kEnqueued;
static const uint8 kExpanded;
+ static const uint8 kFinished;
const uint8 kFinal;
public:
@@ -543,6 +543,7 @@ void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) {
ProcArcs(s);
sp_data_.SetFlags(s, kExpanded, kExpanded);
}
+ sp_data_.SetFlags(q, kFinished, kFinished);
balance_data_.FinishInsert(start);
sp_data_.GC(start);
}
@@ -607,7 +608,11 @@ void PdtShortestPath<Arc, Queue>::ProcOpenParen(
Queue *state_queue = state_queue_;
GetDistance(d.start);
state_queue_ = state_queue;
+ } else if (!(sp_data_.Flags(d) & kFinished)) {
+ FSTERROR() << "PdtShortestPath: open parenthesis recursion: not bounded stack";
+ error_ = true;
}
+
for (CloseSourceIterator set_iter =
balance_data_.Find(paren_id, arc.nextstate);
!set_iter.Done(); set_iter.Next()) {
@@ -765,6 +770,9 @@ template<class Arc, class Queue>
const uint8 PdtShortestPath<Arc, Queue>::kExpanded = 0x20;
template<class Arc, class Queue>
+const uint8 PdtShortestPath<Arc, Queue>::kFinished = 0x40;
+
+template<class Arc, class Queue>
void ShortestPath(const Fst<Arc> &ifst,
const vector<pair<typename Arc::Label,
typename Arc::Label> > &parens,