aboutsummaryrefslogtreecommitdiff
path: root/fcp/tensorflow/make_slices_selector_example_selector_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'fcp/tensorflow/make_slices_selector_example_selector_op.cc')
-rw-r--r--fcp/tensorflow/make_slices_selector_example_selector_op.cc85
1 files changed, 85 insertions, 0 deletions
diff --git a/fcp/tensorflow/make_slices_selector_example_selector_op.cc b/fcp/tensorflow/make_slices_selector_example_selector_op.cc
new file mode 100644
index 0000000..14cf627
--- /dev/null
+++ b/fcp/tensorflow/make_slices_selector_example_selector_op.cc
@@ -0,0 +1,85 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <string>
+#include <utility>
+
+#include "google/protobuf/any.pb.h"
+#include "absl/strings/str_format.h"
+#include "fcp/client/federated_select.h"
+#include "fcp/protos/plan.pb.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/op_requires.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/stringpiece.h"
+
+namespace fcp {
+
+namespace {
+
+REGISTER_OP("MakeSlicesSelectorExampleSelector")
+ .Input("served_at_id: string")
+ .Input("keys: int32")
+ .Output("serialized_proto: string")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape);
+
+class MakeSlicesSelectorExampleSelectorOp : public tensorflow::OpKernel {
+ public:
+ explicit MakeSlicesSelectorExampleSelectorOp(
+ tensorflow::OpKernelConstruction* context)
+ : OpKernel(context) {}
+ void Compute(tensorflow::OpKernelContext* context) override {
+ const tensorflow::Tensor* served_at_id_tensor;
+ OP_REQUIRES_OK(context,
+ context->input("served_at_id", &served_at_id_tensor));
+ std::string served_at_id =
+ served_at_id_tensor->scalar<tensorflow::tstring>()();
+
+ const tensorflow::Tensor* keys_tensor;
+ OP_REQUIRES_OK(context, context->input("keys", &keys_tensor));
+ tensorflow::TTypes<int32_t>::ConstFlat keys = keys_tensor->flat<int32_t>();
+
+ google::internal::federated::plan::SlicesSelector slices_selector;
+ slices_selector.set_served_at_id(std::move(served_at_id));
+ slices_selector.mutable_keys()->Reserve(keys.size());
+ for (size_t i = 0; i < keys.size(); i++) {
+ slices_selector.add_keys(keys(i));
+ }
+
+ google::internal::federated::plan::ExampleSelector example_selector;
+ example_selector.mutable_criteria()->PackFrom(slices_selector);
+ example_selector.set_collection_uri(
+ fcp::client::kFederatedSelectCollectionUri);
+ // `resumption_token` not set.
+
+ tensorflow::Tensor* output_tensor;
+ OP_REQUIRES_OK(context, context->allocate_output(0, {}, &output_tensor));
+ output_tensor->scalar<tensorflow::tstring>()() =
+ example_selector.SerializeAsString();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MakeSlicesSelectorExampleSelector").Device(tensorflow::DEVICE_CPU),
+ MakeSlicesSelectorExampleSelectorOp);
+
+} // namespace
+
+} // namespace fcp