aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/script/weight-class.h
diff options
context:
space:
mode:
authorIan Hodson <idh@google.com>2012-05-30 21:27:06 +0100
committerIan Hodson <idh@google.com>2012-05-30 22:47:36 +0100
commitf4c12fce1ee58e670f9c3fce46c40296ba9ee8a2 (patch)
treeb131ed907f9b2d5af09c0983b651e9e69bc6aab9 /src/include/fst/script/weight-class.h
parenta92766f0a6ba4fac46cd6fd3856ef20c3b204f0d (diff)
downloadopenfst-f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2.tar.gz
Moved from GoogleTTS Change-Id: I6bc6bdadaa53bd0f810b88443339f6d899502cc8
Diffstat (limited to 'src/include/fst/script/weight-class.h')
-rw-r--r--src/include/fst/script/weight-class.h216
1 files changed, 216 insertions, 0 deletions
diff --git a/src/include/fst/script/weight-class.h b/src/include/fst/script/weight-class.h
new file mode 100644
index 0000000..5a4890f
--- /dev/null
+++ b/src/include/fst/script/weight-class.h
@@ -0,0 +1,216 @@
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: jpr@google.com (Jake Ratkiewicz)
+
+// Represents a generic weight in an FST -- that is, represents a specific
+// type of weight underneath while hiding that type from a client.
+
+
+#ifndef FST_SCRIPT_WEIGHT_CLASS_H_
+#define FST_SCRIPT_WEIGHT_CLASS_H_
+
+#include <string>
+
+#include <fst/generic-register.h>
+#include <fst/util.h>
+
+namespace fst {
+namespace script {
+
+class WeightImplBase {
+ public:
+ virtual WeightImplBase *Copy() const = 0;
+ virtual void Print(ostream *o) const = 0;
+ virtual const string &Type() const = 0;
+ virtual string to_string() const = 0;
+ virtual bool operator == (const WeightImplBase &other) const = 0;
+ virtual ~WeightImplBase() { }
+};
+
+template<class W>
+struct WeightClassImpl : public WeightImplBase {
+ W weight;
+
+ explicit WeightClassImpl(const W& weight) : weight(weight) { }
+
+ virtual WeightClassImpl<W> *Copy() const {
+ return new WeightClassImpl<W>(weight);
+ }
+
+ virtual const string &Type() const { return W::Type(); }
+
+ virtual void Print(ostream *o) const {
+ *o << weight;
+ }
+
+ virtual string to_string() const {
+ ostringstream s;
+ s << weight;
+ return s.str();
+ }
+
+ virtual bool operator == (const WeightImplBase &other) const {
+ if (Type() != other.Type()) {
+ return false;
+ } else {
+ const WeightClassImpl<W> *typed_other =
+ static_cast<const WeightClassImpl<W> *>(&other);
+
+ return typed_other->weight == weight;
+ }
+ }
+};
+
+
+class WeightClass {
+ public:
+ WeightClass() : element_type_(ZERO), impl_(0) { }
+
+ template<class W>
+ explicit WeightClass(const W& weight)
+ : element_type_(OTHER), impl_(new WeightClassImpl<W>(weight)) { }
+
+ WeightClass(const string &weight_type, const string &weight_str);
+
+ WeightClass(const WeightClass &other) :
+ element_type_(other.element_type_),
+ impl_(other.impl_ ? other.impl_->Copy() : 0) { }
+
+ WeightClass &operator = (const WeightClass &other) {
+ if (impl_) delete impl_;
+ impl_ = other.impl_ ? other.impl_->Copy() : 0;
+ element_type_ = other.element_type_;
+ return *this;
+ }
+
+ template<class W>
+ const W* GetWeight() const;
+
+ string to_string() const {
+ switch (element_type_) {
+ case ZERO:
+ return "ZERO";
+ case ONE:
+ return "ONE";
+ default:
+ case OTHER:
+ return impl_->to_string();
+ }
+ }
+
+ bool operator == (const WeightClass &other) const {
+ return element_type_ == other.element_type_ &&
+ ((impl_ && other.impl_ && (*impl_ == *other.impl_)) ||
+ (impl_ == 0 && other.impl_ == 0));
+ }
+
+ static const WeightClass &Zero() {
+ static WeightClass w(ZERO);
+
+ return w;
+ }
+
+ static const WeightClass &One() {
+ static WeightClass w(ONE);
+
+ return w;
+ }
+
+ ~WeightClass() { if (impl_) delete impl_; }
+ private:
+ enum ElementType { ZERO, ONE, OTHER };
+ ElementType element_type_;
+
+ WeightImplBase *impl_;
+
+ explicit WeightClass(ElementType et) : element_type_(et), impl_(0) { }
+
+ friend ostream &operator << (ostream &o, const WeightClass &c);
+};
+
+template<class W>
+const W* WeightClass::GetWeight() const {
+ // We need to store zero and one as statics, because the weight type
+ // W might return them as temporaries. We're returning a pointer,
+ // and it won't do to get the address of a temporary.
+ static const W zero = W::Zero();
+ static const W one = W::One();
+
+ if (element_type_ == ZERO) {
+ return &zero;
+ } else if (element_type_ == ONE) {
+ return &one;
+ } else {
+ if (W::Type() != impl_->Type()) {
+ return NULL;
+ } else {
+ WeightClassImpl<W> *typed_impl =
+ static_cast<WeightClassImpl<W> *>(impl_);
+ return &typed_impl->weight;
+ }
+ }
+}
+
+//
+// Registration for generic weight types.
+//
+
+typedef WeightImplBase* (*StrToWeightImplBaseT)(const string &str,
+ const string &src,
+ size_t nline);
+
+template<class W>
+WeightImplBase* StrToWeightImplBase(const string &str,
+ const string &src, size_t nline) {
+ return new WeightClassImpl<W>(StrToWeight<W>(str, src, nline));
+}
+
+// The following confuses swig, and doesn't need to be wrapped anyway.
+#ifndef SWIG
+ostream& operator << (ostream &o, const WeightClass &c);
+
+class WeightClassRegister : public GenericRegister<string,
+ StrToWeightImplBaseT,
+ WeightClassRegister> {
+ protected:
+ virtual string ConvertKeyToSoFilename(const string &key) const {
+ return key + ".so";
+ }
+};
+
+typedef GenericRegisterer<WeightClassRegister> WeightClassRegisterer;
+#endif
+
+// internal version, needs to be called by wrapper in order for
+// macro args to expand
+#define REGISTER_FST_WEIGHT__(Weight, line) \
+ static WeightClassRegisterer weight_registerer ## _ ## line( \
+ Weight::Type(), \
+ StrToWeightImplBase<Weight>)
+
+// This layer is where __FILE__ and __LINE__ are expanded
+#define REGISTER_FST_WEIGHT_EXPANDER(Weight, line) \
+ REGISTER_FST_WEIGHT__(Weight, line)
+
+//
+// Macro for registering new weight types. Clients call this.
+//
+#define REGISTER_FST_WEIGHT(Weight) \
+ REGISTER_FST_WEIGHT_EXPANDER(Weight, __LINE__)
+
+} // namespace script
+} // namespace fst
+
+#endif // FST_SCRIPT_WEIGHT_CLASS_H_