aboutsummaryrefslogtreecommitdiff
path: root/fcp/tensorflow/task_eligibility_info_ops.cc
blob: 3ca968c0fa9222dd74ea476ef7c586c3b1b304b7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
/*
 * Copyright 2021 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <string>

#include "fcp/protos/federated_api.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/public/version.h"

namespace fcp {

using ::google::internal::federatedml::v2::TaskEligibilityInfo;
using ::google::internal::federatedml::v2::TaskWeight;

/**
 * CreateTaskEligibilityInfo op-kernel. Converts a set of input tensors into a
 * `TaskEligibilityInfo` proto serialized into a string tensor.
 *
 * This op is used to generate `TaskEligibilityInfo` protos from a model at
 * runtime, since TF Mobile does not support the standard TensorFlow ops for
 * encoding/decoding protos.
 */
class CreateTaskEligibilityInfoOp : public tensorflow::OpKernel {
 public:
  explicit CreateTaskEligibilityInfoOp(
      tensorflow::OpKernelConstruction* context)
      : OpKernel(context) {}

  void Compute(tensorflow::OpKernelContext* ctx) override {
    // Note: We use the tensorflow::data::ParseScalar/VectorArgument helpers
    // here, even though this op isn't strictly related to our tf.Dataset
    // integration. The helpers are public though, and we already use them in
    // our ExternalDataset implementation, so we might as well use them here
    // too.

    // Parse/validate the input arguments.
    tensorflow::int64 version;
    OP_REQUIRES_OK(
        ctx, tensorflow::data::ParseScalarArgument(ctx, "version", &version));
    std::vector<tensorflow::tstring> task_names;
    OP_REQUIRES_OK(ctx, tensorflow::data::ParseVectorArgument(ctx, "task_names",
                                                              &task_names));
    std::vector<float> task_weights;
    OP_REQUIRES_OK(ctx, tensorflow::data::ParseVectorArgument(
                            ctx, "task_weights", &task_weights));
    OP_REQUIRES(ctx, task_names.size() == task_weights.size(),
                tensorflow::errors::InvalidArgument(absl::StrCat(
                    "task_names length must match task_weights length: ",
                    task_names.size(), " vs. ", task_weights.size())));

    // Create the output proto, based on the inputs.
    TaskEligibilityInfo eligibility_info;
    eligibility_info.set_version(version);
    // Create a `TaskWeight` message for each pair of `task_names` and
    // `task_weights` elements.
    auto task_weight_it = task_weights.cbegin();
    for (const tensorflow::tstring& task_name : task_names) {
      float task_weight = *task_weight_it++;
      TaskWeight* task_weight_proto = eligibility_info.add_task_weights();
      task_weight_proto->set_task_name(std::string(task_name));
      task_weight_proto->set_weight(task_weight);
    }

    // Place the serialized output proto into the output tensor.
    tensorflow::Tensor* output_tensor;
    OP_REQUIRES_OK(ctx,
                   ctx->allocate_output("output", tensorflow::TensorShape({}),
                                        &output_tensor));
    output_tensor->scalar<tensorflow::tstring>()() =
        eligibility_info.SerializeAsString();
  }
};

REGISTER_OP("CreateTaskEligibilityInfo")
    .Input("version: int64")
    .Input("task_names: string")
    .Input("task_weights: float32")
    .Output("output: string")
    .SetShapeFn(tensorflow::shape_inference::ScalarShape);

REGISTER_KERNEL_BUILDER(
    Name("CreateTaskEligibilityInfo").Device(tensorflow::DEVICE_CPU),
    CreateTaskEligibilityInfoOp);

}  // namespace fcp