aboutsummaryrefslogtreecommitdiff
path: root/src/include/fst/extensions/far
diff options
context:
space:
mode:
Diffstat (limited to 'src/include/fst/extensions/far')
-rw-r--r--src/include/fst/extensions/far/extract.h119
-rw-r--r--src/include/fst/extensions/far/far.h3
-rw-r--r--src/include/fst/extensions/far/farscript.h12
-rw-r--r--src/include/fst/extensions/far/stlist.h22
4 files changed, 106 insertions, 50 deletions
diff --git a/src/include/fst/extensions/far/extract.h b/src/include/fst/extensions/far/extract.h
index d6f92ff..95866de 100644
--- a/src/include/fst/extensions/far/extract.h
+++ b/src/include/fst/extensions/far/extract.h
@@ -32,51 +32,106 @@ using std::vector;
namespace fst {
template<class Arc>
+inline void FarWriteFst(const Fst<Arc>* fst, string key,
+ string* okey, int* nrep,
+ const int32 &generate_filenames, int i,
+ const string &filename_prefix,
+ const string &filename_suffix) {
+ if (key == *okey)
+ ++*nrep;
+ else
+ *nrep = 0;
+
+ *okey = key;
+
+ string ofilename;
+ if (generate_filenames) {
+ ostringstream tmp;
+ tmp.width(generate_filenames);
+ tmp.fill('0');
+ tmp << i;
+ ofilename = tmp.str();
+ } else {
+ if (*nrep > 0) {
+ ostringstream tmp;
+ tmp << '.' << nrep;
+ key.append(tmp.str().data(), tmp.str().size());
+ }
+ ofilename = key;
+ }
+ fst->Write(filename_prefix + ofilename + filename_suffix);
+}
+
+template<class Arc>
void FarExtract(const vector<string> &ifilenames,
const int32 &generate_filenames,
- const string &begin_key,
- const string &end_key,
+ const string &keys,
+ const string &key_separator,
+ const string &range_delimiter,
const string &filename_prefix,
const string &filename_suffix) {
FarReader<Arc> *far_reader = FarReader<Arc>::Open(ifilenames);
if (!far_reader) return;
- if (!begin_key.empty())
- far_reader->Find(begin_key);
-
string okey;
int nrep = 0;
- for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) {
- string key = far_reader->GetKey();
- if (!end_key.empty() && end_key < key)
- break;
- const Fst<Arc> &fst = far_reader->GetFst();
-
- if (key == okey)
- ++nrep;
- else
- nrep = 0;
- okey = key;
-
- string ofilename;
- if (generate_filenames) {
- ostringstream tmp;
- tmp.width(generate_filenames);
- tmp.fill('0');
- tmp << i;
- ofilename = tmp.str();
- } else {
- if (nrep > 0) {
- ostringstream tmp;
- tmp << '.' << nrep;
- key.append(tmp.str().data(), tmp.str().size());
+ vector<char *> key_vector;
+ // User has specified a set of fsts to extract, where some of the "fsts" could
+ // be ranges.
+ if (!keys.empty()) {
+ char *keys_cstr = new char[keys.size()+1];
+ strcpy(keys_cstr, keys.c_str());
+ SplitToVector(keys_cstr, key_separator.c_str(), &key_vector, true);
+ int i = 0;
+ for (int k = 0; k < key_vector.size(); ++k, ++i) {
+ string key = string(key_vector[k]);
+ char *key_cstr = new char[key.size()+1];
+ strcpy(key_cstr, key.c_str());
+ vector<char *> range_vector;
+ SplitToVector(key_cstr, range_delimiter.c_str(), &range_vector, false);
+ if (range_vector.size() == 1) { // Not a range
+ if (!far_reader->Find(key)) {
+ LOG(ERROR) << "FarExtract: Cannot find key: " << key;
+ return;
+ }
+ const Fst<Arc> &fst = far_reader->GetFst();
+ FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i,
+ filename_prefix, filename_suffix);
+ } else if (range_vector.size() == 2) { // A legal range
+ string begin_key = string(range_vector[0]);
+ string end_key = string(range_vector[1]);
+ if (begin_key.empty() || end_key.empty()) {
+ LOG(ERROR) << "FarExtract: Illegal range specification: " << key;
+ return;
+ }
+ if (!far_reader->Find(begin_key)) {
+ LOG(ERROR) << "FarExtract: Cannot find key: " << begin_key;
+ return;
+ }
+ for ( ; !far_reader->Done(); far_reader->Next(), ++i) {
+ string ikey = far_reader->GetKey();
+ if (end_key < ikey) break;
+ const Fst<Arc> &fst = far_reader->GetFst();
+ FarWriteFst(&fst, ikey, &okey, &nrep, generate_filenames, i,
+ filename_prefix, filename_suffix);
+ }
+ } else {
+ LOG(ERROR) << "FarExtract: Illegal range specification: " << key;
+ return;
}
- ofilename = key;
+ delete key_cstr;
}
- fst.Write(filename_prefix + ofilename + filename_suffix);
+ delete keys_cstr;
+ return;
+ }
+ // Nothing specified: extract everything.
+ for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) {
+ string key = far_reader->GetKey();
+ const Fst<Arc> &fst = far_reader->GetFst();
+ FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i,
+ filename_prefix, filename_suffix);
}
-
return;
}
diff --git a/src/include/fst/extensions/far/far.h b/src/include/fst/extensions/far/far.h
index acce76e..737f1b8 100644
--- a/src/include/fst/extensions/far/far.h
+++ b/src/include/fst/extensions/far/far.h
@@ -273,13 +273,10 @@ FarWriter<A> *FarWriter<A>::Create(const string &filename, FarType type) {
return STListFarWriter<A>::Create(filename);
case FAR_STTABLE:
return STTableFarWriter<A>::Create(filename);
- break;
case FAR_STLIST:
return STListFarWriter<A>::Create(filename);
- break;
case FAR_FST:
return FstFarWriter<A>::Create(filename);
- break;
default:
LOG(ERROR) << "FarWriter::Create: unknown far type";
return 0;
diff --git a/src/include/fst/extensions/far/farscript.h b/src/include/fst/extensions/far/farscript.h
index 3a9c145..cfd9167 100644
--- a/src/include/fst/extensions/far/farscript.h
+++ b/src/include/fst/extensions/far/farscript.h
@@ -173,18 +173,22 @@ bool FarEqual(const string &filename1,
typedef args::Package<const vector<string> &, int32,
const string&, const string&, const string&,
- const string&> FarExtractArgs;
+ const string&, const string&> FarExtractArgs;
template<class Arc>
void FarExtract(FarExtractArgs *args) {
fst::FarExtract<Arc>(
- args->arg1, args->arg2, args->arg3, args->arg4, args->arg5, args->arg6);
+ args->arg1, args->arg2, args->arg3, args->arg4, args->arg5, args->arg6,
+ args->arg7);
}
void FarExtract(const vector<string> &ifilenames,
const string &arc_type,
- int32 generate_filenames, const string &begin_key,
- const string &end_key, const string &filename_prefix,
+ int32 generate_filenames,
+ const string &keys,
+ const string &key_separator,
+ const string &range_delimiter,
+ const string &filename_prefix,
const string &filename_suffix);
typedef args::Package<const vector<string> &, const string &,
diff --git a/src/include/fst/extensions/far/stlist.h b/src/include/fst/extensions/far/stlist.h
index 1cdc80c..ff3d98b 100644
--- a/src/include/fst/extensions/far/stlist.h
+++ b/src/include/fst/extensions/far/stlist.h
@@ -145,13 +145,13 @@ class STListReader {
ReadType(*streams_[i], &magic_number);
ReadType(*streams_[i], &file_version);
if (magic_number != kSTListMagicNumber) {
- FSTERROR() << "STListReader::STTableReader: wrong file type: "
+ FSTERROR() << "STListReader::STListReader: wrong file type: "
<< filenames[i];
error_ = true;
return;
}
if (file_version != kSTListFileVersion) {
- FSTERROR() << "STListReader::STTableReader: wrong file version: "
+ FSTERROR() << "STListReader::STListReader: wrong file version: "
<< filenames[i];
error_ = true;
return;
@@ -161,7 +161,7 @@ class STListReader {
if (!key.empty())
heap_.push(make_pair(key, i));
if (!*streams_[i]) {
- FSTERROR() << "STTableReader: error reading file: " << sources_[i];
+ FSTERROR() << "STListReader: error reading file: " << sources_[i];
error_ = true;
return;
}
@@ -170,7 +170,7 @@ class STListReader {
size_t current = heap_.top().second;
entry_ = entry_reader_(*streams_[current]);
if (!entry_ || !*streams_[current]) {
- FSTERROR() << "STTableReader: error reading entry for key: "
+ FSTERROR() << "STListReader: error reading entry for key: "
<< heap_.top().first << ", file: " << sources_[current];
error_ = true;
}
@@ -219,7 +219,7 @@ class STListReader {
heap_.pop();
ReadType(*(streams_[current]), &key);
if (!*streams_[current]) {
- FSTERROR() << "STTableReader: error reading file: "
+ FSTERROR() << "STListReader: error reading file: "
<< sources_[current];
error_ = true;
return;
@@ -233,7 +233,7 @@ class STListReader {
delete entry_;
entry_ = entry_reader_(*streams_[current]);
if (!entry_ || !*streams_[current]) {
- FSTERROR() << "STTableReader: error reading entry for key: "
+ FSTERROR() << "STListReader: error reading entry for key: "
<< heap_.top().first << ", file: " << sources_[current];
error_ = true;
}
@@ -267,8 +267,8 @@ class STListReader {
// String-type list header reading function template on the entry header
// type 'H' having a member function:
// Read(istream &strm, const string &filename);
-// Checks that 'filename' is an STTable and call the H::Read() on the last
-// entry in the STTable.
+// Checks that 'filename' is an STList and call the H::Read() on the last
+// entry in the STList.
// Does not support reading from stdin.
template <class H>
bool ReadSTListHeader(const string &filename, H *header) {
@@ -281,18 +281,18 @@ bool ReadSTListHeader(const string &filename, H *header) {
ReadType(strm, &magic_number);
ReadType(strm, &file_version);
if (magic_number != kSTListMagicNumber) {
- LOG(ERROR) << "ReadSTTableHeader: wrong file type: " << filename;
+ LOG(ERROR) << "ReadSTListHeader: wrong file type: " << filename;
return false;
}
if (file_version != kSTListFileVersion) {
- LOG(ERROR) << "ReadSTTableHeader: wrong file version: " << filename;
+ LOG(ERROR) << "ReadSTListHeader: wrong file version: " << filename;
return false;
}
string key;
ReadType(strm, &key);
header->Read(strm, filename + ":" + key);
if (!strm) {
- LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename;
+ LOG(ERROR) << "ReadSTListHeader: error reading file: " << filename;
return false;
}
return true;