aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/encode.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/encode.h')
-rw-r--r--src/include/fst/encode.h599
1 files changed, 599 insertions, 0 deletions
diff --git a/src/include/fst/encode.h b/src/include/fst/encode.h
new file mode 100644
index 0000000..7245b45
--- /dev/null
+++ b/src/include/fst/encode.h
@@ -0,0 +1,599 @@
+// encode.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: johans@google.com (Johan Schalkwyk)
+//
+// \file
+// Class to encode and decoder an fst.
+
+#ifndef FST_LIB_ENCODE_H__
+#define FST_LIB_ENCODE_H__
+
+#include <climits>
+#include <unordered_map>
+using std::tr1::unordered_map;
+using std::tr1::unordered_multimap;
+#include <string>
+#include <vector>
+using std::vector;
+
+#include <fst/arc-map.h>
+#include <fst/rmfinalepsilon.h>
+
+
+namespace fst {
+
+static const uint32 kEncodeLabels = 0x0001;
+static const uint32 kEncodeWeights = 0x0002;
+static const uint32 kEncodeFlags = 0x0003; // All non-internal flags
+
+static const uint32 kEncodeHasISymbols = 0x0004; // For internal use
+static const uint32 kEncodeHasOSymbols = 0x0008; // For internal use
+
+enum EncodeType { ENCODE = 1, DECODE = 2 };
+
+// Identifies stream data as an encode table (and its endianity)
+static const int32 kEncodeMagicNumber = 2129983209;
+
+
+// The following class encapsulates implementation details for the
+// encoding and decoding of label/weight tuples used for encoding
+// and decoding of Fsts. The EncodeTable is bidirectional. I.E it
+// stores both the Tuple of encode labels and weights to a unique
+// label, and the reverse.
+template <class A> class EncodeTable {
+ public:
+ typedef typename A::Label Label;
+ typedef typename A::Weight Weight;
+
+ // Encoded data consists of arc input/output labels and arc weight
+ struct Tuple {
+ Tuple() {}
+ Tuple(Label ilabel_, Label olabel_, Weight weight_)
+ : ilabel(ilabel_), olabel(olabel_), weight(weight_) {}
+ Tuple(const Tuple& tuple)
+ : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {}
+
+ Label ilabel;
+ Label olabel;
+ Weight weight;
+ };
+
+ // Comparison object for hashing EncodeTable Tuple(s).
+ class TupleEqual {
+ public:
+ bool operator()(const Tuple* x, const Tuple* y) const {
+ return (x->ilabel == y->ilabel &&
+ x->olabel == y->olabel &&
+ x->weight == y->weight);
+ }
+ };
+
+ // Hash function for EncodeTabe Tuples. Based on the encode flags
+ // we either hash the labels, weights or combination of them.
+ class TupleKey {
+ public:
+ TupleKey()
+ : encode_flags_(kEncodeLabels | kEncodeWeights) {}
+
+ TupleKey(const TupleKey& key)
+ : encode_flags_(key.encode_flags_) {}
+
+ explicit TupleKey(uint32 encode_flags)
+ : encode_flags_(encode_flags) {}
+
+ size_t operator()(const Tuple* x) const {
+ size_t hash = x->ilabel;
+ const int lshift = 5;
+ const int rshift = CHAR_BIT * sizeof(size_t) - 5;
+ if (encode_flags_ & kEncodeLabels)
+ hash = hash << lshift ^ hash >> rshift ^ x->olabel;
+ if (encode_flags_ & kEncodeWeights)
+ hash = hash << lshift ^ hash >> rshift ^ x->weight.Hash();
+ return hash;
+ }
+
+ private:
+ int32 encode_flags_;
+ };
+
+ typedef unordered_map<const Tuple*,
+ Label,
+ TupleKey,
+ TupleEqual> EncodeHash;
+
+ explicit EncodeTable(uint32 encode_flags)
+ : flags_(encode_flags),
+ encode_hash_(1024, TupleKey(encode_flags)),
+ isymbols_(0), osymbols_(0) {}
+
+ ~EncodeTable() {
+ for (size_t i = 0; i < encode_tuples_.size(); ++i) {
+ delete encode_tuples_[i];
+ }
+ delete isymbols_;
+ delete osymbols_;
+ }
+
+ // Given an arc encode either input/ouptut labels or input/costs or both
+ Label Encode(const A &arc) {
+ const Tuple tuple(arc.ilabel,
+ flags_ & kEncodeLabels ? arc.olabel : 0,
+ flags_ & kEncodeWeights ? arc.weight : Weight::One());
+ typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
+ if (it == encode_hash_.end()) {
+ encode_tuples_.push_back(new Tuple(tuple));
+ encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
+ return encode_tuples_.size();
+ } else {
+ return it->second;
+ }
+ }
+
+ // Given an arc, look up its encoded label. Returns kNoLabel if not found.
+ Label GetLabel(const A &arc) const {
+ const Tuple tuple(arc.ilabel,
+ flags_ & kEncodeLabels ? arc.olabel : 0,
+ flags_ & kEncodeWeights ? arc.weight : Weight::One());
+ typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
+ if (it == encode_hash_.end()) {
+ return kNoLabel;
+ } else {
+ return it->second;
+ }
+ }
+
+ // Given an encode arc Label decode back to input/output labels and costs
+ const Tuple* Decode(Label key) const {
+ if (key < 1 || key > encode_tuples_.size()) {
+ LOG(ERROR) << "EncodeTable::Decode: unknown decode key: " << key;
+ return 0;
+ }
+ return encode_tuples_[key - 1];
+ }
+
+ size_t Size() const { return encode_tuples_.size(); }
+
+ bool Write(ostream &strm, const string &source) const;
+
+ static EncodeTable<A> *Read(istream &strm, const string &source);
+
+ const uint32 flags() const { return flags_ & kEncodeFlags; }
+
+ int RefCount() const { return ref_count_.count(); }
+ int IncrRefCount() { return ref_count_.Incr(); }
+ int DecrRefCount() { return ref_count_.Decr(); }
+
+
+ SymbolTable *InputSymbols() const { return isymbols_; }
+
+ SymbolTable *OutputSymbols() const { return osymbols_; }
+
+ void SetInputSymbols(const SymbolTable* syms) {
+ if (isymbols_) delete isymbols_;
+ if (syms) {
+ isymbols_ = syms->Copy();
+ flags_ |= kEncodeHasISymbols;
+ } else {
+ isymbols_ = 0;
+ flags_ &= ~kEncodeHasISymbols;
+ }
+ }
+
+ void SetOutputSymbols(const SymbolTable* syms) {
+ if (osymbols_) delete osymbols_;
+ if (syms) {
+ osymbols_ = syms->Copy();
+ flags_ |= kEncodeHasOSymbols;
+ } else {
+ osymbols_ = 0;
+ flags_ &= ~kEncodeHasOSymbols;
+ }
+ }
+
+ private:
+ uint32 flags_;
+ vector<Tuple*> encode_tuples_;
+ EncodeHash encode_hash_;
+ RefCounter ref_count_;
+ SymbolTable *isymbols_; // Pre-encoded ilabel symbol table
+ SymbolTable *osymbols_; // Pre-encoded olabel symbol table
+
+ DISALLOW_COPY_AND_ASSIGN(EncodeTable);
+};
+
+template <class A> inline
+bool EncodeTable<A>::Write(ostream &strm, const string &source) const {
+ WriteType(strm, kEncodeMagicNumber);
+ WriteType(strm, flags_);
+ int64 size = encode_tuples_.size();
+ WriteType(strm, size);
+ for (size_t i = 0; i < size; ++i) {
+ const Tuple* tuple = encode_tuples_[i];
+ WriteType(strm, tuple->ilabel);
+ WriteType(strm, tuple->olabel);
+ tuple->weight.Write(strm);
+ }
+
+ if (flags_ & kEncodeHasISymbols)
+ isymbols_->Write(strm);
+
+ if (flags_ & kEncodeHasOSymbols)
+ osymbols_->Write(strm);
+
+ strm.flush();
+ if (!strm) {
+ LOG(ERROR) << "EncodeTable::Write: write failed: " << source;
+ return false;
+ }
+ return true;
+}
+
+template <class A> inline
+EncodeTable<A> *EncodeTable<A>::Read(istream &strm, const string &source) {
+ int32 magic_number = 0;
+ ReadType(strm, &magic_number);
+ if (magic_number != kEncodeMagicNumber) {
+ LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source;
+ return 0;
+ }
+ uint32 flags;
+ ReadType(strm, &flags);
+ EncodeTable<A> *table = new EncodeTable<A>(flags);
+
+ int64 size;
+ ReadType(strm, &size);
+ if (!strm) {
+ LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
+ return 0;
+ }
+
+ for (size_t i = 0; i < size; ++i) {
+ Tuple* tuple = new Tuple();
+ ReadType(strm, &tuple->ilabel);
+ ReadType(strm, &tuple->olabel);
+ tuple->weight.Read(strm);
+ if (!strm) {
+ LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
+ return 0;
+ }
+ table->encode_tuples_.push_back(tuple);
+ table->encode_hash_[table->encode_tuples_.back()] =
+ table->encode_tuples_.size();
+ }
+
+ if (flags & kEncodeHasISymbols)
+ table->isymbols_ = SymbolTable::Read(strm, source);
+
+ if (flags & kEncodeHasOSymbols)
+ table->osymbols_ = SymbolTable::Read(strm, source);
+
+ return table;
+}
+
+
+// A mapper to encode/decode weighted transducers. Encoding of an
+// Fst is useful for performing classical determinization or minimization
+// on a weighted transducer by treating it as an unweighted acceptor over
+// encoded labels.
+//
+// The Encode mapper stores the encoding in a local hash table (EncodeTable)
+// This table is shared (and reference counted) between the encoder and
+// decoder. A decoder has read only access to the EncodeTable.
+//
+// The EncodeMapper allows on the fly encoding of the machine. As the
+// EncodeTable is generated the same table may by used to decode the machine
+// on the fly. For example in the following sequence of operations
+//
+// Encode -> Determinize -> Decode
+//
+// we will use the encoding table generated during the encode step in the
+// decode, even though the encoding is not complete.
+//
+template <class A> class EncodeMapper {
+ typedef typename A::Weight Weight;
+ typedef typename A::Label Label;
+ public:
+ EncodeMapper(uint32 flags, EncodeType type)
+ : flags_(flags),
+ type_(type),
+ table_(new EncodeTable<A>(flags)),
+ error_(false) {}
+
+ EncodeMapper(const EncodeMapper& mapper)
+ : flags_(mapper.flags_),
+ type_(mapper.type_),
+ table_(mapper.table_),
+ error_(false) {
+ table_->IncrRefCount();
+ }
+
+ // Copy constructor but setting the type, typically to DECODE
+ EncodeMapper(const EncodeMapper& mapper, EncodeType type)
+ : flags_(mapper.flags_),
+ type_(type),
+ table_(mapper.table_),
+ error_(mapper.error_) {
+ table_->IncrRefCount();
+ }
+
+ ~EncodeMapper() {
+ if (!table_->DecrRefCount()) delete table_;
+ }
+
+ A operator()(const A &arc);
+
+ MapFinalAction FinalAction() const {
+ return (type_ == ENCODE && (flags_ & kEncodeWeights)) ?
+ MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL;
+ }
+
+ MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
+
+ MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS;}
+
+ uint64 Properties(uint64 inprops) {
+ uint64 outprops = inprops;
+ if (error_) outprops |= kError;
+
+ uint64 mask = kFstProperties;
+ if (flags_ & kEncodeLabels)
+ mask &= kILabelInvariantProperties & kOLabelInvariantProperties;
+ if (flags_ & kEncodeWeights)
+ mask &= kILabelInvariantProperties & kWeightInvariantProperties &
+ (type_ == ENCODE ? kAddSuperFinalProperties :
+ kRmSuperFinalProperties);
+
+ return outprops & mask;
+ }
+
+ const uint32 flags() const { return flags_; }
+ const EncodeType type() const { return type_; }
+ const EncodeTable<A> &table() const { return *table_; }
+
+ bool Write(ostream &strm, const string& source) {
+ return table_->Write(strm, source);
+ }
+
+ bool Write(const string& filename) {
+ ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
+ if (!strm) {
+ LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
+ return false;
+ }
+ return Write(strm, filename);
+ }
+
+ static EncodeMapper<A> *Read(istream &strm,
+ const string& source,
+ EncodeType type = ENCODE) {
+ EncodeTable<A> *table = EncodeTable<A>::Read(strm, source);
+ return table ? new EncodeMapper(table->flags(), type, table) : 0;
+ }
+
+ static EncodeMapper<A> *Read(const string& filename,
+ EncodeType type = ENCODE) {
+ ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
+ if (!strm) {
+ LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
+ return NULL;
+ }
+ return Read(strm, filename, type);
+ }
+
+ SymbolTable *InputSymbols() const { return table_->InputSymbols(); }
+
+ SymbolTable *OutputSymbols() const { return table_->OutputSymbols(); }
+
+ void SetInputSymbols(const SymbolTable* syms) {
+ table_->SetInputSymbols(syms);
+ }
+
+ void SetOutputSymbols(const SymbolTable* syms) {
+ table_->SetOutputSymbols(syms);
+ }
+
+ private:
+ uint32 flags_;
+ EncodeType type_;
+ EncodeTable<A>* table_;
+ bool error_;
+
+ explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table)
+ : flags_(flags), type_(type), table_(table) {}
+ void operator=(const EncodeMapper &); // Disallow.
+};
+
+template <class A> inline
+A EncodeMapper<A>::operator()(const A &arc) {
+ if (type_ == ENCODE) { // labels and/or weights to single label
+ if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) ||
+ (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) &&
+ arc.weight == Weight::Zero())) {
+ return arc;
+ } else {
+ Label label = table_->Encode(arc);
+ return A(label,
+ flags_ & kEncodeLabels ? label : arc.olabel,
+ flags_ & kEncodeWeights ? Weight::One() : arc.weight,
+ arc.nextstate);
+ }
+ } else { // type_ == DECODE
+ if (arc.nextstate == kNoStateId) {
+ return arc;
+ } else {
+ if (arc.ilabel == 0) return arc;
+ if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel) {
+ FSTERROR() << "EncodeMapper: Label-encoded arc has different "
+ "input and output labels";
+ error_ = true;
+ }
+ if (flags_ & kEncodeWeights && arc.weight != Weight::One()) {
+ FSTERROR() <<
+ "EncodeMapper: Weight-encoded arc has non-trivial weight";
+ error_ = true;
+ }
+ const typename EncodeTable<A>::Tuple* tuple = table_->Decode(arc.ilabel);
+ if (!tuple) {
+ FSTERROR() << "EncodeMapper: decode failed";
+ error_ = true;
+ return A(kNoLabel, kNoLabel, Weight::NoWeight(), arc.nextstate);
+ } else {
+ return A(tuple->ilabel,
+ flags_ & kEncodeLabels ? tuple->olabel : arc.olabel,
+ flags_ & kEncodeWeights ? tuple->weight : arc.weight,
+ arc.nextstate);
+ }
+ }
+ }
+}
+
+
+// Complexity: O(nstates + narcs)
+template<class A> inline
+void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) {
+ mapper->SetInputSymbols(fst->InputSymbols());
+ mapper->SetOutputSymbols(fst->OutputSymbols());
+ ArcMap(fst, mapper);
+}
+
+template<class A> inline
+void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) {
+ ArcMap(fst, EncodeMapper<A>(mapper, DECODE));
+ RmFinalEpsilon(fst);
+ fst->SetInputSymbols(mapper.InputSymbols());
+ fst->SetOutputSymbols(mapper.OutputSymbols());
+}
+
+
+// On the fly label and/or weight encoding of input Fst
+//
+// Complexity:
+// - Constructor: O(1)
+// - Traversal: O(nstates_visited + narcs_visited), assuming constant
+// time to visit an input state or arc.
+template <class A>
+class EncodeFst : public ArcMapFst<A, A, EncodeMapper<A> > {
+ public:
+ typedef A Arc;
+ typedef EncodeMapper<A> C;
+ typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl;
+ using ImplToFst<Impl>::GetImpl;
+
+ EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder)
+ : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {
+ encoder->SetInputSymbols(fst.InputSymbols());
+ encoder->SetOutputSymbols(fst.OutputSymbols());
+ }
+
+ EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
+ : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {}
+
+ // See Fst<>::Copy() for doc.
+ EncodeFst(const EncodeFst<A> &fst, bool copy = false)
+ : ArcMapFst<A, A, C>(fst, copy) {}
+
+ // Get a copy of this EncodeFst. See Fst<>::Copy() for further doc.
+ virtual EncodeFst<A> *Copy(bool safe = false) const {
+ if (safe) {
+ FSTERROR() << "EncodeFst::Copy(true): not allowed.";
+ GetImpl()->SetProperties(kError, kError);
+ }
+ return new EncodeFst(*this);
+ }
+};
+
+
+// On the fly label and/or weight encoding of input Fst
+//
+// Complexity:
+// - Constructor: O(1)
+// - Traversal: O(nstates_visited + narcs_visited), assuming constant
+// time to visit an input state or arc.
+template <class A>
+class DecodeFst : public ArcMapFst<A, A, EncodeMapper<A> > {
+ public:
+ typedef A Arc;
+ typedef EncodeMapper<A> C;
+ typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl;
+ using ImplToFst<Impl>::GetImpl;
+
+ DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
+ : ArcMapFst<A, A, C>(fst,
+ EncodeMapper<A>(encoder, DECODE),
+ ArcMapFstOptions()) {
+ GetImpl()->SetInputSymbols(encoder.InputSymbols());
+ GetImpl()->SetOutputSymbols(encoder.OutputSymbols());
+ }
+
+ // See Fst<>::Copy() for doc.
+ DecodeFst(const DecodeFst<A> &fst, bool safe = false)
+ : ArcMapFst<A, A, C>(fst, safe) {}
+
+ // Get a copy of this DecodeFst. See Fst<>::Copy() for further doc.
+ virtual DecodeFst<A> *Copy(bool safe = false) const {
+ return new DecodeFst(*this, safe);
+ }
+};
+
+
+// Specialization for EncodeFst.
+template <class A>
+class StateIterator< EncodeFst<A> >
+ : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
+ public:
+ explicit StateIterator(const EncodeFst<A> &fst)
+ : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {}
+};
+
+
+// Specialization for EncodeFst.
+template <class A>
+class ArcIterator< EncodeFst<A> >
+ : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
+ public:
+ ArcIterator(const EncodeFst<A> &fst, typename A::StateId s)
+ : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {}
+};
+
+
+// Specialization for DecodeFst.
+template <class A>
+class StateIterator< DecodeFst<A> >
+ : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
+ public:
+ explicit StateIterator(const DecodeFst<A> &fst)
+ : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {}
+};
+
+
+// Specialization for DecodeFst.
+template <class A>
+class ArcIterator< DecodeFst<A> >
+ : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
+ public:
+ ArcIterator(const DecodeFst<A> &fst, typename A::StateId s)
+ : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {}
+};
+
+
+// Useful aliases when using StdArc.
+typedef EncodeFst<StdArc> StdEncodeFst;
+
+typedef DecodeFst<StdArc> StdDecodeFst;
+
+} // namespace fst
+
+#endif // FST_LIB_ENCODE_H__