diff options
Diffstat (limited to 'src/include/fst/script')
-rw-r--r-- | src/include/fst/script/convert.h | 2 | ||||
-rw-r--r-- | src/include/fst/script/disambiguate.h | 68 | ||||
-rw-r--r-- | src/include/fst/script/fst-class.h | 44 | ||||
-rw-r--r-- | src/include/fst/script/map.h | 54 | ||||
-rw-r--r-- | src/include/fst/script/shortest-distance.h | 8 | ||||
-rw-r--r-- | src/include/fst/script/weight-class.h | 7 |
6 files changed, 134 insertions, 49 deletions
diff --git a/src/include/fst/script/convert.h b/src/include/fst/script/convert.h index 2c70a70..4a3ce6b 100644 --- a/src/include/fst/script/convert.h +++ b/src/include/fst/script/convert.h @@ -34,7 +34,7 @@ void Convert(ConvertArgs *args) { const string &new_type = args->args.arg2; Fst<Arc> *result = Convert(fst, new_type); - args->retval = new FstClass(result); + args->retval = new FstClass(*result); delete result; } diff --git a/src/include/fst/script/disambiguate.h b/src/include/fst/script/disambiguate.h new file mode 100644 index 0000000..e42a9c2 --- /dev/null +++ b/src/include/fst/script/disambiguate.h @@ -0,0 +1,68 @@ + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Jake Ratkiewicz) + +#ifndef FST_SCRIPT_DISAMBIGUATE_H_ +#define FST_SCRIPT_DISAMBIGUATE_H_ + +#include <fst/disambiguate.h> +#include <fst/script/arg-packs.h> +#include <fst/script/fst-class.h> +#include <fst/script/weight-class.h> + +namespace fst { +namespace script { + +struct DisambiguateOptions { + float delta; + WeightClass weight_threshold; + int64 state_threshold; + int64 subsequential_label; + + explicit DisambiguateOptions(float d = fst::kDelta, + WeightClass w = + fst::script::WeightClass::Zero(), + int64 n = fst::kNoStateId, int64 l = 0) + : delta(d), weight_threshold(w), state_threshold(n), + subsequential_label(l) {} +}; + +typedef args::Package<const FstClass&, MutableFstClass*, + const DisambiguateOptions &> DisambiguateArgs; + +template<class Arc> +void Disambiguate(DisambiguateArgs *args) { + const Fst<Arc> &ifst = *(args->arg1.GetFst<Arc>()); + MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); + const DisambiguateOptions &opts = args->arg3; + + fst::DisambiguateOptions<Arc> detargs; + detargs.delta = opts.delta; + detargs.weight_threshold = + *(opts.weight_threshold.GetWeight<typename Arc::Weight>()); + detargs.state_threshold = opts.state_threshold; + detargs.subsequential_label = opts.subsequential_label; + + Disambiguate(ifst, ofst, detargs); +} + +void Disambiguate(const FstClass &ifst, MutableFstClass *ofst, + const DisambiguateOptions &opts = + fst::script::DisambiguateOptions()); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_DISAMBIGUATE_H_ diff --git a/src/include/fst/script/fst-class.h b/src/include/fst/script/fst-class.h index a820c1c..fe2cf53 100644 --- a/src/include/fst/script/fst-class.h +++ b/src/include/fst/script/fst-class.h @@ -52,8 +52,8 @@ class FstClassBase { virtual const string &WeightType() const = 0; virtual const SymbolTable *InputSymbols() const = 0; virtual const SymbolTable *OutputSymbols() const = 0; - virtual void Write(const string& fname) const = 0; - virtual void Write(ostream &ostr, const FstWriteOptions &opts) const = 0; + virtual bool Write(const string& fname) const = 0; + virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const = 0; virtual uint64 Properties(uint64 mask, bool test) const = 0; virtual ~FstClassBase() { } }; @@ -82,6 +82,8 @@ class FstClassImpl : public FstClassImplBase { bool should_own = false) : impl_(should_own ? impl : impl->Copy()) { } + explicit FstClassImpl(const Fst<Arc> &impl) : impl_(impl.Copy()) { } + virtual const string &ArcType() const { return Arc::Type(); } @@ -112,12 +114,12 @@ class FstClassImpl : public FstClassImplBase { static_cast<MutableFst<Arc> *>(impl_)->SetOutputSymbols(os); } - virtual void Write(const string &fname) const { - impl_->Write(fname); + virtual bool Write(const string &fname) const { + return impl_->Write(fname); } - virtual void Write(ostream &ostr, const FstWriteOptions &opts) const { - impl_->Write(ostr, opts); + virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const { + return impl_->Write(ostr, opts); } virtual uint64 Properties(uint64 mask, bool test) const { @@ -166,10 +168,10 @@ class FstClass : public FstClassBase { } template<class Arc> - explicit FstClass(Fst<Arc> *fst) : impl_(new FstClassImpl<Arc>(fst)) { + explicit FstClass(const Fst<Arc> &fst) : impl_(new FstClassImpl<Arc>(fst)) { } - explicit FstClass(const FstClass &other) : impl_(other.impl_->Copy()) { } + FstClass(const FstClass &other) : impl_(other.impl_->Copy()) { } FstClass &operator=(const FstClass &other) { delete impl_; @@ -201,12 +203,12 @@ class FstClass : public FstClassBase { return impl_->WeightType(); } - virtual void Write(const string &fname) const { - impl_->Write(fname); + virtual bool Write(const string &fname) const { + return impl_->Write(fname); } - virtual void Write(ostream &ostr, const FstWriteOptions &opts) const { - impl_->Write(ostr, opts); + virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const { + return impl_->Write(ostr, opts); } virtual uint64 Properties(uint64 mask, bool test) const { @@ -253,7 +255,7 @@ class FstClass : public FstClassBase { if (!u) { return 0; } else { - FstClassT *r = new FstClassT(u); + FstClassT *r = new FstClassT(*u); delete u; return r; } @@ -276,7 +278,7 @@ class FstClass : public FstClassBase { class MutableFstClass : public FstClass { public: template<class Arc> - explicit MutableFstClass(MutableFst<Arc> *fst) : + explicit MutableFstClass(const MutableFst<Arc> &fst) : FstClass(fst) { } template<class Arc> @@ -294,18 +296,18 @@ class MutableFstClass : public FstClass { if (!mfst) { return 0; } else { - MutableFstClass *retval = new MutableFstClass(mfst); + MutableFstClass *retval = new MutableFstClass(*mfst); delete mfst; return retval; } } - virtual void Write(const string &fname) const { - GetImpl()->Write(fname); + virtual bool Write(const string &fname) const { + return GetImpl()->Write(fname); } - virtual void Write(ostream &ostr, const FstWriteOptions &opts) const { - GetImpl()->Write(ostr, opts); + virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const { + return GetImpl()->Write(ostr, opts); } static MutableFstClass *Read(const string &fname, bool convert = false); @@ -344,7 +346,7 @@ class VectorFstClass : public MutableFstClass { explicit VectorFstClass(const string &arc_type); template<class Arc> - explicit VectorFstClass(VectorFst<Arc> *fst) : + explicit VectorFstClass(const VectorFst<Arc> &fst) : MutableFstClass(fst) { } template<class Arc> @@ -354,7 +356,7 @@ class VectorFstClass : public MutableFstClass { if (!vfst) { return 0; } else { - VectorFstClass *retval = new VectorFstClass(vfst); + VectorFstClass *retval = new VectorFstClass(*vfst); delete vfst; return retval; } diff --git a/src/include/fst/script/map.h b/src/include/fst/script/map.h index 2332074..3caaa9f 100644 --- a/src/include/fst/script/map.h +++ b/src/include/fst/script/map.h @@ -59,46 +59,54 @@ void Map(MapArgs *args) { float delta = args->args.arg3; typename Arc::Weight w = *(args->args.arg4.GetWeight<typename Arc::Weight>()); + Fst<Arc> *fst = NULL; + Fst<LogArc> *lfst = NULL; + Fst<Log64Arc> *l64fst = NULL; + Fst<StdArc> *sfst = NULL; if (map_type == ARC_SUM_MAPPER) { - args->retval = new FstClass( - script::StateMap(ifst, ArcSumMapper<Arc>(ifst))); + args->retval = new FstClass(*(fst = + script::StateMap(ifst, ArcSumMapper<Arc>(ifst)))); } else if (map_type == IDENTITY_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, IdentityArcMapper<Arc>())); + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, IdentityArcMapper<Arc>()))); } else if (map_type == INVERT_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, InvertWeightMapper<Arc>())); + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, InvertWeightMapper<Arc>()))); } else if (map_type == PLUS_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, PlusMapper<Arc>(w))); + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, PlusMapper<Arc>(w)))); } else if (map_type == QUANTIZE_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, QuantizeMapper<Arc>(delta))); + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, QuantizeMapper<Arc>(delta)))); } else if (map_type == RMWEIGHT_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, RmWeightMapper<Arc>())); + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, RmWeightMapper<Arc>()))); } else if (map_type == SUPERFINAL_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, SuperFinalMapper<Arc>())); + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, SuperFinalMapper<Arc>()))); } else if (map_type == TIMES_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, TimesMapper<Arc>(w))); + args->retval = new FstClass(*(fst = + script::ArcMap(ifst, TimesMapper<Arc>(w)))); } else if (map_type == TO_LOG_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, WeightConvertMapper<Arc, LogArc>())); + args->retval = new FstClass(*(lfst = + script::ArcMap(ifst, WeightConvertMapper<Arc, LogArc>()))); } else if (map_type == TO_LOG64_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, WeightConvertMapper<Arc, Log64Arc>())); + args->retval = new FstClass(*(l64fst = + script::ArcMap(ifst, WeightConvertMapper<Arc, Log64Arc>()))); } else if (map_type == TO_STD_MAPPER) { - args->retval = new FstClass( - script::ArcMap(ifst, WeightConvertMapper<Arc, StdArc>())); + args->retval = new FstClass(*(sfst = + script::ArcMap(ifst, WeightConvertMapper<Arc, StdArc>()))); } else { FSTERROR() << "Error: unknown/unsupported mapper type: " << map_type; VectorFst<Arc> *ofst = new VectorFst<Arc>; ofst->SetProperties(kError, kError); - args->retval = new FstClass(ofst); + args->retval = new FstClass(*(fst =ofst)); } + delete sfst; + delete l64fst; + delete lfst; + delete fst; } diff --git a/src/include/fst/script/shortest-distance.h b/src/include/fst/script/shortest-distance.h index 5fc2976..39c5045 100644 --- a/src/include/fst/script/shortest-distance.h +++ b/src/include/fst/script/shortest-distance.h @@ -175,11 +175,11 @@ void ShortestDistance(ShortestDistanceArgs1 *args) { return; case FIFO_QUEUE: - ShortestDistanceHelper<Arc, FifoQueue<StateId> >(args); + ShortestDistanceHelper<Arc, FifoQueue<StateId> >(args); return; case LIFO_QUEUE: - ShortestDistanceHelper<Arc, LifoQueue<StateId> >(args); + ShortestDistanceHelper<Arc, LifoQueue<StateId> >(args); return; case SHORTEST_FIRST_QUEUE: @@ -188,11 +188,11 @@ void ShortestDistance(ShortestDistanceArgs1 *args) { return; case STATE_ORDER_QUEUE: - ShortestDistanceHelper<Arc, StateOrderQueue<StateId> >(args); + ShortestDistanceHelper<Arc, StateOrderQueue<StateId> >(args); return; case TOP_ORDER_QUEUE: - ShortestDistanceHelper<Arc, TopOrderQueue<StateId> >(args); + ShortestDistanceHelper<Arc, TopOrderQueue<StateId> >(args); return; } } diff --git a/src/include/fst/script/weight-class.h b/src/include/fst/script/weight-class.h index 228216d..b9f7ddf 100644 --- a/src/include/fst/script/weight-class.h +++ b/src/include/fst/script/weight-class.h @@ -128,6 +128,13 @@ class WeightClass { return w; } + const string &Type() const { + if (impl_) return impl_->Type(); + static const string no_type = "none"; + return no_type; + } + + ~WeightClass() { if (impl_) delete impl_; } private: enum ElementType { ZERO, ONE, OTHER }; |