aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/extensions/pdt/pdtscript.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/extensions/pdt/pdtscript.h')
-rw-r--r--src/include/fst/extensions/pdt/pdtscript.h284
1 files changed, 284 insertions, 0 deletions
diff --git a/src/include/fst/extensions/pdt/pdtscript.h b/src/include/fst/extensions/pdt/pdtscript.h
new file mode 100644
index 0000000..c2a1cf4
--- /dev/null
+++ b/src/include/fst/extensions/pdt/pdtscript.h
@@ -0,0 +1,284 @@
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+// Copyright 2005-2010 Google, Inc.
+// Author: jpr@google.com (Jake Ratkiewicz)
+// Convenience file for including all PDT operations at once, and/or
+// registering them for new arc types.
+
+#ifndef FST_EXTENSIONS_PDT_PDTSCRIPT_H_
+#define FST_EXTENSIONS_PDT_PDTSCRIPT_H_
+
+#include <utility>
+using std::pair; using std::make_pair;
+#include <vector>
+using std::vector;
+
+#include <fst/compose.h> // for ComposeOptions
+#include <fst/util.h>
+
+#include <fst/script/fst-class.h>
+#include <fst/script/arg-packs.h>
+#include <fst/script/shortest-path.h>
+
+#include <fst/extensions/pdt/compose.h>
+#include <fst/extensions/pdt/expand.h>
+#include <fst/extensions/pdt/info.h>
+#include <fst/extensions/pdt/replace.h>
+#include <fst/extensions/pdt/reverse.h>
+#include <fst/extensions/pdt/shortest-path.h>
+
+
+namespace fst {
+namespace script {
+
+// PDT COMPOSE
+
+typedef args::Package<const FstClass &,
+ const FstClass &,
+ const vector<pair<int64, int64> >&,
+ MutableFstClass *,
+ const ComposeOptions &,
+ bool> PdtComposeArgs;
+
+template<class Arc>
+void PdtCompose(PdtComposeArgs *args) {
+ const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>());
+ const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>());
+ MutableFst<Arc> *ofst = args->arg4->GetMutableFst<Arc>();
+
+ vector<pair<typename Arc::Label, typename Arc::Label> > parens(
+ args->arg3.size());
+
+ for (size_t i = 0; i < parens.size(); ++i) {
+ parens[i].first = args->arg3[i].first;
+ parens[i].second = args->arg3[i].second;
+ }
+
+ if (args->arg6) {
+ Compose(ifst1, parens, ifst2, ofst, args->arg5);
+ } else {
+ Compose(ifst1, ifst2, parens, ofst, args->arg5);
+ }
+}
+
+void PdtCompose(const FstClass & ifst1,
+ const FstClass & ifst2,
+ const vector<pair<int64, int64> > &parens,
+ MutableFstClass *ofst,
+ const ComposeOptions &copts,
+ bool left_pdt);
+
+// PDT EXPAND
+
+struct PdtExpandOptions {
+ bool connect;
+ bool keep_parentheses;
+ WeightClass weight_threshold;
+
+ PdtExpandOptions(bool c = true, bool k = false,
+ WeightClass w = WeightClass::Zero())
+ : connect(c), keep_parentheses(k), weight_threshold(w) {}
+};
+
+typedef args::Package<const FstClass &,
+ const vector<pair<int64, int64> >&,
+ MutableFstClass *, PdtExpandOptions> PdtExpandArgs;
+
+template<class Arc>
+void PdtExpand(PdtExpandArgs *args) {
+ const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
+ MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
+
+ vector<pair<typename Arc::Label, typename Arc::Label> > parens(
+ args->arg2.size());
+ for (size_t i = 0; i < parens.size(); ++i) {
+ parens[i].first = args->arg2[i].first;
+ parens[i].second = args->arg2[i].second;
+ }
+ Expand(fst, parens, ofst,
+ ExpandOptions<Arc>(
+ args->arg4.connect, args->arg4.keep_parentheses,
+ *(args->arg4.weight_threshold.GetWeight<typename Arc::Weight>())));
+}
+
+void PdtExpand(const FstClass &ifst,
+ const vector<pair<int64, int64> > &parens,
+ MutableFstClass *ofst, const PdtExpandOptions &opts);
+
+void PdtExpand(const FstClass &ifst,
+ const vector<pair<int64, int64> > &parens,
+ MutableFstClass *ofst, bool connect);
+
+// PDT REPLACE
+
+typedef args::Package<const vector<pair<int64, const FstClass*> > &,
+ MutableFstClass *,
+ vector<pair<int64, int64> > *,
+ const int64 &> PdtReplaceArgs;
+template<class Arc>
+void PdtReplace(PdtReplaceArgs *args) {
+ vector<pair<typename Arc::Label, const Fst<Arc> *> > tuples(
+ args->arg1.size());
+ for (size_t i = 0; i < tuples.size(); ++i) {
+ tuples[i].first = args->arg1[i].first;
+ tuples[i].second = (args->arg1[i].second)->GetFst<Arc>();
+ }
+ MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>();
+ vector<pair<typename Arc::Label, typename Arc::Label> > parens(
+ args->arg3->size());
+
+ for (size_t i = 0; i < parens.size(); ++i) {
+ parens[i].first = args->arg3->at(i).first;
+ parens[i].second = args->arg3->at(i).second;
+ }
+ Replace(tuples, ofst, &parens, args->arg4);
+
+ // now copy parens back
+ args->arg3->resize(parens.size());
+ for (size_t i = 0; i < parens.size(); ++i) {
+ (*args->arg3)[i].first = parens[i].first;
+ (*args->arg3)[i].second = parens[i].second;
+ }
+}
+
+void PdtReplace(const vector<pair<int64, const FstClass*> > &fst_tuples,
+ MutableFstClass *ofst,
+ vector<pair<int64, int64> > *parens,
+ const int64 &root);
+
+// PDT REVERSE
+
+typedef args::Package<const FstClass &,
+ const vector<pair<int64, int64> >&,
+ MutableFstClass *> PdtReverseArgs;
+
+template<class Arc>
+void PdtReverse(PdtReverseArgs *args) {
+ const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
+ MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
+
+ vector<pair<typename Arc::Label, typename Arc::Label> > parens(
+ args->arg2.size());
+ for (size_t i = 0; i < parens.size(); ++i) {
+ parens[i].first = args->arg2[i].first;
+ parens[i].second = args->arg2[i].second;
+ }
+ Reverse(fst, parens, ofst);
+}
+
+void PdtReverse(const FstClass &ifst,
+ const vector<pair<int64, int64> > &parens,
+ MutableFstClass *ofst);
+
+
+// PDT SHORTESTPATH
+
+struct PdtShortestPathOptions {
+ QueueType queue_type;
+ bool keep_parentheses;
+ bool path_gc;
+
+ PdtShortestPathOptions(QueueType qt = FIFO_QUEUE,
+ bool kp = false, bool gc = true)
+ : queue_type(qt), keep_parentheses(kp), path_gc(gc) {}
+};
+
+typedef args::Package<const FstClass &,
+ const vector<pair<int64, int64> >&,
+ MutableFstClass *,
+ const PdtShortestPathOptions &> PdtShortestPathArgs;
+
+template<class Arc>
+void PdtShortestPath(PdtShortestPathArgs *args) {
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Label Label;
+ typedef typename Arc::Weight Weight;
+
+ const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
+ MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
+ const PdtShortestPathOptions &opts = args->arg4;
+
+
+ vector<pair<Label, Label> > parens(args->arg2.size());
+ for (size_t i = 0; i < parens.size(); ++i) {
+ parens[i].first = args->arg2[i].first;
+ parens[i].second = args->arg2[i].second;
+ }
+
+ switch (opts.queue_type) {
+ default:
+ FSTERROR() << "Unknown queue type: " << opts.queue_type;
+ case FIFO_QUEUE: {
+ typedef FifoQueue<StateId> Queue;
+ fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
+ opts.path_gc);
+ ShortestPath(fst, parens, ofst, spopts);
+ return;
+ }
+ case LIFO_QUEUE: {
+ typedef LifoQueue<StateId> Queue;
+ fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
+ opts.path_gc);
+ ShortestPath(fst, parens, ofst, spopts);
+ return;
+ }
+ case STATE_ORDER_QUEUE: {
+ typedef StateOrderQueue<StateId> Queue;
+ fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
+ opts.path_gc);
+ ShortestPath(fst, parens, ofst, spopts);
+ return;
+ }
+ }
+}
+
+void PdtShortestPath(const FstClass &ifst,
+ const vector<pair<int64, int64> > &parens,
+ MutableFstClass *ofst,
+ const PdtShortestPathOptions &opts =
+ PdtShortestPathOptions());
+
+// PRINT INFO
+
+typedef args::Package<const FstClass &,
+ const vector<pair<int64, int64> > &> PrintPdtInfoArgs;
+
+template<class Arc>
+void PrintPdtInfo(PrintPdtInfoArgs *args) {
+ const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
+ vector<pair<typename Arc::Label, typename Arc::Label> > parens(
+ args->arg2.size());
+ for (size_t i = 0; i < parens.size(); ++i) {
+ parens[i].first = args->arg2[i].first;
+ parens[i].second = args->arg2[i].second;
+ }
+ PdtInfo<Arc> pdtinfo(fst, parens);
+ PrintPdtInfo(pdtinfo);
+}
+
+void PrintPdtInfo(const FstClass &ifst,
+ const vector<pair<int64, int64> > &parens);
+
+} // namespace script
+} // namespace fst
+
+
+#define REGISTER_FST_PDT_OPERATIONS(ArcType) \
+ REGISTER_FST_OPERATION(PdtCompose, ArcType, PdtComposeArgs); \
+ REGISTER_FST_OPERATION(PdtExpand, ArcType, PdtExpandArgs); \
+ REGISTER_FST_OPERATION(PdtReplace, ArcType, PdtReplaceArgs); \
+ REGISTER_FST_OPERATION(PdtReverse, ArcType, PdtReverseArgs); \
+ REGISTER_FST_OPERATION(PdtShortestPath, ArcType, PdtShortestPathArgs); \
+ REGISTER_FST_OPERATION(PrintPdtInfo, ArcType, PrintPdtInfoArgs)
+#endif // FST_EXTENSIONS_PDT_PDTSCRIPT_H_