aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/extensions/pdt/replace.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/extensions/pdt/replace.h')
-rw-r--r--src/include/fst/extensions/pdt/replace.h192
1 files changed, 192 insertions, 0 deletions
diff --git a/src/include/fst/extensions/pdt/replace.h b/src/include/fst/extensions/pdt/replace.h
new file mode 100644
index 0000000..a85d0fe
--- /dev/null
+++ b/src/include/fst/extensions/pdt/replace.h
@@ -0,0 +1,192 @@
+// replace.h
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: riley@google.com (Michael Riley)
+//
+// \file
+// Recursively replace Fst arcs with other Fst(s) returning a PDT.
+
+#ifndef FST_EXTENSIONS_PDT_REPLACE_H__
+#define FST_EXTENSIONS_PDT_REPLACE_H__
+
+#include <fst/replace.h>
+
+namespace fst {
+
+// Hash to paren IDs
+template <typename S>
+struct ReplaceParenHash {
+ size_t operator()(const pair<size_t, S> &p) const {
+ return p.first + p.second * kPrime;
+ }
+ private:
+ static const size_t kPrime = 7853;
+};
+
+template <typename S> const size_t ReplaceParenHash<S>::kPrime;
+
+// Builds a pushdown transducer (PDT) from an RTN specification
+// identical to that in fst/lib/replace.h. The result is a PDT
+// encoded as the FST 'ofst' where some transitions are labeled with
+// open or close parentheses. To be interpreted as a PDT, the parens
+// must balance on a path (see PdtExpand()). The open/close
+// parenthesis label pairs are returned in 'parens'.
+template <class Arc>
+void Replace(const vector<pair<typename Arc::Label,
+ const Fst<Arc>* > >& ifst_array,
+ MutableFst<Arc> *ofst,
+ vector<pair<typename Arc::Label,
+ typename Arc::Label> > *parens,
+ typename Arc::Label root) {
+ typedef typename Arc::Label Label;
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Weight Weight;
+
+ ofst->DeleteStates();
+ parens->clear();
+
+ unordered_map<Label, size_t> label2id;
+ for (size_t i = 0; i < ifst_array.size(); ++i)
+ label2id[ifst_array[i].first] = i;
+
+ Label max_label = kNoLabel;
+
+ deque<size_t> non_term_queue; // Queue of non-terminals to replace
+ unordered_set<Label> non_term_set; // Set of non-terminals to replace
+ non_term_queue.push_back(root);
+ non_term_set.insert(root);
+
+ // PDT state corr. to ith replace FST start state.
+ vector<StateId> fst_start(ifst_array.size(), kNoLabel);
+ // PDT state, weight pairs corr. to ith replace FST final state & weights.
+ vector< vector<pair<StateId, Weight> > > fst_final(ifst_array.size());
+
+ // Builds single Fst combining all referenced input Fsts. Leaves in the
+ // non-termnals for now. Tabulate the PDT states that correspond to
+ // the start and final states of the input Fsts.
+ for (StateId soff = 0; !non_term_queue.empty(); soff = ofst->NumStates()) {
+ Label label = non_term_queue.front();
+ non_term_queue.pop_front();
+ size_t fst_id = label2id[label];
+
+ const Fst<Arc> *ifst = ifst_array[fst_id].second;
+ for (StateIterator< Fst<Arc> > siter(*ifst);
+ !siter.Done(); siter.Next()) {
+ StateId is = siter.Value();
+ StateId os = ofst->AddState();
+ if (is == ifst->Start()) {
+ fst_start[fst_id] = os;
+ if (label == root)
+ ofst->SetStart(os);
+ }
+ if (ifst->Final(is) != Weight::Zero()) {
+ if (label == root)
+ ofst->SetFinal(os, ifst->Final(is));
+ fst_final[fst_id].push_back(make_pair(os, ifst->Final(is)));
+ }
+ for (ArcIterator< Fst<Arc> > aiter(*ifst, is);
+ !aiter.Done(); aiter.Next()) {
+ Arc arc = aiter.Value();
+ if (max_label == kNoLabel || arc.olabel > max_label)
+ max_label = arc.olabel;
+ typename unordered_map<Label, size_t>::const_iterator it =
+ label2id.find(arc.olabel);
+ if (it != label2id.end()) {
+ size_t nfst_id = it->second;
+ if (ifst_array[nfst_id].second->Start() == -1)
+ continue;
+ if (non_term_set.count(arc.olabel) == 0) {
+ non_term_queue.push_back(arc.olabel);
+ non_term_set.insert(arc.olabel);
+ }
+ }
+ arc.nextstate += soff;
+ ofst->AddArc(os, arc);
+ }
+ }
+ }
+
+ // Changes each non-terminal transition to an open parenthesis
+ // transition redirected to the PDT state that corresponds to the
+ // start state of the input FST for the non-terminal. Adds close parenthesis
+ // transitions from the PDT states corr. to the final states of the
+ // input FST for the non-terminal to the former destination state of the
+ // non-terminal transition.
+
+ typedef MutableArcIterator< MutableFst<Arc> > MIter;
+ typedef unordered_map<pair<size_t, StateId >, size_t,
+ ReplaceParenHash<StateId> > ParenMap;
+
+ // Parenthesis pair ID per fst, state pair.
+ ParenMap paren_map;
+ // # of parenthesis pairs per fst.
+ vector<size_t> nparens(ifst_array.size(), 0);
+ // Initial open parenthesis label
+ Label first_paren = max_label + 1;
+
+ for (StateIterator< Fst<Arc> > siter(*ofst);
+ !siter.Done(); siter.Next()) {
+ StateId os = siter.Value();
+ MIter *aiter = new MIter(ofst, os);
+ for (size_t n = 0; !aiter->Done(); aiter->Next(), ++n) {
+ Arc arc = aiter->Value();
+ typename unordered_map<Label, size_t>::const_iterator lit =
+ label2id.find(arc.olabel);
+ if (lit != label2id.end()) {
+ size_t nfst_id = lit->second;
+
+ // Get parentheses. Ensures distinct parenthesis pair per
+ // non-terminal and destination state but otherwise reuses them.
+ Label open_paren = kNoLabel, close_paren = kNoLabel;
+ pair<size_t, StateId> paren_key(nfst_id, arc.nextstate);
+ typename ParenMap::const_iterator pit = paren_map.find(paren_key);
+ if (pit != paren_map.end()) {
+ size_t paren_id = pit->second;
+ open_paren = (*parens)[paren_id].first;
+ close_paren = (*parens)[paren_id].second;
+ } else {
+ size_t paren_id = nparens[nfst_id]++;
+ open_paren = first_paren + 2 * paren_id;
+ close_paren = open_paren + 1;
+ paren_map[paren_key] = paren_id;
+ if (paren_id >= parens->size())
+ parens->push_back(make_pair(open_paren, close_paren));
+ }
+
+ // Sets open parenthesis.
+ Arc sarc(open_paren, open_paren, arc.weight, fst_start[nfst_id]);
+ aiter->SetValue(sarc);
+
+ // Adds close parentheses.
+ for (size_t i = 0; i < fst_final[nfst_id].size(); ++i) {
+ pair<StateId, Weight> &p = fst_final[nfst_id][i];
+ Arc farc(close_paren, close_paren, p.second, arc.nextstate);
+
+ ofst->AddArc(p.first, farc);
+ if (os == p.first) { // Invalidated iterator
+ delete aiter;
+ aiter = new MIter(ofst, os);
+ aiter->Seek(n);
+ }
+ }
+ }
+ }
+ delete aiter;
+ }
+}
+
+} // namespace fst
+
+#endif // FST_EXTENSIONS_PDT_REPLACE_H__