aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/equal.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/equal.h')
-rw-r--r--src/include/fst/equal.h124
1 files changed, 124 insertions, 0 deletions
diff --git a/src/include/fst/equal.h b/src/include/fst/equal.h
new file mode 100644
index 0000000..33be198
--- /dev/null
+++ b/src/include/fst/equal.h
@@ -0,0 +1,124 @@
+// test.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Function to test equality of two Fsts.
+
+#ifndef FST_LIB_EQUAL_H__
+#define FST_LIB_EQUAL_H__
+
+#include <fst/fst.h>
+
+
+namespace fst {
+
+// Tests if two Fsts have the same states and arcs in the same order.
+template<class Arc>
+bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2, float delta = kDelta) {
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Weight Weight;
+
+ if (fst1.Start() != fst2.Start()) {
+ VLOG(1) << "Equal: mismatched start states";
+ return false;
+ }
+
+ StateIterator< Fst<Arc> > siter1(fst1);
+ StateIterator< Fst<Arc> > siter2(fst2);
+
+ while (!siter1.Done() || !siter2.Done()) {
+ if (siter1.Done() || siter2.Done()) {
+ VLOG(1) << "Equal: mismatched # of states";
+ return false;
+ }
+ StateId s1 = siter1.Value();
+ StateId s2 = siter2.Value();
+ if (s1 != s2) {
+ VLOG(1) << "Equal: mismatched states:"
+ << ", state1 = " << s1
+ << ", state2 = " << s2;
+ return false;
+ }
+ Weight final1 = fst1.Final(s1);
+ Weight final2 = fst2.Final(s2);
+ if (!ApproxEqual(final1, final2, delta)) {
+ VLOG(1) << "Equal: mismatched final weights:"
+ << " state = " << s1
+ << ", final1 = " << final1
+ << ", final2 = " << final2;
+ return false;
+ }
+ ArcIterator< Fst<Arc> > aiter1(fst1, s1);
+ ArcIterator< Fst<Arc> > aiter2(fst2, s2);
+ for (size_t a = 0; !aiter1.Done() || !aiter2.Done(); ++a) {
+ if (aiter1.Done() || aiter2.Done()) {
+ VLOG(1) << "Equal: mismatched # of arcs"
+ << " state = " << s1;
+ return false;
+ }
+ Arc arc1 = aiter1.Value();
+ Arc arc2 = aiter2.Value();
+ if (arc1.ilabel != arc2.ilabel) {
+ VLOG(1) << "Equal: mismatched arc input labels:"
+ << " state = " << s1
+ << ", arc = " << a
+ << ", ilabel1 = " << arc1.ilabel
+ << ", ilabel2 = " << arc2.ilabel;
+ return false;
+ } else if (arc1.olabel != arc2.olabel) {
+ VLOG(1) << "Equal: mismatched arc output labels:"
+ << " state = " << s1
+ << ", arc = " << a
+ << ", olabel1 = " << arc1.olabel
+ << ", olabel2 = " << arc2.olabel;
+ return false;
+ } else if (!ApproxEqual(arc1.weight, arc2.weight, delta)) {
+ VLOG(1) << "Equal: mismatched arc weights:"
+ << " state = " << s1
+ << ", arc = " << a
+ << ", weight1 = " << arc1.weight
+ << ", weight2 = " << arc2.weight;
+ return false;
+ } else if (arc1.nextstate != arc2.nextstate) {
+ VLOG(1) << "Equal: mismatched input label:"
+ << " state = " << s1
+ << ", arc = " << a
+ << ", nextstate1 = " << arc1.nextstate
+ << ", nextstate2 = " << arc2.nextstate;
+ return false;
+ }
+ aiter1.Next();
+ aiter2.Next();
+
+ }
+ // Sanity checks: should never fail
+ if (fst1.NumArcs(s1) != fst2.NumArcs(s2) ||
+ fst1.NumInputEpsilons(s1) != fst2.NumInputEpsilons(s2) ||
+ fst1.NumOutputEpsilons(s1) != fst2.NumOutputEpsilons(s2)) {
+ FSTERROR() << "Equal: inconsistent arc/epsilon counts";
+ }
+
+ siter1.Next();
+ siter2.Next();
+ }
+ return true;
+}
+
+} // namespace fst
+
+
+#endif // FST_LIB_EQUAL_H__