aboutsummaryrefslogtreecommitdiff
path: root/osp/public/request_response_handler.h
blob: de783efcc0bd0360f38058ee6266b507fe6c664b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_
#define OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_

#include <cstddef>
#include <cstdint>
#include <type_traits>
#include <utility>
#include <vector>

#include "absl/types/optional.h"
#include "osp/public/message_demuxer.h"
#include "osp/public/network_service_manager.h"
#include "osp/public/protocol_connection.h"
#include "platform/base/error.h"
#include "platform/base/macros.h"
#include "util/osp_logging.h"

namespace openscreen {
namespace osp {

template <typename T>
using MessageDecodingFunction = ssize_t (*)(const uint8_t*, size_t, T*);

// Provides a uniform way of accessing import properties of a request/response
// message pair from a template: request encode function, response decode
// function, request serializable data member.
template <typename T>
struct DefaultRequestCoderTraits {
 public:
  using RequestMsgType = typename T::RequestMsgType;
  static constexpr MessageEncodingFunction<RequestMsgType> kEncoder =
      T::kEncoder;
  static constexpr MessageDecodingFunction<typename T::ResponseMsgType>
      kDecoder = T::kDecoder;

  static const RequestMsgType* serial_request(const T& data) {
    return &data.request;
  }
  static RequestMsgType* serial_request(T& data) { return &data.request; }
};

// Provides a wrapper for the common pattern of sending a request message and
// waiting for a response message with a matching |request_id| field.  It also
// handles the business of queueing messages to be sent until a protocol
// connection is available.
//
// Messages are written using WriteMessage.  This will queue messages if there
// is no protocol connection or write them immediately if there is.  When a
// matching response is received via the MessageDemuxer (taken from the global
// ProtocolConnectionClient), OnMatchedResponse is called on the provided
// Delegate object along with the original request that it matches.
template <typename RequestT,
          typename RequestCoderTraits = DefaultRequestCoderTraits<RequestT>>
class RequestResponseHandler : public MessageDemuxer::MessageCallback {
 public:
  class Delegate {
   public:
    virtual ~Delegate() = default;

    virtual void OnMatchedResponse(RequestT* request,
                                   typename RequestT::ResponseMsgType* response,
                                   uint64_t endpoint_id) = 0;
    virtual void OnError(RequestT* request, Error error) = 0;
  };

  explicit RequestResponseHandler(Delegate* delegate) : delegate_(delegate) {}
  ~RequestResponseHandler() { Reset(); }

  void Reset() {
    connection_ = nullptr;
    for (auto& message : to_send_) {
      delegate_->OnError(&message.request, Error::Code::kRequestCancelled);
    }
    to_send_.clear();
    for (auto& message : sent_) {
      delegate_->OnError(&message.request, Error::Code::kRequestCancelled);
    }
    sent_.clear();
    response_watch_ = MessageDemuxer::MessageWatch();
  }

  // Write a message to the underlying protocol connection, or queue it until
  // one is provided via SetConnection.  If |id| is provided, it can be used to
  // cancel the message via CancelMessage.
  template <typename RequestTRval>
  typename std::enable_if<
      !std::is_lvalue_reference<RequestTRval>::value &&
          std::is_same<typename std::decay<RequestTRval>::type,
                       RequestT>::value,
      Error>::type
  WriteMessage(absl::optional<uint64_t> id, RequestTRval&& message) {
    auto* request_msg = RequestCoderTraits::serial_request(message);
    if (connection_) {
      request_msg->request_id = GetNextRequestId(connection_->endpoint_id());
      Error result =
          connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder);
      if (!result.ok()) {
        return result;
      }
      sent_.emplace_back(RequestWithId{id, std::move(message)});
      EnsureResponseWatch();
    } else {
      to_send_.emplace_back(RequestWithId{id, std::move(message)});
    }
    return Error::None();
  }

  template <typename RequestTRval>
  typename std::enable_if<
      !std::is_lvalue_reference<RequestTRval>::value &&
          std::is_same<typename std::decay<RequestTRval>::type,
                       RequestT>::value,
      Error>::type
  WriteMessage(RequestTRval&& message) {
    return WriteMessage(absl::nullopt, std::move(message));
  }

  // Remove the message that was originally written with |id| from the send and
  // sent queues so that we are no longer looking for a response.
  void CancelMessage(uint64_t id) {
    to_send_.erase(std::remove_if(to_send_.begin(), to_send_.end(),
                                  [&id](const RequestWithId& msg) {
                                    return id == msg.id;
                                  }),
                   to_send_.end());
    sent_.erase(std::remove_if(
                    sent_.begin(), sent_.end(),
                    [&id](const RequestWithId& msg) { return id == msg.id; }),
                sent_.end());
    if (sent_.empty()) {
      response_watch_ = MessageDemuxer::MessageWatch();
    }
  }

  // Assign a ProtocolConnection to this handler for writing messages.
  void SetConnection(ProtocolConnection* connection) {
    connection_ = connection;
    for (auto& message : to_send_) {
      auto* request_msg = RequestCoderTraits::serial_request(message.request);
      request_msg->request_id = GetNextRequestId(connection_->endpoint_id());
      Error result =
          connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder);
      if (result.ok()) {
        sent_.emplace_back(std::move(message));
      } else {
        delegate_->OnError(&message.request, result);
      }
    }
    if (!to_send_.empty()) {
      EnsureResponseWatch();
    }
    to_send_.clear();
  }

  // MessageDemuxer::MessageCallback overrides.
  ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id,
                                  uint64_t connection_id,
                                  msgs::Type message_type,
                                  const uint8_t* buffer,
                                  size_t buffer_size,
                                  Clock::time_point now) override {
    if (message_type != RequestT::kResponseType) {
      return 0;
    }
    typename RequestT::ResponseMsgType response;
    ssize_t result =
        RequestCoderTraits::kDecoder(buffer, buffer_size, &response);
    if (result < 0) {
      return 0;
    }
    auto it = std::find_if(
        sent_.begin(), sent_.end(), [&response](const RequestWithId& msg) {
          return RequestCoderTraits::serial_request(msg.request)->request_id ==
                 response.request_id;
        });
    if (it != sent_.end()) {
      delegate_->OnMatchedResponse(&it->request, &response,
                                   connection_->endpoint_id());
      sent_.erase(it);
      if (sent_.empty()) {
        response_watch_ = MessageDemuxer::MessageWatch();
      }
    } else {
      OSP_LOG_WARN << "got response for unknown request id: "
                   << response.request_id;
    }
    return result;
  }

 private:
  struct RequestWithId {
    absl::optional<uint64_t> id;
    RequestT request;
  };

  void EnsureResponseWatch() {
    if (!response_watch_) {
      response_watch_ = NetworkServiceManager::Get()
                            ->GetProtocolConnectionClient()
                            ->message_demuxer()
                            ->WatchMessageType(connection_->endpoint_id(),
                                               RequestT::kResponseType, this);
    }
  }

  uint64_t GetNextRequestId(uint64_t endpoint_id) {
    return NetworkServiceManager::Get()
        ->GetProtocolConnectionClient()
        ->endpoint_request_ids()
        ->GetNextRequestId(endpoint_id);
  }

  ProtocolConnection* connection_ = nullptr;
  Delegate* const delegate_;
  std::vector<RequestWithId> to_send_;
  std::vector<RequestWithId> sent_;
  MessageDemuxer::MessageWatch response_watch_;

  OSP_DISALLOW_COPY_AND_ASSIGN(RequestResponseHandler);
};

}  // namespace osp
}  // namespace openscreen

#endif  // OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_