diff options
Diffstat (limited to 'src/include/fst/extensions/far')
-rw-r--r-- | src/include/fst/extensions/far/extract.h | 119 | ||||
-rw-r--r-- | src/include/fst/extensions/far/far.h | 3 | ||||
-rw-r--r-- | src/include/fst/extensions/far/farscript.h | 12 | ||||
-rw-r--r-- | src/include/fst/extensions/far/stlist.h | 22 |
4 files changed, 106 insertions, 50 deletions
diff --git a/src/include/fst/extensions/far/extract.h b/src/include/fst/extensions/far/extract.h index d6f92ff..95866de 100644 --- a/src/include/fst/extensions/far/extract.h +++ b/src/include/fst/extensions/far/extract.h @@ -32,51 +32,106 @@ using std::vector; namespace fst { template<class Arc> +inline void FarWriteFst(const Fst<Arc>* fst, string key, + string* okey, int* nrep, + const int32 &generate_filenames, int i, + const string &filename_prefix, + const string &filename_suffix) { + if (key == *okey) + ++*nrep; + else + *nrep = 0; + + *okey = key; + + string ofilename; + if (generate_filenames) { + ostringstream tmp; + tmp.width(generate_filenames); + tmp.fill('0'); + tmp << i; + ofilename = tmp.str(); + } else { + if (*nrep > 0) { + ostringstream tmp; + tmp << '.' << nrep; + key.append(tmp.str().data(), tmp.str().size()); + } + ofilename = key; + } + fst->Write(filename_prefix + ofilename + filename_suffix); +} + +template<class Arc> void FarExtract(const vector<string> &ifilenames, const int32 &generate_filenames, - const string &begin_key, - const string &end_key, + const string &keys, + const string &key_separator, + const string &range_delimiter, const string &filename_prefix, const string &filename_suffix) { FarReader<Arc> *far_reader = FarReader<Arc>::Open(ifilenames); if (!far_reader) return; - if (!begin_key.empty()) - far_reader->Find(begin_key); - string okey; int nrep = 0; - for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) { - string key = far_reader->GetKey(); - if (!end_key.empty() && end_key < key) - break; - const Fst<Arc> &fst = far_reader->GetFst(); - - if (key == okey) - ++nrep; - else - nrep = 0; - okey = key; - - string ofilename; - if (generate_filenames) { - ostringstream tmp; - tmp.width(generate_filenames); - tmp.fill('0'); - tmp << i; - ofilename = tmp.str(); - } else { - if (nrep > 0) { - ostringstream tmp; - tmp << '.' << nrep; - key.append(tmp.str().data(), tmp.str().size()); + vector<char *> key_vector; + // User has specified a set of fsts to extract, where some of the "fsts" could + // be ranges. + if (!keys.empty()) { + char *keys_cstr = new char[keys.size()+1]; + strcpy(keys_cstr, keys.c_str()); + SplitToVector(keys_cstr, key_separator.c_str(), &key_vector, true); + int i = 0; + for (int k = 0; k < key_vector.size(); ++k, ++i) { + string key = string(key_vector[k]); + char *key_cstr = new char[key.size()+1]; + strcpy(key_cstr, key.c_str()); + vector<char *> range_vector; + SplitToVector(key_cstr, range_delimiter.c_str(), &range_vector, false); + if (range_vector.size() == 1) { // Not a range + if (!far_reader->Find(key)) { + LOG(ERROR) << "FarExtract: Cannot find key: " << key; + return; + } + const Fst<Arc> &fst = far_reader->GetFst(); + FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i, + filename_prefix, filename_suffix); + } else if (range_vector.size() == 2) { // A legal range + string begin_key = string(range_vector[0]); + string end_key = string(range_vector[1]); + if (begin_key.empty() || end_key.empty()) { + LOG(ERROR) << "FarExtract: Illegal range specification: " << key; + return; + } + if (!far_reader->Find(begin_key)) { + LOG(ERROR) << "FarExtract: Cannot find key: " << begin_key; + return; + } + for ( ; !far_reader->Done(); far_reader->Next(), ++i) { + string ikey = far_reader->GetKey(); + if (end_key < ikey) break; + const Fst<Arc> &fst = far_reader->GetFst(); + FarWriteFst(&fst, ikey, &okey, &nrep, generate_filenames, i, + filename_prefix, filename_suffix); + } + } else { + LOG(ERROR) << "FarExtract: Illegal range specification: " << key; + return; } - ofilename = key; + delete key_cstr; } - fst.Write(filename_prefix + ofilename + filename_suffix); + delete keys_cstr; + return; + } + // Nothing specified: extract everything. + for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) { + string key = far_reader->GetKey(); + const Fst<Arc> &fst = far_reader->GetFst(); + FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i, + filename_prefix, filename_suffix); } - return; } diff --git a/src/include/fst/extensions/far/far.h b/src/include/fst/extensions/far/far.h index acce76e..737f1b8 100644 --- a/src/include/fst/extensions/far/far.h +++ b/src/include/fst/extensions/far/far.h @@ -273,13 +273,10 @@ FarWriter<A> *FarWriter<A>::Create(const string &filename, FarType type) { return STListFarWriter<A>::Create(filename); case FAR_STTABLE: return STTableFarWriter<A>::Create(filename); - break; case FAR_STLIST: return STListFarWriter<A>::Create(filename); - break; case FAR_FST: return FstFarWriter<A>::Create(filename); - break; default: LOG(ERROR) << "FarWriter::Create: unknown far type"; return 0; diff --git a/src/include/fst/extensions/far/farscript.h b/src/include/fst/extensions/far/farscript.h index 3a9c145..cfd9167 100644 --- a/src/include/fst/extensions/far/farscript.h +++ b/src/include/fst/extensions/far/farscript.h @@ -173,18 +173,22 @@ bool FarEqual(const string &filename1, typedef args::Package<const vector<string> &, int32, const string&, const string&, const string&, - const string&> FarExtractArgs; + const string&, const string&> FarExtractArgs; template<class Arc> void FarExtract(FarExtractArgs *args) { fst::FarExtract<Arc>( - args->arg1, args->arg2, args->arg3, args->arg4, args->arg5, args->arg6); + args->arg1, args->arg2, args->arg3, args->arg4, args->arg5, args->arg6, + args->arg7); } void FarExtract(const vector<string> &ifilenames, const string &arc_type, - int32 generate_filenames, const string &begin_key, - const string &end_key, const string &filename_prefix, + int32 generate_filenames, + const string &keys, + const string &key_separator, + const string &range_delimiter, + const string &filename_prefix, const string &filename_suffix); typedef args::Package<const vector<string> &, const string &, diff --git a/src/include/fst/extensions/far/stlist.h b/src/include/fst/extensions/far/stlist.h index 1cdc80c..ff3d98b 100644 --- a/src/include/fst/extensions/far/stlist.h +++ b/src/include/fst/extensions/far/stlist.h @@ -145,13 +145,13 @@ class STListReader { ReadType(*streams_[i], &magic_number); ReadType(*streams_[i], &file_version); if (magic_number != kSTListMagicNumber) { - FSTERROR() << "STListReader::STTableReader: wrong file type: " + FSTERROR() << "STListReader::STListReader: wrong file type: " << filenames[i]; error_ = true; return; } if (file_version != kSTListFileVersion) { - FSTERROR() << "STListReader::STTableReader: wrong file version: " + FSTERROR() << "STListReader::STListReader: wrong file version: " << filenames[i]; error_ = true; return; @@ -161,7 +161,7 @@ class STListReader { if (!key.empty()) heap_.push(make_pair(key, i)); if (!*streams_[i]) { - FSTERROR() << "STTableReader: error reading file: " << sources_[i]; + FSTERROR() << "STListReader: error reading file: " << sources_[i]; error_ = true; return; } @@ -170,7 +170,7 @@ class STListReader { size_t current = heap_.top().second; entry_ = entry_reader_(*streams_[current]); if (!entry_ || !*streams_[current]) { - FSTERROR() << "STTableReader: error reading entry for key: " + FSTERROR() << "STListReader: error reading entry for key: " << heap_.top().first << ", file: " << sources_[current]; error_ = true; } @@ -219,7 +219,7 @@ class STListReader { heap_.pop(); ReadType(*(streams_[current]), &key); if (!*streams_[current]) { - FSTERROR() << "STTableReader: error reading file: " + FSTERROR() << "STListReader: error reading file: " << sources_[current]; error_ = true; return; @@ -233,7 +233,7 @@ class STListReader { delete entry_; entry_ = entry_reader_(*streams_[current]); if (!entry_ || !*streams_[current]) { - FSTERROR() << "STTableReader: error reading entry for key: " + FSTERROR() << "STListReader: error reading entry for key: " << heap_.top().first << ", file: " << sources_[current]; error_ = true; } @@ -267,8 +267,8 @@ class STListReader { // String-type list header reading function template on the entry header // type 'H' having a member function: // Read(istream &strm, const string &filename); -// Checks that 'filename' is an STTable and call the H::Read() on the last -// entry in the STTable. +// Checks that 'filename' is an STList and call the H::Read() on the last +// entry in the STList. // Does not support reading from stdin. template <class H> bool ReadSTListHeader(const string &filename, H *header) { @@ -281,18 +281,18 @@ bool ReadSTListHeader(const string &filename, H *header) { ReadType(strm, &magic_number); ReadType(strm, &file_version); if (magic_number != kSTListMagicNumber) { - LOG(ERROR) << "ReadSTTableHeader: wrong file type: " << filename; + LOG(ERROR) << "ReadSTListHeader: wrong file type: " << filename; return false; } if (file_version != kSTListFileVersion) { - LOG(ERROR) << "ReadSTTableHeader: wrong file version: " << filename; + LOG(ERROR) << "ReadSTListHeader: wrong file version: " << filename; return false; } string key; ReadType(strm, &key); header->Read(strm, filename + ":" + key); if (!strm) { - LOG(ERROR) << "ReadSTTableHeader: error reading file: " << filename; + LOG(ERROR) << "ReadSTListHeader: error reading file: " << filename; return false; } return true; |