diff options
Diffstat (limited to 'src/include/fst/extensions/pdt/replace.h')
-rw-r--r-- | src/include/fst/extensions/pdt/replace.h | 192 |
1 files changed, 192 insertions, 0 deletions
diff --git a/src/include/fst/extensions/pdt/replace.h b/src/include/fst/extensions/pdt/replace.h new file mode 100644 index 0000000..a85d0fe --- /dev/null +++ b/src/include/fst/extensions/pdt/replace.h @@ -0,0 +1,192 @@ +// replace.h + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright 2005-2010 Google, Inc. +// Author: riley@google.com (Michael Riley) +// +// \file +// Recursively replace Fst arcs with other Fst(s) returning a PDT. + +#ifndef FST_EXTENSIONS_PDT_REPLACE_H__ +#define FST_EXTENSIONS_PDT_REPLACE_H__ + +#include <fst/replace.h> + +namespace fst { + +// Hash to paren IDs +template <typename S> +struct ReplaceParenHash { + size_t operator()(const pair<size_t, S> &p) const { + return p.first + p.second * kPrime; + } + private: + static const size_t kPrime = 7853; +}; + +template <typename S> const size_t ReplaceParenHash<S>::kPrime; + +// Builds a pushdown transducer (PDT) from an RTN specification +// identical to that in fst/lib/replace.h. The result is a PDT +// encoded as the FST 'ofst' where some transitions are labeled with +// open or close parentheses. To be interpreted as a PDT, the parens +// must balance on a path (see PdtExpand()). The open/close +// parenthesis label pairs are returned in 'parens'. +template <class Arc> +void Replace(const vector<pair<typename Arc::Label, + const Fst<Arc>* > >& ifst_array, + MutableFst<Arc> *ofst, + vector<pair<typename Arc::Label, + typename Arc::Label> > *parens, + typename Arc::Label root) { + typedef typename Arc::Label Label; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + ofst->DeleteStates(); + parens->clear(); + + unordered_map<Label, size_t> label2id; + for (size_t i = 0; i < ifst_array.size(); ++i) + label2id[ifst_array[i].first] = i; + + Label max_label = kNoLabel; + + deque<size_t> non_term_queue; // Queue of non-terminals to replace + unordered_set<Label> non_term_set; // Set of non-terminals to replace + non_term_queue.push_back(root); + non_term_set.insert(root); + + // PDT state corr. to ith replace FST start state. + vector<StateId> fst_start(ifst_array.size(), kNoLabel); + // PDT state, weight pairs corr. to ith replace FST final state & weights. + vector< vector<pair<StateId, Weight> > > fst_final(ifst_array.size()); + + // Builds single Fst combining all referenced input Fsts. Leaves in the + // non-termnals for now. Tabulate the PDT states that correspond to + // the start and final states of the input Fsts. + for (StateId soff = 0; !non_term_queue.empty(); soff = ofst->NumStates()) { + Label label = non_term_queue.front(); + non_term_queue.pop_front(); + size_t fst_id = label2id[label]; + + const Fst<Arc> *ifst = ifst_array[fst_id].second; + for (StateIterator< Fst<Arc> > siter(*ifst); + !siter.Done(); siter.Next()) { + StateId is = siter.Value(); + StateId os = ofst->AddState(); + if (is == ifst->Start()) { + fst_start[fst_id] = os; + if (label == root) + ofst->SetStart(os); + } + if (ifst->Final(is) != Weight::Zero()) { + if (label == root) + ofst->SetFinal(os, ifst->Final(is)); + fst_final[fst_id].push_back(make_pair(os, ifst->Final(is))); + } + for (ArcIterator< Fst<Arc> > aiter(*ifst, is); + !aiter.Done(); aiter.Next()) { + Arc arc = aiter.Value(); + if (max_label == kNoLabel || arc.olabel > max_label) + max_label = arc.olabel; + typename unordered_map<Label, size_t>::const_iterator it = + label2id.find(arc.olabel); + if (it != label2id.end()) { + size_t nfst_id = it->second; + if (ifst_array[nfst_id].second->Start() == -1) + continue; + if (non_term_set.count(arc.olabel) == 0) { + non_term_queue.push_back(arc.olabel); + non_term_set.insert(arc.olabel); + } + } + arc.nextstate += soff; + ofst->AddArc(os, arc); + } + } + } + + // Changes each non-terminal transition to an open parenthesis + // transition redirected to the PDT state that corresponds to the + // start state of the input FST for the non-terminal. Adds close parenthesis + // transitions from the PDT states corr. to the final states of the + // input FST for the non-terminal to the former destination state of the + // non-terminal transition. + + typedef MutableArcIterator< MutableFst<Arc> > MIter; + typedef unordered_map<pair<size_t, StateId >, size_t, + ReplaceParenHash<StateId> > ParenMap; + + // Parenthesis pair ID per fst, state pair. + ParenMap paren_map; + // # of parenthesis pairs per fst. + vector<size_t> nparens(ifst_array.size(), 0); + // Initial open parenthesis label + Label first_paren = max_label + 1; + + for (StateIterator< Fst<Arc> > siter(*ofst); + !siter.Done(); siter.Next()) { + StateId os = siter.Value(); + MIter *aiter = new MIter(ofst, os); + for (size_t n = 0; !aiter->Done(); aiter->Next(), ++n) { + Arc arc = aiter->Value(); + typename unordered_map<Label, size_t>::const_iterator lit = + label2id.find(arc.olabel); + if (lit != label2id.end()) { + size_t nfst_id = lit->second; + + // Get parentheses. Ensures distinct parenthesis pair per + // non-terminal and destination state but otherwise reuses them. + Label open_paren = kNoLabel, close_paren = kNoLabel; + pair<size_t, StateId> paren_key(nfst_id, arc.nextstate); + typename ParenMap::const_iterator pit = paren_map.find(paren_key); + if (pit != paren_map.end()) { + size_t paren_id = pit->second; + open_paren = (*parens)[paren_id].first; + close_paren = (*parens)[paren_id].second; + } else { + size_t paren_id = nparens[nfst_id]++; + open_paren = first_paren + 2 * paren_id; + close_paren = open_paren + 1; + paren_map[paren_key] = paren_id; + if (paren_id >= parens->size()) + parens->push_back(make_pair(open_paren, close_paren)); + } + + // Sets open parenthesis. + Arc sarc(open_paren, open_paren, arc.weight, fst_start[nfst_id]); + aiter->SetValue(sarc); + + // Adds close parentheses. + for (size_t i = 0; i < fst_final[nfst_id].size(); ++i) { + pair<StateId, Weight> &p = fst_final[nfst_id][i]; + Arc farc(close_paren, close_paren, p.second, arc.nextstate); + + ofst->AddArc(p.first, farc); + if (os == p.first) { // Invalidated iterator + delete aiter; + aiter = new MIter(ofst, os); + aiter->Seek(n); + } + } + } + } + delete aiter; + } +} + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_REPLACE_H__ |