aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/randgen.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/randgen.h')
-rw-r--r--src/include/fst/randgen.h712
1 files changed, 712 insertions, 0 deletions
diff --git a/src/include/fst/randgen.h b/src/include/fst/randgen.h
new file mode 100644
index 0000000..82ddffa
--- /dev/null
+++ b/src/include/fst/randgen.h
@@ -0,0 +1,712 @@
+// randgen.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Classes and functions to generate random paths through an FST.
+
+#ifndef FST_LIB_RANDGEN_H__
+#define FST_LIB_RANDGEN_H__
+
+#include <cmath>
+#include <cstdlib>
+#include <ctime>
+#include <map>
+
+#include <fst/accumulator.h>
+#include <fst/cache.h>
+#include <fst/dfs-visit.h>
+#include <fst/mutable-fst.h>
+
+namespace fst {
+
+//
+// ARC SELECTORS - these function objects are used to select a random
+// transition to take from an FST's state. They should return a number
+// N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th
+// transition is selected. If N == NumArcs(), then the final weight at
+// that state is selected (i.e., the 'super-final' transition is selected).
+// It can be assumed these will not be called unless either there
+// are transitions leaving the state and/or the state is final.
+//
+
+// Randomly selects a transition using the uniform distribution.
+template <class A>
+struct UniformArcSelector {
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ UniformArcSelector(int seed = time(0)) { srand(seed); }
+
+ size_t operator()(const Fst<A> &fst, StateId s) const {
+ double r = rand()/(RAND_MAX + 1.0);
+ size_t n = fst.NumArcs(s);
+ if (fst.Final(s) != Weight::Zero())
+ ++n;
+ return static_cast<size_t>(r * n);
+ }
+};
+
+
+// Randomly selects a transition w.r.t. the weights treated as negative
+// log probabilities after normalizing for the total weight leaving
+// the state. Weight::zero transitions are disregarded.
+// Assumes Weight::Value() accesses the floating point
+// representation of the weight.
+template <class A>
+class LogProbArcSelector {
+ public:
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ LogProbArcSelector(int seed = time(0)) { srand(seed); }
+
+ size_t operator()(const Fst<A> &fst, StateId s) const {
+ // Find total weight leaving state
+ double sum = 0.0;
+ for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
+ aiter.Next()) {
+ const A &arc = aiter.Value();
+ sum += exp(-to_log_weight_(arc.weight).Value());
+ }
+ sum += exp(-to_log_weight_(fst.Final(s)).Value());
+
+ double r = rand()/(RAND_MAX + 1.0);
+ double p = 0.0;
+ int n = 0;
+ for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
+ aiter.Next(), ++n) {
+ const A &arc = aiter.Value();
+ p += exp(-to_log_weight_(arc.weight).Value());
+ if (p > r * sum) return n;
+ }
+ return n;
+ }
+
+ private:
+ WeightConvert<Weight, Log64Weight> to_log_weight_;
+};
+
+// Convenience definitions
+typedef LogProbArcSelector<StdArc> StdArcSelector;
+typedef LogProbArcSelector<LogArc> LogArcSelector;
+
+
+// Same as LogProbArcSelector but use CacheLogAccumulator to cache
+// the cummulative weight computations.
+template <class A>
+class FastLogProbArcSelector : public LogProbArcSelector<A> {
+ public:
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+ using LogProbArcSelector<A>::operator();
+
+ FastLogProbArcSelector(int seed = time(0))
+ : LogProbArcSelector<A>(seed),
+ seed_(seed) {}
+
+ size_t operator()(const Fst<A> &fst, StateId s,
+ CacheLogAccumulator<A> *accumulator) const {
+ accumulator->SetState(s);
+ ArcIterator< Fst<A> > aiter(fst, s);
+ // Find total weight leaving state
+ double sum = to_log_weight_(accumulator->Sum(fst.Final(s), &aiter, 0,
+ fst.NumArcs(s))).Value();
+ double r = -log(rand()/(RAND_MAX + 1.0));
+ return accumulator->LowerBound(r + sum, &aiter);
+ }
+
+ int Seed() const { return seed_; }
+ private:
+ int seed_;
+ WeightConvert<Weight, Log64Weight> to_log_weight_;
+};
+
+// Random path state info maintained by RandGenFst and passed to samplers.
+template <typename A>
+struct RandState {
+ typedef typename A::StateId StateId;
+
+ StateId state_id; // current input FST state
+ size_t nsamples; // # of samples to be sampled at this state
+ size_t length; // length of path to this random state
+ size_t select; // previous sample arc selection
+ const RandState<A> *parent; // previous random state on this path
+
+ RandState(StateId s, size_t n, size_t l, size_t k, const RandState<A> *p)
+ : state_id(s), nsamples(n), length(l), select(k), parent(p) {}
+
+ RandState()
+ : state_id(kNoStateId), nsamples(0), length(0), select(0), parent(0) {}
+};
+
+// This class, given an arc selector, samples, with raplacement,
+// multiple random transitions from an FST's state. This is a generic
+// version with a straight-forward use of the arc selector.
+// Specializations may be defined for arc selectors for greater
+// efficiency or special behavior.
+template <class A, class S>
+class ArcSampler {
+ public:
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+
+ // The 'max_length' may be interpreted (including ignored) by a
+ // sampler as it chooses. This generic version interprets this literally.
+ ArcSampler(const Fst<A> &fst, const S &arc_selector,
+ int max_length = INT_MAX)
+ : fst_(fst),
+ arc_selector_(arc_selector),
+ max_length_(max_length) {}
+
+ // Allow updating Fst argument; pass only if changed.
+ ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0)
+ : fst_(fst ? *fst : sampler.fst_),
+ arc_selector_(sampler.arc_selector_),
+ max_length_(sampler.max_length_) {
+ Reset();
+ }
+
+ // Samples 'rstate.nsamples' from state 'state_id'. The 'rstate.length' is
+ // the length of the path to 'rstate'. Returns true if samples were
+ // collected. No samples may be collected if either there are no (including
+ // 'super-final') transitions leaving that state or if the
+ // 'max_length' has been deemed reached. Use the iterator members to
+ // read the samples. The samples will be in their original order.
+ bool Sample(const RandState<A> &rstate) {
+ sample_map_.clear();
+ if ((fst_.NumArcs(rstate.state_id) == 0 &&
+ fst_.Final(rstate.state_id) == Weight::Zero()) ||
+ rstate.length == max_length_) {
+ Reset();
+ return false;
+ }
+
+ for (size_t i = 0; i < rstate.nsamples; ++i)
+ ++sample_map_[arc_selector_(fst_, rstate.state_id)];
+ Reset();
+ return true;
+ }
+
+ // More samples?
+ bool Done() const { return sample_iter_ == sample_map_.end(); }
+
+ // Gets the next sample.
+ void Next() { ++sample_iter_; }
+
+ // Returns a pair (N, K) where 0 <= N <= NumArcs(s) and 0 < K <= nsamples.
+ // If N < NumArcs(s), then the N-th transition is specified.
+ // If N == NumArcs(s), then the final weight at that state is
+ // specified (i.e., the 'super-final' transition is specified).
+ // For the specified transition, K repetitions have been sampled.
+ pair<size_t, size_t> Value() const { return *sample_iter_; }
+
+ void Reset() { sample_iter_ = sample_map_.begin(); }
+
+ bool Error() const { return false; }
+
+ private:
+ const Fst<A> &fst_;
+ const S &arc_selector_;
+ int max_length_;
+
+ // Stores (N, K) as described for Value().
+ map<size_t, size_t> sample_map_;
+ map<size_t, size_t>::const_iterator sample_iter_;
+
+ // disallow
+ ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s);
+};
+
+
+// Specialization for FastLogProbArcSelector.
+template <class A>
+class ArcSampler<A, FastLogProbArcSelector<A> > {
+ public:
+ typedef FastLogProbArcSelector<A> S;
+ typedef typename A::StateId StateId;
+ typedef typename A::Weight Weight;
+ typedef CacheLogAccumulator<A> C;
+
+ ArcSampler(const Fst<A> &fst, const S &arc_selector, int max_length = INT_MAX)
+ : fst_(fst),
+ arc_selector_(arc_selector),
+ max_length_(max_length),
+ accumulator_(new C()) {
+ accumulator_->Init(fst);
+ }
+
+ ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0)
+ : fst_(fst ? *fst : sampler.fst_),
+ arc_selector_(sampler.arc_selector_),
+ max_length_(sampler.max_length_) {
+ if (fst) {
+ accumulator_ = new C();
+ accumulator_->Init(*fst);
+ } else { // shallow copy
+ accumulator_ = new C(*sampler.accumulator_);
+ }
+ }
+
+ ~ArcSampler() {
+ delete accumulator_;
+ }
+
+ bool Sample(const RandState<A> &rstate) {
+ sample_map_.clear();
+ if ((fst_.NumArcs(rstate.state_id) == 0 &&
+ fst_.Final(rstate.state_id) == Weight::Zero()) ||
+ rstate.length == max_length_) {
+ Reset();
+ return false;
+ }
+
+ for (size_t i = 0; i < rstate.nsamples; ++i)
+ ++sample_map_[arc_selector_(fst_, rstate.state_id, accumulator_)];
+ Reset();
+ return true;
+ }
+
+ bool Done() const { return sample_iter_ == sample_map_.end(); }
+ void Next() { ++sample_iter_; }
+ pair<size_t, size_t> Value() const { return *sample_iter_; }
+ void Reset() { sample_iter_ = sample_map_.begin(); }
+
+ bool Error() const { return accumulator_->Error(); }
+
+ private:
+ const Fst<A> &fst_;
+ const S &arc_selector_;
+ int max_length_;
+
+ // Stores (N, K) as described for Value().
+ map<size_t, size_t> sample_map_;
+ map<size_t, size_t>::const_iterator sample_iter_;
+ C *accumulator_;
+
+ // disallow
+ ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s);
+};
+
+
+// Options for random path generation with RandGenFst. The template argument
+// is an arc sampler, typically class 'ArcSampler' above. Ownership of
+// the sampler is taken by RandGenFst.
+template <class S>
+struct RandGenFstOptions : public CacheOptions {
+ S *arc_sampler; // How to sample transitions at a state
+ size_t npath; // # of paths to generate
+ bool weighted; // Output tree weighted by path count; o.w.
+ // output unweighted DAG
+ bool remove_total_weight; // Remove total weight when output is weighted.
+
+ RandGenFstOptions(const CacheOptions &copts, S *samp,
+ size_t n = 1, bool w = true, bool rw = false)
+ : CacheOptions(copts),
+ arc_sampler(samp),
+ npath(n),
+ weighted(w),
+ remove_total_weight(rw) {}
+};
+
+
+// Implementation of RandGenFst.
+template <class A, class B, class S>
+class RandGenFstImpl : public CacheImpl<B> {
+ public:
+ using FstImpl<B>::SetType;
+ using FstImpl<B>::SetProperties;
+ using FstImpl<B>::SetInputSymbols;
+ using FstImpl<B>::SetOutputSymbols;
+
+ using CacheBaseImpl< CacheState<B> >::AddArc;
+ using CacheBaseImpl< CacheState<B> >::HasArcs;
+ using CacheBaseImpl< CacheState<B> >::HasFinal;
+ using CacheBaseImpl< CacheState<B> >::HasStart;
+ using CacheBaseImpl< CacheState<B> >::SetArcs;
+ using CacheBaseImpl< CacheState<B> >::SetFinal;
+ using CacheBaseImpl< CacheState<B> >::SetStart;
+
+ typedef B Arc;
+ typedef typename A::Label Label;
+ typedef typename A::Weight Weight;
+ typedef typename A::StateId StateId;
+
+ RandGenFstImpl(const Fst<A> &fst, const RandGenFstOptions<S> &opts)
+ : CacheImpl<B>(opts),
+ fst_(fst.Copy()),
+ arc_sampler_(opts.arc_sampler),
+ npath_(opts.npath),
+ weighted_(opts.weighted),
+ remove_total_weight_(opts.remove_total_weight),
+ superfinal_(kNoLabel) {
+ SetType("randgen");
+
+ uint64 props = fst.Properties(kFstProperties, false);
+ SetProperties(RandGenProperties(props, weighted_), kCopyProperties);
+
+ SetInputSymbols(fst.InputSymbols());
+ SetOutputSymbols(fst.OutputSymbols());
+ }
+
+ RandGenFstImpl(const RandGenFstImpl &impl)
+ : CacheImpl<B>(impl),
+ fst_(impl.fst_->Copy(true)),
+ arc_sampler_(new S(*impl.arc_sampler_, fst_)),
+ npath_(impl.npath_),
+ weighted_(impl.weighted_),
+ superfinal_(kNoLabel) {
+ SetType("randgen");
+ SetProperties(impl.Properties(), kCopyProperties);
+ SetInputSymbols(impl.InputSymbols());
+ SetOutputSymbols(impl.OutputSymbols());
+ }
+
+ ~RandGenFstImpl() {
+ for (int i = 0; i < state_table_.size(); ++i)
+ delete state_table_[i];
+ delete fst_;
+ delete arc_sampler_;
+ }
+
+ StateId Start() {
+ if (!HasStart()) {
+ StateId s = fst_->Start();
+ if (s == kNoStateId)
+ return kNoStateId;
+ StateId start = state_table_.size();
+ SetStart(start);
+ RandState<A> *rstate = new RandState<A>(s, npath_, 0, 0, 0);
+ state_table_.push_back(rstate);
+ }
+ return CacheImpl<B>::Start();
+ }
+
+ Weight Final(StateId s) {
+ if (!HasFinal(s)) {
+ Expand(s);
+ }
+ return CacheImpl<B>::Final(s);
+ }
+
+ size_t NumArcs(StateId s) {
+ if (!HasArcs(s)) {
+ Expand(s);
+ }
+ return CacheImpl<B>::NumArcs(s);
+ }
+
+ size_t NumInputEpsilons(StateId s) {
+ if (!HasArcs(s))
+ Expand(s);
+ return CacheImpl<B>::NumInputEpsilons(s);
+ }
+
+ size_t NumOutputEpsilons(StateId s) {
+ if (!HasArcs(s))
+ Expand(s);
+ return CacheImpl<B>::NumOutputEpsilons(s);
+ }
+
+ uint64 Properties() const { return Properties(kFstProperties); }
+
+ // Set error if found; return FST impl properties.
+ uint64 Properties(uint64 mask) const {
+ if ((mask & kError) &&
+ (fst_->Properties(kError, false) || arc_sampler_->Error())) {
+ SetProperties(kError, kError);
+ }
+ return FstImpl<Arc>::Properties(mask);
+ }
+
+ void InitArcIterator(StateId s, ArcIteratorData<B> *data) {
+ if (!HasArcs(s))
+ Expand(s);
+ CacheImpl<B>::InitArcIterator(s, data);
+ }
+
+ // Computes the outgoing transitions from a state, creating new destination
+ // states as needed.
+ void Expand(StateId s) {
+ if (s == superfinal_) {
+ SetFinal(s, Weight::One());
+ SetArcs(s);
+ return;
+ }
+
+ SetFinal(s, Weight::Zero());
+ const RandState<A> &rstate = *state_table_[s];
+ arc_sampler_->Sample(rstate);
+ ArcIterator< Fst<A> > aiter(*fst_, rstate.state_id);
+ size_t narcs = fst_->NumArcs(rstate.state_id);
+ for (;!arc_sampler_->Done(); arc_sampler_->Next()) {
+ const pair<size_t, size_t> &sample_pair = arc_sampler_->Value();
+ size_t pos = sample_pair.first;
+ size_t count = sample_pair.second;
+ double prob = static_cast<double>(count)/rstate.nsamples;
+ if (pos < narcs) { // regular transition
+ aiter.Seek(sample_pair.first);
+ const A &aarc = aiter.Value();
+ Weight weight = weighted_ ? to_weight_(-log(prob)) : Weight::One();
+ B barc(aarc.ilabel, aarc.olabel, weight, state_table_.size());
+ AddArc(s, barc);
+ RandState<A> *nrstate =
+ new RandState<A>(aarc.nextstate, count, rstate.length + 1,
+ pos, &rstate);
+ state_table_.push_back(nrstate);
+ } else { // super-final transition
+ if (weighted_) {
+ Weight weight = remove_total_weight_ ?
+ to_weight_(-log(prob)) : to_weight_(-log(prob * npath_));
+ SetFinal(s, weight);
+ } else {
+ if (superfinal_ == kNoLabel) {
+ superfinal_ = state_table_.size();
+ RandState<A> *nrstate = new RandState<A>(kNoStateId, 0, 0, 0, 0);
+ state_table_.push_back(nrstate);
+ }
+ for (size_t n = 0; n < count; ++n) {
+ B barc(0, 0, Weight::One(), superfinal_);
+ AddArc(s, barc);
+ }
+ }
+ }
+ }
+ SetArcs(s);
+ }
+
+ private:
+ Fst<A> *fst_;
+ S *arc_sampler_;
+ size_t npath_;
+ vector<RandState<A> *> state_table_;
+ bool weighted_;
+ bool remove_total_weight_;
+ StateId superfinal_;
+ WeightConvert<Log64Weight, Weight> to_weight_;
+
+ void operator=(const RandGenFstImpl<A, B, S> &); // disallow
+};
+
+
+// Fst class to randomly generate paths through an FST; details controlled
+// by RandGenOptionsFst. Output format is a tree weighted by the
+// path count.
+template <class A, class B, class S>
+class RandGenFst : public ImplToFst< RandGenFstImpl<A, B, S> > {
+ public:
+ friend class ArcIterator< RandGenFst<A, B, S> >;
+ friend class StateIterator< RandGenFst<A, B, S> >;
+ typedef B Arc;
+ typedef S Sampler;
+ typedef typename A::Label Label;
+ typedef typename A::Weight Weight;
+ typedef typename A::StateId StateId;
+ typedef CacheState<B> State;
+ typedef RandGenFstImpl<A, B, S> Impl;
+
+ RandGenFst(const Fst<A> &fst, const RandGenFstOptions<S> &opts)
+ : ImplToFst<Impl>(new Impl(fst, opts)) {}
+
+ // See Fst<>::Copy() for doc.
+ RandGenFst(const RandGenFst<A, B, S> &fst, bool safe = false)
+ : ImplToFst<Impl>(fst, safe) {}
+
+ // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc.
+ virtual RandGenFst<A, B, S> *Copy(bool safe = false) const {
+ return new RandGenFst<A, B, S>(*this, safe);
+ }
+
+ virtual inline void InitStateIterator(StateIteratorData<B> *data) const;
+
+ virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const {
+ GetImpl()->InitArcIterator(s, data);
+ }
+
+ private:
+ // Makes visible to friends.
+ Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
+
+ void operator=(const RandGenFst<A, B, S> &fst); // Disallow
+};
+
+
+
+// Specialization for RandGenFst.
+template <class A, class B, class S>
+class StateIterator< RandGenFst<A, B, S> >
+ : public CacheStateIterator< RandGenFst<A, B, S> > {
+ public:
+ explicit StateIterator(const RandGenFst<A, B, S> &fst)
+ : CacheStateIterator< RandGenFst<A, B, S> >(fst, fst.GetImpl()) {}
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(StateIterator);
+};
+
+
+// Specialization for RandGenFst.
+template <class A, class B, class S>
+class ArcIterator< RandGenFst<A, B, S> >
+ : public CacheArcIterator< RandGenFst<A, B, S> > {
+ public:
+ typedef typename A::StateId StateId;
+
+ ArcIterator(const RandGenFst<A, B, S> &fst, StateId s)
+ : CacheArcIterator< RandGenFst<A, B, S> >(fst.GetImpl(), s) {
+ if (!fst.GetImpl()->HasArcs(s))
+ fst.GetImpl()->Expand(s);
+ }
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(ArcIterator);
+};
+
+
+template <class A, class B, class S> inline
+void RandGenFst<A, B, S>::InitStateIterator(StateIteratorData<B> *data) const
+{
+ data->base = new StateIterator< RandGenFst<A, B, S> >(*this);
+}
+
+// Options for random path generation.
+template <class S>
+struct RandGenOptions {
+ const S &arc_selector; // How an arc is selected at a state
+ int max_length; // Maximum path length
+ size_t npath; // # of paths to generate
+ bool weighted; // Output is tree weighted by path count; o.w.
+ // output unweighted union of paths.
+ bool remove_total_weight; // Remove total weight when output is weighted.
+
+ RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1,
+ bool w = false, bool rw = false)
+ : arc_selector(sel),
+ max_length(len),
+ npath(n),
+ weighted(w),
+ remove_total_weight(rw) {}
+};
+
+
+template <class IArc, class OArc>
+class RandGenVisitor {
+ public:
+ typedef typename IArc::Weight Weight;
+ typedef typename IArc::StateId StateId;
+
+ RandGenVisitor(MutableFst<OArc> *ofst) : ofst_(ofst) {}
+
+ void InitVisit(const Fst<IArc> &ifst) {
+ ifst_ = &ifst;
+
+ ofst_->DeleteStates();
+ ofst_->SetInputSymbols(ifst.InputSymbols());
+ ofst_->SetOutputSymbols(ifst.OutputSymbols());
+ if (ifst.Properties(kError, false))
+ ofst_->SetProperties(kError, kError);
+ path_.clear();
+ }
+
+ bool InitState(StateId s, StateId root) { return true; }
+
+ bool TreeArc(StateId s, const IArc &arc) {
+ if (ifst_->Final(arc.nextstate) == Weight::Zero()) {
+ path_.push_back(arc);
+ } else {
+ OutputPath();
+ }
+ return true;
+ }
+
+ bool BackArc(StateId s, const IArc &arc) {
+ FSTERROR() << "RandGenVisitor: cyclic input";
+ ofst_->SetProperties(kError, kError);
+ return false;
+ }
+
+ bool ForwardOrCrossArc(StateId s, const IArc &arc) {
+ OutputPath();
+ return true;
+ }
+
+ void FinishState(StateId s, StateId p, const IArc *) {
+ if (p != kNoStateId && ifst_->Final(s) == Weight::Zero())
+ path_.pop_back();
+ }
+
+ void FinishVisit() {}
+
+ private:
+ void OutputPath() {
+ if (ofst_->Start() == kNoStateId) {
+ StateId start = ofst_->AddState();
+ ofst_->SetStart(start);
+ }
+
+ StateId src = ofst_->Start();
+ for (size_t i = 0; i < path_.size(); ++i) {
+ StateId dest = ofst_->AddState();
+ OArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest);
+ ofst_->AddArc(src, arc);
+ src = dest;
+ }
+ ofst_->SetFinal(src, Weight::One());
+ }
+
+ const Fst<IArc> *ifst_;
+ MutableFst<OArc> *ofst_;
+ vector<OArc> path_;
+
+ DISALLOW_COPY_AND_ASSIGN(RandGenVisitor);
+};
+
+
+// Randomly generate paths through an FST; details controlled by
+// RandGenOptions.
+template<class IArc, class OArc, class Selector>
+void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst,
+ const RandGenOptions<Selector> &opts) {
+ typedef ArcSampler<IArc, Selector> Sampler;
+ typedef RandGenFst<IArc, OArc, Sampler> RandFst;
+ typedef typename OArc::StateId StateId;
+ typedef typename OArc::Weight Weight;
+
+ Sampler* arc_sampler = new Sampler(ifst, opts.arc_selector, opts.max_length);
+ RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), arc_sampler,
+ opts.npath, opts.weighted,
+ opts.remove_total_weight);
+ RandFst rfst(ifst, fopts);
+ if (opts.weighted) {
+ *ofst = rfst;
+ } else {
+ RandGenVisitor<IArc, OArc> rand_visitor(ofst);
+ DfsVisit(rfst, &rand_visitor);
+ }
+}
+
+// Randomly generate a path through an FST with the uniform distribution
+// over the transitions.
+template<class IArc, class OArc>
+void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst) {
+ UniformArcSelector<IArc> uniform_selector;
+ RandGenOptions< UniformArcSelector<IArc> > opts(uniform_selector);
+ RandGen(ifst, ofst, opts);
+}
+
+} // namespace fst
+
+#endif // FST_LIB_RANDGEN_H__