aboutsummaryrefslogtreecommitdiff
path: root/fcp/aggregation/tensorflow/python/aggregation_protocols_test.py
blob: e32a6f5a790eb6c9763c9782d43f1a7b96ffcb78 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for aggregation_protocols."""

import tempfile
from typing import Any
from unittest import mock

from absl.testing import absltest
import tensorflow as tf

from fcp.aggregation.protocol import aggregation_protocol_messages_pb2 as apm_pb2
from fcp.aggregation.protocol import configuration_pb2
from fcp.aggregation.protocol.python import aggregation_protocol
from fcp.aggregation.tensorflow.python import aggregation_protocols
from pybind11_abseil import status


def create_client_input(tensors: dict[str, Any]) -> apm_pb2.ClientMessage:
  with tempfile.NamedTemporaryFile() as tmpfile:
    tf.raw_ops.Save(
        filename=tmpfile.name,
        tensor_names=list(tensors.keys()),
        data=list(tensors.values()))
    with open(tmpfile.name, 'rb') as f:
      return apm_pb2.ClientMessage(
          simple_aggregation=apm_pb2.ClientMessage.SimpleAggregation(
              input=apm_pb2.ClientResource(inline_bytes=f.read())))


class CallbackProxy(aggregation_protocol.AggregationProtocol.Callback):
  """A pass-through Callback that delegates to another Callback.

  This works around the issue that mock.Mock objects aren't recognized as
  Callback subclasses by pybind11.
  """

  def __init__(self,
               callback: aggregation_protocol.AggregationProtocol.Callback):
    super().__init__()
    self._callback = callback

  def OnAcceptClients(self, start_client_id: int, num_clients: int,
                      message: apm_pb2.AcceptanceMessage):
    self._callback.OnAcceptClients(start_client_id, num_clients, message)

  def OnSendServerMessage(self, client_id: int, message: apm_pb2.ServerMessage):
    self._callback.OnSendServerMessage(client_id, message)

  def OnCloseClient(self, client_id: int, diagnostic_status: status.Status):
    self._callback.OnCloseClient(client_id, diagnostic_status)

  def OnComplete(self, result: bytes):
    self._callback.OnComplete(result)

  def OnAbort(self, diagnostic_status: status.Status):
    self._callback.OnAbort(diagnostic_status)


class AggregationProtocolsTest(absltest.TestCase):

  def test_simple_aggregation_protocol(self):
    input_tensor = tf.TensorSpec((), tf.int32, 'in')
    output_tensor = tf.TensorSpec((), tf.int32, 'out')
    config = configuration_pb2.Configuration(aggregation_configs=[
        configuration_pb2.Configuration.ServerAggregationConfig(
            intrinsic_uri='federated_sum',
            intrinsic_args=[
                configuration_pb2.Configuration.ServerAggregationConfig.
                IntrinsicArg(input_tensor=input_tensor.experimental_as_proto()),
            ],
            output_tensors=[output_tensor.experimental_as_proto()],
        ),
    ])
    callback = mock.create_autospec(
        aggregation_protocol.AggregationProtocol.Callback, instance=True)

    agg_protocol = aggregation_protocols.create_simple_aggregation_protocol(
        config, CallbackProxy(callback))
    self.assertIsNotNone(agg_protocol)

    agg_protocol.Start(2)
    callback.OnAcceptClients.assert_called_once_with(mock.ANY, 2, mock.ANY)
    start_client_id = callback.OnAcceptClients.call_args.args[0]

    agg_protocol.ReceiveClientMessage(
        start_client_id, create_client_input({input_tensor.name: 3}))
    agg_protocol.ReceiveClientMessage(
        start_client_id + 1, create_client_input({input_tensor.name: 5}))
    callback.OnCloseClient.assert_has_calls([
        mock.call(start_client_id, status.Status.OkStatus()),
        mock.call(start_client_id + 1, status.Status.OkStatus()),
    ])

    agg_protocol.Complete()
    callback.OnComplete.assert_called_once()
    with tempfile.NamedTemporaryFile('wb') as tmpfile:
      tmpfile.write(callback.OnComplete.call_args.args[0])
      tmpfile.flush()
      self.assertEqual(
          tf.raw_ops.Restore(
              file_pattern=tmpfile.name,
              tensor_name=output_tensor.name,
              dt=output_tensor.dtype), 8)


if __name__ == '__main__':
  absltest.main()