diff options
author | Alexander Gutkin <agutkin@google.com> | 2012-09-12 18:11:43 +0100 |
---|---|---|
committer | Alexander Gutkin <agutkin@google.com> | 2012-09-12 18:11:43 +0100 |
commit | dfd8b8327b93660601d016cdc6f29f433b45a8d8 (patch) | |
tree | 968ec84b8e32ad73ec18d74334930f36b7471906 /src/include | |
parent | f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2 (diff) | |
download | openfst-dfd8b8327b93660601d016cdc6f29f433b45a8d8.tar.gz |
Updated OpenFST version to openfst-1.3.2-CL32004048 from Greco3.
Change-Id: I19b0db718256b35c0e3e5a7315f1ed6335e6dcac
Diffstat (limited to 'src/include')
59 files changed, 2331 insertions, 433 deletions
diff --git a/src/include/Makefile.am b/src/include/Makefile.am index b4e6213..e9eb92f 100644 --- a/src/include/Makefile.am +++ b/src/include/Makefile.am @@ -1,10 +1,17 @@ if HAVE_FAR far_include_headers = fst/extensions/far/compile-strings.h \ -fst/extensions/far/create.h fst/extensions/far/extract.h \ -fst/extensions/far/far.h fst/extensions/far/farlib.h \ -fst/extensions/far/farscript.h fst/extensions/far/info.h \ -fst/extensions/far/main.h fst/extensions/far/print-strings.h \ -fst/extensions/far/stlist.h fst/extensions/far/sttable.h +fst/extensions/far/create.h fst/extensions/far/equal.h \ +fst/extensions/far/extract.h fst/extensions/far/far.h \ +fst/extensions/far/farlib.h fst/extensions/far/farscript.h \ +fst/extensions/far/info.h fst/extensions/far/main.h \ +fst/extensions/far/print-strings.h fst/extensions/far/stlist.h \ +fst/extensions/far/sttable.h +endif + +if HAVE_NGRAM +ngram_include_headers = fst/extensions/ngram/bitmap-index.h \ +fst/extensions/ngram/ngram-fst.h \ +fst/extensions/ngram/nthbit.h endif if HAVE_PDT @@ -62,6 +69,6 @@ fst/replace-util.h fst/icu.h fst/string.h fst/signed-log-weight.h \ fst/sparse-tuple-weight.h fst/sparse-power-weight.h fst/expectation-weight.h \ fst/symbol-table-ops.h fst/bi-table.h \ $(far_include_headers) \ +$(ngram_include_headers) \ $(pdt_include_headers) \ $(script_include_headers) - diff --git a/src/include/Makefile.in b/src/include/Makefile.in index ab6c28d..b4e3f86 100644 --- a/src/include/Makefile.in +++ b/src/include/Makefile.in @@ -82,11 +82,14 @@ am__nobase_include_HEADERS_DIST = fst/arc.h fst/determinize.h \ fst/sparse-power-weight.h fst/expectation-weight.h \ fst/symbol-table-ops.h fst/bi-table.h \ fst/extensions/far/compile-strings.h \ - fst/extensions/far/create.h fst/extensions/far/extract.h \ - fst/extensions/far/far.h fst/extensions/far/farlib.h \ - fst/extensions/far/farscript.h fst/extensions/far/info.h \ - fst/extensions/far/main.h fst/extensions/far/print-strings.h \ - fst/extensions/far/stlist.h fst/extensions/far/sttable.h \ + fst/extensions/far/create.h fst/extensions/far/equal.h \ + fst/extensions/far/extract.h fst/extensions/far/far.h \ + fst/extensions/far/farlib.h fst/extensions/far/farscript.h \ + fst/extensions/far/info.h fst/extensions/far/main.h \ + fst/extensions/far/print-strings.h fst/extensions/far/stlist.h \ + fst/extensions/far/sttable.h \ + fst/extensions/ngram/bitmap-index.h \ + fst/extensions/ngram/ngram-fst.h fst/extensions/ngram/nthbit.h \ fst/extensions/pdt/collection.h fst/extensions/pdt/compose.h \ fst/extensions/pdt/expand.h fst/extensions/pdt/info.h \ fst/extensions/pdt/paren.h fst/extensions/pdt/pdt.h \ @@ -264,11 +267,16 @@ top_build_prefix = @top_build_prefix@ top_builddir = @top_builddir@ top_srcdir = @top_srcdir@ @HAVE_FAR_TRUE@far_include_headers = fst/extensions/far/compile-strings.h \ -@HAVE_FAR_TRUE@fst/extensions/far/create.h fst/extensions/far/extract.h \ -@HAVE_FAR_TRUE@fst/extensions/far/far.h fst/extensions/far/farlib.h \ -@HAVE_FAR_TRUE@fst/extensions/far/farscript.h fst/extensions/far/info.h \ -@HAVE_FAR_TRUE@fst/extensions/far/main.h fst/extensions/far/print-strings.h \ -@HAVE_FAR_TRUE@fst/extensions/far/stlist.h fst/extensions/far/sttable.h +@HAVE_FAR_TRUE@fst/extensions/far/create.h fst/extensions/far/equal.h \ +@HAVE_FAR_TRUE@fst/extensions/far/extract.h fst/extensions/far/far.h \ +@HAVE_FAR_TRUE@fst/extensions/far/farlib.h fst/extensions/far/farscript.h \ +@HAVE_FAR_TRUE@fst/extensions/far/info.h fst/extensions/far/main.h \ +@HAVE_FAR_TRUE@fst/extensions/far/print-strings.h fst/extensions/far/stlist.h \ +@HAVE_FAR_TRUE@fst/extensions/far/sttable.h + +@HAVE_NGRAM_TRUE@ngram_include_headers = fst/extensions/ngram/bitmap-index.h \ +@HAVE_NGRAM_TRUE@fst/extensions/ngram/ngram-fst.h \ +@HAVE_NGRAM_TRUE@fst/extensions/ngram/nthbit.h @HAVE_PDT_TRUE@pdt_include_headers = fst/extensions/pdt/collection.h \ @HAVE_PDT_TRUE@fst/extensions/pdt/compose.h fst/extensions/pdt/expand.h \ @@ -323,6 +331,7 @@ fst/replace-util.h fst/icu.h fst/string.h fst/signed-log-weight.h \ fst/sparse-tuple-weight.h fst/sparse-power-weight.h fst/expectation-weight.h \ fst/symbol-table-ops.h fst/bi-table.h \ $(far_include_headers) \ +$(ngram_include_headers) \ $(pdt_include_headers) \ $(script_include_headers) diff --git a/src/include/fst/accumulator.h b/src/include/fst/accumulator.h index fcb960c..9801b93 100644 --- a/src/include/fst/accumulator.h +++ b/src/include/fst/accumulator.h @@ -258,7 +258,7 @@ class FastLogAccumulator { for(StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { StateId s = siter.Value(); if (fst.NumArcs(s) >= arc_limit_) { - double sum = FloatLimits<double>::kPosInfinity; + double sum = FloatLimits<double>::PosInfinity(); weight_positions.push_back(weight_position); weights.push_back(sum); ++weight_position; @@ -282,12 +282,12 @@ class FastLogAccumulator { private: double LogPosExp(double x) { - return x == FloatLimits<double>::kPosInfinity ? + return x == FloatLimits<double>::PosInfinity() ? 0.0 : log(1.0F + exp(-x)); } double LogMinusExp(double x) { - return x == FloatLimits<double>::kPosInfinity ? + return x == FloatLimits<double>::PosInfinity() ? 0.0 : log(1.0F - exp(-x)); } @@ -302,7 +302,7 @@ class FastLogAccumulator { double LogPlus(double f1, Weight v) { double f2 = to_log_weight_(v).Value(); - if (f1 == FloatLimits<double>::kPosInfinity) + if (f1 == FloatLimits<double>::PosInfinity()) return f2; else if (f1 > f2) return f2 - LogPosExp(f1 - f2); @@ -317,7 +317,7 @@ class FastLogAccumulator { error_ = true; return Weight::NoWeight(); } - if (f2 == FloatLimits<double>::kPosInfinity) + if (f2 == FloatLimits<double>::PosInfinity()) return to_weight_(f1); else return to_weight_(f1 - LogMinusExp(f2 - f1)); @@ -485,7 +485,7 @@ class CacheLogAccumulator { if ((weights_ == 0) && (fst_->NumArcs(s) >= arc_limit_)) { weights_ = new vector<double>; weights_->reserve(fst_->NumArcs(s) + 1); - weights_->push_back(FloatLimits<double>::kPosInfinity); + weights_->push_back(FloatLimits<double>::PosInfinity()); data_->AddWeights(s, weights_); } } @@ -524,7 +524,7 @@ class CacheLogAccumulator { - weights_->begin() - 1; } else { size_t n = 0; - double x = FloatLimits<double>::kPosInfinity; + double x = FloatLimits<double>::PosInfinity(); for(aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) { x = LogPlus(x, aiter->Value().weight); if (x < w) break; @@ -537,12 +537,12 @@ class CacheLogAccumulator { private: double LogPosExp(double x) { - return x == FloatLimits<double>::kPosInfinity ? + return x == FloatLimits<double>::PosInfinity() ? 0.0 : log(1.0F + exp(-x)); } double LogMinusExp(double x) { - return x == FloatLimits<double>::kPosInfinity ? + return x == FloatLimits<double>::PosInfinity() ? 0.0 : log(1.0F - exp(-x)); } @@ -557,7 +557,7 @@ class CacheLogAccumulator { double LogPlus(double f1, Weight v) { double f2 = to_log_weight_(v).Value(); - if (f1 == FloatLimits<double>::kPosInfinity) + if (f1 == FloatLimits<double>::PosInfinity()) return f2; else if (f1 > f2) return f2 - LogPosExp(f1 - f2); @@ -572,7 +572,7 @@ class CacheLogAccumulator { error_ = true; return Weight::NoWeight(); } - if (f2 == FloatLimits<double>::kPosInfinity) + if (f2 == FloatLimits<double>::PosInfinity()) return to_weight_(f1); else return to_weight_(f1 - LogMinusExp(f2 - f1)); diff --git a/src/include/fst/arc.h b/src/include/fst/arc.h index 56086c9..5f4014b 100644 --- a/src/include/fst/arc.h +++ b/src/include/fst/arc.h @@ -34,6 +34,7 @@ #include <fst/sparse-power-weight.h> #include <iostream> #include <fstream> +#include <sstream> #include <fst/string-weight.h> diff --git a/src/include/fst/arcsort.h b/src/include/fst/arcsort.h index 38f4f95..37a51dc 100644 --- a/src/include/fst/arcsort.h +++ b/src/include/fst/arcsort.h @@ -118,9 +118,11 @@ typedef CacheOptions ArcSortFstOptions; // and exclusive of caching. template <class A, class C> class ArcSortFst : public StateMapFst<A, A, ArcSortMapper<A, C> > { + using StateMapFst<A, A, ArcSortMapper<A, C> >::GetImpl; public: typedef A Arc; - typedef ArcSortMapper<A, C> M; + typedef typename Arc::StateId StateId; + typedef ArcSortMapper<A, C> M; ArcSortFst(const Fst<A> &fst, const C &comp) : StateMapFst<A, A, M>(fst, ArcSortMapper<A, C>(fst, comp)) {} @@ -136,6 +138,18 @@ class ArcSortFst : public StateMapFst<A, A, ArcSortMapper<A, C> > { virtual ArcSortFst<A, C> *Copy(bool safe = false) const { return new ArcSortFst(*this, safe); } + + virtual size_t NumArcs(StateId s) const { + return GetImpl()->GetFst().NumArcs(s); + } + + virtual size_t NumInputEpsilons(StateId s) const { + return GetImpl()->GetFst().NumInputEpsilons(s); + } + + virtual size_t NumOutputEpsilons(StateId s) const { + return GetImpl()->GetFst().NumOutputEpsilons(s); + } }; diff --git a/src/include/fst/bi-table.h b/src/include/fst/bi-table.h index dbb436c..bd37781 100644 --- a/src/include/fst/bi-table.h +++ b/src/include/fst/bi-table.h @@ -39,8 +39,9 @@ namespace fst { // // Required constructors. // BiTable(); // -// // Lookup integer ID from entry. If it doesn't exist, then add it. -// I FindId(const T &entry); +// // Lookup integer ID from entry. If it doesn't exist and 'insert' +// / is true, then add it. Otherwise return -1. +// I FindId(const T &entry, bool insert = true); // // Lookup entry from integer ID. // const T &FindEntry(I) const; // // # of stored entries. @@ -58,11 +59,15 @@ class HashBiTable { T empty_entry; } - I FindId(const T &entry) { + I FindId(const T &entry, bool insert = true) { I &id_ref = entry2id_[entry]; - if (id_ref == 0) { // T not found; store and assign it a new ID. - id2entry_.push_back(entry); - id_ref = id2entry_.size(); + if (id_ref == 0) { // T not found + if (insert) { // store and assign it a new ID + id2entry_.push_back(entry); + id_ref = id2entry_.size(); + } else { + return -1; + } } return id_ref - 1; // NB: id_ref = ID + 1 } @@ -109,14 +114,18 @@ class CompactHashBiTable { id2entry_.reserve(table_size); } - I FindId(const T &entry) { + I FindId(const T &entry, bool insert = true) { current_entry_ = &entry; typename KeyHashSet::const_iterator it = keys_.find(kCurrentKey); - if (it == keys_.end()) { - I key = id2entry_.size(); - id2entry_.push_back(entry); - keys_.insert(key); - return key; + if (it == keys_.end()) { // T not found + if (insert) { // store and assign it a new ID + I key = id2entry_.size(); + id2entry_.push_back(entry); + keys_.insert(key); + return key; + } else { + return -1; + } } else { return *it; } @@ -191,14 +200,18 @@ class VectorBiTable { ~VectorBiTable() { delete fp_; } - I FindId(const T &entry) { + I FindId(const T &entry, bool insert = true) { ssize_t fp = (*fp_)(entry); if (fp >= fp2id_.size()) fp2id_.resize(fp + 1); I &id_ref = fp2id_[fp]; - if (id_ref == 0) { // T not found; store and assign it a new ID. - id2entry_.push_back(entry); - id_ref = id2entry_.size(); + if (id_ref == 0) { // T not found + if (insert) { // store and assign it a new ID + id2entry_.push_back(entry); + id_ref = id2entry_.size(); + } else { + return -1; + } } return id_ref - 1; // NB: id_ref = ID + 1 } @@ -251,24 +264,32 @@ class VectorHashBiTable { delete h_; } - I FindId(const T &entry) { + I FindId(const T &entry, bool insert = true) { if ((*selector_)(entry)) { // Use the vector if 'selector_(entry) == true' uint64 fp = (*fp_)(entry); if (fp2id_.size() <= fp) fp2id_.resize(fp + 1, 0); - if (fp2id_[fp] == 0) { - id2entry_.push_back(entry); - fp2id_[fp] = id2entry_.size(); + if (fp2id_[fp] == 0) { // T not found + if (insert) { // store and assign it a new ID + id2entry_.push_back(entry); + fp2id_[fp] = id2entry_.size(); + } else { + return -1; + } } return fp2id_[fp] - 1; // NB: assoc_value = ID + 1 } else { // Use the hash table otherwise. current_entry_ = &entry; typename KeyHashSet::const_iterator it = keys_.find(kCurrentKey); if (it == keys_.end()) { - I key = id2entry_.size(); - id2entry_.push_back(entry); - keys_.insert(key); - return key; + if (insert) { + I key = id2entry_.size(); + id2entry_.push_back(entry); + keys_.insert(key); + return key; + } else { + return -1; + } } else { return *it; } @@ -357,11 +378,15 @@ class ErasableBiTable { public: ErasableBiTable() : first_(0) {} - I FindId(const T &entry) { + I FindId(const T &entry, bool insert = true) { I &id_ref = entry2id_[entry]; - if (id_ref == 0) { // T not found; store and assign it a new ID. - id2entry_.push_back(entry); - id_ref = id2entry_.size() + first_; + if (id_ref == 0) { // T not found + if (insert) { // store and assign it a new ID + id2entry_.push_back(entry); + id_ref = id2entry_.size() + first_; + } else { + return -1; + } } return id_ref - 1; // NB: id_ref = ID + 1 } diff --git a/src/include/fst/cache.h b/src/include/fst/cache.h index a6a92d4..0177396 100644 --- a/src/include/fst/cache.h +++ b/src/include/fst/cache.h @@ -292,13 +292,13 @@ class CacheBaseImpl : public VectorFstBaseImpl<S> { void DeleteArcs(StateId s, size_t n) { S *state = ExtendState(s); - const vector<Arc> &arcs = GetState(s)->arcs; + const vector<Arc> &arcs = state->arcs; for (size_t i = 0; i < n; ++i) { size_t j = arcs.size() - i - 1; if (arcs[j].ilabel == 0) - --GetState(s)->niepsilons; + --state->niepsilons; if (arcs[j].olabel == 0) - --GetState(s)->noepsilons; + --state->noepsilons; } state->arcs.resize(arcs.size() - n); SetProperties(DeleteArcsProperties(Properties())); @@ -503,9 +503,6 @@ struct CacheState { size_t noepsilons; // # of output epsilons mutable uint32 flags; mutable int ref_count; - - private: - DISALLOW_COPY_AND_ASSIGN(CacheState); }; // A CacheBaseImpl with a commonly used CacheState. diff --git a/src/include/fst/compact-fst.h b/src/include/fst/compact-fst.h index efa567a..57c927e 100644 --- a/src/include/fst/compact-fst.h +++ b/src/include/fst/compact-fst.h @@ -175,10 +175,10 @@ class CompactFstData { bool Error() const { return error_; } - private: // Byte alignment for states and arcs in file format (version 1 only) static const int kFileAlign = 16; + private: Unsigned *states_; CompactElement *compacts_; size_t nstates_; @@ -539,17 +539,16 @@ class CompactFstImpl : public CacheImpl<A> { } Weight Final(StateId s) { - if (!HasFinal(s)) { - Arc arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId); - if ((compactor_->Size() != -1) || - (data_->States(s) != data_->States(s + 1))) - arc = ComputeArc(s, - compactor_->Size() == -1 - ? data_->States(s) - : s * compactor_->Size()); - SetFinal(s, arc.ilabel == kNoLabel ? arc.weight : Weight::Zero()); - } - return CacheImpl<A>::Final(s); + if (HasFinal(s)) + return CacheImpl<A>::Final(s); + Arc arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId); + if ((compactor_->Size() != -1) || + (data_->States(s) != data_->States(s + 1))) + arc = ComputeArc(s, + compactor_->Size() == -1 + ? data_->States(s) + : s * compactor_->Size()); + return arc.ilabel == kNoLabel ? arc.weight : Weight::Zero(); } StateId NumStates() const { @@ -650,7 +649,6 @@ class CompactFstImpl : public CacheImpl<A> { // Ensures compatibility int file_version = opts.align ? kAlignedFileVersion : kFileVersion; WriteHeader(strm, opts, file_version, &hdr); - compactor_->Write(strm); return data_->Write(strm, opts); } @@ -678,9 +676,13 @@ class CompactFstImpl : public CacheImpl<A> { data_->States(s + 1) : (s + 1) * compactor_->Size(); for (size_t i = begin; i < end; ++i) { const Arc &arc = ComputeArc(s, i); - if (arc.ilabel == kNoLabel) continue; - PushArc(s, arc); + if (arc.ilabel == kNoLabel) + SetFinal(s, arc.weight); + else + PushArc(s, arc); } + if (!HasFinal(s)) + SetFinal(s, Weight::Zero()); SetArcs(s); } @@ -694,6 +696,9 @@ class CompactFstImpl : public CacheImpl<A> { C *GetCompactor() const { return compactor_; } CompactFstData<CompactElement, U> *Data() const { return data_; } + // Properties always true of this Fst class + static const uint64 kStaticProperties = kExpanded; + protected: template <class B, class D> explicit CompactFstImpl(const CompactFstImpl<B, D, U> &impl) @@ -710,6 +715,8 @@ class CompactFstImpl : public CacheImpl<A> { } private: + friend class CompactFst<A, C, U>; // allow access during write. + void Init(const Fst<Arc> &fst) { string type = "compact"; if (sizeof(U) != sizeof(uint32)) { @@ -751,8 +758,6 @@ class CompactFstImpl : public CacheImpl<A> { SetProperties(kError, kError); } - // Properties always true of this Fst class - static const uint64 kStaticProperties = kExpanded; // Current unaligned file format version static const int kFileVersion = 2; // Current aligned file format version @@ -863,6 +868,10 @@ class CompactFst : public ImplToExpandedFst< CompactFstImpl<A, C, U> > { return Fst<A>::WriteFile(filename); } + template <class F> + static bool WriteFst(const F &fst, const C &compactor, ostream &strm, + const FstWriteOptions &opts); + virtual void InitStateIterator(StateIteratorData<A> *data) const { GetImpl()->InitStateIterator(data); } @@ -893,6 +902,115 @@ class CompactFst : public ImplToExpandedFst< CompactFstImpl<A, C, U> > { void operator=(const CompactFst<A, C, U> &fst); // disallow }; +// Writes Fst in Compact format, potentially with a pass over the machine +// before writing to compute the number of states and arcs. +// +template <class A, class C, class U> +template <class F> +bool CompactFst<A, C, U>::WriteFst(const F &fst, + const C &compactor, + ostream &strm, + const FstWriteOptions &opts) { + typedef U Unsigned; + typedef typename C::Element CompactElement; + typedef typename A::Weight Weight; + static const int kFileAlign = + CompactFstData<CompactElement, U>::kFileAlign; + int file_version = opts.align ? + CompactFstImpl<A, C, U>::kAlignedFileVersion : + CompactFstImpl<A, C, U>::kFileVersion; + size_t num_arcs = -1, num_states = -1, num_compacts = -1; + C first_pass_compactor = compactor; + if (fst.Type() == CompactFst<A, C, U>().Type()) { + const CompactFst<A, C, U> *compact_fst = + reinterpret_cast<const CompactFst<A, C, U> *>(&fst); + num_arcs = compact_fst->GetImpl()->Data()->NumArcs(); + num_states = compact_fst->GetImpl()->Data()->NumStates(); + num_compacts = compact_fst->GetImpl()->Data()->NumCompacts(); + first_pass_compactor = *compact_fst->GetImpl()->GetCompactor(); + } else { + // A first pass is needed to compute the state of the compactor, which + // is saved ahead of the rest of the data structures. This unfortunately + // means forcing a complete double compaction when writing in this format. + // TODO(allauzen): eliminate mutable state from compactors. + num_arcs = 0; + num_states = 0; + for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { + const StateId s = siter.Value(); + ++num_states; + if (fst.Final(s) != Weight::Zero()) { + first_pass_compactor.Compact( + s, A(kNoLabel, kNoLabel, fst.Final(s), kNoStateId)); + } + for (ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) { + ++num_arcs; + first_pass_compactor.Compact(s, aiter.Value()); + } + } + } + FstHeader hdr; + hdr.SetStart(fst.Start()); + hdr.SetNumStates(num_states); + hdr.SetNumArcs(num_arcs); + string type = "compact"; + if (sizeof(U) != sizeof(uint32)) { + string size; + Int64ToStr(8 * sizeof(U), &size); + type += size; + } + type += "_"; + type += C::Type(); + uint64 copy_properties = fst.Properties(kCopyProperties, true); + if ((copy_properties & kError) || !compactor.Compatible(fst)) { + LOG(ERROR) << "fst incompatible with compactor"; + return false; + } + uint64 properties = copy_properties | + CompactFstImpl<A, C, U>::kStaticProperties; + FstImpl<A>::WriteFstHeader(fst, strm, opts, file_version, type, properties, + &hdr); + first_pass_compactor.Write(strm); + if (first_pass_compactor.Size() == -1) { + if (opts.align && !AlignOutput(strm, kFileAlign)) { + LOG(ERROR) << "CompactFst::Write: Alignment failed: " << opts.source; + return false; + } + Unsigned compacts = 0; + for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { + const StateId s = siter.Value(); + strm.write(reinterpret_cast<const char *>(&compacts), sizeof(compacts)); + if (fst.Final(s) != Weight::Zero()) { + ++compacts; + } + compacts += fst.NumArcs(s); + } + strm.write(reinterpret_cast<const char *>(&compacts), sizeof(compacts)); + } + if (opts.align && !AlignOutput(strm, kFileAlign)) { + LOG(ERROR) << "Could not align file during write after writing states"; + } + C second_pass_compactor = compactor; + CompactElement element; + for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { + const StateId s = siter.Value(); + if (fst.Final(s) != Weight::Zero()) { + element = second_pass_compactor.Compact( + s, A(kNoLabel, kNoLabel, fst.Final(s), kNoStateId)); + strm.write(reinterpret_cast<const char *>(&element), sizeof(element)); + } + for (ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) { + element = second_pass_compactor.Compact(s, aiter.Value()); + strm.write(reinterpret_cast<const char *>(&element), sizeof(element)); + } + } + strm.flush(); + if (!strm) { + LOG(ERROR) << "CompactFst write failed: " << opts.source; + return false; + } + return true; +} + // Specialization for CompactFst; see generic version in fst.h // for sample usage (but use the CompactFst type!). This version diff --git a/src/include/fst/compat.h b/src/include/fst/compat.h index 034b57e..3b5275d 100644 --- a/src/include/fst/compat.h +++ b/src/include/fst/compat.h @@ -39,10 +39,7 @@ #include <fst/lock.h> #include <fst/flags.h> #include <fst/log.h> - -#ifdef HAVE_ICU #include <fst/icu.h> -#endif using std::cin; using std::cout; @@ -87,7 +84,7 @@ class CheckSummer { void Reset() { count_ = 0; for (int i = 0; i < kCheckSumLength; ++i) - check_sum_[0] = '\0'; + check_sum_[i] = '\0'; } void Update(void const *data, int size) { @@ -113,24 +110,6 @@ class CheckSummer { DISALLOW_COPY_AND_ASSIGN(CheckSummer); }; -// Define the UTF8 string conversion function to throw an error -// when the ICU Library is missing or disabled. -#ifndef HAVE_ICU - -template <class Label> -bool UTF8StringToLabels(const string&, std::vector<Label>*) { - LOG(ERROR) << "UTF8StringToLabels: ICU Library required for UTF8 handling"; - return false; -} - -template <class Label> -bool LabelsToUTF8String(const std::vector<Label>&, string*) { - LOG(ERROR) << "LabelsToUTF8String: ICU Library required for UTF8 handling"; - return false; -} - -#endif // HAVE_ICU - } // namespace fst diff --git a/src/include/fst/compose.h b/src/include/fst/compose.h index c0bf4b1..dfdff0a 100644 --- a/src/include/fst/compose.h +++ b/src/include/fst/compose.h @@ -360,6 +360,10 @@ class ComposeFstImpl : public ComposeFstImplBase<typename M1::Arc> { return Times(final1, final2); } + // Identifies and verifies the capabilities of the matcher to be used for + // composition. + void SetMatchType(); + F *filter_; Matcher1 *matcher1_; Matcher2 *matcher2_; @@ -385,14 +389,63 @@ ComposeFstImpl<M1, M2, F, T>::ComposeFstImpl( fst2_(matcher2_->GetFst()), state_table_(opts.state_table ? opts.state_table : new T(fst1_, fst2_)) { + SetMatchType(); + if (match_type_ == MATCH_NONE) + SetProperties(kError, kError); + VLOG(2) << "ComposeFst(" << this << "): Match type: " + << (match_type_ == MATCH_OUTPUT ? "output" : + (match_type_ == MATCH_INPUT ? "input" : + (match_type_ == MATCH_BOTH ? "both" : + (match_type_ == MATCH_NONE ? "none" : "unknown")))); + + uint64 fprops1 = fst1.Properties(kFstProperties, false); + uint64 fprops2 = fst2.Properties(kFstProperties, false); + uint64 mprops1 = matcher1_->Properties(fprops1); + uint64 mprops2 = matcher2_->Properties(fprops2); + uint64 cprops = ComposeProperties(mprops1, mprops2); + SetProperties(filter_->Properties(cprops), kCopyProperties); + if (state_table_->Error()) SetProperties(kError, kError); + VLOG(2) << "ComposeFst(" << this << "): Initialized"; +} + +template <class M1, class M2, class F, class T> +void ComposeFstImpl<M1, M2, F, T>::SetMatchType() { MatchType type1 = matcher1_->Type(false); MatchType type2 = matcher2_->Type(false); - if (type1 == MATCH_OUTPUT && type2 == MATCH_INPUT) { + uint32 flags1 = matcher1_->Flags(); + uint32 flags2 = matcher2_->Flags(); + if (flags1 & flags2 & kRequireMatch) { + FSTERROR() << "ComposeFst: only one argument can require matching."; + match_type_ = MATCH_NONE; + } else if (flags1 & kRequireMatch) { + if (matcher1_->Type(true) != MATCH_OUTPUT) { + FSTERROR() << "ComposeFst: 1st argument requires matching but cannot."; + match_type_ = MATCH_NONE; + } + match_type_ = MATCH_OUTPUT; + } else if (flags2 & kRequireMatch) { + if (matcher2_->Type(true) != MATCH_INPUT) { + FSTERROR() << "ComposeFst: 2nd argument requires matching but cannot."; + match_type_ = MATCH_NONE; + } + match_type_ = MATCH_INPUT; + } else if (flags1 & flags2 & kPreferMatch && + type1 == MATCH_OUTPUT && type2 == MATCH_INPUT) { + match_type_ = MATCH_BOTH; + } else if (flags1 & kPreferMatch && type1 == MATCH_OUTPUT) { + match_type_ = MATCH_OUTPUT; + } else if (flags2 & kPreferMatch && type2 == MATCH_INPUT) { + match_type_ = MATCH_INPUT; + } else if (type1 == MATCH_OUTPUT && type2 == MATCH_INPUT) { match_type_ = MATCH_BOTH; } else if (type1 == MATCH_OUTPUT) { match_type_ = MATCH_OUTPUT; } else if (type2 == MATCH_INPUT) { match_type_ = MATCH_INPUT; + } else if (flags1 & kPreferMatch && matcher1_->Type(true) == MATCH_OUTPUT) { + match_type_ = MATCH_OUTPUT; + } else if (flags2 & kPreferMatch && matcher2_->Type(true) == MATCH_INPUT) { + match_type_ = MATCH_INPUT; } else if (matcher1_->Type(true) == MATCH_OUTPUT) { match_type_ = MATCH_OUTPUT; } else if (matcher2_->Type(true) == MATCH_INPUT) { @@ -400,16 +453,8 @@ ComposeFstImpl<M1, M2, F, T>::ComposeFstImpl( } else { FSTERROR() << "ComposeFst: 1st argument cannot match on output labels " << "and 2nd argument cannot match on input labels (sort?)."; - SetProperties(kError, kError); + match_type_ = MATCH_NONE; } - uint64 fprops1 = fst1.Properties(kFstProperties, false); - uint64 fprops2 = fst2.Properties(kFstProperties, false); - uint64 mprops1 = matcher1_->Properties(fprops1); - uint64 mprops2 = matcher2_->Properties(fprops2); - uint64 cprops = ComposeProperties(mprops1, mprops2); - SetProperties(filter_->Properties(cprops), kCopyProperties); - if (state_table_->Error()) SetProperties(kError, kError); - VLOG(2) << "ComposeFst(" << this << "): Initialized"; } @@ -539,16 +584,19 @@ class ComposeFst : public ImplToFst< ComposeFstImplBase<A> > { switch (LookAheadMatchType(fst1, fst2)) { // Check for lookahead matchers default: case MATCH_NONE: { // Default composition (no look-ahead) + VLOG(2) << "ComposeFst: Default composition (no look-ahead)"; ComposeFstOptions<Arc> nopts(opts); return CreateBase1(fst1, fst2, nopts); } case MATCH_OUTPUT: { // Lookahead on fst1 + VLOG(2) << "ComposeFst: Lookahead on fst1"; typedef typename DefaultLookAhead<Arc, MATCH_OUTPUT>::FstMatcher M; typedef typename DefaultLookAhead<Arc, MATCH_OUTPUT>::ComposeFilter F; ComposeFstOptions<Arc, M, F> nopts(opts); return CreateBase1(fst1, fst2, nopts); } case MATCH_INPUT: { // Lookahead on fst2 + VLOG(2) << "ComposeFst: Lookahead on fst2"; typedef typename DefaultLookAhead<Arc, MATCH_INPUT>::FstMatcher M; typedef typename DefaultLookAhead<Arc, MATCH_INPUT>::ComposeFilter F; ComposeFstOptions<Arc, M, F> nopts(opts); diff --git a/src/include/fst/config.h b/src/include/fst/config.h index 046b49c..47e472e 100644 --- a/src/include/fst/config.h +++ b/src/include/fst/config.h @@ -4,7 +4,7 @@ /* Define to 1 if you have the ICU library. */ /* #undef HAVE_ICU */ -/* Define to 1 if the system has the type `std::tr1::hash<long long +/* Define to 1 if the system has the type `std::hash<long long unsigned>'. */ #define HAVE_STD__TR1__HASH_LONG_LONG_UNSIGNED_ 1 diff --git a/src/include/fst/const-fst.h b/src/include/fst/const-fst.h index f68e8ed..80efc8d 100644 --- a/src/include/fst/const-fst.h +++ b/src/include/fst/const-fst.h @@ -87,8 +87,6 @@ class ConstFstImpl : public FstImpl<A> { static ConstFstImpl<A, U> *Read(istream &strm, const FstReadOptions &opts); - bool Write(ostream &strm, const FstWriteOptions &opts) const; - A *Arcs(StateId s) { return arcs_ + states_[s].pos; } // Provide information needed for generic state iterator @@ -330,23 +328,24 @@ template <class A, class U> template <class F> bool ConstFst<A, U>::WriteFst(const F &fst, ostream &strm, const FstWriteOptions &opts) { - static const int kFileVersion = 2; - static const int kAlignedFileVersion = 1; - static const int kFileAlign = 16; - int file_version = opts.align ? kAlignedFileVersion : kFileVersion; + int file_version = opts.align ? ConstFstImpl<A, U>::kAlignedFileVersion : + ConstFstImpl<A, U>::kFileVersion; size_t num_arcs = -1, num_states = -1; size_t start_offset = 0; bool update_header = true; if (fst.Type() == ConstFst<A, U>().Type()) { - const ConstFst<A, U> *const_fst = static_cast<const ConstFst<A, U> *>(&fst); + const ConstFst<A, U> *const_fst = + reinterpret_cast<const ConstFst<A, U> *>(&fst); num_arcs = const_fst->GetImpl()->narcs_; num_states = const_fst->GetImpl()->nstates_; update_header = false; } else if ((start_offset = strm.tellp()) == -1) { // precompute values needed for header when we cannot seek to rewrite it. + num_arcs = 0; + num_states = 0; for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { num_arcs += fst.NumArcs(siter.Value()); - num_states++; + ++num_states; } update_header = false; } @@ -360,8 +359,11 @@ bool ConstFst<A, U>::WriteFst(const F &fst, ostream &strm, Int64ToStr(8 * sizeof(U), &size); type += size; } - FstImpl<A>::WriteFstHeader(fst, strm, opts, file_version, type, &hdr); - if (opts.align && !AlignOutput(strm, kFileAlign)) { + uint64 properties = fst.Properties(kCopyProperties, true) | + ConstFstImpl<A, U>::kStaticProperties; + FstImpl<A>::WriteFstHeader(fst, strm, opts, file_version, type, properties, + &hdr); + if (opts.align && !AlignOutput(strm, ConstFstImpl<A, U>::kFileAlign)) { LOG(ERROR) << "Could not align file during write after header"; return false; } @@ -375,11 +377,11 @@ bool ConstFst<A, U>::WriteFst(const F &fst, ostream &strm, state.noepsilons = fst.NumOutputEpsilons(siter.Value()); strm.write(reinterpret_cast<const char *>(&state), sizeof(state)); pos += state.narcs; - states++; + ++states; } hdr.SetNumStates(states); hdr.SetNumArcs(pos); - if (opts.align && !AlignOutput(strm, kFileAlign)) { + if (opts.align && !AlignOutput(strm, ConstFstImpl<A, U>::kFileAlign)) { LOG(ERROR) << "Could not align file during write after writing states"; } for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { @@ -391,12 +393,12 @@ bool ConstFst<A, U>::WriteFst(const F &fst, ostream &strm, } strm.flush(); if (!strm) { - LOG(ERROR) << "WriteAsVectorFst write failed: " << opts.source; + LOG(ERROR) << "ConstFst Write write failed: " << opts.source; return false; } if (update_header) { return FstImpl<A>::UpdateFstHeader(fst, strm, opts, file_version, type, - &hdr, start_offset); + properties, &hdr, start_offset); } else { if (hdr.NumStates() != num_states) { LOG(ERROR) << "Inconsistent number of states observed during write"; diff --git a/src/include/fst/dfs-visit.h b/src/include/fst/dfs-visit.h index b47c78d..4d93a39 100644 --- a/src/include/fst/dfs-visit.h +++ b/src/include/fst/dfs-visit.h @@ -177,7 +177,8 @@ void DfsVisit(const Fst<Arc> &fst, V *visitor, ArcFilter filter) { // Find next tree root for (root = root == start ? 0 : root + 1; root < nstates && state_color[root] != kDfsWhite; - ++root); + ++root) { + } // Check for a state beyond the largest known state if (!expanded && root == nstates) { diff --git a/src/include/fst/expanded-fst.h b/src/include/fst/expanded-fst.h index b44b81c..676ceb3 100644 --- a/src/include/fst/expanded-fst.h +++ b/src/include/fst/expanded-fst.h @@ -82,7 +82,7 @@ class ExpandedFst : public Fst<A> { } return Read(strm, FstReadOptions(filename)); } else { - return Read(std::cin, FstReadOptions("standard input")); + return Read(cin, FstReadOptions("standard input")); } } }; @@ -154,7 +154,7 @@ class ImplToExpandedFst : public ImplToFst<I, F> { } return I::Read(strm, FstReadOptions(filename)); } else { - return I::Read(std::cin, FstReadOptions("standard input")); + return I::Read(cin, FstReadOptions("standard input")); } } diff --git a/src/include/fst/extensions/far/compile-strings.h b/src/include/fst/extensions/far/compile-strings.h index d7f4d6b..ca247db 100644 --- a/src/include/fst/extensions/far/compile-strings.h +++ b/src/include/fst/extensions/far/compile-strings.h @@ -56,7 +56,7 @@ class StringReader { const SymbolTable *syms = 0, Label unknown_label = kNoStateId) : nline_(0), strm_(istrm), source_(source), entry_type_(entry_type), - token_type_(token_type), done_(false), + token_type_(token_type), symbols_(syms), done_(false), compiler_(token_type, syms, unknown_label, allow_negative_labels) { Next(); // Initialize the reader to the first input. } @@ -87,8 +87,12 @@ class StringReader { done_ = true; // whitespace at the end of a file. } - VectorFst<A> *GetVectorFst() { + VectorFst<A> *GetVectorFst(bool keep_symbols = false) { VectorFst<A> *fst = new VectorFst<A>; + if (keep_symbols) { + fst->SetInputSymbols(symbols_); + fst->SetOutputSymbols(symbols_); + } if (compiler_(content_, fst)) { return fst; } else { @@ -97,9 +101,16 @@ class StringReader { } } - CompactFst<A, StringCompactor<A> > *GetCompactFst() { - CompactFst<A, StringCompactor<A> > *fst = - new CompactFst<A, StringCompactor<A> >; + CompactFst<A, StringCompactor<A> > *GetCompactFst(bool keep_symbols = false) { + CompactFst<A, StringCompactor<A> > *fst; + if (keep_symbols) { + VectorFst<A> tmp; + tmp.SetInputSymbols(symbols_); + tmp.SetOutputSymbols(symbols_); + fst = new CompactFst<A, StringCompactor<A> >(tmp); + } else { + fst = new CompactFst<A, StringCompactor<A> >; + } if (compiler_(content_, fst)) { return fst; } else { @@ -114,6 +125,7 @@ class StringReader { string source_; EntryType entry_type_; TokenType token_type_; + const SymbolTable *symbols_; bool done_; StringCompiler<A> compiler_; string content_; // The actual content of the input stream's next FST. @@ -135,6 +147,8 @@ void FarCompileStrings(const vector<string> &in_fnames, FarTokenType tt, const string &symbols_fname, const string &unknown_symbol, + bool keep_symbols, + bool initial_symbols, bool allow_negative_labels, bool file_list_input, const string &key_prefix, @@ -175,8 +189,9 @@ void FarCompileStrings(const vector<string> &in_fnames, const SymbolTable *syms = 0; typename Arc::Label unknown_label = kNoLabel; if (!symbols_fname.empty()) { - syms = SymbolTable::ReadText(symbols_fname, - allow_negative_labels); + SymbolTableTextOptions opts; + opts.allow_negative = allow_negative_labels; + syms = SymbolTable::ReadText(symbols_fname, opts); if (!syms) { FSTERROR() << "FarCompileStrings: error reading symbol table: " << symbols_fname; @@ -199,32 +214,47 @@ void FarCompileStrings(const vector<string> &in_fnames, vector<string> inputs; if (file_list_input) { for (int i = 1; i < in_fnames.size(); ++i) { - ifstream istrm(in_fnames[i].c_str()); + istream *istrm = in_fnames.empty() ? &cin : + new ifstream(in_fnames[i].c_str()); string str; - while (getline(istrm, str)) + while (getline(*istrm, str)) inputs.push_back(str); + if (!in_fnames.empty()) + delete istrm; } } else { inputs = in_fnames; } for (int i = 0, n = 0; i < inputs.size(); ++i) { + if (generate_keys == 0 && inputs[i].empty()) { + FSTERROR() << "FarCompileStrings: read from a file instead of stdin or" + << " set the --generate_keys flags."; + delete far_writer; + delete syms; + return; + } int key_size = generate_keys ? generate_keys : (entry_type == StringReader<Arc>::FILE ? 1 : KeySize(inputs[i].c_str())); - ifstream istrm(inputs[i].c_str()); + istream *istrm = inputs[i].empty() ? &cin : + new ifstream(inputs[i].c_str()); + bool keep_syms = keep_symbols; for (StringReader<Arc> reader( - istrm, inputs[i], entry_type, token_type, - allow_negative_labels, syms, unknown_label); + *istrm, inputs[i].empty() ? "stdin" : inputs[i], + entry_type, token_type, allow_negative_labels, + syms, unknown_label); !reader.Done(); reader.Next()) { ++n; const Fst<Arc> *fst; if (compact) - fst = reader.GetCompactFst(); + fst = reader.GetCompactFst(keep_syms); else - fst = reader.GetVectorFst(); + fst = reader.GetVectorFst(keep_syms); + if (initial_symbols) + keep_syms = false; if (!fst) { FSTERROR() << "FarCompileStrings: compiling string number " << n << " in file " << inputs[i] << " failed with token_type = " @@ -236,6 +266,7 @@ void FarCompileStrings(const vector<string> &in_fnames, (fet == FET_FILE ? "file" : "unknown")); delete far_writer; delete syms; + if (!inputs[i].empty()) delete istrm; return; } ostringstream keybuf; @@ -260,6 +291,8 @@ void FarCompileStrings(const vector<string> &in_fnames, } if (generate_keys == 0) n = 0; + if (!inputs[i].empty()) + delete istrm; } delete far_writer; diff --git a/src/include/fst/extensions/far/equal.h b/src/include/fst/extensions/far/equal.h new file mode 100644 index 0000000..be82e2d --- /dev/null +++ b/src/include/fst/extensions/far/equal.h @@ -0,0 +1,99 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: allauzen@google.com (Cyril Allauzen) + +#ifndef FST_EXTENSIONS_FAR_EQUAL_H_ +#define FST_EXTENSIONS_FAR_EQUAL_H_ + +#include <string> + +#include <fst/extensions/far/far.h> +#include <fst/equal.h> + +namespace fst { + +template <class Arc> +bool FarEqual(const string &filename1, + const string &filename2, + float delta = kDelta, + const string &begin_key = string(), + const string &end_key = string()) { + + FarReader<Arc> *reader1 = FarReader<Arc>::Open(filename1); + FarReader<Arc> *reader2 = FarReader<Arc>::Open(filename2); + if (!reader1 || !reader2) { + delete reader1; + delete reader2; + VLOG(1) << "FarEqual: cannot open input Far file(s)"; + return false; + } + + if (!begin_key.empty()) { + bool find_begin1 = reader1->Find(begin_key); + bool find_begin2 = reader2->Find(begin_key); + if (!find_begin1 || !find_begin2) { + bool ret = !find_begin1 && !find_begin2; + if (!ret) { + VLOG(1) << "FarEqual: key \"" << begin_key << "\" missing from " + << (find_begin1 ? "second" : "first") << " archive."; + } + delete reader1; + delete reader2; + return ret; + } + } + + for(; !reader1->Done() && !reader2->Done(); + reader1->Next(), reader2->Next()) { + const string key1 = reader1->GetKey(); + const string key2 = reader2->GetKey(); + if (!end_key.empty() && end_key < key1 && end_key < key2) { + delete reader1; + delete reader2; + return true; + } + if (key1 != key2) { + VLOG(1) << "FarEqual: mismatched keys \"" + << key1 << "\" <> \"" << key2 << "\"."; + delete reader1; + delete reader2; + return false; + } + if (!Equal(reader1->GetFst(), reader2->GetFst(), delta)) { + VLOG(1) << "FarEqual: Fsts for key \"" << key1 << "\" are not equal."; + delete reader1; + delete reader2; + return false; + } + } + + if (!reader1->Done() || !reader2->Done()) { + VLOG(1) << "FarEqual: key \"" + << (reader1->Done() ? reader2->GetKey() : reader1->GetKey()) + << "\" missing form " << (reader2->Done() ? "first" : "second") + << " archive."; + delete reader1; + delete reader2; + return false; + } + + delete reader1; + delete reader2; + return true; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_EQUAL_H_ diff --git a/src/include/fst/extensions/far/extract.h b/src/include/fst/extensions/far/extract.h index 022ca60..d6f92ff 100644 --- a/src/include/fst/extensions/far/extract.h +++ b/src/include/fst/extensions/far/extract.h @@ -70,7 +70,7 @@ void FarExtract(const vector<string> &ifilenames, if (nrep > 0) { ostringstream tmp; tmp << '.' << nrep; - key += tmp.str(); + key.append(tmp.str().data(), tmp.str().size()); } ofilename = key; } diff --git a/src/include/fst/extensions/far/far.h b/src/include/fst/extensions/far/far.h index 82b9e5c..acce76e 100644 --- a/src/include/fst/extensions/far/far.h +++ b/src/include/fst/extensions/far/far.h @@ -32,6 +32,13 @@ namespace fst { enum FarEntryType { FET_LINE, FET_FILE }; enum FarTokenType { FTT_SYMBOL, FTT_BYTE, FTT_UTF8 }; +inline bool IsFst(const string &filename) { + ifstream strm(filename.c_str()); + if (!strm) + return false; + return IsFstHeader(strm, filename); +} + // FST archive header class class FarHeader { public: @@ -40,8 +47,11 @@ class FarHeader { bool Read(const string &filename) { FstHeader fsthdr; - if (filename.empty()) { // Header reading unsupported on stdin. - return false; + if (filename.empty()) { + // Header reading unsupported on stdin. Assumes STList and StdArc. + fartype_ = "stlist"; + arctype_ = "standard"; + return true; } else if (IsSTTable(filename)) { // Check if STTable ReadSTTableHeader(filename, &fsthdr); fartype_ = "sttable"; @@ -52,6 +62,12 @@ class FarHeader { fartype_ = "sttable"; arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); return true; + } else if (IsFst(filename)) { // Check if Fst + ifstream istrm(filename.c_str()); + fsthdr.Read(istrm, filename); + fartype_ = "fst"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; } return false; } @@ -61,8 +77,12 @@ class FarHeader { string arctype_; }; -enum FarType { FAR_DEFAULT = 0, FAR_STTABLE = 1, FAR_STLIST = 2, - FAR_SSTABLE = 3 }; +enum FarType { + FAR_DEFAULT = 0, + FAR_STTABLE = 1, + FAR_STLIST = 2, + FAR_FST = 3, +}; // This class creates an archive of FSTs. template <class A> @@ -153,7 +173,7 @@ class STTableFarWriter : public FarWriter<A> { public: typedef A Arc; - static STTableFarWriter *Create(const string filename) { + static STTableFarWriter *Create(const string &filename) { STTableWriter<Fst<A>, FstWriter<A> > *writer = STTableWriter<Fst<A>, FstWriter<A> >::Create(filename); return new STTableFarWriter(writer); @@ -183,7 +203,7 @@ class STListFarWriter : public FarWriter<A> { public: typedef A Arc; - static STListFarWriter *Create(const string filename) { + static STListFarWriter *Create(const string &filename) { STListWriter<Fst<A>, FstWriter<A> > *writer = STListWriter<Fst<A>, FstWriter<A> >::Create(filename); return new STListFarWriter(writer); @@ -209,6 +229,43 @@ class STListFarWriter : public FarWriter<A> { template <class A> +class FstFarWriter : public FarWriter<A> { + public: + typedef A Arc; + + explicit FstFarWriter(const string &filename) + : filename_(filename), error_(false), written_(false) {} + + static FstFarWriter *Create(const string &filename) { + return new FstFarWriter(filename); + } + + void Add(const string &key, const Fst<A> &fst) { + if (written_) { + LOG(WARNING) << "FstFarWriter::Add: only one Fst supported," + << " subsequent entries discarded."; + } else { + error_ = !fst.Write(filename_); + written_ = true; + } + } + + FarType Type() const { return FAR_FST; } + + bool Error() const { return error_; } + + ~FstFarWriter() {} + + private: + string filename_; + bool error_; + bool written_; + + DISALLOW_COPY_AND_ASSIGN(FstFarWriter); +}; + + +template <class A> FarWriter<A> *FarWriter<A>::Create(const string &filename, FarType type) { switch(type) { case FAR_DEFAULT: @@ -220,6 +277,9 @@ FarWriter<A> *FarWriter<A>::Create(const string &filename, FarType type) { case FAR_STLIST: return STListFarWriter<A>::Create(filename); break; + case FAR_FST: + return FstFarWriter<A>::Create(filename); + break; default: LOG(ERROR) << "FarWriter::Create: unknown far type"; return 0; @@ -331,6 +391,114 @@ class STListFarReader : public FarReader<A> { DISALLOW_COPY_AND_ASSIGN(STListFarReader); }; +template <class A> +class FstFarReader : public FarReader<A> { + public: + typedef A Arc; + + static FstFarReader *Open(const string &filename) { + vector<string> filenames; + filenames.push_back(filename); + return new FstFarReader<A>(filenames); + } + + static FstFarReader *Open(const vector<string> &filenames) { + return new FstFarReader<A>(filenames); + } + + FstFarReader(const vector<string> &filenames) + : keys_(filenames), has_stdin_(false), pos_(0), fst_(0), error_(false) { + sort(keys_.begin(), keys_.end()); + streams_.resize(keys_.size(), 0); + for (size_t i = 0; i < keys_.size(); ++i) { + if (keys_[i].empty()) { + if (!has_stdin_) { + streams_[i] = &cin; + //sources_[i] = "stdin"; + has_stdin_ = true; + } else { + FSTERROR() << "FstFarReader::FstFarReader: stdin should only " + << "appear once in the input file list."; + error_ = true; + return; + } + } else { + streams_[i] = new ifstream( + keys_[i].c_str(), ifstream::in | ifstream::binary); + } + } + if (pos_ >= keys_.size()) return; + ReadFst(); + } + + void Reset() { + if (has_stdin_) { + FSTERROR() << "FstFarReader::Reset: operation not supported on stdin"; + error_ = true; + return; + } + pos_ = 0; + ReadFst(); + } + + bool Find(const string &key) { + if (has_stdin_) { + FSTERROR() << "FstFarReader::Find: operation not supported on stdin"; + error_ = true; + return false; + } + pos_ = 0;//TODO + ReadFst(); + return true; + } + + bool Done() const { return error_ || pos_ >= keys_.size(); } + + void Next() { + ++pos_; + ReadFst(); + } + + const string &GetKey() const { + return keys_[pos_]; + } + + const Fst<A> &GetFst() const { + return *fst_; + } + + FarType Type() const { return FAR_FST; } + + bool Error() const { return error_; } + + ~FstFarReader() { + if (fst_) delete fst_; + for (size_t i = 0; i < keys_.size(); ++i) + delete streams_[i]; + } + + private: + void ReadFst() { + if (fst_) delete fst_; + if (pos_ >= keys_.size()) return; + streams_[pos_]->seekg(0); + fst_ = Fst<A>::Read(*streams_[pos_], FstReadOptions()); + if (!fst_) { + FSTERROR() << "FstFarReader: error reading Fst from: " << keys_[pos_]; + error_ = true; + } + } + + private: + vector<string> keys_; + vector<istream*> streams_; + bool has_stdin_; + size_t pos_; + mutable Fst<A> *fst_; + mutable bool error_; + + DISALLOW_COPY_AND_ASSIGN(FstFarReader); +}; template <class A> FarReader<A> *FarReader<A>::Open(const string &filename) { @@ -340,6 +508,8 @@ FarReader<A> *FarReader<A>::Open(const string &filename) { return STTableFarReader<A>::Open(filename); else if (IsSTList(filename)) return STListFarReader<A>::Open(filename); + else if (IsFst(filename)) + return FstFarReader<A>::Open(filename); return 0; } @@ -352,6 +522,8 @@ FarReader<A> *FarReader<A>::Open(const vector<string> &filenames) { return STTableFarReader<A>::Open(filenames); else if (!filenames.empty() && IsSTList(filenames[0])) return STListFarReader<A>::Open(filenames); + else if (!filenames.empty() && IsFst(filenames[0])) + return FstFarReader<A>::Open(filenames); return 0; } diff --git a/src/include/fst/extensions/far/farscript.h b/src/include/fst/extensions/far/farscript.h index 9c3b1ca..3a9c145 100644 --- a/src/include/fst/extensions/far/farscript.h +++ b/src/include/fst/extensions/far/farscript.h @@ -27,6 +27,7 @@ using std::vector; #include <fst/script/arg-packs.h> #include <fst/extensions/far/compile-strings.h> #include <fst/extensions/far/create.h> +#include <fst/extensions/far/equal.h> #include <fst/extensions/far/extract.h> #include <fst/extensions/far/info.h> #include <fst/extensions/far/print-strings.h> @@ -51,6 +52,8 @@ struct FarCompileStringsArgs { const FarTokenType tt; const string &symbols_fname; const string &unknown_symbol; + const bool keep_symbols; + const bool initial_symbols; const bool allow_negative_labels; const bool file_list_input; const string &key_prefix; @@ -65,6 +68,8 @@ struct FarCompileStringsArgs { FarTokenType tt, const string &symbols_fname, const string &unknown_symbol, + bool keep_symbols, + bool initial_symbols, bool allow_negative_labels, bool file_list_input, const string &key_prefix, @@ -72,6 +77,7 @@ struct FarCompileStringsArgs { in_fnames(in_fnames), out_fname(out_fname), fst_type(fst_type), far_type(far_type), generate_keys(generate_keys), fet(fet), tt(tt), symbols_fname(symbols_fname), unknown_symbol(unknown_symbol), + keep_symbols(keep_symbols), initial_symbols(initial_symbols), allow_negative_labels(allow_negative_labels), file_list_input(file_list_input), key_prefix(key_prefix), key_suffix(key_suffix) { } @@ -82,7 +88,8 @@ void FarCompileStrings(FarCompileStringsArgs *args) { fst::FarCompileStrings<Arc>( args->in_fnames, args->out_fname, args->fst_type, args->far_type, args->generate_keys, args->fet, args->tt, args->symbols_fname, - args->unknown_symbol, args->allow_negative_labels, args->file_list_input, + args->unknown_symbol, args->keep_symbols, args->initial_symbols, + args->allow_negative_labels, args->file_list_input, args->key_prefix, args->key_suffix); } @@ -97,6 +104,8 @@ void FarCompileStrings( FarTokenType tt, const string &symbols_fname, const string &unknown_symbol, + bool keep_symbols, + bool initial_symbols, bool allow_negative_labels, bool file_list_input, const string &key_prefix, @@ -143,6 +152,25 @@ void FarCreate(const vector<string> &in_fnames, const string &key_suffix); +typedef args::Package<const string &, const string &, float, + const string &, const string &> FarEqualInnerArgs; +typedef args::WithReturnValue<bool, FarEqualInnerArgs> FarEqualArgs; + +template <class Arc> +void FarEqual(FarEqualArgs *args) { + args->retval = fst::FarEqual<Arc>( + args->args.arg1, args->args.arg2, args->args.arg3, + args->args.arg4, args->args.arg5); +} + +bool FarEqual(const string &filename1, + const string &filename2, + const string &arc_type, + float delta = kDelta, + const string &begin_key = string(), + const string &end_key = string()); + + typedef args::Package<const vector<string> &, int32, const string&, const string&, const string&, const string&> FarExtractArgs; @@ -180,7 +208,9 @@ struct FarPrintStringsArgs { const string &begin_key; const string &end_key; const bool print_key; + const bool print_weight; const string &symbols_fname; + const bool initial_symbols; const int32 generate_filenames; const string &filename_prefix; const string &filename_suffix; @@ -188,12 +218,14 @@ struct FarPrintStringsArgs { FarPrintStringsArgs( const vector<string> &ifilenames, const FarEntryType entry_type, const FarTokenType token_type, const string &begin_key, - const string &end_key, const bool print_key, - const string &symbols_fname, const int32 generate_filenames, + const string &end_key, const bool print_key, const bool print_weight, + const string &symbols_fname, const bool initial_symbols, + const int32 generate_filenames, const string &filename_prefix, const string &filename_suffix) : ifilenames(ifilenames), entry_type(entry_type), token_type(token_type), - begin_key(begin_key), end_key(end_key), print_key(print_key), - symbols_fname(symbols_fname), + begin_key(begin_key), end_key(end_key), + print_key(print_key), print_weight(print_weight), + symbols_fname(symbols_fname), initial_symbols(initial_symbols), generate_filenames(generate_filenames), filename_prefix(filename_prefix), filename_suffix(filename_suffix) { } }; @@ -202,9 +234,9 @@ template <class Arc> void FarPrintStrings(FarPrintStringsArgs *args) { fst::FarPrintStrings<Arc>( args->ifilenames, args->entry_type, args->token_type, - args->begin_key, args->end_key, args->print_key, - args->symbols_fname, args->generate_filenames, args->filename_prefix, - args->filename_suffix); + args->begin_key, args->end_key, args->print_key, args->print_weight, + args->symbols_fname, args->initial_symbols, args->generate_filenames, + args->filename_prefix, args->filename_suffix); } @@ -215,7 +247,9 @@ void FarPrintStrings(const vector<string> &ifilenames, const string &begin_key, const string &end_key, const bool print_key, + const bool print_weight, const string &symbols_fname, + const bool initial_symbols, const int32 generate_filenames, const string &filename_prefix, const string &filename_suffix); @@ -227,6 +261,7 @@ void FarPrintStrings(const vector<string> &ifilenames, #define REGISTER_FST_FAR_OPERATIONS(ArcType) \ REGISTER_FST_OPERATION(FarCompileStrings, ArcType, FarCompileStringsArgs); \ REGISTER_FST_OPERATION(FarCreate, ArcType, FarCreateArgs); \ + REGISTER_FST_OPERATION(FarEqual, ArcType, FarEqualArgs); \ REGISTER_FST_OPERATION(FarExtract, ArcType, FarExtractArgs); \ REGISTER_FST_OPERATION(FarInfo, ArcType, FarInfoArgs); \ REGISTER_FST_OPERATION(FarPrintStrings, ArcType, FarPrintStringsArgs) diff --git a/src/include/fst/extensions/far/info.h b/src/include/fst/extensions/far/info.h index f010546..100fe68 100644 --- a/src/include/fst/extensions/far/info.h +++ b/src/include/fst/extensions/far/info.h @@ -34,7 +34,7 @@ void CountStatesAndArcs(const Fst<Arc> &fst, size_t *nstate, size_t *narc) { StateIterator<Fst<Arc> > siter(fst); for (; !siter.Done(); siter.Next(), ++(*nstate)) { ArcIterator<Fst<Arc> > aiter(fst, siter.Value()); - for (; !aiter.Done(); aiter.Next(), ++(*narc)); + for (; !aiter.Done(); aiter.Next(), ++(*narc)) {} } } diff --git a/src/include/fst/extensions/far/print-strings.h b/src/include/fst/extensions/far/print-strings.h index aff1e51..dcc7351 100644 --- a/src/include/fst/extensions/far/print-strings.h +++ b/src/include/fst/extensions/far/print-strings.h @@ -27,17 +27,21 @@ using std::vector; #include <fst/extensions/far/far.h> +#include <fst/shortest-distance.h> #include <fst/string.h> +DECLARE_string(far_field_separator); + namespace fst { template <class Arc> void FarPrintStrings( const vector<string> &ifilenames, const FarEntryType entry_type, const FarTokenType far_token_type, const string &begin_key, - const string &end_key, const bool print_key, const string &symbols_fname, - const int32 generate_filenames, const string &filename_prefix, - const string &filename_suffix) { + const string &end_key, const bool print_key, const bool print_weight, + const string &symbols_fname, const bool initial_symbols, + const int32 generate_filenames, + const string &filename_prefix, const string &filename_suffix) { typename StringPrinter<Arc>::TokenType token_type; if (far_token_type == FTT_SYMBOL) { @@ -54,7 +58,9 @@ void FarPrintStrings( const SymbolTable *syms = 0; if (!symbols_fname.empty()) { // allow negative flag? - syms = SymbolTable::ReadText(symbols_fname, true); + SymbolTableTextOptions opts; + opts.allow_negative = true; + syms = SymbolTable::ReadText(symbols_fname, opts); if (!syms) { FSTERROR() << "FarPrintStrings: error reading symbol table: " << symbols_fname; @@ -62,8 +68,6 @@ void FarPrintStrings( } } - StringPrinter<Arc> string_printer(token_type, syms); - FarReader<Arc> *far_reader = FarReader<Arc>::Open(ifilenames); if (!far_reader) return; @@ -83,14 +87,21 @@ void FarPrintStrings( okey = key; const Fst<Arc> &fst = far_reader->GetFst(); + if (i == 1 && initial_symbols && syms == 0 && fst.InputSymbols() != 0) + syms = fst.InputSymbols()->Copy(); string str; VLOG(2) << "Handling key: " << key; + StringPrinter<Arc> string_printer( + token_type, syms ? syms : fst.InputSymbols()); string_printer(fst, &str); if (entry_type == FET_LINE) { if (print_key) - cout << key << "\t"; - cout << str << endl; + cout << key << FLAGS_far_field_separator[0]; + cout << str; + if (print_weight) + cout << FLAGS_far_field_separator[0] << ShortestDistance(fst); + cout << endl; } else if (entry_type == FET_FILE) { stringstream sstrm; if (generate_filenames) { @@ -117,6 +128,7 @@ void FarPrintStrings( ostrm << "\n"; } } + delete syms; } diff --git a/src/include/fst/extensions/far/stlist.h b/src/include/fst/extensions/far/stlist.h index 4738181..1cdc80c 100644 --- a/src/include/fst/extensions/far/stlist.h +++ b/src/include/fst/extensions/far/stlist.h @@ -26,6 +26,7 @@ #include <iostream> #include <fstream> +#include <sstream> #include <fst/util.h> #include <algorithm> @@ -58,7 +59,7 @@ class STListWriter { explicit STListWriter(const string filename) : stream_( - filename.empty() ? &std::cout : + filename.empty() ? &cout : new ofstream(filename.c_str(), ofstream::out | ofstream::binary)), error_(false) { WriteType(*stream_, kSTListMagicNumber); @@ -92,7 +93,7 @@ class STListWriter { ~STListWriter() { WriteType(*stream_, string()); - if (stream_ != &std::cout) + if (stream_ != &cout) delete stream_; } @@ -127,7 +128,7 @@ class STListReader { for (size_t i = 0; i < filenames.size(); ++i) { if (filenames[i].empty()) { if (!has_stdin) { - streams_[i] = &std::cin; + streams_[i] = &cin; sources_[i] = "stdin"; has_stdin = true; } else { @@ -177,7 +178,7 @@ class STListReader { ~STListReader() { for (size_t i = 0; i < streams_.size(); ++i) { - if (streams_[i] != &std::cin) + if (streams_[i] != &cin) delete streams_[i]; } if (entry_) diff --git a/src/include/fst/extensions/far/sttable.h b/src/include/fst/extensions/far/sttable.h index 3a03133..3ce0a4b 100644 --- a/src/include/fst/extensions/far/sttable.h +++ b/src/include/fst/extensions/far/sttable.h @@ -29,6 +29,7 @@ #include <algorithm> #include <iostream> #include <fstream> +#include <sstream> #include <fst/util.h> namespace fst { diff --git a/src/include/fst/extensions/ngram/bitmap-index.h b/src/include/fst/extensions/ngram/bitmap-index.h new file mode 100644 index 0000000..f5a5ba7 --- /dev/null +++ b/src/include/fst/extensions/ngram/bitmap-index.h @@ -0,0 +1,183 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: sorenj@google.com (Jeffrey Sorensen) + +#ifndef FST_EXTENSIONS_NGRAM_BITMAP_INDEX_H_ +#define FST_EXTENSIONS_NGRAM_BITMAP_INDEX_H_ + +#include <vector> +using std::vector; + +#include <fst/compat.h> + +// This class is a bitstring storage class with an index that allows +// seeking to the Nth set or clear bit in time O(Log(N)) where N is +// the length of the bit vector. In addition, it allows counting set or +// clear bits over ranges in constant time. +// +// This is accomplished by maintaining an "secondary" index of limited +// size in bits that maintains a running count of the number of bits set +// in each block of bitmap data. A block is defined as the number of +// uint64 values that can fit in the secondary index before an overflow +// occurs. +// +// To handle overflows, a "primary" index containing a running count of +// bits set in each block is created using the type uint64. + +namespace fst { + +class BitmapIndex { + public: + static size_t StorageSize(size_t size) { + return ((size + kStorageBlockMask) >> kStorageLogBitSize); + } + + BitmapIndex() : bits_(NULL), size_(0) { } + + bool Get(size_t index) const { + return (bits_[index >> kStorageLogBitSize] & + (kOne << (index & kStorageBlockMask))) != 0; + } + + static void Set(uint64* bits, size_t index) { + bits[index >> kStorageLogBitSize] |= (kOne << (index & kStorageBlockMask)); + } + + static void Clear(uint64* bits, size_t index) { + bits[index >> kStorageLogBitSize] &= ~(kOne << (index & kStorageBlockMask)); + } + + size_t Bits() const { + return size_; + } + + size_t ArraySize() const { + return StorageSize(size_); + } + + // Returns the number of one bits in the bitmap + size_t GetOnesCount() const { + return primary_index_[primary_index_size() - 1]; + } + + // Returns the number of one bits in positions 0 to limit - 1. + // REQUIRES: limit <= Bits() + size_t Rank1(size_t end) const; + + // Returns the number of one bits in the range start to end - 1. + // REQUIRES: limit <= Bits() + size_t GetOnesCountInRange(size_t start, size_t end) const { + return Rank1(end) - Rank1(start); + } + + // Returns the number of zero bits in positions 0 to limit - 1. + // REQUIRES: limit <= Bits() + size_t Rank0(size_t end) const { + return end - Rank1(end); + } + + // Returns the number of zero bits in the range start to end - 1. + // REQUIRES: limit <= Bits() + size_t GetZeroesCountInRange(size_t start, size_t end) const { + return end - start - GetOnesCountInRange(start, end); + } + + // Return true if any bit between begin inclusive and end exclusive + // is set. 0 <= begin <= end <= Bits() is required. + // + bool TestRange(size_t start, size_t end) const { + return Rank1(end) > Rank1(start); + } + + // Returns the offset to the nth set bit (zero based) + // or Bits() if index >= number of ones + size_t Select1(size_t bit_index) const; + + // Returns the offset to the nth clear bit (zero based) + // or Bits() if index > number of + size_t Select0(size_t bit_index) const; + + // Rebuilds from index for the associated Bitmap, should be called + // whenever changes have been made to the Bitmap or else behavior + // of the indexed bitmap methods will be undefined. + void BuildIndex(const uint64 *bits, size_t size); + + // the secondary index accumulates counts until it can possibly overflow + // this constant computes the number of uint64 units that can fit into + // units the size of uint16. + static const uint64 kOne = 1; + static const uint32 kStorageBitSize = 64; + static const uint32 kStorageLogBitSize = 6; + static const uint32 kSecondaryBlockSize = ((1 << 16) - 1) + >> kStorageLogBitSize; + + private: + static const uint32 kStorageBlockMask = kStorageBitSize - 1; + + // returns, from the index, the count of ones up to array_index + size_t get_index_ones_count(size_t array_index) const; + + // because the indexes, both primary and secondary, contain a running + // count of the population of one bits contained in [0,i), there is + // no reason to have an element in the zeroth position as this value would + // necessarily be zero. (The bits are indexed in a zero based way.) Thus + // we don't store the 0th element in either index. Both of the following + // functions, if greater than 0, must be decremented by one before retreiving + // the value from the corresponding array. + // returns the 1 + the block that contains the bitindex in question + // the inverted version works the same but looks for zeros using an inverted + // view of the index + size_t find_primary_block(size_t bit_index) const; + + size_t find_inverted_primary_block(size_t bit_index) const; + + // similarly, the secondary index (which resets its count to zero at + // the end of every kSecondaryBlockSize entries) does not store the element + // at 0. Note that the rem_bit_index parameter is the number of bits + // within the secondary block, after the bits accounted for by the primary + // block have been removed (i.e. the remaining bits) And, because we + // reset to zero with each new block, there is no need to store those + // actual zeros. + // returns 1 + the secondary block that contains the bitindex in question + size_t find_secondary_block(size_t block, size_t rem_bit_index) const; + + size_t find_inverted_secondary_block(size_t block, size_t rem_bit_index) + const; + + // We create a primary index based upon the number of secondary index + // blocks. The primary index uses fields wide enough to accomodate any + // index of the bitarray so cannot overflow + // The primary index is the actual running + // count of one bits set for all blocks (and, thus, all uint64s). + size_t primary_index_size() const { + return (ArraySize() + kSecondaryBlockSize - 1) / kSecondaryBlockSize; + } + + const uint64* bits_; + size_t size_; + + // The primary index contains the running popcount of all blocks + // which means the nth value contains the popcounts of + // [0,n*kSecondaryBlockSize], however, the 0th element is omitted. + vector<uint32> primary_index_; + // The secondary index contains the running popcount of the associated + // bitmap. It is the same length (in units of uint16) as the + // bitmap's map is in units of uint64s. + vector<uint16> secondary_index_; +}; + +} // end namespace fst + +#endif // FST_EXTENSIONS_NGRAM_BITMAP_INDEX_H_ diff --git a/src/include/fst/extensions/ngram/ngram-fst.h b/src/include/fst/extensions/ngram/ngram-fst.h new file mode 100644 index 0000000..eee664a --- /dev/null +++ b/src/include/fst/extensions/ngram/ngram-fst.h @@ -0,0 +1,912 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: sorenj@google.com (Jeffrey Sorensen) +// +#ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ +#define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ + +#include <stddef.h> +#include <string.h> +#include <algorithm> +#include <string> +#include <vector> +using std::vector; + +#include <fst/compat.h> +#include <fst/fstlib.h> +#include <fst/extensions/ngram/bitmap-index.h> + +// NgramFst implements a n-gram language model based upon the LOUDS data +// structure. Please refer to "Unary Data Strucutres for Language Models" +// http://research.google.com/pubs/archive/37218.pdf + +namespace fst { +template <class A> class NGramFst; +template <class A> class NGramFstMatcher; + +// Instance data containing mutable state for bookkeeping repeated access to +// the same state. +template <class A> +struct NGramFstInst { + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + StateId state_; + size_t num_futures_; + size_t offset_; + size_t node_; + StateId node_state_; + vector<Label> context_; + StateId context_state_; + NGramFstInst() + : state_(kNoStateId), node_state_(kNoStateId), + context_state_(kNoStateId) { } +}; + +// Implementation class for LOUDS based NgramFst interface +template <class A> +class NGramFstImpl : public FstImpl<A> { + using FstImpl<A>::SetInputSymbols; + using FstImpl<A>::SetOutputSymbols; + using FstImpl<A>::SetType; + using FstImpl<A>::WriteHeader; + + friend class ArcIterator<NGramFst<A> >; + friend class NGramFstMatcher<A>; + + public: + using FstImpl<A>::InputSymbols; + using FstImpl<A>::SetProperties; + using FstImpl<A>::Properties; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + NGramFstImpl() : data_(0), owned_(false) { + SetType("ngram"); + SetInputSymbols(NULL); + SetOutputSymbols(NULL); + SetProperties(kStaticProperties); + } + + NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out); + + ~NGramFstImpl() { + if (owned_) { + delete [] data_; + } + } + + static NGramFstImpl<A>* Read(istream &strm, // NOLINT + const FstReadOptions &opts) { + NGramFstImpl<A>* impl = new NGramFstImpl(); + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0; + uint64 num_states, num_futures, num_final; + const size_t offset = sizeof(num_states) + sizeof(num_futures) + + sizeof(num_final); + // Peek at num_states and num_futures to see how much more needs to be read. + strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states)); + strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures)); + strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final)); + size_t size = Storage(num_states, num_futures, num_final); + char* data = new char[size]; + // Copy num_states, num_futures and num_final back into data. + memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states)); + memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures), + sizeof(num_futures)); + memcpy(data + sizeof(num_states) + sizeof(num_futures), + reinterpret_cast<char *>(&num_final), sizeof(num_final)); + strm.read(data + offset, size - offset); + if (!strm) { + delete impl; + return NULL; + } + impl->Init(data, true /* owned */); + return impl; + } + + bool Write(ostream &strm, // NOLINT + const FstWriteOptions &opts) const { + FstHeader hdr; + hdr.SetStart(Start()); + hdr.SetNumStates(num_states_); + WriteHeader(strm, opts, kFileVersion, &hdr); + strm.write(data_, Storage(num_states_, num_futures_, num_final_)); + return strm; + } + + StateId Start() const { + return 1; + } + + Weight Final(StateId state) const { + if (final_index_.Get(state)) { + return final_probs_[final_index_.Rank1(state)]; + } else { + return Weight::Zero(); + } + } + + size_t NumArcs(StateId state, NGramFstInst<A> *inst = NULL) const { + if (inst == NULL) { + const size_t next_zero = future_index_.Select0(state + 1); + const size_t this_zero = future_index_.Select0(state); + return next_zero - this_zero - 1; + } + SetInstFuture(state, inst); + return inst->num_futures_ + ((state == 0) ? 0 : 1); + } + + size_t NumInputEpsilons(StateId state) const { + // State 0 has no parent, thus no backoff. + if (state == 0) return 0; + return 1; + } + + size_t NumOutputEpsilons(StateId state) const { + return NumInputEpsilons(state); + } + + StateId NumStates() const { + return num_states_; + } + + void InitStateIterator(StateIteratorData<A>* data) const { + data->base = 0; + data->nstates = num_states_; + } + + static size_t Storage(uint64 num_states, uint64 num_futures, + uint64 num_final) { + uint64 b64; + Weight weight; + Label label; + size_t offset = sizeof(num_states) + sizeof(num_futures) + + sizeof(num_final); + offset += sizeof(b64) * ( + BitmapIndex::StorageSize(num_states * 2 + 1) + + BitmapIndex::StorageSize(num_futures + num_states + 1) + + BitmapIndex::StorageSize(num_states)); + offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label); + // Pad for alignemnt, see + // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding + offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1); + offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) + + (num_futures + 1) * sizeof(weight); + return offset; + } + + void SetInstFuture(StateId state, NGramFstInst<A> *inst) const { + if (inst->state_ != state) { + inst->state_ = state; + const size_t next_zero = future_index_.Select0(state + 1); + const size_t this_zero = future_index_.Select0(state); + inst->num_futures_ = next_zero - this_zero - 1; + inst->offset_ = future_index_.Rank1(future_index_.Select0(state) + 1); + } + } + + void SetInstNode(NGramFstInst<A> *inst) const { + if (inst->node_state_ != inst->state_) { + inst->node_state_ = inst->state_; + inst->node_ = context_index_.Select1(inst->state_); + } + } + + void SetInstContext(NGramFstInst<A> *inst) const { + SetInstNode(inst); + if (inst->context_state_ != inst->state_) { + inst->context_state_ = inst->state_; + inst->context_.clear(); + size_t node = inst->node_; + while (node != 0) { + inst->context_.push_back(context_words_[context_index_.Rank1(node)]); + node = context_index_.Select1(context_index_.Rank0(node) - 1); + } + } + } + + // Access to the underlying representation + const char* GetData(size_t* data_size) const { + *data_size = Storage(num_states_, num_futures_, num_final_); + return data_; + } + + void Init(const char* data, bool owned); + + private: + StateId Transition(const vector<Label> &context, Label future) const; + + // Properties always true for this Fst class. + static const uint64 kStaticProperties = kAcceptor | kIDeterministic | + kODeterministic | kEpsilons | kIEpsilons | kOEpsilons | kILabelSorted | + kOLabelSorted | kWeighted | kCyclic | kInitialAcyclic | kNotTopSorted | + kAccessible | kCoAccessible | kNotString | kExpanded; + // Current file format version. + static const int kFileVersion = 4; + // Minimum file format version supported. + static const int kMinFileVersion = 4; + + const char* data_; + bool owned_; // True if we own data_ + uint64 num_states_, num_futures_, num_final_; + size_t root_num_children_; + const Label *root_children_; + size_t root_first_child_; + // borrowed references + const uint64 *context_, *future_, *final_; + const Label *context_words_, *future_words_; + const Weight *backoff_, *final_probs_, *future_probs_; + BitmapIndex context_index_; + BitmapIndex future_index_; + BitmapIndex final_index_; + + void operator=(const NGramFstImpl<A> &); // Disallow +}; + +template<typename A> +NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) + : data_(0), owned_(false) { + typedef A Arc; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + SetType("ngram"); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + SetProperties(kStaticProperties); + + // Check basic requirements for an OpenGRM language model Fst. + int64 props = kAcceptor | kIDeterministic | kIEpsilons | kILabelSorted; + if (fst.Properties(props, true) != props) { + FSTERROR() << "NGramFst only accepts OpenGRM langauge models as input"; + SetProperties(kError, kError); + return; + } + + int64 num_states = CountStates(fst); + Label* context = new Label[num_states]; + + // Find the unigram state by starting from the start state, following + // epsilons. + StateId unigram = fst.Start(); + while (1) { + ArcIterator<Fst<A> > aiter(fst, unigram); + if (aiter.Done()) { + FSTERROR() << "Start state has no arcs"; + SetProperties(kError, kError); + return; + } + if (aiter.Value().ilabel != 0) break; + unigram = aiter.Value().nextstate; + } + + // Each state's context is determined by the subtree it is under from the + // unigram state. + queue<pair<StateId, Label> > label_queue; + vector<bool> visited(num_states); + // Force an epsilon link to the start state. + label_queue.push(make_pair(fst.Start(), 0)); + for (ArcIterator<Fst<A> > aiter(fst, unigram); + !aiter.Done(); aiter.Next()) { + label_queue.push(make_pair(aiter.Value().nextstate, aiter.Value().ilabel)); + } + // investigate states in breadth first fashion to assign context words. + while (!label_queue.empty()) { + pair<StateId, Label> &now = label_queue.front(); + if (!visited[now.first]) { + context[now.first] = now.second; + visited[now.first] = true; + for (ArcIterator<Fst<A> > aiter(fst, now.first); + !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { + label_queue.push(make_pair(arc.nextstate, now.second)); + } + } + } + label_queue.pop(); + } + visited.clear(); + + // The arc from the start state should be assigned an epsilon to put it + // in front of the all other labels (which makes Start state 1 after + // unigram which is state 0). + context[fst.Start()] = 0; + + // Build the tree of contexts fst by reversing the epsilon arcs from fst. + VectorFst<Arc> context_fst; + uint64 num_final = 0; + for (int i = 0; i < num_states; ++i) { + if (fst.Final(i) != Weight::Zero()) { + ++num_final; + } + context_fst.SetFinal(context_fst.AddState(), fst.Final(i)); + } + context_fst.SetStart(unigram); + context_fst.SetInputSymbols(fst.InputSymbols()); + context_fst.SetOutputSymbols(fst.OutputSymbols()); + int64 num_context_arcs = 0; + int64 num_futures = 0; + for (StateIterator<Fst<A> > siter(fst); !siter.Done(); siter.Next()) { + const StateId &state = siter.Value(); + num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state); + ArcIterator<Fst<A> > aiter(fst, state); + if (!aiter.Done()) { + const Arc &arc = aiter.Value(); + // this arc goes from state to arc.nextstate, so create an arc from + // arc.nextstate to state to reverse it. + if (arc.ilabel == 0) { + context_fst.AddArc(arc.nextstate, Arc(context[state], context[state], + arc.weight, state)); + num_context_arcs++; + } + } + } + if (num_context_arcs != context_fst.NumStates() - 1) { + FSTERROR() << "Number of contexts arcs != number of states - 1"; + SetProperties(kError, kError); + return; + } + if (context_fst.NumStates() != num_states) { + FSTERROR() << "Number of contexts != number of states"; + SetProperties(kError, kError); + return; + } + int64 context_props = context_fst.Properties(kIDeterministic | + kILabelSorted, true); + if (!(context_props & kIDeterministic)) { + FSTERROR() << "Input fst is not structured properly"; + SetProperties(kError, kError); + return; + } + if (!(context_props & kILabelSorted)) { + ArcSort(&context_fst, ILabelCompare<Arc>()); + } + + delete [] context; + + uint64 b64; + Weight weight; + Label label = kNoLabel; + const size_t storage = Storage(num_states, num_futures, num_final); + char* data = new char[storage]; + memset(data, 0, storage); + size_t offset = 0; + memcpy(data + offset, reinterpret_cast<char *>(&num_states), + sizeof(num_states)); + offset += sizeof(num_states); + memcpy(data + offset, reinterpret_cast<char *>(&num_futures), + sizeof(num_futures)); + offset += sizeof(num_futures); + memcpy(data + offset, reinterpret_cast<char *>(&num_final), + sizeof(num_final)); + offset += sizeof(num_final); + uint64* context_bits = reinterpret_cast<uint64*>(data + offset); + offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64); + uint64* future_bits = reinterpret_cast<uint64*>(data + offset); + offset += + BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64); + uint64* final_bits = reinterpret_cast<uint64*>(data + offset); + offset += BitmapIndex::StorageSize(num_states) * sizeof(b64); + Label* context_words = reinterpret_cast<Label*>(data + offset); + offset += (num_states + 1) * sizeof(label); + Label* future_words = reinterpret_cast<Label*>(data + offset); + offset += num_futures * sizeof(label); + offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1); + Weight* backoff = reinterpret_cast<Weight*>(data + offset); + offset += (num_states + 1) * sizeof(weight); + Weight* final_probs = reinterpret_cast<Weight*>(data + offset); + offset += num_final * sizeof(weight); + Weight* future_probs = reinterpret_cast<Weight*>(data + offset); + int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0, + final_bit = 0; + + // pseudo-root bits + BitmapIndex::Set(context_bits, context_bit++); + ++context_bit; + context_words[context_arc] = label; + backoff[context_arc] = Weight::Zero(); + context_arc++; + + ++future_bit; + if (order_out) { + order_out->clear(); + order_out->resize(num_states); + } + + queue<StateId> context_q; + context_q.push(context_fst.Start()); + StateId state_number = 0; + while (!context_q.empty()) { + const StateId &state = context_q.front(); + if (order_out) { + (*order_out)[state] = state_number; + } + + const Weight &final = context_fst.Final(state); + if (final != Weight::Zero()) { + BitmapIndex::Set(final_bits, state_number); + final_probs[final_bit] = final; + ++final_bit; + } + + for (ArcIterator<VectorFst<A> > aiter(context_fst, state); + !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + context_words[context_arc] = arc.ilabel; + backoff[context_arc] = arc.weight; + ++context_arc; + BitmapIndex::Set(context_bits, context_bit++); + context_q.push(arc.nextstate); + } + ++context_bit; + + for (ArcIterator<Fst<A> > aiter(fst, state); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { + future_words[future_arc] = arc.ilabel; + future_probs[future_arc] = arc.weight; + ++future_arc; + BitmapIndex::Set(future_bits, future_bit++); + } + } + ++future_bit; + ++state_number; + context_q.pop(); + } + + if ((state_number != num_states) || + (context_bit != num_states * 2 + 1) || + (context_arc != num_states) || + (future_arc != num_futures) || + (future_bit != num_futures + num_states + 1) || + (final_bit != num_final)) { + FSTERROR() << "Structure problems detected during construction"; + SetProperties(kError, kError); + return; + } + + Init(data, true /* owned */); +} + +template<typename A> +inline void NGramFstImpl<A>::Init(const char* data, bool owned) { + if (owned_) { + delete [] data_; + } + owned_ = owned; + data_ = data; + size_t offset = 0; + num_states_ = *(reinterpret_cast<const uint64*>(data_ + offset)); + offset += sizeof(num_states_); + num_futures_ = *(reinterpret_cast<const uint64*>(data_ + offset)); + offset += sizeof(num_futures_); + num_final_ = *(reinterpret_cast<const uint64*>(data_ + offset)); + offset += sizeof(num_final_); + uint64 bits; + size_t context_bits = num_states_ * 2 + 1; + size_t future_bits = num_futures_ + num_states_ + 1; + context_ = reinterpret_cast<const uint64*>(data_ + offset); + offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits); + future_ = reinterpret_cast<const uint64*>(data_ + offset); + offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits); + final_ = reinterpret_cast<const uint64*>(data_ + offset); + offset += BitmapIndex::StorageSize(num_states_ + 1) * sizeof(bits); + context_words_ = reinterpret_cast<const Label*>(data_ + offset); + offset += (num_states_ + 1) * sizeof(*context_words_); + future_words_ = reinterpret_cast<const Label*>(data_ + offset); + offset += num_futures_ * sizeof(*future_words_); + offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1); + backoff_ = reinterpret_cast<const Weight*>(data_ + offset); + offset += (num_states_ + 1) * sizeof(*backoff_); + final_probs_ = reinterpret_cast<const Weight*>(data_ + offset); + offset += num_final_ * sizeof(*final_probs_); + future_probs_ = reinterpret_cast<const Weight*>(data_ + offset); + + context_index_.BuildIndex(context_, context_bits); + future_index_.BuildIndex(future_, future_bits); + final_index_.BuildIndex(final_, num_states_); + + const size_t node_rank = context_index_.Rank1(0); + root_first_child_ = context_index_.Select0(node_rank) + 1; + if (context_index_.Get(root_first_child_) == false) { + FSTERROR() << "Missing unigrams"; + SetProperties(kError, kError); + return; + } + const size_t last_child = context_index_.Select0(node_rank + 1) - 1; + root_num_children_ = last_child - root_first_child_ + 1; + root_children_ = context_words_ + context_index_.Rank1(root_first_child_); +} + +template<typename A> +inline typename A::StateId NGramFstImpl<A>::Transition( + const vector<Label> &context, Label future) const { + size_t num_children = root_num_children_; + const Label *children = root_children_; + const Label *loc = lower_bound(children, children + num_children, future); + if (loc == children + num_children || *loc != future) { + return context_index_.Rank1(0); + } + size_t node = root_first_child_ + loc - children; + size_t node_rank = context_index_.Rank1(node); + size_t first_child = context_index_.Select0(node_rank) + 1; + if (context_index_.Get(first_child) == false) { + return context_index_.Rank1(node); + } + size_t last_child = context_index_.Select0(node_rank + 1) - 1; + num_children = last_child - first_child + 1; + for (int word = context.size() - 1; word >= 0; --word) { + children = context_words_ + context_index_.Rank1(first_child); + loc = lower_bound(children, children + last_child - first_child + 1, + context[word]); + if (loc == children + last_child - first_child + 1 || + *loc != context[word]) { + break; + } + node = first_child + loc - children; + node_rank = context_index_.Rank1(node); + first_child = context_index_.Select0(node_rank) + 1; + if (context_index_.Get(first_child) == false) break; + last_child = context_index_.Select0(node_rank + 1) - 1; + } + return context_index_.Rank1(node); +} + +/*****************************************************************************/ +template<class A> +class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { + friend class ArcIterator<NGramFst<A> >; + friend class NGramFstMatcher<A>; + + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef NGramFstImpl<A> Impl; + + explicit NGramFst(const Fst<A> &dst) + : ImplToExpandedFst<Impl>(new Impl(dst, NULL)) {} + + NGramFst(const Fst<A> &fst, vector<StateId>* order_out) + : ImplToExpandedFst<Impl>(new Impl(fst, order_out)) {} + + // Because the NGramFstImpl is a const stateless data structure, there + // is never a need to do anything beside copy the reference. + NGramFst(const NGramFst<A> &fst, bool safe = false) + : ImplToExpandedFst<Impl>(fst, false) {} + + NGramFst() : ImplToExpandedFst<Impl>(new Impl()) {} + + // Non-standard constructor to initialize NGramFst directly from data. + NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) { + GetImpl()->Init(data, owned); + } + + // Get method that gets the data associated with Init(). + const char* GetData(size_t* data_size) const { + return GetImpl()->GetData(data_size); + } + + virtual size_t NumArcs(StateId s) const { + return GetImpl()->NumArcs(s, &inst_); + } + + virtual NGramFst<A>* Copy(bool safe = false) const { + return new NGramFst(*this, safe); + } + + static NGramFst<A>* Read(istream &strm, const FstReadOptions &opts) { + Impl* impl = Impl::Read(strm, opts); + return impl ? new NGramFst<A>(impl) : 0; + } + + static NGramFst<A>* Read(const string &filename) { + if (!filename.empty()) { + ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); + if (!strm) { + LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename; + return 0; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(cin, FstReadOptions("standard input")); + } + } + + virtual bool Write(ostream &strm, const FstWriteOptions &opts) const { + return GetImpl()->Write(strm, opts); + } + + virtual bool Write(const string &filename) const { + return Fst<A>::WriteFile(filename); + } + + virtual inline void InitStateIterator(StateIteratorData<A>* data) const { + GetImpl()->InitStateIterator(data); + } + + virtual inline void InitArcIterator( + StateId s, ArcIteratorData<A>* data) const; + + virtual MatcherBase<A>* InitMatcher(MatchType match_type) const { + return new NGramFstMatcher<A>(*this, match_type); + } + + private: + explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {} + + Impl* GetImpl() const { + return + ImplToExpandedFst<Impl, ExpandedFst<A> >::GetImpl(); + } + + void SetImpl(Impl* impl, bool own_impl = true) { + ImplToExpandedFst<Impl, Fst<A> >::SetImpl(impl, own_impl); + } + + mutable NGramFstInst<A> inst_; +}; + +template <class A> inline void +NGramFst<A>::InitArcIterator(StateId s, ArcIteratorData<A>* data) const { + GetImpl()->SetInstFuture(s, &inst_); + GetImpl()->SetInstNode(&inst_); + data->base = new ArcIterator<NGramFst<A> >(*this, s); +} + +/*****************************************************************************/ +template <class A> +class NGramFstMatcher : public MatcherBase<A> { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + NGramFstMatcher(const NGramFst<A> &fst, MatchType match_type) + : fst_(fst), inst_(fst.inst_), match_type_(match_type), + current_loop_(false), + loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) { + swap(loop_.ilabel, loop_.olabel); + } + } + + NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false) + : fst_(matcher.fst_), inst_(matcher.inst_), + match_type_(matcher.match_type_), current_loop_(false), + loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) { + swap(loop_.ilabel, loop_.olabel); + } + } + + virtual NGramFstMatcher<A>* Copy(bool safe = false) const { + return new NGramFstMatcher<A>(*this, safe); + } + + virtual MatchType Type(bool test) const { + return match_type_; + } + + virtual const Fst<A> &GetFst() const { + return fst_; + } + + virtual uint64 Properties(uint64 props) const { + return props; + } + + private: + virtual void SetState_(StateId s) { + fst_.GetImpl()->SetInstFuture(s, &inst_); + current_loop_ = false; + } + + virtual bool Find_(Label label) { + const Label nolabel = kNoLabel; + done_ = true; + if (label == 0 || label == nolabel) { + if (label == 0) { + current_loop_ = true; + loop_.nextstate = inst_.state_; + } + // The unigram state has no epsilon arc. + if (inst_.state_ != 0) { + arc_.ilabel = arc_.olabel = 0; + fst_.GetImpl()->SetInstNode(&inst_); + arc_.nextstate = fst_.GetImpl()->context_index_.Rank1( + fst_.GetImpl()->context_index_.Select1( + fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1)); + arc_.weight = fst_.GetImpl()->backoff_[inst_.state_]; + done_ = false; + } + } else { + const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_; + const Label *end = start + inst_.num_futures_; + const Label* search = lower_bound(start, end, label); + if (search != end && *search == label) { + size_t state = search - start; + arc_.ilabel = arc_.olabel = label; + arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state]; + fst_.GetImpl()->SetInstContext(&inst_); + arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label); + done_ = false; + } + } + return !Done_(); + } + + virtual bool Done_() const { + return !current_loop_ && done_; + } + + virtual const Arc& Value_() const { + return (current_loop_) ? loop_ : arc_; + } + + virtual void Next_() { + if (current_loop_) { + current_loop_ = false; + } else { + done_ = true; + } + } + + const NGramFst<A>& fst_; + NGramFstInst<A> inst_; + MatchType match_type_; // Supplied by caller + bool done_; + Arc arc_; + bool current_loop_; // Current arc is the implicit loop + Arc loop_; +}; + +/*****************************************************************************/ +template<class A> +class ArcIterator<NGramFst<A> > : public ArcIteratorBase<A> { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + ArcIterator(const NGramFst<A> &fst, StateId state) + : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) { + inst_ = fst.inst_; + impl_->SetInstFuture(state, &inst_); + impl_->SetInstNode(&inst_); + } + + bool Done() const { + return i_ >= ((inst_.node_ == 0) ? inst_.num_futures_ : + inst_.num_futures_ + 1); + } + + const Arc &Value() const { + bool eps = (inst_.node_ != 0 && i_ == 0); + StateId state = (inst_.node_ == 0) ? i_ : i_ - 1; + if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) { + arc_.ilabel = + arc_.olabel = eps ? 0 : impl_->future_words_[inst_.offset_ + state]; + lazy_ &= ~(kArcILabelValue | kArcOLabelValue); + } + if (flags_ & lazy_ & kArcNextStateValue) { + if (eps) { + arc_.nextstate = impl_->context_index_.Rank1( + impl_->context_index_.Select1( + impl_->context_index_.Rank0(inst_.node_) - 1)); + } else { + if (lazy_ & kArcNextStateValue) { + impl_->SetInstContext(&inst_); // first time only. + } + arc_.nextstate = + impl_->Transition(inst_.context_, + impl_->future_words_[inst_.offset_ + state]); + } + lazy_ &= ~kArcNextStateValue; + } + if (flags_ & lazy_ & kArcWeightValue) { + arc_.weight = eps ? impl_->backoff_[inst_.state_] : + impl_->future_probs_[inst_.offset_ + state]; + lazy_ &= ~kArcWeightValue; + } + return arc_; + } + + void Next() { + ++i_; + lazy_ = ~0; + } + + size_t Position() const { return i_; } + + void Reset() { + i_ = 0; + lazy_ = ~0; + } + + void Seek(size_t a) { + if (i_ != a) { + i_ = a; + lazy_ = ~0; + } + } + + uint32 Flags() const { + return flags_; + } + + void SetFlags(uint32 f, uint32 m) { + flags_ &= ~m; + flags_ |= (f & kArcValueFlags); + } + + private: + virtual bool Done_() const { return Done(); } + virtual const Arc& Value_() const { return Value(); } + virtual void Next_() { Next(); } + virtual size_t Position_() const { return Position(); } + virtual void Reset_() { Reset(); } + virtual void Seek_(size_t a) { Seek(a); } + uint32 Flags_() const { return Flags(); } + void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); } + + mutable Arc arc_; + mutable uint32 lazy_; + const NGramFstImpl<A> *impl_; + mutable NGramFstInst<A> inst_; + + size_t i_; + uint32 flags_; + + DISALLOW_COPY_AND_ASSIGN(ArcIterator); +}; + +/*****************************************************************************/ +// Specialization for NGramFst; see generic version in fst.h +// for sample usage (but use the ProdLmFst type!). This version +// should inline. +template <class A> +class StateIterator<NGramFst<A> > : public StateIteratorBase<A> { + public: + typedef typename A::StateId StateId; + + explicit StateIterator(const NGramFst<A> &fst) + : s_(0), num_states_(fst.NumStates()) { } + + bool Done() const { return s_ >= num_states_; } + StateId Value() const { return s_; } + void Next() { ++s_; } + void Reset() { s_ = 0; } + + private: + virtual bool Done_() const { return Done(); } + virtual StateId Value_() const { return Value(); } + virtual void Next_() { Next(); } + virtual void Reset_() { Reset(); } + + StateId s_, num_states_; + + DISALLOW_COPY_AND_ASSIGN(StateIterator); +}; +} // namespace fst +#endif // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ diff --git a/src/include/fst/extensions/ngram/nthbit.h b/src/include/fst/extensions/ngram/nthbit.h new file mode 100644 index 0000000..d4a9a5a --- /dev/null +++ b/src/include/fst/extensions/ngram/nthbit.h @@ -0,0 +1,46 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: sorenj@google.com (Jeffrey Sorensen) +// dr@google.com (Doug Rohde) + +#ifndef FST_EXTENSIONS_NGRAM_NTHBIT_H_ +#define FST_EXTENSIONS_NGRAM_NTHBIT_H_ + +#include <fst/types.h> + +extern uint32 nth_bit_bit_offset[]; + +inline uint32 nth_bit(uint64 v, uint32 r) { + uint32 shift = 0; + uint32 c = __builtin_popcount(v & 0xffffffff); + uint32 mask = -(r > c); + r -= c & mask; + shift += (32 & mask); + + c = __builtin_popcount((v >> shift) & 0xffff); + mask = -(r > c); + r -= c & mask; + shift += (16 & mask); + + c = __builtin_popcount((v >> shift) & 0xff); + mask = -(r > c); + r -= c & mask; + shift += (8 & mask); + + return shift + ((nth_bit_bit_offset[(v >> shift) & 0xff] >> + ((r - 1) << 2)) & 0xf); +} + +#endif // FST_EXTENSIONS_NGRAM_NTHBIT_H_ diff --git a/src/include/fst/extensions/pdt/collection.h b/src/include/fst/extensions/pdt/collection.h index 26be504..24a443f 100644 --- a/src/include/fst/extensions/pdt/collection.h +++ b/src/include/fst/extensions/pdt/collection.h @@ -16,7 +16,7 @@ // Author: riley@google.com (Michael Riley) // // \file -// Class to store a collection of sets with elements of type T. +// Class to store a collection of ordered (multi-)sets with elements of type T. #ifndef FST_EXTENSIONS_PDT_COLLECTION_H__ #define FST_EXTENSIONS_PDT_COLLECTION_H__ @@ -29,11 +29,11 @@ using std::vector; namespace fst { -// Stores a collection of non-empty sets with elements of type T. A -// default constructor, equality ==, a total order <, and an STL-style -// hash class must be defined on the elements. Provides signed -// integer ID (of type I) of each unique set. The IDs are allocated -// starting from 0 in order. +// Stores a collection of non-empty, ordered (multi-)sets with elements +// of type T. A default constructor, equality ==, and an STL-style +// hash class must be defined on the elements. Provides signed integer +// ID (of type I) of each unique set. The IDs are allocated starting +// from 0 in order. template <class I, class T> class Collection { public: @@ -80,31 +80,34 @@ class Collection { Collection() {} - // Lookups integer ID from set. If it doesn't exist, then adds it. - // Set elements should be in strict order (and therefore unique). - I FindId(const vector<T> &set) { + // Lookups integer ID from ordered multi-set. If it doesn't exist + // and 'insert' is true, then adds it. Otherwise returns -1. + I FindId(const vector<T> &set, bool insert = true) { I node_id = kNoNodeId; for (ssize_t i = set.size() - 1; i >= 0; --i) { Node node(node_id, set[i]); - node_id = node_table_.FindId(node); + node_id = node_table_.FindId(node, insert); + if (node_id == -1) break; } return node_id; } - // Finds set given integer ID. Returns true if ID corresponds - // to set. Use iterators below to traverse result. + // Finds ordered (multi-)set given integer ID. Returns set iterator + // to traverse result. SetIterator FindSet(I id) { - if (id < 0 && id >= node_table_.Size()) { + if (id < 0 || id >= node_table_.Size()) { return SetIterator(kNoNodeId, Node(kNoNodeId, T()), &node_table_); } else { return SetIterator(id, node_table_.FindEntry(id), &node_table_); } } + I Size() const { return node_table_.Size(); } + private: static const I kNoNodeId; static const size_t kPrime; - static std::tr1::hash<T> hash_; + static std::hash<T> hash_; NodeTable node_table_; @@ -115,7 +118,7 @@ template<class I, class T> const I Collection<I, T>::kNoNodeId = -1; template <class I, class T> const size_t Collection<I, T>::kPrime = 7853; -template <class I, class T> std::tr1::hash<T> Collection<I, T>::hash_; +template <class I, class T> std::hash<T> Collection<I, T>::hash_; } // namespace fst diff --git a/src/include/fst/extensions/pdt/info.h b/src/include/fst/extensions/pdt/info.h index ef9a860..55e76c4 100644 --- a/src/include/fst/extensions/pdt/info.h +++ b/src/include/fst/extensions/pdt/info.h @@ -24,7 +24,7 @@ #include <unordered_map> using std::tr1::unordered_map; using std::tr1::unordered_multimap; -#include <tr1/unordered_set> +#include <unordered_set> using std::tr1::unordered_set; using std::tr1::unordered_multiset; #include <vector> diff --git a/src/include/fst/extensions/pdt/paren.h b/src/include/fst/extensions/pdt/paren.h index 7b9887f..a9d30c5 100644 --- a/src/include/fst/extensions/pdt/paren.h +++ b/src/include/fst/extensions/pdt/paren.h @@ -26,7 +26,7 @@ #include <unordered_map> using std::tr1::unordered_map; using std::tr1::unordered_multimap; -#include <tr1/unordered_set> +#include <unordered_set> using std::tr1::unordered_set; using std::tr1::unordered_multiset; #include <set> @@ -144,7 +144,8 @@ class PdtParenReachable { const vector<pair<Label, Label> > &parens, bool close) : fst_(fst), parens_(parens), - close_(close) { + close_(close), + error_(false) { for (Label i = 0; i < parens.size(); ++i) { const pair<Label, Label> &p = parens[i]; paren_id_map_[p.first] = i; @@ -155,12 +156,18 @@ class PdtParenReachable { StateId start = fst.Start(); if (start == kNoStateId) return; - DFSearch(start, start); + if (!DFSearch(start)) { + FSTERROR() << "PdtReachable: Underlying cyclicity not supported"; + error_ = true; + } } else { FSTERROR() << "PdtParenReachable: open paren info not implemented"; + error_ = true; } } + bool const Error() { return error_; } + // Given a state ID, returns an iterator over paren IDs // for close (open) parens reachable from that state along balanced // paths. @@ -194,7 +201,7 @@ class PdtParenReachable { private: // DFS that gathers paren and state set information. // Bool returns false when cycle detected. - bool DFSearch(StateId s, StateId start); + bool DFSearch(StateId s); // Unions state sets together gathered by the DFS. void ComputeStateSet(StateId s); @@ -212,12 +219,13 @@ class PdtParenReachable { vector<char> state_color_; // DFS state mutable Collection<ssize_t, StateId> state_sets_; // Reachable states -> ID StateSetMap set_map_; // ID -> Reachable states + bool error_; DISALLOW_COPY_AND_ASSIGN(PdtParenReachable); }; // DFS that gathers paren and state set information. template <class A> -bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) { +bool PdtParenReachable<A>::DFSearch(StateId s) { if (s >= state_color_.size()) state_color_.resize(s + 1, kDfsWhite); @@ -239,7 +247,8 @@ bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) { if (pit != paren_id_map_.end()) { // paren? Label paren_id = pit->second; if (arc.ilabel == parens_[paren_id].first) { // open paren - DFSearch(arc.nextstate, arc.nextstate); + if (!DFSearch(arc.nextstate)) + return false; for (SetIterator set_iter = FindStates(paren_id, arc.nextstate); !set_iter.Done(); set_iter.Next()) { for (ParenArcIterator paren_arc_iter = @@ -247,15 +256,14 @@ bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) { !paren_arc_iter.Done(); paren_arc_iter.Next()) { const A &cparc = paren_arc_iter.Value(); - DFSearch(cparc.nextstate, start); + if (!DFSearch(cparc.nextstate)) + return false; } } } } else { // non-paren - if(!DFSearch(arc.nextstate, start)) { - FSTERROR() << "PdtReachable: Underlying cyclicity not supported"; - return true; - } + if(!DFSearch(arc.nextstate)) + return false; } } ComputeStateSet(s); diff --git a/src/include/fst/extensions/pdt/shortest-path.h b/src/include/fst/extensions/pdt/shortest-path.h index e90471b..85f94b8 100644 --- a/src/include/fst/extensions/pdt/shortest-path.h +++ b/src/include/fst/extensions/pdt/shortest-path.h @@ -28,7 +28,7 @@ #include <unordered_map> using std::tr1::unordered_map; using std::tr1::unordered_multimap; -#include <tr1/unordered_set> +#include <unordered_set> using std::tr1::unordered_set; using std::tr1::unordered_multiset; #include <stack> @@ -387,7 +387,6 @@ class PdtShortestPath { typedef typename SpData::SearchState SearchState; typedef typename SpData::ParenSpec ParenSpec; - typedef typename PdtParenReachable<Arc>::SetIterator StateSetIterator; typedef typename PdtBalanceData<Arc>::SetIterator CloseSourceIterator; PdtShortestPath(const Fst<Arc> &ifst, @@ -403,7 +402,7 @@ class PdtShortestPath { if ((Weight::Properties() & (kPath | kRightSemiring)) != (kPath | kRightSemiring)) { - FSTERROR() << "SingleShortestPath: Weight needs to have the path" + FSTERROR() << "PdtShortestPath: Weight needs to have the path" << " property and be right distributive: " << Weight::Type(); error_ = true; } @@ -440,6 +439,7 @@ class PdtShortestPath { static const Arc kNoArc; static const uint8 kEnqueued; static const uint8 kExpanded; + static const uint8 kFinished; const uint8 kFinal; public: @@ -543,6 +543,7 @@ void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) { ProcArcs(s); sp_data_.SetFlags(s, kExpanded, kExpanded); } + sp_data_.SetFlags(q, kFinished, kFinished); balance_data_.FinishInsert(start); sp_data_.GC(start); } @@ -607,7 +608,11 @@ void PdtShortestPath<Arc, Queue>::ProcOpenParen( Queue *state_queue = state_queue_; GetDistance(d.start); state_queue_ = state_queue; + } else if (!(sp_data_.Flags(d) & kFinished)) { + FSTERROR() << "PdtShortestPath: open parenthesis recursion: not bounded stack"; + error_ = true; } + for (CloseSourceIterator set_iter = balance_data_.Find(paren_id, arc.nextstate); !set_iter.Done(); set_iter.Next()) { @@ -765,6 +770,9 @@ template<class Arc, class Queue> const uint8 PdtShortestPath<Arc, Queue>::kExpanded = 0x20; template<class Arc, class Queue> +const uint8 PdtShortestPath<Arc, Queue>::kFinished = 0x40; + +template<class Arc, class Queue> void ShortestPath(const Fst<Arc> &ifst, const vector<pair<typename Arc::Label, typename Arc::Label> > &parens, diff --git a/src/include/fst/flags.h b/src/include/fst/flags.h index ec3d301..b3bb66c 100644 --- a/src/include/fst/flags.h +++ b/src/include/fst/flags.h @@ -22,6 +22,8 @@ #include <iostream> #include <map> +#include <set> +#include <sstream> #include <string> #include <fst/types.h> @@ -42,7 +44,7 @@ using std::string; // // DECLARE_int32(length); // -// SetFlags() can be used to set flags from the command line +// SET_FLAGS() can be used to set flags from the command line // using, for example, '--length=2'. // // ShowUsage() can be used to print out command and flag usage. @@ -56,12 +58,18 @@ using std::string; template <typename T> struct FlagDescription { - FlagDescription(T *addr, const char *doc, const char *type, const T val) - : address(addr), doc_string(doc), type_name(type), default_value(val) {} + FlagDescription(T *addr, const char *doc, const char *type, + const char *file, const T val) + : address(addr), + doc_string(doc), + type_name(type), + file_name(file), + default_value(val) {} T *address; const char *doc_string; const char *type_name; + const char *file_name; const T default_value; }; @@ -118,8 +126,7 @@ class FlagRegister { } bool SetFlag(const string &arg, const string &val) const { - for (typename std::map< string, - FlagDescription<T> >::const_iterator it = + for (typename std::map< string, FlagDescription<T> >::const_iterator it = flag_table_.begin(); it != flag_table_.end(); ++it) { @@ -131,19 +138,7 @@ class FlagRegister { return false; } - void ShowDefault(bool default_value) const { - std::cout << ", default = "; - std::cout << (default_value ? "true" : "false"); - } - void ShowDefault(const string &default_value) const { - std::cout << ", default = "; - std::cout << "\"" << default_value << "\""; - } - template<typename V> void ShowDefault(const V& default_value) const { - std::cout << ", default = "; - std::cout << default_value; - } - void ShowUsage() const { + void GetUsage(std::set< std::pair<string, string> > *usage_set) const { for (typename std::map< string, FlagDescription<T> >::const_iterator it = flag_table_.begin(); @@ -151,10 +146,13 @@ class FlagRegister { ++it) { const string &name = it->first; const FlagDescription<T> &desc = it->second; - std::cout << " --" << name - << ": type = " << desc.type_name; - ShowDefault(desc.default_value); - std::cout << "\n " << desc.doc_string << "\n"; + string usage = " --" + name; + usage += ": type = "; + usage += desc.type_name; + usage += ", default = "; + usage += GetDefault(desc.default_value) + "\n "; + usage += desc.doc_string; + usage_set->insert(make_pair(desc.file_name, usage)); } } @@ -163,11 +161,26 @@ class FlagRegister { register_lock_ = new fst::Mutex; register_ = new FlagRegister<T>; } + + std::map< string, FlagDescription<T> > flag_table_; + + string GetDefault(bool default_value) const { + return default_value ? "true" : "false"; + } + + string GetDefault(const string &default_value) const { + return "\"" + default_value + "\""; + } + + template<typename V> string GetDefault(const V& default_value) const { + std::ostringstream strm; + strm << default_value; + return strm.str(); + } + static fst::FstOnceType register_init_; // ensures only called once static fst::Mutex* register_lock_; // multithreading lock static FlagRegister<T> *register_; - - std::map< string, FlagDescription<T> > flag_table_; }; template <class T> @@ -199,6 +212,7 @@ class FlagRegisterer { name ## _flags_registerer(#name, FlagDescription<type>(&FLAGS_ ## name, \ doc, \ #type, \ + __FILE__, \ value)) #define DEFINE_bool(name, value, doc) DEFINE_VAR(bool, name, value, doc) @@ -212,13 +226,17 @@ class FlagRegisterer { // Temporary directory DECLARE_string(tmpdir); -void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags); +void SetFlags(const char *usage, int *argc, char ***argv, bool remove_flags, + const char *src = ""); + +#define SET_FLAGS(usage, argc, argv, rmflags) \ +SetFlags(usage, argc, argv, rmflags, __FILE__) // Deprecated - for backward compatibility inline void InitFst(const char *usage, int *argc, char ***argv, bool rmflags) { return SetFlags(usage, argc, argv, rmflags); } -void ShowUsage(); +void ShowUsage(bool long_usage = true); #endif // FST_LIB_FLAGS_H__ diff --git a/src/include/fst/float-weight.h b/src/include/fst/float-weight.h index 530cbdd..eb22638 100644 --- a/src/include/fst/float-weight.h +++ b/src/include/fst/float-weight.h @@ -37,19 +37,22 @@ namespace fst { template <class T> class FloatLimits { public: - static const T kPosInfinity; - static const T kNegInfinity; - static const T kNumberBad; -}; + static const T PosInfinity() { + static const T pos_infinity = numeric_limits<T>::infinity(); + return pos_infinity; + } -template <class T> -const T FloatLimits<T>::kPosInfinity = numeric_limits<T>::infinity(); + static const T NegInfinity() { + static const T neg_infinity = -PosInfinity(); + return neg_infinity; + } -template <class T> -const T FloatLimits<T>::kNegInfinity = -FloatLimits<T>::kPosInfinity; + static const T NumberBad() { + static const T number_bad = numeric_limits<T>::quiet_NaN(); + return number_bad; + } -template <class T> -const T FloatLimits<T>::kNumberBad = numeric_limits<T>::quiet_NaN(); +}; // weight class to be templated on floating-points types template <class T = float> @@ -151,9 +154,9 @@ inline bool ApproxEqual(const FloatWeightTpl<T> &w1, template <class T> inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) { - if (w.Value() == FloatLimits<T>::kPosInfinity) + if (w.Value() == FloatLimits<T>::PosInfinity()) return strm << "Infinity"; - else if (w.Value() == FloatLimits<T>::kNegInfinity) + else if (w.Value() == FloatLimits<T>::NegInfinity()) return strm << "-Infinity"; else if (w.Value() != w.Value()) // Fails for NaN return strm << "BadNumber"; @@ -166,9 +169,9 @@ inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) { string s; strm >> s; if (s == "Infinity") { - w = FloatWeightTpl<T>(FloatLimits<T>::kPosInfinity); + w = FloatWeightTpl<T>(FloatLimits<T>::PosInfinity()); } else if (s == "-Infinity") { - w = FloatWeightTpl<T>(FloatLimits<T>::kNegInfinity); + w = FloatWeightTpl<T>(FloatLimits<T>::NegInfinity()); } else { char *p; T f = strtod(s.c_str(), &p); @@ -196,13 +199,13 @@ class TropicalWeightTpl : public FloatWeightTpl<T> { TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} static const TropicalWeightTpl<T> Zero() { - return TropicalWeightTpl<T>(FloatLimits<T>::kPosInfinity); } + return TropicalWeightTpl<T>(FloatLimits<T>::PosInfinity()); } static const TropicalWeightTpl<T> One() { return TropicalWeightTpl<T>(0.0F); } static const TropicalWeightTpl<T> NoWeight() { - return TropicalWeightTpl<T>(FloatLimits<T>::kNumberBad); } + return TropicalWeightTpl<T>(FloatLimits<T>::NumberBad()); } static const string &Type() { static const string type = "tropical" + @@ -212,12 +215,12 @@ class TropicalWeightTpl : public FloatWeightTpl<T> { bool Member() const { // First part fails for IEEE NaN - return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity; + return Value() == Value() && Value() != FloatLimits<T>::NegInfinity(); } TropicalWeightTpl<T> Quantize(float delta = kDelta) const { - if (Value() == FloatLimits<T>::kNegInfinity || - Value() == FloatLimits<T>::kPosInfinity || + if (Value() == FloatLimits<T>::NegInfinity() || + Value() == FloatLimits<T>::PosInfinity() || Value() != Value()) return *this; else @@ -259,9 +262,9 @@ inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1, if (!w1.Member() || !w2.Member()) return TropicalWeightTpl<T>::NoWeight(); T f1 = w1.Value(), f2 = w2.Value(); - if (f1 == FloatLimits<T>::kPosInfinity) + if (f1 == FloatLimits<T>::PosInfinity()) return w1; - else if (f2 == FloatLimits<T>::kPosInfinity) + else if (f2 == FloatLimits<T>::PosInfinity()) return w2; else return TropicalWeightTpl<T>(f1 + f2); @@ -284,10 +287,10 @@ inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1, if (!w1.Member() || !w2.Member()) return TropicalWeightTpl<T>::NoWeight(); T f1 = w1.Value(), f2 = w2.Value(); - if (f2 == FloatLimits<T>::kPosInfinity) - return FloatLimits<T>::kNumberBad; - else if (f1 == FloatLimits<T>::kPosInfinity) - return FloatLimits<T>::kPosInfinity; + if (f2 == FloatLimits<T>::PosInfinity()) + return FloatLimits<T>::NumberBad(); + else if (f1 == FloatLimits<T>::PosInfinity()) + return FloatLimits<T>::PosInfinity(); else return TropicalWeightTpl<T>(f1 - f2); } @@ -320,7 +323,7 @@ class LogWeightTpl : public FloatWeightTpl<T> { LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} static const LogWeightTpl<T> Zero() { - return LogWeightTpl<T>(FloatLimits<T>::kPosInfinity); + return LogWeightTpl<T>(FloatLimits<T>::PosInfinity()); } static const LogWeightTpl<T> One() { @@ -328,7 +331,7 @@ class LogWeightTpl : public FloatWeightTpl<T> { } static const LogWeightTpl<T> NoWeight() { - return LogWeightTpl<T>(FloatLimits<T>::kNumberBad); } + return LogWeightTpl<T>(FloatLimits<T>::NumberBad()); } static const string &Type() { static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString(); @@ -337,12 +340,12 @@ class LogWeightTpl : public FloatWeightTpl<T> { bool Member() const { // First part fails for IEEE NaN - return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity; + return Value() == Value() && Value() != FloatLimits<T>::NegInfinity(); } LogWeightTpl<T> Quantize(float delta = kDelta) const { - if (Value() == FloatLimits<T>::kNegInfinity || - Value() == FloatLimits<T>::kPosInfinity || + if (Value() == FloatLimits<T>::NegInfinity() || + Value() == FloatLimits<T>::PosInfinity() || Value() != Value()) return *this; else @@ -368,9 +371,9 @@ template <class T> inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1, const LogWeightTpl<T> &w2) { T f1 = w1.Value(), f2 = w2.Value(); - if (f1 == FloatLimits<T>::kPosInfinity) + if (f1 == FloatLimits<T>::PosInfinity()) return w2; - else if (f2 == FloatLimits<T>::kPosInfinity) + else if (f2 == FloatLimits<T>::PosInfinity()) return w1; else if (f1 > f2) return LogWeightTpl<T>(f2 - LogExp(f1 - f2)); @@ -394,9 +397,9 @@ inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1, if (!w1.Member() || !w2.Member()) return LogWeightTpl<T>::NoWeight(); T f1 = w1.Value(), f2 = w2.Value(); - if (f1 == FloatLimits<T>::kPosInfinity) + if (f1 == FloatLimits<T>::PosInfinity()) return w1; - else if (f2 == FloatLimits<T>::kPosInfinity) + else if (f2 == FloatLimits<T>::PosInfinity()) return w2; else return LogWeightTpl<T>(f1 + f2); @@ -419,10 +422,10 @@ inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1, if (!w1.Member() || !w2.Member()) return LogWeightTpl<T>::NoWeight(); T f1 = w1.Value(), f2 = w2.Value(); - if (f2 == FloatLimits<T>::kPosInfinity) - return FloatLimits<T>::kNumberBad; - else if (f1 == FloatLimits<T>::kPosInfinity) - return FloatLimits<T>::kPosInfinity; + if (f2 == FloatLimits<T>::PosInfinity()) + return FloatLimits<T>::NumberBad(); + else if (f1 == FloatLimits<T>::PosInfinity()) + return FloatLimits<T>::PosInfinity(); else return LogWeightTpl<T>(f1 - f2); } @@ -454,15 +457,15 @@ class MinMaxWeightTpl : public FloatWeightTpl<T> { MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} static const MinMaxWeightTpl<T> Zero() { - return MinMaxWeightTpl<T>(FloatLimits<T>::kPosInfinity); + return MinMaxWeightTpl<T>(FloatLimits<T>::PosInfinity()); } static const MinMaxWeightTpl<T> One() { - return MinMaxWeightTpl<T>(FloatLimits<T>::kNegInfinity); + return MinMaxWeightTpl<T>(FloatLimits<T>::NegInfinity()); } static const MinMaxWeightTpl<T> NoWeight() { - return MinMaxWeightTpl<T>(FloatLimits<T>::kNumberBad); } + return MinMaxWeightTpl<T>(FloatLimits<T>::NumberBad()); } static const string &Type() { static const string type = "minmax" + @@ -477,8 +480,8 @@ class MinMaxWeightTpl : public FloatWeightTpl<T> { MinMaxWeightTpl<T> Quantize(float delta = kDelta) const { // If one of infinities, or a NaN - if (Value() == FloatLimits<T>::kNegInfinity || - Value() == FloatLimits<T>::kPosInfinity || + if (Value() == FloatLimits<T>::NegInfinity() || + Value() == FloatLimits<T>::PosInfinity() || Value() != Value()) return *this; else @@ -541,7 +544,7 @@ inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1, if (!w1.Member() || !w2.Member()) return MinMaxWeightTpl<T>::NoWeight(); // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2 - return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::kNumberBad; + return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::NumberBad(); } inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1, diff --git a/src/include/fst/fst.h b/src/include/fst/fst.h index 9c4d0db..dd11e4f 100644 --- a/src/include/fst/fst.h +++ b/src/include/fst/fst.h @@ -36,6 +36,7 @@ #include <fst/register.h> #include <iostream> #include <fstream> +#include <sstream> #include <fst/symbol-table.h> #include <fst/util.h> @@ -229,7 +230,7 @@ class Fst { } return Read(strm, FstReadOptions(filename)); } else { - return Read(std::cin, FstReadOptions("standard input")); + return Read(cin, FstReadOptions("standard input")); } } @@ -267,7 +268,6 @@ class Fst { virtual MatcherBase<A> *InitMatcher(MatchType match_type) const; protected: - bool WriteFile(const string &filename) const { if (!filename.empty()) { ofstream strm(filename.c_str(), ofstream::out | ofstream::binary); @@ -277,7 +277,7 @@ class Fst { } return Write(strm, FstWriteOptions(filename)); } else { - return Write(std::cout, FstWriteOptions("standard output")); + return Write(cout, FstWriteOptions("standard output")); } } }; @@ -706,12 +706,13 @@ template <class A> class FstImpl { // This method is used in the cross-type serialization methods Fst::WriteFst. static void WriteFstHeader(const Fst<A> &fst, ostream &strm, const FstWriteOptions& opts, int version, - const string &type, FstHeader *hdr) { + const string &type, uint64 properties, + FstHeader *hdr) { if (opts.write_header) { hdr->SetFstType(type); hdr->SetArcType(A::Type()); hdr->SetVersion(version); - hdr->SetProperties(fst.Properties(kFstProperties, false)); + hdr->SetProperties(properties); int32 file_flags = 0; if (fst.InputSymbols() && opts.write_isymbols) file_flags |= FstHeader::HAS_ISYMBOLS; @@ -737,14 +738,14 @@ template <class A> class FstImpl { // returns true on success, false on failure. static bool UpdateFstHeader(const Fst<A> &fst, ostream &strm, const FstWriteOptions& opts, int version, - const string &type, FstHeader *hdr, - size_t header_offset) { + const string &type, uint64 properties, + FstHeader *hdr, size_t header_offset) { strm.seekp(header_offset); if (!strm) { LOG(ERROR) << "Fst::UpdateFstHeader: write failed: " << opts.source; return false; } - WriteFstHeader(fst, strm, opts, version, type, hdr); + WriteFstHeader(fst, strm, opts, version, type, properties, hdr); if (!strm) { LOG(ERROR) << "Fst::UpdateFstHeader: write failed: " << opts.source; return false; diff --git a/src/include/fst/fstlib.h b/src/include/fst/fstlib.h index c05c775..de5976d 100644 --- a/src/include/fst/fstlib.h +++ b/src/include/fst/fstlib.h @@ -142,6 +142,8 @@ #include <fst/state-reachable.h> #include <iostream> #include <fstream> +#include <sstream> +#include <fst/string.h> #include <fst/symbol-table.h> #include <fst/symbol-table-ops.h> #include <fst/test-properties.h> diff --git a/src/include/fst/icu.h b/src/include/fst/icu.h index 6b74c2e..3947716 100644 --- a/src/include/fst/icu.h +++ b/src/include/fst/icu.h @@ -13,88 +13,101 @@ // limitations under the License. // // Copyright 2005-2010 Google, Inc. -// Author: roubert@google.com (Fredrik Roubert) - -// Wrapper class for UErrorCode, with conversion operators for direct use in -// ICU C and C++ APIs. -// -// Features: -// - The constructor initializes the internal UErrorCode to U_ZERO_ERROR, -// removing one common source of errors. -// - Same use in C APIs taking a UErrorCode* (pointer) and C++ taking -// UErrorCode& (reference), via conversion operators. -// - Automatic checking for success when it goes out of scope. On failure, -// the destructor will FSTERROR() an error message. -// -// Most of ICU will handle errors gracefully and provide sensible fallbacks. -// Using IcuErrorCode, it is therefore possible to write very compact code -// that does sensible things on failure and provides logging for debugging. +// Author: sorenj@google.com (Jeffrey Sorensen) +// roubert@google.com (Fredrik Roubert) // -// Example: -// -// IcuErrorCode icuerrorcode; -// return collator.compareUTF8(a, b, icuerrorcode) == UCOL_EQUAL; +// This library implements an unrestricted Thompson/Pike UTF-8 parser and +// serializer. UTF-8 is a restricted subset of this byte stream encoding. See +// http://en.wikipedia.org/wiki/UTF-8 for a good description of the encoding +// details. #ifndef FST_LIB_ICU_H_ #define FST_LIB_ICU_H_ -#include <unicode/errorcode.h> -#include <unicode/unistr.h> -#include <unicode/ustring.h> -#include <unicode/utf8.h> - -class IcuErrorCode : public icu::ErrorCode { - public: - IcuErrorCode() {} - virtual ~IcuErrorCode() { if (isFailure()) handleFailure(); } - - // Redefine 'errorName()' in order to be compatible with ICU version 4.2 - const char* errorName() const { - return u_errorName(errorCode); - } - - protected: - virtual void handleFailure() const { - FSTERROR() << errorName(); -} - - private: - DISALLOW_COPY_AND_ASSIGN(IcuErrorCode); -}; +#include <iostream> +#include <fstream> +#include <sstream> namespace fst { template <class Label> bool UTF8StringToLabels(const string &str, vector<Label> *labels) { - const char *c_str = str.c_str(); - int32_t length = str.size(); - UChar32 c; - for (int32_t i = 0; i < length; /* no update */) { - U8_NEXT(c_str, i, length, c); - if (c < 0) { - LOG(ERROR) << "UTF8StringToLabels: Invalid character found: " << c; - return false; + const char *data = str.data(); + size_t length = str.size(); + for (int i = 0; i < length; /* no update */) { + int c = data[i++] & 0xff; + if ((c & 0x80) == 0) { + labels->push_back(c); + } else { + if ((c & 0xc0) == 0x80) { + LOG(ERROR) << "UTF8StringToLabels: continuation byte as lead byte"; + return false; + } + int count = (c >= 0xc0) + (c >= 0xe0) + (c >= 0xf0) + (c >= 0xf8) + + (c >= 0xfc); + int code = c & ((1 << (6 - count)) - 1); + while (count != 0) { + if (i == length) { + LOG(ERROR) << "UTF8StringToLabels: truncated utf-8 byte sequence"; + return false; + } + char cb = data[i++]; + if ((cb & 0xc0) != 0x80) { + LOG(ERROR) << "UTF8StringToLabels: missing/invalid continuation byte"; + return false; + } + code = (code << 6) | (cb & 0x3f); + count--; + } + if (code < 0) { + // This should not be able to happen. + LOG(ERROR) << "UTF8StringToLabels: Invalid character found: " << c; + return false; + } + labels->push_back(code); } - labels->push_back(c); } return true; } template <class Label> bool LabelsToUTF8String(const vector<Label> &labels, string *str) { - icu::UnicodeString u_str; - char c_str[5]; + ostringstream ostr; for (size_t i = 0; i < labels.size(); ++i) { - u_str.setTo(labels[i]); - IcuErrorCode error; - u_strToUTF8(c_str, 5, NULL, u_str.getTerminatedBuffer(), -1, error); - if (error.isFailure()) { - LOG(ERROR) << "LabelsToUTF8String: Bad encoding: " - << error.errorName(); + int32_t code = labels[i]; + if (code < 0) { + LOG(ERROR) << "LabelsToUTF8String: Invalid character found: " << code; return false; + } else if (code < 0x80) { + ostr << static_cast<char>(code); + } else if (code < 0x800) { + ostr << static_cast<char>((code >> 6) | 0xc0); + ostr << static_cast<char>((code & 0x3f) | 0x80); + } else if (code < 0x10000) { + ostr << static_cast<char>((code >> 12) | 0xe0); + ostr << static_cast<char>(((code >> 6) & 0x3f) | 0x80); + ostr << static_cast<char>((code & 0x3f) | 0x80); + } else if (code < 0x200000) { + ostr << static_cast<char>((code >> 18) | 0xf0); + ostr << static_cast<char>(((code >> 12) & 0x3f) | 0x80); + ostr << static_cast<char>(((code >> 6) & 0x3f) | 0x80); + ostr << static_cast<char>((code & 0x3f) | 0x80); + } else if (code < 0x4000000) { + ostr << static_cast<char>((code >> 24) | 0xf8); + ostr << static_cast<char>(((code >> 18) & 0x3f) | 0x80); + ostr << static_cast<char>(((code >> 12) & 0x3f) | 0x80); + ostr << static_cast<char>(((code >> 6) & 0x3f) | 0x80); + ostr << static_cast<char>((code & 0x3f) | 0x80); + } else { + ostr << static_cast<char>((code >> 30) | 0xfc); + ostr << static_cast<char>(((code >> 24) & 0x3f) | 0x80); + ostr << static_cast<char>(((code >> 18) & 0x3f) | 0x80); + ostr << static_cast<char>(((code >> 12) & 0x3f) | 0x80); + ostr << static_cast<char>(((code >> 6) & 0x3f) | 0x80); + ostr << static_cast<char>((code & 0x3f) | 0x80); } - *str += c_str; } + *str = ostr.str(); return true; } diff --git a/src/include/fst/lock.h b/src/include/fst/lock.h index 3adf7df..329015d 100644 --- a/src/include/fst/lock.h +++ b/src/include/fst/lock.h @@ -16,6 +16,9 @@ // // \file // Google-compatibility locking declarations and inline definitions +// +// Classes and functions here are no-ops (by design); proper locking requires +// actual implementation. #ifndef FST_LIB_LOCK_H__ #define FST_LIB_LOCK_H__ @@ -61,6 +64,14 @@ class MutexLock { DISALLOW_COPY_AND_ASSIGN(MutexLock); }; +class ReaderMutexLock { + public: + ReaderMutexLock(Mutex *) {} + + private: + DISALLOW_COPY_AND_ASSIGN(ReaderMutexLock); +}; + // Reference counting - single-thread implementation class RefCounter { public: diff --git a/src/include/fst/lookahead-matcher.h b/src/include/fst/lookahead-matcher.h index 10d9c01..f927d65 100644 --- a/src/include/fst/lookahead-matcher.h +++ b/src/include/fst/lookahead-matcher.h @@ -96,35 +96,35 @@ namespace fst { // LOOK-AHEAD FLAGS (see also kMatcherFlags in matcher.h): // // Matcher is a lookahead matcher when 'match_type' is MATCH_INPUT. -const uint32 kInputLookAheadMatcher = 0x00000001; +const uint32 kInputLookAheadMatcher = 0x00000010; // Matcher is a lookahead matcher when 'match_type' is MATCH_OUTPUT. -const uint32 kOutputLookAheadMatcher = 0x00000002; +const uint32 kOutputLookAheadMatcher = 0x00000020; // A non-trivial implementation of LookAheadWeight() method defined and // should be used? -const uint32 kLookAheadWeight = 0x00000004; +const uint32 kLookAheadWeight = 0x00000040; // A non-trivial implementation of LookAheadPrefix() method defined and // should be used? -const uint32 kLookAheadPrefix = 0x00000008; +const uint32 kLookAheadPrefix = 0x00000080; // Look-ahead of matcher FST non-epsilon arcs? -const uint32 kLookAheadNonEpsilons = 0x00000010; +const uint32 kLookAheadNonEpsilons = 0x00000100; // Look-ahead of matcher FST epsilon arcs? -const uint32 kLookAheadEpsilons = 0x00000020; +const uint32 kLookAheadEpsilons = 0x00000200; // Ignore epsilon paths for the lookahead prefix? Note this gives // correct results in composition only with an appropriate composition // filter since it depends on the filter blocking the ignored paths. -const uint32 kLookAheadNonEpsilonPrefix = 0x00000040; +const uint32 kLookAheadNonEpsilonPrefix = 0x00000400; // For LabelLookAheadMatcher, save relabeling data to file -const uint32 kLookAheadKeepRelabelData = 0x00000080; +const uint32 kLookAheadKeepRelabelData = 0x00000800; // Flags used for lookahead matchers. -const uint32 kLookAheadFlags = 0x000000ff; +const uint32 kLookAheadFlags = 0x00000ff0; // LookAhead Matcher interface, templated on the Arc definition; used // for lookahead matcher specializations that are returned by the @@ -601,10 +601,12 @@ bool LabelLookAheadMatcher<M, F, S>::LookAheadFst(const L &fst, StateId s) { bool reach_arc = label_reachable_->Reach(&aiter, 0, internal::NumArcs(*lfst_, s), reach_input, compute_weight); + Weight lfinal = internal::Final(*lfst_, s); + bool reach_final = lfinal != Weight::Zero() && label_reachable_->ReachFinal(); if (reach_arc) { ssize_t begin = label_reachable_->ReachBegin(); ssize_t end = label_reachable_->ReachEnd(); - if (compute_prefix && end - begin == 1) { + if (compute_prefix && end - begin == 1 && !reach_final) { aiter.Seek(begin); SetLookAheadPrefix(aiter.Value()); compute_weight = false; @@ -612,9 +614,6 @@ bool LabelLookAheadMatcher<M, F, S>::LookAheadFst(const L &fst, StateId s) { SetLookAheadWeight(label_reachable_->ReachWeight()); } } - Weight lfinal = internal::Final(*lfst_, s); - bool reach_final = lfinal != Weight::Zero() && - label_reachable_->ReachFinal(); if (reach_final && compute_weight) SetLookAheadWeight(reach_arc ? Plus(LookAheadWeight(), lfinal) : lfinal); diff --git a/src/include/fst/matcher.h b/src/include/fst/matcher.h index a89325b..5ab3d26 100644 --- a/src/include/fst/matcher.h +++ b/src/include/fst/matcher.h @@ -83,8 +83,17 @@ namespace fst { // uint64 Properties(uint64 props) const; // }; +// +// MATCHER FLAGS (see also kLookAheadFlags in lookahead-matcher.h) +// +// Matcher prefers being used as the matching side in composition. +const uint32 kPreferMatch = 0x00000001; + +// Matcher needs to be used as the matching side in composition. +const uint32 kRequireMatch = 0x00000002; + // Flags used for basic matchers (see also lookahead.h). -const uint32 kMatcherFlags = 0x00000000; +const uint32 kMatcherFlags = kPreferMatch | kRequireMatch; // Matcher interface, templated on the Arc definition; used // for matcher specializations that are returned by the @@ -452,6 +461,12 @@ class RhoMatcher : public MatcherBase<typename M::Arc> { virtual uint64 Properties(uint64 props) const; + virtual uint32 Flags() const { + if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE) + return matcher_->Flags(); + return matcher_->Flags() | kRequireMatch; + } + private: virtual void SetState_(StateId s) { SetState(s); } virtual bool Find_(Label label) { return Find(label); } @@ -631,6 +646,15 @@ class SigmaMatcher : public MatcherBase<typename M::Arc> { virtual uint64 Properties(uint64 props) const; + virtual uint32 Flags() const { + if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE) + return matcher_->Flags(); + // kRequireMatch temporarily disabled until issues + // in //speech/gaudi/annotation/util/denorm are resolved. + // return matcher_->Flags() | kRequireMatch; + return matcher_->Flags(); + } + private: virtual void SetState_(StateId s) { SetState(s); } virtual bool Find_(Label label) { return Find(label); } @@ -722,11 +746,6 @@ class PhiMatcher : public MatcherBase<typename M::Arc> { match_type_ = MATCH_NONE; error_ = true; } - if (phi_label == 0) { - FSTERROR() << "PhiMatcher: 0 cannot be used as phi_label"; - phi_label_ = kNoLabel; - error_ = true; - } if (rewrite_mode == MATCHER_REWRITE_AUTO) rewrite_both_ = fst.Properties(kAcceptor, true); @@ -768,10 +787,15 @@ class PhiMatcher : public MatcherBase<typename M::Arc> { const Arc& Value() const { if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) { return matcher_->Value(); + } else if (phi_match_ == 0) { // Virtual epsilon loop + phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_); + if (match_type_ == MATCH_OUTPUT) + swap(phi_arc_.ilabel, phi_arc_.olabel); + return phi_arc_; } else { phi_arc_ = matcher_->Value(); phi_arc_.weight = Times(phi_weight_, phi_arc_.weight); - if (phi_match_ != kNoLabel) { + if (phi_match_ != kNoLabel) { // Phi loop match if (rewrite_both_) { if (phi_arc_.ilabel == phi_label_) phi_arc_.ilabel = phi_match_; @@ -793,6 +817,12 @@ class PhiMatcher : public MatcherBase<typename M::Arc> { virtual uint64 Properties(uint64 props) const; + virtual uint32 Flags() const { + if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE) + return matcher_->Flags(); + return matcher_->Flags() | kRequireMatch; + } + private: virtual void SetState_(StateId s) { SetState(s); } virtual bool Find_(Label label) { return Find(label); } @@ -818,19 +848,33 @@ private: template <class M> inline bool PhiMatcher<M>::Find(Label match_label) { - if (match_label == phi_label_ && phi_label_ != kNoLabel) { - FSTERROR() << "PhiMatcher::Find: bad label (phi)"; + if (match_label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) { + FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_; error_ = true; return false; } matcher_->SetState(state_); phi_match_ = kNoLabel; phi_weight_ = Weight::One(); + if (phi_label_ == 0) { // When 'phi_label_ == 0', + if (match_label == kNoLabel) // there are no more true epsilon arcs, + return false; + if (match_label == 0) { // but virtual eps loop need to be returned + if (!matcher_->Find(kNoLabel)) { + return matcher_->Find(0); + } else { + phi_match_ = 0; + return true; + } + } + } if (!has_phi_ || match_label == 0 || match_label == kNoLabel) return matcher_->Find(match_label); StateId state = state_; while (!matcher_->Find(match_label)) { - if (!matcher_->Find(phi_label_)) + // Look for phi transition (if phi_label_ == 0, we need to look + // for -1 to avoid getting the virtual self-loop) + if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) return false; if (phi_loop_ && matcher_->Value().nextstate == state) { phi_match_ = match_label; @@ -856,6 +900,10 @@ uint64 PhiMatcher<M>::Properties(uint64 inprops) const { if (match_type_ == MATCH_NONE) { return outprops; } else if (match_type_ == MATCH_INPUT) { + if (phi_label_ == 0) { + outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons; + outprops |= kNoEpsilons | kNoIEpsilons; + } if (rewrite_both_) { return outprops & ~(kODeterministic | kNonODeterministic | kString | kILabelSorted | kNotILabelSorted | @@ -866,6 +914,10 @@ uint64 PhiMatcher<M>::Properties(uint64 inprops) const { kOLabelSorted | kNotOLabelSorted); } } else if (match_type_ == MATCH_OUTPUT) { + if (phi_label_ == 0) { + outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons; + outprops |= kNoEpsilons | kNoOEpsilons; + } if (rewrite_both_) { return outprops & ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted | kNotILabelSorted | diff --git a/src/include/fst/mutable-fst.h b/src/include/fst/mutable-fst.h index 9afcab3..09eb237 100644 --- a/src/include/fst/mutable-fst.h +++ b/src/include/fst/mutable-fst.h @@ -128,7 +128,7 @@ class MutableFst : public ExpandedFst<A> { } return Read(strm, FstReadOptions(filename)); } else { - return Read(std::cin, FstReadOptions("standard input")); + return Read(cin, FstReadOptions("standard input")); } } else { // Converts to 'convert_type' if not mutable. Fst<A> *ifst = Fst<A>::Read(filename); diff --git a/src/include/fst/queue.h b/src/include/fst/queue.h index 707dffc..e31f087 100644 --- a/src/include/fst/queue.h +++ b/src/include/fst/queue.h @@ -451,7 +451,7 @@ class SccQueue : public QueueBase<S> { while ((front_ <= back_) && (((*queue_)[front_] && (*queue_)[front_]->Empty()) || (((*queue_)[front_] == 0) && - ((front_ > trivial_queue_.size()) + ((front_ >= trivial_queue_.size()) || (trivial_queue_[front_] == kNoStateId))))) ++front_; if ((*queue_)[front_]) @@ -493,7 +493,7 @@ class SccQueue : public QueueBase<S> { else if ((*queue_)[front_]) return (*queue_)[front_]->Empty(); else - return (front_ > trivial_queue_.size()) + return (front_ >= trivial_queue_.size()) || (trivial_queue_[front_] == kNoStateId); } diff --git a/src/include/fst/register.h b/src/include/fst/register.h index 55651cd..ea3f4d8 100644 --- a/src/include/fst/register.h +++ b/src/include/fst/register.h @@ -28,6 +28,7 @@ #include <fst/compat.h> #include <iostream> #include <fstream> +#include <sstream> #include <fst/util.h> #include <fst/generic-register.h> diff --git a/src/include/fst/rmepsilon.h b/src/include/fst/rmepsilon.h index ee9753e..32e64de 100644 --- a/src/include/fst/rmepsilon.h +++ b/src/include/fst/rmepsilon.h @@ -110,7 +110,6 @@ class RmEpsilonState { class ElementKey { public: size_t operator()(const Element& e) const { - return static_cast<size_t>(e.nextstate); return static_cast<size_t>(e.nextstate + e.ilabel * kPrime0 + e.olabel * kPrime1); diff --git a/src/include/fst/script/compile-impl.h b/src/include/fst/script/compile-impl.h index 4aab15b..1743452 100644 --- a/src/include/fst/script/compile-impl.h +++ b/src/include/fst/script/compile-impl.h @@ -31,6 +31,7 @@ using std::vector; #include <iostream> #include <fstream> +#include <sstream> #include <fst/fst.h> #include <fst/util.h> #include <fst/vector-fst.h> diff --git a/src/include/fst/script/draw-impl.h b/src/include/fst/script/draw-impl.h index e346649..893e258 100644 --- a/src/include/fst/script/draw-impl.h +++ b/src/include/fst/script/draw-impl.h @@ -139,9 +139,9 @@ template <class A> class FstDrawer { EscapeChars(symbol, &nsymbol); PrintString(nsymbol); } else { - ostringstream sid; - sid << id; - PrintString(sid.str()); + string idstr; + Int64ToStr(id, &idstr); + PrintString(idstr); } } diff --git a/src/include/fst/script/draw.h b/src/include/fst/script/draw.h index 1611ad1..2b66373 100644 --- a/src/include/fst/script/draw.h +++ b/src/include/fst/script/draw.h @@ -22,6 +22,7 @@ #include <fst/script/draw-impl.h> #include <iostream> #include <fstream> +#include <sstream> namespace fst { namespace script { diff --git a/src/include/fst/script/fst-class.h b/src/include/fst/script/fst-class.h index 3eacab4..a820c1c 100644 --- a/src/include/fst/script/fst-class.h +++ b/src/include/fst/script/fst-class.h @@ -24,6 +24,7 @@ #include <fst/vector-fst.h> #include <iostream> #include <fstream> +#include <sstream> // Classes to support "boxing" all existing types of FST arcs in a single // FstClass which hides the arc types. This allows clients to load @@ -52,6 +53,7 @@ class FstClassBase { virtual const SymbolTable *InputSymbols() const = 0; virtual const SymbolTable *OutputSymbols() const = 0; virtual void Write(const string& fname) const = 0; + virtual void Write(ostream &ostr, const FstWriteOptions &opts) const = 0; virtual uint64 Properties(uint64 mask, bool test) const = 0; virtual ~FstClassBase() { } }; @@ -114,12 +116,18 @@ class FstClassImpl : public FstClassImplBase { impl_->Write(fname); } + virtual void Write(ostream &ostr, const FstWriteOptions &opts) const { + impl_->Write(ostr, opts); + } + virtual uint64 Properties(uint64 mask, bool test) const { return impl_->Properties(mask, test); } virtual ~FstClassImpl() { delete impl_; } + Fst<Arc> *GetImpl() const { return impl_; } + Fst<Arc> *GetImpl() { return impl_; } virtual FstClassImpl *Copy() { @@ -154,13 +162,25 @@ class FstClass : public FstClassBase { } } + FstClass() : impl_(NULL) { + } + template<class Arc> - explicit FstClass(Fst<Arc> *fst) : impl_(new FstClassImpl<Arc>(fst)) { } + explicit FstClass(Fst<Arc> *fst) : impl_(new FstClassImpl<Arc>(fst)) { + } explicit FstClass(const FstClass &other) : impl_(other.impl_->Copy()) { } + FstClass &operator=(const FstClass &other) { + delete impl_; + impl_ = other.impl_->Copy(); + return *this; + } + static FstClass *Read(const string &fname); + static FstClass *Read(istream &istr, const string &source); + virtual const string &ArcType() const { return impl_->ArcType(); } @@ -185,6 +205,10 @@ class FstClass : public FstClassBase { impl_->Write(fname); } + virtual void Write(ostream &ostr, const FstWriteOptions &opts) const { + impl_->Write(ostr, opts); + } + virtual uint64 Properties(uint64 mask, bool test) const { return impl_->Properties(mask, test); } @@ -214,6 +238,8 @@ class FstClass : public FstClassBase { << "particular arc type."; return 0; } + + protected: explicit FstClass(FstClassImplBase *impl) : impl_(impl) { } @@ -233,7 +259,12 @@ class FstClass : public FstClassBase { } } + FstClassImplBase *GetImpl() const { return impl_; } + FstClassImplBase *GetImpl() { return impl_; } + +// friend ostream &operator<<(ostream&, const FstClass&); + private: FstClassImplBase *impl_; }; @@ -269,6 +300,14 @@ class MutableFstClass : public FstClass { } } + virtual void Write(const string &fname) const { + GetImpl()->Write(fname); + } + + virtual void Write(ostream &ostr, const FstWriteOptions &opts) const { + GetImpl()->Write(ostr, opts); + } + static MutableFstClass *Read(const string &fname, bool convert = false); virtual void SetInputSymbols(SymbolTable *is) { @@ -338,6 +377,4 @@ class VectorFstClass : public MutableFstClass { } // namespace script } // namespace fst - - #endif // FST_SCRIPT_FST_CLASS_H_ diff --git a/src/include/fst/script/text-io.h b/src/include/fst/script/text-io.h index 95cc182..d97a007 100644 --- a/src/include/fst/script/text-io.h +++ b/src/include/fst/script/text-io.h @@ -32,6 +32,7 @@ using std::vector; #include <iostream> #include <fstream> +#include <sstream> #include <fst/script/weight-class.h> namespace fst { diff --git a/src/include/fst/script/weight-class.h b/src/include/fst/script/weight-class.h index 5a4890f..228216d 100644 --- a/src/include/fst/script/weight-class.h +++ b/src/include/fst/script/weight-class.h @@ -56,9 +56,9 @@ struct WeightClassImpl : public WeightImplBase { } virtual string to_string() const { - ostringstream s; - s << weight; - return s.str(); + string str; + WeightToStr(weight, &str); + return str; } virtual bool operator == (const WeightImplBase &other) const { diff --git a/src/include/fst/shortest-distance.h b/src/include/fst/shortest-distance.h index 5d38409..9320c4c 100644 --- a/src/include/fst/shortest-distance.h +++ b/src/include/fst/shortest-distance.h @@ -178,7 +178,7 @@ void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance( !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); - if (!arc_filter_(arc) || arc.weight == Weight::Zero()) + if (!arc_filter_(arc)) continue; while (distance_->size() <= arc.nextstate) { distance_->push_back(Weight::Zero()); diff --git a/src/include/fst/signed-log-weight.h b/src/include/fst/signed-log-weight.h index da96479..61adefb 100644 --- a/src/include/fst/signed-log-weight.h +++ b/src/include/fst/signed-log-weight.h @@ -113,9 +113,9 @@ inline SignedLogWeightTpl<T> Plus(const SignedLogWeightTpl<T> &w1, bool s2 = w2.Value1().Value() > 0.0; T f1 = w1.Value2().Value(); T f2 = w2.Value2().Value(); - if (f1 == FloatLimits<T>::kPosInfinity) + if (f1 == FloatLimits<T>::PosInfinity()) return w2; - else if (f2 == FloatLimits<T>::kPosInfinity) + else if (f2 == FloatLimits<T>::PosInfinity()) return w1; else if (f1 == f2) { if (s1 == s2) @@ -173,12 +173,12 @@ inline SignedLogWeightTpl<T> Divide(const SignedLogWeightTpl<T> &w1, bool s2 = w2.Value1().Value() > 0.0; T f1 = w1.Value2().Value(); T f2 = w2.Value2().Value(); - if (f2 == FloatLimits<T>::kPosInfinity) + if (f2 == FloatLimits<T>::PosInfinity()) return SignedLogWeightTpl<T>(TropicalWeight(1.0), - FloatLimits<T>::kNumberBad); - else if (f1 == FloatLimits<T>::kPosInfinity) + FloatLimits<T>::NumberBad()); + else if (f1 == FloatLimits<T>::PosInfinity()) return SignedLogWeightTpl<T>(TropicalWeight(1.0), - FloatLimits<T>::kPosInfinity); + FloatLimits<T>::PosInfinity()); else if (s1 == s2) return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 - f2)); else diff --git a/src/include/fst/slist.h b/src/include/fst/slist.h index 9f94027..d061ebe 100644 --- a/src/include/fst/slist.h +++ b/src/include/fst/slist.h @@ -22,7 +22,7 @@ #include <fst/config.h> -#if !defined(__ANDROID__) && defined(HAVE___GNU_CXX__SLIST_INT_) +#if !defined(__ANDROID__) && HAVE___GNU_CXX__SLIST_INT_ #include <slist> diff --git a/src/include/fst/state-map.h b/src/include/fst/state-map.h index ace4a3c..454db24 100644 --- a/src/include/fst/state-map.h +++ b/src/include/fst/state-map.h @@ -295,6 +295,10 @@ class StateMapFstImpl : public CacheImpl<B> { SetArcs(s); } + const Fst<A> &GetFst() const { + return *fst_; + } + private: void Init() { SetType("statemap"); @@ -364,10 +368,10 @@ class StateMapFst : public ImplToFst< StateMapFstImpl<A, B, C> > { GetImpl()->InitArcIterator(s, data); } - private: - // Makes visible to friends. + protected: Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } + private: void operator=(const StateMapFst<A, B, C> &fst); // disallow }; diff --git a/src/include/fst/string.h b/src/include/fst/string.h index 3099b87..d51182e 100644 --- a/src/include/fst/string.h +++ b/src/include/fst/string.h @@ -24,6 +24,7 @@ #define FST_LIB_STRING_H_ #include <fst/compact-fst.h> +#include <fst/icu.h> #include <fst/mutable-fst.h> DECLARE_string(fst_field_separator); @@ -48,7 +49,7 @@ class StringCompiler { // Compile string 's' into FST 'fst'. template <class F> - bool operator()(const string &s, F *fst) { + bool operator()(const string &s, F *fst) const { vector<Label> labels; if (!ConvertStringToLabels(s, &labels)) return false; diff --git a/src/include/fst/symbol-table.h b/src/include/fst/symbol-table.h index 93ebe76..6eb6c2d 100644 --- a/src/include/fst/symbol-table.h +++ b/src/include/fst/symbol-table.h @@ -33,6 +33,7 @@ using std::vector; #include <fst/compat.h> #include <iostream> #include <fstream> +#include <sstream> #include <map> @@ -56,6 +57,13 @@ struct SymbolTableReadOptions { string source; }; +struct SymbolTableTextOptions { + SymbolTableTextOptions(); + + bool allow_negative; + string fst_field_separator; +}; + class SymbolTableImpl { public: SymbolTableImpl(const string &name) @@ -88,9 +96,9 @@ class SymbolTableImpl { return (key == -1) ? AddSymbol(symbol, available_key_++) : key; } - static SymbolTableImpl* ReadText(istream &strm, - const string &name, - bool allow_negative = false); + static SymbolTableImpl* ReadText( + istream &strm, const string &name, + const SymbolTableTextOptions &opts = SymbolTableTextOptions()); static SymbolTableImpl* Read(istream &strm, const SymbolTableReadOptions& opts); @@ -149,13 +157,11 @@ class SymbolTableImpl { } string CheckSum() const { - MutexLock check_sum_lock(&check_sum_mutex_); MaybeRecomputeCheckSum(); return check_sum_string_; } string LabeledCheckSum() const { - MutexLock check_sum_lock(&check_sum_mutex_); MaybeRecomputeCheckSum(); return labeled_check_sum_string_; } @@ -171,6 +177,8 @@ class SymbolTableImpl { private: // Recomputes the checksums (both of them) if we've had changes since the last // computation (i.e., if check_sum_finalized_ is false). + // Takes ~2.5 microseconds (dbg) or ~230 nanoseconds (opt) on a 2.67GHz Xeon + // if the checksum is up-to-date (requiring no recomputation). void MaybeRecomputeCheckSum() const; struct StrCmp { @@ -188,8 +196,6 @@ class SymbolTableImpl { mutable RefCounter ref_count_; mutable bool check_sum_finalized_; - mutable CheckSummer check_sum_; - mutable CheckSummer labeled_check_sum_; mutable string check_sum_string_; mutable string labeled_check_sum_string_; mutable Mutex check_sum_mutex_; @@ -212,6 +218,9 @@ class SymbolTable { public: static const int64 kNoSymbol = -1; + // Construct symbol table with an unspecified name. + SymbolTable() : impl_(new SymbolTableImpl("<unspecified>")) {} + // Construct symbol table with a unique name. SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {} @@ -226,14 +235,21 @@ class SymbolTable { if (!impl_->DecrRefCount()) delete impl_; } + // Copys the implemenation from one symbol table to another. + void operator=(const SymbolTable &st) { + if (impl_ != st.impl_) { + st.impl_->IncrRefCount(); + if (!impl_->DecrRefCount()) delete impl_; + impl_ = st.impl_; + } + } + // Read an ascii representation of the symbol table from an istream. Pass a // name to give the resulting SymbolTable. - static SymbolTable* ReadText(istream &strm, - const string& name, - bool allow_negative = false) { - SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm, - name, - allow_negative); + static SymbolTable* ReadText( + istream &strm, const string& name, + const SymbolTableTextOptions &opts = SymbolTableTextOptions()) { + SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm, name, opts); if (!impl) return 0; else @@ -242,13 +258,13 @@ class SymbolTable { // read an ascii representation of the symbol table static SymbolTable* ReadText(const string& filename, - bool allow_negative = false) { + const SymbolTableTextOptions &opts = SymbolTableTextOptions()) { ifstream strm(filename.c_str(), ifstream::in); if (!strm) { LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename; return 0; } - return ReadText(strm, filename, allow_negative); + return ReadText(strm, filename, opts); } @@ -341,7 +357,9 @@ class SymbolTable { } // Dump an ascii text representation of the symbol table via a stream - virtual bool WriteText(ostream &strm) const; + virtual bool WriteText( + ostream &strm, + const SymbolTableTextOptions &opts = SymbolTableTextOptions()) const; // Dump an ascii text representation of the symbol table bool WriteText(const string& filename) const { @@ -404,8 +422,6 @@ class SymbolTable { private: SymbolTableImpl* impl_; - - void operator=(const SymbolTable &table); // disallow }; @@ -502,6 +518,20 @@ SymbolTable *RelabelSymbolTable(const SymbolTable *table, return new_table; } +// Symbol Table Serialization +inline void SymbolTableToString(const SymbolTable *table, string *result) { + ostringstream ostrm; + table->Write(ostrm); + *result = ostrm.str(); +} + +inline SymbolTable *StringToSymbolTable(const string &s) { + istringstream istrm(s); + return SymbolTable::Read(istrm, SymbolTableReadOptions()); +} + + + } // namespace fst #endif // FST_LIB_SYMBOL_TABLE_H__ diff --git a/src/include/fst/test-properties.h b/src/include/fst/test-properties.h index db1ddcc..12bcba7 100644 --- a/src/include/fst/test-properties.h +++ b/src/include/fst/test-properties.h @@ -125,13 +125,14 @@ uint64 ComputeProperties(const Fst<Arc> &fst, uint64 mask, uint64 *known, siter.Next()) { StateId s = siter.Value(); - Arc prev_arc(kNoLabel, kNoLabel, Weight::One(), 0); + Arc prev_arc; // Create these only if we need to if (mask & (kIDeterministic | kNonIDeterministic)) ilabels = new unordered_set<Label>; if (mask & (kODeterministic | kNonODeterministic)) olabels = new unordered_set<Label>; + bool first_arc = true; for (ArcIterator< Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) { @@ -161,13 +162,15 @@ uint64 ComputeProperties(const Fst<Arc> &fst, uint64 mask, uint64 *known, comp_props |= kOEpsilons; comp_props &= ~kNoOEpsilons; } - if (prev_arc.ilabel != kNoLabel && arc.ilabel < prev_arc.ilabel) { - comp_props |= kNotILabelSorted; - comp_props &= ~kILabelSorted; - } - if (prev_arc.olabel != kNoLabel && arc.olabel < prev_arc.olabel) { - comp_props |= kNotOLabelSorted; - comp_props &= ~kOLabelSorted; + if (!first_arc) { + if (arc.ilabel < prev_arc.ilabel) { + comp_props |= kNotILabelSorted; + comp_props &= ~kILabelSorted; + } + if (arc.olabel < prev_arc.olabel) { + comp_props |= kNotOLabelSorted; + comp_props &= ~kOLabelSorted; + } } if (arc.weight != Weight::One() && arc.weight != Weight::Zero()) { comp_props |= kWeighted; @@ -182,6 +185,7 @@ uint64 ComputeProperties(const Fst<Arc> &fst, uint64 mask, uint64 *known, comp_props &= ~kString; } prev_arc = arc; + first_arc = false; if (ilabels) ilabels->insert(arc.ilabel); if (olabels) diff --git a/src/include/fst/util.h b/src/include/fst/util.h index 87231e1..a325beb 100644 --- a/src/include/fst/util.h +++ b/src/include/fst/util.h @@ -41,6 +41,7 @@ using std::vector; #include <iostream> #include <fstream> +#include <sstream> // // UTILITY FOR ERROR HANDLING @@ -264,7 +265,7 @@ void WeightToStr(Weight w, string *s) { ostringstream strm; strm.precision(9); strm << w; - *s += strm.str(); + s->append(strm.str().data(), strm.str().size()); } // Utilities for reading/writing label pairs @@ -312,7 +313,7 @@ bool ReadLabelPairs(const string& filename, template <typename Label> bool WriteLabelPairs(const string& filename, const vector<pair<Label, Label> >& pairs) { - ostream *strm = &std::cout; + ostream *strm = &cout; if (!filename.empty()) { strm = new ofstream(filename.c_str()); if (!*strm) { @@ -329,7 +330,7 @@ bool WriteLabelPairs(const string& filename, << (filename.empty() ? "standard output" : filename); return false; } - if (strm != &std::cout) + if (strm != &cout) delete strm; return true; } diff --git a/src/include/fst/vector-fst.h b/src/include/fst/vector-fst.h index f6d8a6d..8b80876 100644 --- a/src/include/fst/vector-fst.h +++ b/src/include/fst/vector-fst.h @@ -273,9 +273,10 @@ class VectorFstImpl : public VectorFstBaseImpl< VectorState<A> > { SetProperties(DeleteArcsProperties(Properties())); } - private: // Properties always true of this Fst class static const uint64 kStaticProperties = kExpanded | kMutable; + + private: // Current file format version static const int kFileVersion = 2; // Minimum file format version supported @@ -542,7 +543,10 @@ bool VectorFst<A>::WriteFst(const F &fst, ostream &strm, hdr.SetNumStates(CountStates(fst)); update_header = false; } - FstImpl<A>::WriteFstHeader(fst, strm, opts, kFileVersion, "vector", &hdr); + uint64 properties = fst.Properties(kCopyProperties, false) | + VectorFstImpl<A>::kStaticProperties; + FstImpl<A>::WriteFstHeader(fst, strm, opts, kFileVersion, "vector", + properties, &hdr); StateId num_states = 0; for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { typename A::StateId s = siter.Value(); @@ -566,7 +570,7 @@ bool VectorFst<A>::WriteFst(const F &fst, ostream &strm, if (update_header) { hdr.SetNumStates(num_states); return FstImpl<A>::UpdateFstHeader(fst, strm, opts, kFileVersion, "vector", - &hdr, start_offset); + properties, &hdr, start_offset); } else { if (num_states != hdr.NumStates()) { LOG(ERROR) << "Inconsistent number of states observed during write"; diff --git a/src/include/fst/visit.h b/src/include/fst/visit.h index 31a00a8..a02d86a 100644 --- a/src/include/fst/visit.h +++ b/src/include/fst/visit.h @@ -166,7 +166,8 @@ void Visit(const Fst<Arc> &fst, V *visitor, Q *queue, ArcFilter filter) { // Finds next tree root for (root = root == start ? 0 : root + 1; root < nstates && state_status[root] != kWhiteState; - ++root); + ++root) { + } // Check for a state beyond the largest known state if (!expanded && root == nstates) { diff --git a/src/include/fst/weight.h b/src/include/fst/weight.h index 72f5a22..7eb4bb1 100644 --- a/src/include/fst/weight.h +++ b/src/include/fst/weight.h @@ -28,16 +28,18 @@ // A left semiring distributes on the left; a right semiring is // similarly defined. // -// A Weight class is required to be (at least) a left or right semiring. +// A Weight class must have binary functions =Plus= and =Times= and +// static member functions =Zero()= and =One()= and these must form +// (at least) a left or right semiring. // // In addition, the following should be defined for a Weight: // Member: predicate on set membership. -// NoWeight: returns an element that is not a member, should only be -// used to signal an error. -// >>: reads weight. -// <<: prints weight. -// Read(istream &strm): reads from an input stream. -// Write(ostream &strm): writes to an output stream. +// NoWeight: static member function that returns an element that is +// not a set member; used to signal an error. +// >>: reads textual representation of a weight. +// <<: prints textual representation of a weight. +// Read(istream &strm): reads binary representation of a weight. +// Write(ostream &strm): writes binary representation of a weight. // Hash: maps weight to size_t. // ApproxEqual: approximate equality (for inexact weights) // Quantize: quantizes wrt delta (for inexact weights) @@ -46,11 +48,9 @@ // and Times(a, b') == c // --> a' = Divide(c, b, DIVIDE_RIGHT) if a right semiring, a'.Member() // and Times(a', b) == c -// --> b' = Divide(c, a) -// = Divide(c, a, DIVIDE_ANY) -// = Divide(c, a, DIVIDE_LEFT) -// = Divide(c, a, DIVIDE_RIGHT) if a commutative semiring, -// b'.Member() and Times(a, b') == Times(b', a) == c +// --> b' = Divide(c, a) = Divide(c, a, DIVIDE_ANY) = +// Divide(c, a, DIVIDE_LEFT) = Divide(c, a, DIVIDE_RIGHT) if a +// commutative semiring, b'.Member() and Times(a, b') = Times(b', a) = c // ReverseWeight: the type of the corresponding reverse weight. // Typically the same type as Weight for a (both left and right) semiring. // For the left string semiring, it is the right string semiring. @@ -66,7 +66,7 @@ // RightSemiring: indicates weights form a right semiring. // Commutative: for all a,b: Times(a,b) == Times(b,a) // Idempotent: for all a: Plus(a, a) == a. -// Path Property: for all a, b: Plus(a, b) == a or Plus(a, b) == b. +// Path: for all a, b: Plus(a, b) == a or Plus(a, b) == b. #ifndef FST_LIB_WEIGHT_H__ |