aboutsummaryrefslogtreecommitdiff
path: root/src/test/weight-tester.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/test/weight-tester.h')
-rw-r--r--src/test/weight-tester.h225
1 files changed, 225 insertions, 0 deletions
diff --git a/src/test/weight-tester.h b/src/test/weight-tester.h
new file mode 100644
index 0000000..751e7d6
--- /dev/null
+++ b/src/test/weight-tester.h
@@ -0,0 +1,225 @@
+// weight-tester.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Utility class for regression testing of Fst weights.
+
+#ifndef FST_TEST_WEIGHT_TESTER_H_
+#define FST_TEST_WEIGHT_TESTER_H_
+
+#include <iostream>
+#include <sstream>
+
+#include <fst/random-weight.h>
+
+namespace fst {
+
+// This class tests a variety of identities and properties that must
+// hold for the Weight class to be well-defined. It calls function object
+// WEIGHT_GENERATOR to select weights that are used in the tests.
+template<class Weight, class WeightGenerator>
+class WeightTester {
+ public:
+ WeightTester(WeightGenerator generator) : weight_generator_(generator) {}
+
+ void Test(int iterations, bool test_division = true) {
+ for (int i = 0; i < iterations; ++i) {
+ // Selects the test weights.
+ Weight w1 = weight_generator_();
+ Weight w2 = weight_generator_();
+ Weight w3 = weight_generator_();
+
+ VLOG(1) << "weight type = " << Weight::Type();
+ VLOG(1) << "w1 = " << w1;
+ VLOG(1) << "w2 = " << w2;
+ VLOG(1) << "w3 = " << w3;
+
+ TestSemiring(w1, w2, w3);
+ if (test_division)
+ TestDivision(w1, w2);
+ TestReverse(w1, w2);
+ TestEquality(w1, w2, w3);
+ TestIO(w1);
+ TestCopy(w1);
+ }
+ }
+
+ private:
+ // Note in the tests below we use ApproxEqual rather than == and add
+ // kDelta to inequalities where the weights might be inexact.
+
+ // Tests (Plus, Times, Zero, One) defines a commutative semiring.
+ void TestSemiring(Weight w1, Weight w2, Weight w3) {
+ // Checks that the operations are closed.
+ CHECK(Plus(w1, w2).Member());
+ CHECK(Times(w1, w2).Member());
+
+ // Checks that the operations are associative.
+ CHECK(ApproxEqual(Plus(w1, Plus(w2, w3)), Plus(Plus(w1, w2), w3)));
+ CHECK(ApproxEqual(Times(w1, Times(w2, w3)), Times(Times(w1, w2), w3)));
+
+ // Checks the identity elements.
+ CHECK(Plus(w1, Weight::Zero()) == w1);
+ CHECK(Plus(Weight::Zero(), w1) == w1);
+ CHECK(Times(w1, Weight::One()) == w1);
+ CHECK(Times(Weight::One(), w1) == w1);
+
+ // Check the no weight element.
+ CHECK(!Weight::NoWeight().Member());
+ CHECK(!Plus(w1, Weight::NoWeight()).Member());
+ CHECK(!Plus(Weight::NoWeight(), w1).Member());
+ CHECK(!Times(w1, Weight::NoWeight()).Member());
+ CHECK(!Times(Weight::NoWeight(), w1).Member());
+
+ // Checks that the operations commute.
+ CHECK(ApproxEqual(Plus(w1, w2), Plus(w2, w1)));
+ if (Weight::Properties() & kCommutative)
+ CHECK(ApproxEqual(Times(w1, w2), Times(w2, w1)));
+
+ // Checks Zero() is the annihilator.
+ CHECK(Times(w1, Weight::Zero()) == Weight::Zero());
+ CHECK(Times(Weight::Zero(), w1) == Weight::Zero());
+
+ // Check Power(w, 0) is Weight::One()
+ CHECK(Power(w1, 0) == Weight::One());
+
+ // Check Power(w, 1) is w
+ CHECK(Power(w1, 1) == w1);
+
+ // Check Power(w, 3) is Times(w, Times(w, w))
+ CHECK(Power(w1, 3) == Times(w1, Times(w1, w1)));
+
+ // Checks distributivity.
+ if (Weight::Properties() & kLeftSemiring)
+ CHECK(ApproxEqual(Times(w1, Plus(w2, w3)),
+ Plus(Times(w1, w2), Times(w1, w3))));
+ if (Weight::Properties() & kRightSemiring)
+ CHECK(ApproxEqual(Times(Plus(w1, w2), w3),
+ Plus(Times(w1, w3), Times(w2, w3))));
+
+ if (Weight::Properties() & kIdempotent)
+ CHECK(Plus(w1, w1) == w1);
+
+ if (Weight::Properties() & kPath)
+ CHECK(Plus(w1, w2) == w1 || Plus(w1, w2) == w2);
+
+ // Ensure weights form a left or right semiring.
+ CHECK(Weight::Properties() & (kLeftSemiring | kRightSemiring));
+
+ // Check when Times() is commutative that it is marked as a semiring.
+ if (Weight::Properties() & kCommutative)
+ CHECK(Weight::Properties() & kSemiring);
+ }
+
+ // Tests division operation.
+ void TestDivision(Weight w1, Weight w2) {
+ Weight p = Times(w1, w2);
+
+ if (Weight::Properties() & kLeftSemiring) {
+ Weight d = Divide(p, w1, DIVIDE_LEFT);
+ if (d.Member())
+ CHECK(ApproxEqual(p, Times(w1, d)));
+ CHECK(!Divide(w1, Weight::NoWeight(), DIVIDE_LEFT).Member());
+ CHECK(!Divide(Weight::NoWeight(), w1, DIVIDE_LEFT).Member());
+ }
+
+ if (Weight::Properties() & kRightSemiring) {
+ Weight d = Divide(p, w2, DIVIDE_RIGHT);
+ if (d.Member())
+ CHECK(ApproxEqual(p, Times(d, w2)));
+ CHECK(!Divide(w1, Weight::NoWeight(), DIVIDE_RIGHT).Member());
+ CHECK(!Divide(Weight::NoWeight(), w1, DIVIDE_RIGHT).Member());
+ }
+
+ if (Weight::Properties() & kCommutative) {
+ Weight d = Divide(p, w1, DIVIDE_RIGHT);
+ if (d.Member())
+ CHECK(ApproxEqual(p, Times(d, w1)));
+ }
+ }
+
+ // Tests reverse operation.
+ void TestReverse(Weight w1, Weight w2) {
+ typedef typename Weight::ReverseWeight ReverseWeight;
+
+ ReverseWeight rw1 = w1.Reverse();
+ ReverseWeight rw2 = w2.Reverse();
+
+ CHECK(rw1.Reverse() == w1);
+ CHECK(Plus(w1, w2).Reverse() == Plus(rw1, rw2));
+ CHECK(Times(w1, w2).Reverse() == Times(rw2, rw1));
+ }
+
+ // Tests == is an equivalence relation.
+ void TestEquality(Weight w1, Weight w2, Weight w3) {
+ // Checks reflexivity.
+ CHECK(w1 == w1);
+
+ // Checks symmetry.
+ CHECK((w1 == w2) == (w2 == w1));
+
+ // Checks transitivity.
+ if (w1 == w2 && w2 == w3)
+ CHECK(w1 == w3);
+ }
+
+ // Tests binary serialization and textual I/O.
+ void TestIO(Weight w) {
+ // Tests binary I/O
+ {
+ ostringstream os;
+ w.Write(os);
+ os.flush();
+ istringstream is(os.str());
+ Weight v;
+ v.Read(is);
+ CHECK_EQ(w, v);
+ }
+
+ // Tests textual I/O.
+ {
+ ostringstream os;
+ os << w;
+ istringstream is(os.str());
+ Weight v(Weight::One());
+ is >> v;
+ CHECK(ApproxEqual(w, v));
+ }
+ }
+
+ // Tests copy constructor and assignment operator
+ void TestCopy(Weight w) {
+ Weight x = w;
+ CHECK(w == x);
+
+ x = Weight(w);
+ CHECK(w == x);
+
+ x.operator=(x);
+ CHECK(w == x);
+
+ }
+
+ // Generates weights used in testing.
+ WeightGenerator weight_generator_;
+
+ DISALLOW_COPY_AND_ASSIGN(WeightTester);
+};
+
+} // namespace fst
+
+#endif // FST_TEST_WEIGHT_TESTER_H_