aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/minimize.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/minimize.h')
-rw-r--r--src/include/fst/minimize.h584
1 files changed, 584 insertions, 0 deletions
diff --git a/src/include/fst/minimize.h b/src/include/fst/minimize.h
new file mode 100644
index 0000000..3fbe3ba
--- /dev/null
+++ b/src/include/fst/minimize.h
@@ -0,0 +1,584 @@
+// minimize.h
+// minimize.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: johans@google.com (Johan Schalkwyk)
+//
+// \file Functions and classes to minimize a finite state acceptor
+//
+
+#ifndef FST_LIB_MINIMIZE_H__
+#define FST_LIB_MINIMIZE_H__
+
+#include <cmath>
+
+#include <algorithm>
+#include <map>
+#include <queue>
+#include <vector>
+using std::vector;
+
+#include <fst/arcsort.h>
+#include <fst/connect.h>
+#include <fst/dfs-visit.h>
+#include <fst/encode.h>
+#include <fst/factor-weight.h>
+#include <fst/fst.h>
+#include <fst/mutable-fst.h>
+#include <fst/partition.h>
+#include <fst/push.h>
+#include <fst/queue.h>
+#include <fst/reverse.h>
+#include <fst/state-map.h>
+
+
+namespace fst {
+
+// comparator for creating partition based on sorting on
+// - states
+// - final weight
+// - out degree,
+// - (input label, output label, weight, destination_block)
+template <class A>
+class StateComparator {
+ public:
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ static const uint32 kCompareFinal = 0x00000001;
+ static const uint32 kCompareOutDegree = 0x00000002;
+ static const uint32 kCompareArcs = 0x00000004;
+ static const uint32 kCompareAll = 0x00000007;
+
+ StateComparator(const Fst<A>& fst,
+ const Partition<typename A::StateId>& partition,
+ uint32 flags = kCompareAll)
+ : fst_(fst), partition_(partition), flags_(flags) {}
+
+ // compare state x with state y based on sort criteria
+ bool operator()(const StateId x, const StateId y) const {
+ // check for final state equivalence
+ if (flags_ & kCompareFinal) {
+ const size_t xfinal = fst_.Final(x).Hash();
+ const size_t yfinal = fst_.Final(y).Hash();
+ if (xfinal < yfinal) return true;
+ else if (xfinal > yfinal) return false;
+ }
+
+ if (flags_ & kCompareOutDegree) {
+ // check for # arcs
+ if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true;
+ if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false;
+
+ if (flags_ & kCompareArcs) {
+ // # arcs are equal, check for arc match
+ for (ArcIterator<Fst<A> > aiter1(fst_, x), aiter2(fst_, y);
+ !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
+ const A& arc1 = aiter1.Value();
+ const A& arc2 = aiter2.Value();
+ if (arc1.ilabel < arc2.ilabel) return true;
+ if (arc1.ilabel > arc2.ilabel) return false;
+
+ if (partition_.class_id(arc1.nextstate) <
+ partition_.class_id(arc2.nextstate)) return true;
+ if (partition_.class_id(arc1.nextstate) >
+ partition_.class_id(arc2.nextstate)) return false;
+ }
+ }
+ }
+
+ return false;
+ }
+
+ private:
+ const Fst<A>& fst_;
+ const Partition<typename A::StateId>& partition_;
+ const uint32 flags_;
+};
+
+template <class A> const uint32 StateComparator<A>::kCompareFinal;
+template <class A> const uint32 StateComparator<A>::kCompareOutDegree;
+template <class A> const uint32 StateComparator<A>::kCompareArcs;
+template <class A> const uint32 StateComparator<A>::kCompareAll;
+
+
+// Computes equivalence classes for cyclic Fsts. For cyclic minimization
+// we use the classic HopCroft minimization algorithm, which is of
+//
+// O(E)log(N),
+//
+// where E is the number of edges in the machine and N is number of states.
+//
+// The following paper describes the original algorithm
+// An N Log N algorithm for minimizing states in a finite automaton
+// by John HopCroft, January 1971
+//
+template <class A, class Queue>
+class CyclicMinimizer {
+ public:
+ typedef typename A::Label Label;
+ typedef typename A::StateId StateId;
+ typedef typename A::StateId ClassId;
+ typedef typename A::Weight Weight;
+ typedef ReverseArc<A> RevA;
+
+ CyclicMinimizer(const ExpandedFst<A>& fst) {
+ Initialize(fst);
+ Compute(fst);
+ }
+
+ ~CyclicMinimizer() {
+ delete aiter_queue_;
+ }
+
+ const Partition<StateId>& partition() const {
+ return P_;
+ }
+
+ // helper classes
+ private:
+ typedef ArcIterator<Fst<RevA> > ArcIter;
+ class ArcIterCompare {
+ public:
+ ArcIterCompare(const Partition<StateId>& partition)
+ : partition_(partition) {}
+
+ ArcIterCompare(const ArcIterCompare& comp)
+ : partition_(comp.partition_) {}
+
+ // compare two iterators based on there input labels, and proto state
+ // (partition class Ids)
+ bool operator()(const ArcIter* x, const ArcIter* y) const {
+ const RevA& xarc = x->Value();
+ const RevA& yarc = y->Value();
+ return (xarc.ilabel > yarc.ilabel);
+ }
+
+ private:
+ const Partition<StateId>& partition_;
+ };
+
+ typedef priority_queue<ArcIter*, vector<ArcIter*>, ArcIterCompare>
+ ArcIterQueue;
+
+ // helper methods
+ private:
+ // prepartitions the space into equivalence classes with
+ // same final weight
+ // same # arcs per state
+ // same outgoing arcs
+ void PrePartition(const Fst<A>& fst) {
+ VLOG(5) << "PrePartition";
+
+ typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
+ StateComparator<A> comp(fst, P_, StateComparator<A>::kCompareFinal);
+ EquivalenceMap equiv_map(comp);
+
+ StateIterator<Fst<A> > siter(fst);
+ StateId class_id = P_.AddClass();
+ P_.Add(siter.Value(), class_id);
+ equiv_map[siter.Value()] = class_id;
+ L_.Enqueue(class_id);
+ for (siter.Next(); !siter.Done(); siter.Next()) {
+ StateId s = siter.Value();
+ typename EquivalenceMap::const_iterator it = equiv_map.find(s);
+ if (it == equiv_map.end()) {
+ class_id = P_.AddClass();
+ P_.Add(s, class_id);
+ equiv_map[s] = class_id;
+ L_.Enqueue(class_id);
+ } else {
+ P_.Add(s, it->second);
+ equiv_map[s] = it->second;
+ }
+ }
+
+ VLOG(5) << "Initial Partition: " << P_.num_classes();
+ }
+
+ // - Create inverse transition Tr_ = rev(fst)
+ // - loop over states in fst and split on final, creating two blocks
+ // in the partition corresponding to final, non-final
+ void Initialize(const Fst<A>& fst) {
+ // construct Tr
+ Reverse(fst, &Tr_);
+ ILabelCompare<RevA> ilabel_comp;
+ ArcSort(&Tr_, ilabel_comp);
+
+ // initial split (F, S - F)
+ P_.Initialize(Tr_.NumStates() - 1);
+
+ // prep partition
+ PrePartition(fst);
+
+ // allocate arc iterator queue
+ ArcIterCompare comp(P_);
+ aiter_queue_ = new ArcIterQueue(comp);
+ }
+
+ // partition all classes with destination C
+ void Split(ClassId C) {
+ // Prep priority queue. Open arc iterator for each state in C, and
+ // insert into priority queue.
+ for (PartitionIterator<StateId> siter(P_, C);
+ !siter.Done(); siter.Next()) {
+ StateId s = siter.Value();
+ if (Tr_.NumArcs(s + 1))
+ aiter_queue_->push(new ArcIterator<Fst<RevA> >(Tr_, s + 1));
+ }
+
+ // Now pop arc iterator from queue, split entering equivalence class
+ // re-insert updated iterator into queue.
+ Label prev_label = -1;
+ while (!aiter_queue_->empty()) {
+ ArcIterator<Fst<RevA> >* aiter = aiter_queue_->top();
+ aiter_queue_->pop();
+ if (aiter->Done()) {
+ delete aiter;
+ continue;
+ }
+
+ const RevA& arc = aiter->Value();
+ StateId from_state = aiter->Value().nextstate - 1;
+ Label from_label = arc.ilabel;
+ if (prev_label != from_label)
+ P_.FinalizeSplit(&L_);
+
+ StateId from_class = P_.class_id(from_state);
+ if (P_.class_size(from_class) > 1)
+ P_.SplitOn(from_state);
+
+ prev_label = from_label;
+ aiter->Next();
+ if (aiter->Done())
+ delete aiter;
+ else
+ aiter_queue_->push(aiter);
+ }
+ P_.FinalizeSplit(&L_);
+ }
+
+ // Main loop for hopcroft minimization.
+ void Compute(const Fst<A>& fst) {
+ // process active classes (FIFO, or FILO)
+ while (!L_.Empty()) {
+ ClassId C = L_.Head();
+ L_.Dequeue();
+
+ // split on C, all labels in C
+ Split(C);
+ }
+ }
+
+ // helper data
+ private:
+ // Partioning of states into equivalence classes
+ Partition<StateId> P_;
+
+ // L = set of active classes to be processed in partition P
+ Queue L_;
+
+ // reverse transition function
+ VectorFst<RevA> Tr_;
+
+ // Priority queue of open arc iterators for all states in the 'splitter'
+ // equivalence class
+ ArcIterQueue* aiter_queue_;
+};
+
+
+// Computes equivalence classes for acyclic Fsts. The implementation details
+// for this algorithms is documented by the following paper.
+//
+// Minimization of acyclic deterministic automata in linear time
+// Dominque Revuz
+//
+// Complexity O(|E|)
+//
+template <class A>
+class AcyclicMinimizer {
+ public:
+ typedef typename A::Label Label;
+ typedef typename A::StateId StateId;
+ typedef typename A::StateId ClassId;
+ typedef typename A::Weight Weight;
+
+ AcyclicMinimizer(const ExpandedFst<A>& fst) {
+ Initialize(fst);
+ Refine(fst);
+ }
+
+ const Partition<StateId>& partition() {
+ return partition_;
+ }
+
+ // helper classes
+ private:
+ // DFS visitor to compute the height (distance) to final state.
+ class HeightVisitor {
+ public:
+ HeightVisitor() : max_height_(0), num_states_(0) { }
+
+ // invoked before dfs visit
+ void InitVisit(const Fst<A>& fst) {}
+
+ // invoked when state is discovered (2nd arg is DFS tree root)
+ bool InitState(StateId s, StateId root) {
+ // extend height array and initialize height (distance) to 0
+ for (size_t i = height_.size(); i <= s; ++i)
+ height_.push_back(-1);
+
+ if (s >= num_states_) num_states_ = s + 1;
+ return true;
+ }
+
+ // invoked when tree arc examined (to undiscoverted state)
+ bool TreeArc(StateId s, const A& arc) {
+ return true;
+ }
+
+ // invoked when back arc examined (to unfinished state)
+ bool BackArc(StateId s, const A& arc) {
+ return true;
+ }
+
+ // invoked when forward or cross arc examined (to finished state)
+ bool ForwardOrCrossArc(StateId s, const A& arc) {
+ if (height_[arc.nextstate] + 1 > height_[s])
+ height_[s] = height_[arc.nextstate] + 1;
+ return true;
+ }
+
+ // invoked when state finished (parent is kNoStateId for tree root)
+ void FinishState(StateId s, StateId parent, const A* parent_arc) {
+ if (height_[s] == -1) height_[s] = 0;
+ StateId h = height_[s] + 1;
+ if (parent >= 0) {
+ if (h > height_[parent]) height_[parent] = h;
+ if (h > max_height_) max_height_ = h;
+ }
+ }
+
+ // invoked after DFS visit
+ void FinishVisit() {}
+
+ size_t max_height() const { return max_height_; }
+
+ const vector<StateId>& height() const { return height_; }
+
+ const size_t num_states() const { return num_states_; }
+
+ private:
+ vector<StateId> height_;
+ size_t max_height_;
+ size_t num_states_;
+ };
+
+ // helper methods
+ private:
+ // cluster states according to height (distance to final state)
+ void Initialize(const Fst<A>& fst) {
+ // compute height (distance to final state)
+ HeightVisitor hvisitor;
+ DfsVisit(fst, &hvisitor);
+
+ // create initial partition based on height
+ partition_.Initialize(hvisitor.num_states());
+ partition_.AllocateClasses(hvisitor.max_height() + 1);
+ const vector<StateId>& hstates = hvisitor.height();
+ for (size_t s = 0; s < hstates.size(); ++s)
+ partition_.Add(s, hstates[s]);
+ }
+
+ // refine states based on arc sort (out degree, arc equivalence)
+ void Refine(const Fst<A>& fst) {
+ typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
+ StateComparator<A> comp(fst, partition_);
+
+ // start with tail (height = 0)
+ size_t height = partition_.num_classes();
+ for (size_t h = 0; h < height; ++h) {
+ EquivalenceMap equiv_classes(comp);
+
+ // sort states within equivalence class
+ PartitionIterator<StateId> siter(partition_, h);
+ equiv_classes[siter.Value()] = h;
+ for (siter.Next(); !siter.Done(); siter.Next()) {
+ const StateId s = siter.Value();
+ typename EquivalenceMap::const_iterator it = equiv_classes.find(s);
+ if (it == equiv_classes.end())
+ equiv_classes[s] = partition_.AddClass();
+ else
+ equiv_classes[s] = it->second;
+ }
+
+ // create refined partition
+ for (siter.Reset(); !siter.Done();) {
+ const StateId s = siter.Value();
+ const StateId old_class = partition_.class_id(s);
+ const StateId new_class = equiv_classes[s];
+
+ // a move operation can invalidate the iterator, so
+ // we first update the iterator to the next element
+ // before we move the current element out of the list
+ siter.Next();
+ if (old_class != new_class)
+ partition_.Move(s, new_class);
+ }
+ }
+ }
+
+ private:
+ Partition<StateId> partition_;
+};
+
+
+// Given a partition and a mutable fst, merge states of Fst inplace
+// (i.e. destructively). Merging works by taking the first state in
+// a class of the partition to be the representative state for the class.
+// Each arc is then reconnected to this state. All states in the class
+// are merged by adding there arcs to the representative state.
+template <class A>
+void MergeStates(
+ const Partition<typename A::StateId>& partition, MutableFst<A>* fst) {
+ typedef typename A::StateId StateId;
+
+ vector<StateId> state_map(partition.num_classes());
+ for (size_t i = 0; i < partition.num_classes(); ++i) {
+ PartitionIterator<StateId> siter(partition, i);
+ state_map[i] = siter.Value(); // first state in partition;
+ }
+
+ // relabel destination states
+ for (size_t c = 0; c < partition.num_classes(); ++c) {
+ for (PartitionIterator<StateId> siter(partition, c);
+ !siter.Done(); siter.Next()) {
+ StateId s = siter.Value();
+ for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
+ !aiter.Done(); aiter.Next()) {
+ A arc = aiter.Value();
+ arc.nextstate = state_map[partition.class_id(arc.nextstate)];
+
+ if (s == state_map[c]) // first state just set destination
+ aiter.SetValue(arc);
+ else
+ fst->AddArc(state_map[c], arc);
+ }
+ }
+ }
+ fst->SetStart(state_map[partition.class_id(fst->Start())]);
+
+ Connect(fst);
+}
+
+template <class A>
+void AcceptorMinimize(MutableFst<A>* fst) {
+ typedef typename A::StateId StateId;
+ if (!(fst->Properties(kAcceptor | kUnweighted, true))) {
+ FSTERROR() << "FST is not an unweighted acceptor";
+ fst->SetProperties(kError, kError);
+ return;
+ }
+
+ // connect fst before minimization, handles disconnected states
+ Connect(fst);
+ if (fst->NumStates() == 0) return;
+
+ if (fst->Properties(kAcyclic, true)) {
+ // Acyclic minimization (revuz)
+ VLOG(2) << "Acyclic Minimization";
+ ArcSort(fst, ILabelCompare<A>());
+ AcyclicMinimizer<A> minimizer(*fst);
+ MergeStates(minimizer.partition(), fst);
+
+ } else {
+ // Cyclic minimizaton (hopcroft)
+ VLOG(2) << "Cyclic Minimization";
+ CyclicMinimizer<A, LifoQueue<StateId> > minimizer(*fst);
+ MergeStates(minimizer.partition(), fst);
+ }
+
+ // Merge in appropriate semiring
+ ArcUniqueMapper<A> mapper(*fst);
+ StateMap(fst, mapper);
+}
+
+
+// In place minimization of deterministic weighted automata and transducers.
+// For transducers, then the 'sfst' argument is not null, the algorithm
+// produces a compact factorization of the minimal transducer.
+//
+// In the acyclic case, we use an algorithm from Dominique Revuz that
+// is linear in the number of arcs (edges) in the machine.
+// Complexity = O(E)
+//
+// In the cyclic case, we use the classical hopcroft minimization.
+// Complexity = O(|E|log(|N|)
+//
+template <class A>
+void Minimize(MutableFst<A>* fst,
+ MutableFst<A>* sfst = 0,
+ float delta = kDelta) {
+ uint64 props = fst->Properties(kAcceptor | kIDeterministic|
+ kWeighted | kUnweighted, true);
+ if (!(props & kIDeterministic)) {
+ FSTERROR() << "FST is not deterministic";
+ fst->SetProperties(kError, kError);
+ return;
+ }
+
+ if (!(props & kAcceptor)) { // weighted transducer
+ VectorFst< GallicArc<A, STRING_LEFT> > gfst;
+ ArcMap(*fst, &gfst, ToGallicMapper<A, STRING_LEFT>());
+ fst->DeleteStates();
+ gfst.SetProperties(kAcceptor, kAcceptor);
+ Push(&gfst, REWEIGHT_TO_INITIAL, delta);
+ ArcMap(&gfst, QuantizeMapper< GallicArc<A, STRING_LEFT> >(delta));
+ EncodeMapper< GallicArc<A, STRING_LEFT> >
+ encoder(kEncodeLabels | kEncodeWeights, ENCODE);
+ Encode(&gfst, &encoder);
+ AcceptorMinimize(&gfst);
+ Decode(&gfst, encoder);
+
+ if (sfst == 0) {
+ FactorWeightFst< GallicArc<A, STRING_LEFT>,
+ GallicFactor<typename A::Label,
+ typename A::Weight, STRING_LEFT> > fwfst(gfst);
+ SymbolTable *osyms = fst->OutputSymbols() ?
+ fst->OutputSymbols()->Copy() : 0;
+ ArcMap(fwfst, fst, FromGallicMapper<A, STRING_LEFT>());
+ fst->SetOutputSymbols(osyms);
+ delete osyms;
+ } else {
+ sfst->SetOutputSymbols(fst->OutputSymbols());
+ GallicToNewSymbolsMapper<A, STRING_LEFT> mapper(sfst);
+ ArcMap(gfst, fst, &mapper);
+ fst->SetOutputSymbols(sfst->InputSymbols());
+ }
+ } else if (props & kWeighted) { // weighted acceptor
+ Push(fst, REWEIGHT_TO_INITIAL, delta);
+ ArcMap(fst, QuantizeMapper<A>(delta));
+ EncodeMapper<A> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
+ Encode(fst, &encoder);
+ AcceptorMinimize(fst);
+ Decode(fst, encoder);
+ } else { // unweighted acceptor
+ AcceptorMinimize(fst);
+ }
+}
+
+} // namespace fst
+
+#endif // FST_LIB_MINIMIZE_H__