aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/push.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/push.h')
-rw-r--r--src/include/fst/push.h175
1 files changed, 175 insertions, 0 deletions
diff --git a/src/include/fst/push.h b/src/include/fst/push.h
new file mode 100644
index 0000000..1f7a8fa
--- /dev/null
+++ b/src/include/fst/push.h
@@ -0,0 +1,175 @@
+// push.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: allauzen@google.com (Cyril Allauzen)
+//
+// \file
+// Class to reweight/push an FST.
+
+#ifndef FST_LIB_PUSH_H__
+#define FST_LIB_PUSH_H__
+
+#include <vector>
+using std::vector;
+
+#include <fst/factor-weight.h>
+#include <fst/fst.h>
+#include <fst/arc-map.h>
+#include <fst/reweight.h>
+#include <fst/shortest-distance.h>
+
+
+namespace fst {
+
+// Private helper functions for Push
+namespace internal {
+
+// Compute the total weight (sum of the weights of all accepting paths) from
+// the output of ShortestDistance. 'distance' is the shortest distance from the
+// initial state when 'reverse == false' and to the final states when
+// 'reverse == true'.
+template <class Arc>
+typename Arc::Weight ComputeTotalWeight(
+ const Fst<Arc> &fst,
+ const vector<typename Arc::Weight> &distance,
+ bool reverse) {
+ if (reverse)
+ return fst.Start() < distance.size() ?
+ distance[fst.Start()] : Arc::Weight::Zero();
+
+ typename Arc::Weight sum = Arc::Weight::Zero();
+ for (typename Arc::StateId s = 0; s < distance.size(); ++s)
+ sum = Plus(sum, Times(distance[s], fst.Final(s)));
+ return sum;
+}
+
+// Divide the weight of every accepting path by 'w'. The weight 'w' is
+// divided at the final states if 'at_final == true' and at the
+// initial state otherwise.
+template <class Arc>
+void RemoveWeight(MutableFst<Arc> *fst, typename Arc::Weight w, bool at_final) {
+ if ((w == Arc::Weight::One()) || (w == Arc::Weight::Zero()))
+ return;
+
+ if (at_final) {
+ // Remove 'w' from the final states
+ for (StateIterator< MutableFst<Arc> > sit(*fst);
+ !sit.Done();
+ sit.Next())
+ fst->SetFinal(sit.Value(),
+ Divide(fst->Final(sit.Value()), w, DIVIDE_RIGHT));
+ } else { // at_final == false
+ // Remove 'w' from the initial state
+ typename Arc::StateId start = fst->Start();
+ for (MutableArcIterator<MutableFst<Arc> > ait(fst, start);
+ !ait.Done();
+ ait.Next()) {
+ Arc arc = ait.Value();
+ arc.weight = Divide(arc.weight, w, DIVIDE_LEFT);
+ ait.SetValue(arc);
+ }
+ fst->SetFinal(start, Divide(fst->Final(start), w, DIVIDE_LEFT));
+ }
+}
+} // namespace internal
+
+// Pushes the weights in FST in the direction defined by TYPE. If
+// pushing towards the initial state, the sum of the weight of the
+// outgoing transitions and final weight at a non-initial state is
+// equal to One() in the resulting machine. If pushing towards the
+// final state, the same property holds on the reverse machine.
+//
+// Weight needs to be left distributive when pushing towards the
+// initial state and right distributive when pushing towards the final
+// states.
+template <class Arc>
+void Push(MutableFst<Arc> *fst,
+ ReweightType type,
+ float delta = kDelta,
+ bool remove_total_weight = false) {
+ vector<typename Arc::Weight> distance;
+ ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL, delta);
+ typename Arc::Weight total_weight = Arc::Weight::One();
+ if (remove_total_weight)
+ total_weight = internal::ComputeTotalWeight(*fst, distance,
+ type == REWEIGHT_TO_INITIAL);
+ Reweight(fst, distance, type);
+ if (remove_total_weight)
+ internal::RemoveWeight(fst, total_weight, type == REWEIGHT_TO_FINAL);
+}
+
+const uint32 kPushWeights = 0x0001;
+const uint32 kPushLabels = 0x0002;
+const uint32 kPushRemoveTotalWeight = 0x0004;
+const uint32 kPushRemoveCommonAffix = 0x0008;
+
+// OFST obtained from IFST by pushing weights and/or labels according
+// to PTYPE in the direction defined by RTYPE. Weight needs to be
+// left distributive when pushing weights towards the initial state
+// and right distributive when pushing weights towards the final
+// states.
+template <class Arc, ReweightType rtype>
+void Push(const Fst<Arc> &ifst,
+ MutableFst<Arc> *ofst,
+ uint32 ptype,
+ float delta = kDelta) {
+
+ if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) {
+ *ofst = ifst;
+ Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight);
+ } else if (ptype & kPushLabels) {
+ const StringType stype = rtype == REWEIGHT_TO_INITIAL
+ ? STRING_LEFT
+ : STRING_RIGHT;
+ vector<typename GallicArc<Arc, stype>::Weight> gdistance;
+ VectorFst<GallicArc<Arc, stype> > gfst;
+ ArcMap(ifst, &gfst, ToGallicMapper<Arc, stype>());
+ if (ptype & kPushWeights ) {
+ ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
+ } else {
+ ArcMapFst<Arc, Arc, RmWeightMapper<Arc> >
+ uwfst(ifst, RmWeightMapper<Arc>());
+ ArcMapFst<Arc, GallicArc<Arc, stype>, ToGallicMapper<Arc, stype> >
+ guwfst(uwfst, ToGallicMapper<Arc, stype>());
+ ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
+ }
+ typename GallicArc<Arc, stype>::Weight total_weight =
+ GallicArc<Arc, stype>::Weight::One();
+ if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) {
+ total_weight = internal::ComputeTotalWeight(
+ gfst, gdistance, rtype == REWEIGHT_TO_INITIAL);
+ total_weight = typename GallicArc<Arc, stype>::Weight(
+ ptype & kPushRemoveCommonAffix ? total_weight.Value1()
+ : StringWeight<typename Arc::Label, stype>::One(),
+ ptype & kPushRemoveTotalWeight ? total_weight.Value2()
+ : Arc::Weight::One());
+ }
+ Reweight(&gfst, gdistance, rtype);
+ if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix))
+ internal::RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL);
+ FactorWeightFst< GallicArc<Arc, stype>, GallicFactor<typename Arc::Label,
+ typename Arc::Weight, stype> > fwfst(gfst);
+ ArcMap(fwfst, ofst, FromGallicMapper<Arc, stype>());
+ ofst->SetOutputSymbols(ifst.OutputSymbols());
+ } else {
+ LOG(WARNING) << "Push: pushing type is set to 0: "
+ << "pushing neither labels nor weights.";
+ *ofst = ifst;
+ }
+}
+
+} // namespace fst
+
+#endif /* FST_LIB_PUSH_H_ */