// factor-weight.h // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Author: allauzen@cs.nyu.edu (Cyril Allauzen) // // \file // Classes to factor weights in an FST. #ifndef FST_LIB_FACTOR_WEIGHT_H__ #define FST_LIB_FACTOR_WEIGHT_H__ #include #include #include #include "fst/lib/cache.h" #include "fst/lib/test-properties.h" namespace fst { struct FactorWeightOptions : CacheOptions { float delta; bool final_only; // only factor final weights when true FactorWeightOptions(const CacheOptions &opts, float d, bool of) : CacheOptions(opts), delta(d), final_only(of) {} explicit FactorWeightOptions(float d, bool of = false) : delta(d), final_only(of) {} FactorWeightOptions(bool of = false) : delta(kDelta), final_only(of) {} }; // A factor iterator takes as argument a weight w and returns a // sequence of pairs of weights (xi,yi) such that the sum of the // products xi times yi is equal to w. If w is fully factored, // the iterator should return nothing. // // template // class FactorIterator { // public: // FactorIterator(W w); // bool Done() const; // void Next(); // pair Value() const; // void Reset(); // } // Factor trivially. template class IdentityFactor { public: IdentityFactor(const W &w) {} bool Done() const { return true; } void Next() {} pair Value() const { return make_pair(W::One(), W::One()); } // unused void Reset() {} }; // Factor a StringWeight w as 'ab' where 'a' is a label. template class StringFactor { public: StringFactor(const StringWeight &w) : weight_(w), done_(w.Size() <= 1) {} bool Done() const { return done_; } void Next() { done_ = true; } pair< StringWeight, StringWeight > Value() const { StringWeightIterator iter(weight_); StringWeight w1(iter.Value()); StringWeight w2; for (iter.Next(); !iter.Done(); iter.Next()) w2.PushBack(iter.Value()); return make_pair(w1, w2); } void Reset() { done_ = weight_.Size() <= 1; } private: StringWeight weight_; bool done_; }; // Factor a GallicWeight using StringFactor. template class GallicFactor { public: GallicFactor(const GallicWeight &w) : weight_(w), done_(w.Value1().Size() <= 1) {} bool Done() const { return done_; } void Next() { done_ = true; } pair< GallicWeight, GallicWeight > Value() const { StringFactor iter(weight_.Value1()); GallicWeight w1(iter.Value().first, weight_.Value2()); GallicWeight w2(iter.Value().second, W::One()); return make_pair(w1, w2); } void Reset() { done_ = weight_.Value1().Size() <= 1; } private: GallicWeight weight_; bool done_; }; // Implementation class for FactorWeight template class FactorWeightFstImpl : public CacheImpl { public: using FstImpl::SetType; using FstImpl::SetProperties; using FstImpl::Properties; using FstImpl::SetInputSymbols; using FstImpl::SetOutputSymbols; using CacheBaseImpl< CacheState >::HasStart; using CacheBaseImpl< CacheState >::HasFinal; using CacheBaseImpl< CacheState >::HasArcs; typedef A Arc; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef F FactorIterator; struct Element { Element() {} Element(StateId s, Weight w) : state(s), weight(w) {} StateId state; // Input state Id Weight weight; // Residual weight }; FactorWeightFstImpl(const Fst &fst, const FactorWeightOptions &opts) : CacheImpl(opts), fst_(fst.Copy()), delta_(opts.delta), final_only_(opts.final_only) { SetType("factor-weight"); uint64 props = fst.Properties(kFstProperties, false); SetProperties(FactorWeightProperties(props), kCopyProperties); SetInputSymbols(fst.InputSymbols()); SetOutputSymbols(fst.OutputSymbols()); } ~FactorWeightFstImpl() { delete fst_; } StateId Start() { if (!HasStart()) { StateId s = fst_->Start(); if (s == kNoStateId) return kNoStateId; StateId start = FindState(Element(fst_->Start(), Weight::One())); this->SetStart(start); } return CacheImpl::Start(); } Weight Final(StateId s) { if (!HasFinal(s)) { const Element &e = elements_[s]; // TODO: fix so cast is unnecessary Weight w = e.state == kNoStateId ? e.weight : (Weight) Times(e.weight, fst_->Final(e.state)); FactorIterator f(w); if (w != Weight::Zero() && f.Done()) this->SetFinal(s, w); else this->SetFinal(s, Weight::Zero()); } return CacheImpl::Final(s); } size_t NumArcs(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumArcs(s); } size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumInputEpsilons(s); } size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl::NumOutputEpsilons(s); } void InitArcIterator(StateId s, ArcIteratorData *data) { if (!HasArcs(s)) Expand(s); CacheImpl::InitArcIterator(s, data); } // Find state corresponding to an element. Create new state // if element not found. StateId FindState(const Element &e) { if (final_only_ && e.weight == Weight::One()) { while (unfactored_.size() <= (unsigned int)e.state) unfactored_.push_back(kNoStateId); if (unfactored_[e.state] == kNoStateId) { unfactored_[e.state] = elements_.size(); elements_.push_back(e); } return unfactored_[e.state]; } else { typename ElementMap::iterator eit = element_map_.find(e); if (eit != element_map_.end()) { return (*eit).second; } else { StateId s = elements_.size(); elements_.push_back(e); element_map_.insert(pair(e, s)); return s; } } } // Computes the outgoing transitions from a state, creating new destination // states as needed. void Expand(StateId s) { Element e = elements_[s]; if (e.state != kNoStateId) { for (ArcIterator< Fst > ait(*fst_, e.state); !ait.Done(); ait.Next()) { const A &arc = ait.Value(); Weight w = Times(e.weight, arc.weight); FactorIterator fit(w); if (final_only_ || fit.Done()) { StateId d = FindState(Element(arc.nextstate, Weight::One())); this->AddArc(s, Arc(arc.ilabel, arc.olabel, w, d)); } else { for (; !fit.Done(); fit.Next()) { const pair &p = fit.Value(); StateId d = FindState(Element(arc.nextstate, p.second.Quantize(delta_))); this->AddArc(s, Arc(arc.ilabel, arc.olabel, p.first, d)); } } } } if ((e.state == kNoStateId) || (fst_->Final(e.state) != Weight::Zero())) { Weight w = e.state == kNoStateId ? e.weight : Times(e.weight, fst_->Final(e.state)); for (FactorIterator fit(w); !fit.Done(); fit.Next()) { const pair &p = fit.Value(); StateId d = FindState(Element(kNoStateId, p.second.Quantize(delta_))); this->AddArc(s, Arc(0, 0, p.first, d)); } } this->SetArcs(s); } private: // Equality function for Elements, assume weights have been quantized. class ElementEqual { public: bool operator()(const Element &x, const Element &y) const { return x.state == y.state && x.weight == y.weight; } }; // Hash function for Elements to Fst states. class ElementKey { public: size_t operator()(const Element &x) const { return static_cast(x.state * kPrime + x.weight.Hash()); } private: static const int kPrime = 7853; }; typedef std::unordered_map ElementMap; const Fst *fst_; float delta_; bool final_only_; vector elements_; // mapping Fst state to Elements ElementMap element_map_; // mapping Elements to Fst state // mapping between old/new 'StateId' for states that do not need to // be factored when 'final_only_' is true vector unfactored_; DISALLOW_EVIL_CONSTRUCTORS(FactorWeightFstImpl); }; // FactorWeightFst takes as template parameter a FactorIterator as // defined above. The result of weight factoring is a transducer // equivalent to the input whose path weights have been factored // according to the FactorIterator. States and transitions will be // added as necessary. The algorithm is a generalization to arbitrary // weights of the second step of the input epsilon-normalization // algorithm due to Mohri, "Generic epsilon-removal and input // epsilon-normalization algorithms for weighted transducers", // International Journal of Computer Science 13(1): 129-143 (2002). template class FactorWeightFst : public Fst { public: friend class ArcIterator< FactorWeightFst >; friend class CacheStateIterator< FactorWeightFst >; friend class CacheArcIterator< FactorWeightFst >; typedef A Arc; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef CacheState State; FactorWeightFst(const Fst &fst) : impl_(new FactorWeightFstImpl(fst, FactorWeightOptions())) {} FactorWeightFst(const Fst &fst, const FactorWeightOptions &opts) : impl_(new FactorWeightFstImpl(fst, opts)) {} FactorWeightFst(const FactorWeightFst &fst) : Fst(fst), impl_(fst.impl_) { impl_->IncrRefCount(); } virtual ~FactorWeightFst() { if (!impl_->DecrRefCount()) delete impl_; } virtual StateId Start() const { return impl_->Start(); } virtual Weight Final(StateId s) const { return impl_->Final(s); } virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); } virtual size_t NumInputEpsilons(StateId s) const { return impl_->NumInputEpsilons(s); } virtual size_t NumOutputEpsilons(StateId s) const { return impl_->NumOutputEpsilons(s); } virtual uint64 Properties(uint64 mask, bool test) const { if (test) { uint64 known, test = TestProperties(*this, mask, &known); impl_->SetProperties(test, known); return test & mask; } else { return impl_->Properties(mask); } } virtual const string& Type() const { return impl_->Type(); } virtual FactorWeightFst *Copy() const { return new FactorWeightFst(*this); } virtual const SymbolTable* InputSymbols() const { return impl_->InputSymbols(); } virtual const SymbolTable* OutputSymbols() const { return impl_->OutputSymbols(); } virtual inline void InitStateIterator(StateIteratorData *data) const; virtual void InitArcIterator(StateId s, ArcIteratorData *data) const { impl_->InitArcIterator(s, data); } private: FactorWeightFstImpl *Impl() { return impl_; } FactorWeightFstImpl *impl_; void operator=(const FactorWeightFst &fst); // Disallow }; // Specialization for FactorWeightFst. template class StateIterator< FactorWeightFst > : public CacheStateIterator< FactorWeightFst > { public: explicit StateIterator(const FactorWeightFst &fst) : CacheStateIterator< FactorWeightFst >(fst) {} }; // Specialization for FactorWeightFst. template class ArcIterator< FactorWeightFst > : public CacheArcIterator< FactorWeightFst > { public: typedef typename A::StateId StateId; ArcIterator(const FactorWeightFst &fst, StateId s) : CacheArcIterator< FactorWeightFst >(fst, s) { if (!fst.impl_->HasArcs(s)) fst.impl_->Expand(s); } private: DISALLOW_EVIL_CONSTRUCTORS(ArcIterator); }; template inline void FactorWeightFst::InitStateIterator(StateIteratorData *data) const { data->base = new StateIterator< FactorWeightFst >(*this); } } // namespace fst #endif // FST_LIB_FACTOR_WEIGHT_H__