aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/float-weight.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/float-weight.h')
-rw-r--r--src/include/fst/float-weight.h598
1 files changed, 598 insertions, 0 deletions
diff --git a/src/include/fst/float-weight.h b/src/include/fst/float-weight.h
new file mode 100644
index 0000000..530cbdd
--- /dev/null
+++ b/src/include/fst/float-weight.h
@@ -0,0 +1,598 @@
+// float-weight.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Float weight set and associated semiring operation definitions.
+//
+
+#ifndef FST_LIB_FLOAT_WEIGHT_H__
+#define FST_LIB_FLOAT_WEIGHT_H__
+
+#include <limits>
+#include <climits>
+#include <sstream>
+#include <string>
+
+#include <fst/util.h>
+#include <fst/weight.h>
+
+
+namespace fst {
+
+// numeric limits class
+template <class T>
+class FloatLimits {
+ public:
+ static const T kPosInfinity;
+ static const T kNegInfinity;
+ static const T kNumberBad;
+};
+
+template <class T>
+const T FloatLimits<T>::kPosInfinity = numeric_limits<T>::infinity();
+
+template <class T>
+const T FloatLimits<T>::kNegInfinity = -FloatLimits<T>::kPosInfinity;
+
+template <class T>
+const T FloatLimits<T>::kNumberBad = numeric_limits<T>::quiet_NaN();
+
+// weight class to be templated on floating-points types
+template <class T = float>
+class FloatWeightTpl {
+ public:
+ FloatWeightTpl() {}
+
+ FloatWeightTpl(T f) : value_(f) {}
+
+ FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {}
+
+ FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) {
+ value_ = w.value_;
+ return *this;
+ }
+
+ istream &Read(istream &strm) {
+ return ReadType(strm, &value_);
+ }
+
+ ostream &Write(ostream &strm) const {
+ return WriteType(strm, value_);
+ }
+
+ size_t Hash() const {
+ union {
+ T f;
+ size_t s;
+ } u;
+ u.s = 0;
+ u.f = value_;
+ return u.s;
+ }
+
+ const T &Value() const { return value_; }
+
+ protected:
+ void SetValue(const T &f) { value_ = f; }
+
+ inline static string GetPrecisionString() {
+ int64 size = sizeof(T);
+ if (size == sizeof(float)) return "";
+ size *= CHAR_BIT;
+
+ string result;
+ Int64ToStr(size, &result);
+ return result;
+ }
+
+ private:
+ T value_;
+};
+
+// Single-precision float weight
+typedef FloatWeightTpl<float> FloatWeight;
+
+template <class T>
+inline bool operator==(const FloatWeightTpl<T> &w1,
+ const FloatWeightTpl<T> &w2) {
+ // Volatile qualifier thwarts over-aggressive compiler optimizations
+ // that lead to problems esp. with NaturalLess().
+ volatile T v1 = w1.Value();
+ volatile T v2 = w2.Value();
+ return v1 == v2;
+}
+
+inline bool operator==(const FloatWeightTpl<double> &w1,
+ const FloatWeightTpl<double> &w2) {
+ return operator==<double>(w1, w2);
+}
+
+inline bool operator==(const FloatWeightTpl<float> &w1,
+ const FloatWeightTpl<float> &w2) {
+ return operator==<float>(w1, w2);
+}
+
+template <class T>
+inline bool operator!=(const FloatWeightTpl<T> &w1,
+ const FloatWeightTpl<T> &w2) {
+ return !(w1 == w2);
+}
+
+inline bool operator!=(const FloatWeightTpl<double> &w1,
+ const FloatWeightTpl<double> &w2) {
+ return operator!=<double>(w1, w2);
+}
+
+inline bool operator!=(const FloatWeightTpl<float> &w1,
+ const FloatWeightTpl<float> &w2) {
+ return operator!=<float>(w1, w2);
+}
+
+template <class T>
+inline bool ApproxEqual(const FloatWeightTpl<T> &w1,
+ const FloatWeightTpl<T> &w2,
+ float delta = kDelta) {
+ return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
+}
+
+template <class T>
+inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) {
+ if (w.Value() == FloatLimits<T>::kPosInfinity)
+ return strm << "Infinity";
+ else if (w.Value() == FloatLimits<T>::kNegInfinity)
+ return strm << "-Infinity";
+ else if (w.Value() != w.Value()) // Fails for NaN
+ return strm << "BadNumber";
+ else
+ return strm << w.Value();
+}
+
+template <class T>
+inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) {
+ string s;
+ strm >> s;
+ if (s == "Infinity") {
+ w = FloatWeightTpl<T>(FloatLimits<T>::kPosInfinity);
+ } else if (s == "-Infinity") {
+ w = FloatWeightTpl<T>(FloatLimits<T>::kNegInfinity);
+ } else {
+ char *p;
+ T f = strtod(s.c_str(), &p);
+ if (p < s.c_str() + s.size())
+ strm.clear(std::ios::badbit);
+ else
+ w = FloatWeightTpl<T>(f);
+ }
+ return strm;
+}
+
+
+// Tropical semiring: (min, +, inf, 0)
+template <class T>
+class TropicalWeightTpl : public FloatWeightTpl<T> {
+ public:
+ using FloatWeightTpl<T>::Value;
+
+ typedef TropicalWeightTpl<T> ReverseWeight;
+
+ TropicalWeightTpl() : FloatWeightTpl<T>() {}
+
+ TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
+
+ TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
+
+ static const TropicalWeightTpl<T> Zero() {
+ return TropicalWeightTpl<T>(FloatLimits<T>::kPosInfinity); }
+
+ static const TropicalWeightTpl<T> One() {
+ return TropicalWeightTpl<T>(0.0F); }
+
+ static const TropicalWeightTpl<T> NoWeight() {
+ return TropicalWeightTpl<T>(FloatLimits<T>::kNumberBad); }
+
+ static const string &Type() {
+ static const string type = "tropical" +
+ FloatWeightTpl<T>::GetPrecisionString();
+ return type;
+ }
+
+ bool Member() const {
+ // First part fails for IEEE NaN
+ return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity;
+ }
+
+ TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
+ if (Value() == FloatLimits<T>::kNegInfinity ||
+ Value() == FloatLimits<T>::kPosInfinity ||
+ Value() != Value())
+ return *this;
+ else
+ return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
+ }
+
+ TropicalWeightTpl<T> Reverse() const { return *this; }
+
+ static uint64 Properties() {
+ return kLeftSemiring | kRightSemiring | kCommutative |
+ kPath | kIdempotent;
+ }
+};
+
+// Single precision tropical weight
+typedef TropicalWeightTpl<float> TropicalWeight;
+
+template <class T>
+inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1,
+ const TropicalWeightTpl<T> &w2) {
+ if (!w1.Member() || !w2.Member())
+ return TropicalWeightTpl<T>::NoWeight();
+ return w1.Value() < w2.Value() ? w1 : w2;
+}
+
+inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1,
+ const TropicalWeightTpl<float> &w2) {
+ return Plus<float>(w1, w2);
+}
+
+inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1,
+ const TropicalWeightTpl<double> &w2) {
+ return Plus<double>(w1, w2);
+}
+
+template <class T>
+inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1,
+ const TropicalWeightTpl<T> &w2) {
+ if (!w1.Member() || !w2.Member())
+ return TropicalWeightTpl<T>::NoWeight();
+ T f1 = w1.Value(), f2 = w2.Value();
+ if (f1 == FloatLimits<T>::kPosInfinity)
+ return w1;
+ else if (f2 == FloatLimits<T>::kPosInfinity)
+ return w2;
+ else
+ return TropicalWeightTpl<T>(f1 + f2);
+}
+
+inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1,
+ const TropicalWeightTpl<float> &w2) {
+ return Times<float>(w1, w2);
+}
+
+inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1,
+ const TropicalWeightTpl<double> &w2) {
+ return Times<double>(w1, w2);
+}
+
+template <class T>
+inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1,
+ const TropicalWeightTpl<T> &w2,
+ DivideType typ = DIVIDE_ANY) {
+ if (!w1.Member() || !w2.Member())
+ return TropicalWeightTpl<T>::NoWeight();
+ T f1 = w1.Value(), f2 = w2.Value();
+ if (f2 == FloatLimits<T>::kPosInfinity)
+ return FloatLimits<T>::kNumberBad;
+ else if (f1 == FloatLimits<T>::kPosInfinity)
+ return FloatLimits<T>::kPosInfinity;
+ else
+ return TropicalWeightTpl<T>(f1 - f2);
+}
+
+inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1,
+ const TropicalWeightTpl<float> &w2,
+ DivideType typ = DIVIDE_ANY) {
+ return Divide<float>(w1, w2, typ);
+}
+
+inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1,
+ const TropicalWeightTpl<double> &w2,
+ DivideType typ = DIVIDE_ANY) {
+ return Divide<double>(w1, w2, typ);
+}
+
+
+// Log semiring: (log(e^-x + e^y), +, inf, 0)
+template <class T>
+class LogWeightTpl : public FloatWeightTpl<T> {
+ public:
+ using FloatWeightTpl<T>::Value;
+
+ typedef LogWeightTpl ReverseWeight;
+
+ LogWeightTpl() : FloatWeightTpl<T>() {}
+
+ LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
+
+ LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
+
+ static const LogWeightTpl<T> Zero() {
+ return LogWeightTpl<T>(FloatLimits<T>::kPosInfinity);
+ }
+
+ static const LogWeightTpl<T> One() {
+ return LogWeightTpl<T>(0.0F);
+ }
+
+ static const LogWeightTpl<T> NoWeight() {
+ return LogWeightTpl<T>(FloatLimits<T>::kNumberBad); }
+
+ static const string &Type() {
+ static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString();
+ return type;
+ }
+
+ bool Member() const {
+ // First part fails for IEEE NaN
+ return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity;
+ }
+
+ LogWeightTpl<T> Quantize(float delta = kDelta) const {
+ if (Value() == FloatLimits<T>::kNegInfinity ||
+ Value() == FloatLimits<T>::kPosInfinity ||
+ Value() != Value())
+ return *this;
+ else
+ return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
+ }
+
+ LogWeightTpl<T> Reverse() const { return *this; }
+
+ static uint64 Properties() {
+ return kLeftSemiring | kRightSemiring | kCommutative;
+ }
+};
+
+// Single-precision log weight
+typedef LogWeightTpl<float> LogWeight;
+// Double-precision log weight
+typedef LogWeightTpl<double> Log64Weight;
+
+template <class T>
+inline T LogExp(T x) { return log(1.0F + exp(-x)); }
+
+template <class T>
+inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1,
+ const LogWeightTpl<T> &w2) {
+ T f1 = w1.Value(), f2 = w2.Value();
+ if (f1 == FloatLimits<T>::kPosInfinity)
+ return w2;
+ else if (f2 == FloatLimits<T>::kPosInfinity)
+ return w1;
+ else if (f1 > f2)
+ return LogWeightTpl<T>(f2 - LogExp(f1 - f2));
+ else
+ return LogWeightTpl<T>(f1 - LogExp(f2 - f1));
+}
+
+inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1,
+ const LogWeightTpl<float> &w2) {
+ return Plus<float>(w1, w2);
+}
+
+inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1,
+ const LogWeightTpl<double> &w2) {
+ return Plus<double>(w1, w2);
+}
+
+template <class T>
+inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1,
+ const LogWeightTpl<T> &w2) {
+ if (!w1.Member() || !w2.Member())
+ return LogWeightTpl<T>::NoWeight();
+ T f1 = w1.Value(), f2 = w2.Value();
+ if (f1 == FloatLimits<T>::kPosInfinity)
+ return w1;
+ else if (f2 == FloatLimits<T>::kPosInfinity)
+ return w2;
+ else
+ return LogWeightTpl<T>(f1 + f2);
+}
+
+inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1,
+ const LogWeightTpl<float> &w2) {
+ return Times<float>(w1, w2);
+}
+
+inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1,
+ const LogWeightTpl<double> &w2) {
+ return Times<double>(w1, w2);
+}
+
+template <class T>
+inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1,
+ const LogWeightTpl<T> &w2,
+ DivideType typ = DIVIDE_ANY) {
+ if (!w1.Member() || !w2.Member())
+ return LogWeightTpl<T>::NoWeight();
+ T f1 = w1.Value(), f2 = w2.Value();
+ if (f2 == FloatLimits<T>::kPosInfinity)
+ return FloatLimits<T>::kNumberBad;
+ else if (f1 == FloatLimits<T>::kPosInfinity)
+ return FloatLimits<T>::kPosInfinity;
+ else
+ return LogWeightTpl<T>(f1 - f2);
+}
+
+inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1,
+ const LogWeightTpl<float> &w2,
+ DivideType typ = DIVIDE_ANY) {
+ return Divide<float>(w1, w2, typ);
+}
+
+inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1,
+ const LogWeightTpl<double> &w2,
+ DivideType typ = DIVIDE_ANY) {
+ return Divide<double>(w1, w2, typ);
+}
+
+// MinMax semiring: (min, max, inf, -inf)
+template <class T>
+class MinMaxWeightTpl : public FloatWeightTpl<T> {
+ public:
+ using FloatWeightTpl<T>::Value;
+
+ typedef MinMaxWeightTpl<T> ReverseWeight;
+
+ MinMaxWeightTpl() : FloatWeightTpl<T>() {}
+
+ MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {}
+
+ MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
+
+ static const MinMaxWeightTpl<T> Zero() {
+ return MinMaxWeightTpl<T>(FloatLimits<T>::kPosInfinity);
+ }
+
+ static const MinMaxWeightTpl<T> One() {
+ return MinMaxWeightTpl<T>(FloatLimits<T>::kNegInfinity);
+ }
+
+ static const MinMaxWeightTpl<T> NoWeight() {
+ return MinMaxWeightTpl<T>(FloatLimits<T>::kNumberBad); }
+
+ static const string &Type() {
+ static const string type = "minmax" +
+ FloatWeightTpl<T>::GetPrecisionString();
+ return type;
+ }
+
+ bool Member() const {
+ // Fails for IEEE NaN
+ return Value() == Value();
+ }
+
+ MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
+ // If one of infinities, or a NaN
+ if (Value() == FloatLimits<T>::kNegInfinity ||
+ Value() == FloatLimits<T>::kPosInfinity ||
+ Value() != Value())
+ return *this;
+ else
+ return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
+ }
+
+ MinMaxWeightTpl<T> Reverse() const { return *this; }
+
+ static uint64 Properties() {
+ return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath;
+ }
+};
+
+// Single-precision min-max weight
+typedef MinMaxWeightTpl<float> MinMaxWeight;
+
+// Min
+template <class T>
+inline MinMaxWeightTpl<T> Plus(
+ const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
+ if (!w1.Member() || !w2.Member())
+ return MinMaxWeightTpl<T>::NoWeight();
+ return w1.Value() < w2.Value() ? w1 : w2;
+}
+
+inline MinMaxWeightTpl<float> Plus(
+ const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
+ return Plus<float>(w1, w2);
+}
+
+inline MinMaxWeightTpl<double> Plus(
+ const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
+ return Plus<double>(w1, w2);
+}
+
+// Max
+template <class T>
+inline MinMaxWeightTpl<T> Times(
+ const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
+ if (!w1.Member() || !w2.Member())
+ return MinMaxWeightTpl<T>::NoWeight();
+ return w1.Value() >= w2.Value() ? w1 : w2;
+}
+
+inline MinMaxWeightTpl<float> Times(
+ const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
+ return Times<float>(w1, w2);
+}
+
+inline MinMaxWeightTpl<double> Times(
+ const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
+ return Times<double>(w1, w2);
+}
+
+// Defined only for special cases
+template <class T>
+inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1,
+ const MinMaxWeightTpl<T> &w2,
+ DivideType typ = DIVIDE_ANY) {
+ if (!w1.Member() || !w2.Member())
+ return MinMaxWeightTpl<T>::NoWeight();
+ // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2
+ return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::kNumberBad;
+}
+
+inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1,
+ const MinMaxWeightTpl<float> &w2,
+ DivideType typ = DIVIDE_ANY) {
+ return Divide<float>(w1, w2, typ);
+}
+
+inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1,
+ const MinMaxWeightTpl<double> &w2,
+ DivideType typ = DIVIDE_ANY) {
+ return Divide<double>(w1, w2, typ);
+}
+
+//
+// WEIGHT CONVERTER SPECIALIZATIONS.
+//
+
+// Convert to tropical
+template <>
+struct WeightConvert<LogWeight, TropicalWeight> {
+ TropicalWeight operator()(LogWeight w) const { return w.Value(); }
+};
+
+template <>
+struct WeightConvert<Log64Weight, TropicalWeight> {
+ TropicalWeight operator()(Log64Weight w) const { return w.Value(); }
+};
+
+// Convert to log
+template <>
+struct WeightConvert<TropicalWeight, LogWeight> {
+ LogWeight operator()(TropicalWeight w) const { return w.Value(); }
+};
+
+template <>
+struct WeightConvert<Log64Weight, LogWeight> {
+ LogWeight operator()(Log64Weight w) const { return w.Value(); }
+};
+
+// Convert to log64
+template <>
+struct WeightConvert<TropicalWeight, Log64Weight> {
+ Log64Weight operator()(TropicalWeight w) const { return w.Value(); }
+};
+
+template <>
+struct WeightConvert<LogWeight, Log64Weight> {
+ Log64Weight operator()(LogWeight w) const { return w.Value(); }
+};
+
+} // namespace fst
+
+#endif // FST_LIB_FLOAT_WEIGHT_H__