diff options
Diffstat (limited to 'fcp/tensorflow/external_dataset.h')
-rw-r--r-- | fcp/tensorflow/external_dataset.h | 179 |
1 files changed, 179 insertions, 0 deletions
diff --git a/fcp/tensorflow/external_dataset.h b/fcp/tensorflow/external_dataset.h new file mode 100644 index 0000000..0b2559f --- /dev/null +++ b/fcp/tensorflow/external_dataset.h @@ -0,0 +1,179 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef FCP_TENSORFLOW_EXTERNAL_DATASET_H_ +#define FCP_TENSORFLOW_EXTERNAL_DATASET_H_ + +#include <memory> +#include <string> + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "fcp/base/bounds.h" +#include "fcp/tensorflow/host_object.h" + +namespace fcp { + +/** + * Interface for an iterator, created from a particular dataset. A single + * dataset may be used to create multiple iterators. + */ +class ExternalDatasetIterator { + public: + virtual ~ExternalDatasetIterator() = default; + + /** + * Returns the next element, if possible. Indicates end-of-stream with + * OUT_OF_RANGE, even when repeatedly called. Corresponds to + * tensorflow::data::IteratorBase::GetNext. + * + * Implementations must be thread-safe. + */ + virtual absl::StatusOr<std::string> GetNext() = 0; +}; + +namespace external_dataset_internal { + +template <typename FuncType> +class DatasetFromFunction; + +} // namespace external_dataset_internal + +/** + * Interface for a particular dataset - created from an ExternalDatasetProvider + * (during dataset op execution), for a particular selector. A dataset may be + * used zero or more times to create an ExternalDatasetIterator. + * + * Dataset implementations are often trivial, just needing to capture some + * values (like the selector) for the iterator constructor. Consider using + * ExternalDataset::FromFunction. + */ +class ExternalDataset { + public: + virtual ~ExternalDataset() = default; + + /** + * Creates a new iterator. Corresponds to + * tensorflow::data::DatasetBase::MakeIterator. + */ + virtual std::unique_ptr<ExternalDatasetIterator> MakeIterator() = 0; + + /** + * Creates an ExternalDataset that wraps a callable object 'f', implementing + * MakeIterator(). The lifetime of 'f' is that of the dataset (so, + * by-reference lambda captures are almost always unsafe here). + */ + template <typename F> + static std::unique_ptr<ExternalDataset> FromFunction(F f) { + return std::make_unique<external_dataset_internal::DatasetFromFunction<F>>( + std::move(f)); + } +}; + +/** + * Interface for an ExternalDataset op's host object. + * + * An ExternalDatasetProvider is a function from Selector -> ExternalDataset. + * Here, 'Selector' is a string provided to the dataset op (typically, an + * encoded proto). The returned ExternalDataset may be used (perhaps multiple + * times) to create an iterator. + * + * When implementing a dataset provider and the selector is a proto message, + * consider inheritng from ExternalDatasetProvider::UsingProtoSelector<T> (for + * some message type T). + */ +class ExternalDatasetProvider { + public: + virtual ~ExternalDatasetProvider() = default; + + /** + * Creates a dataset for a given selector. + * + * This function can usually be implemented succinctly, using + * ExternalDataset::FromFunction. + * + * Corresponds to tensorflow::data::DatasetOpKernel::MakeDataset. + */ + virtual absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset( + absl::string_view selector) = 0; + + /** + * Base class for dataset providers that expect a selector of a particular + * proto message type. If inheriting from UsingProtoSelector<T>, then one + * implements MakeDataset(T) instead of MakeDataset(absl::string_view). + */ + template <typename T> + class UsingProtoSelector; +}; + +/** + * HostObjectRegistry for the ExternalDataset interface. + */ +using ExternalDatasetProviderRegistry = + HostObjectRegistry<ExternalDatasetProvider>; + +namespace external_dataset_internal { + +template <typename T> +absl::StatusOr<T> TryParseProtoSelector(absl::string_view selector) { + T msg; + if (!msg.ParseFromArray(selector.data(), + CastIntegerChecked<int>(selector.size()))) { + return absl::InvalidArgumentError(absl::StrCat( + "Failed to parse selector proto of type ", msg.GetTypeName())); + } + + return msg; +} + +template <typename FuncType> +class DatasetFromFunction : public ExternalDataset { + public: + explicit DatasetFromFunction(FuncType func) : func_(std::move(func)) {} + + std::unique_ptr<ExternalDatasetIterator> MakeIterator() final { + return func_(); + } + + private: + FuncType func_; +}; + +} // namespace external_dataset_internal + +template <typename T> +class ExternalDatasetProvider::UsingProtoSelector + : public ExternalDatasetProvider { + public: + absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset( + absl::string_view selector) final { + auto maybe_msg = + external_dataset_internal::TryParseProtoSelector<T>(selector); + if (!maybe_msg.ok()) { + return maybe_msg.status(); + } + + return MakeDataset(std::move(maybe_msg).value()); + } + + virtual absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset( + T selector) = 0; +}; + +} // namespace fcp + +#endif // FCP_TENSORFLOW_EXTERNAL_DATASET_H_ |