aboutsummaryrefslogtreecommitdiff
path: root/cast/common/channel/virtual_connection_router.cc
blob: 85df78d9bace2dfa9ef7613fa52202d8adc1710c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "cast/common/channel/virtual_connection_router.h"

#include "cast/common/channel/cast_message_handler.h"
#include "cast/common/channel/cast_socket.h"
#include "cast/common/channel/message_util.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
#include "cast/common/channel/virtual_connection_manager.h"
#include "util/logging.h"

namespace cast {
namespace channel {

using openscreen::Error;

VirtualConnectionRouter::VirtualConnectionRouter(
    VirtualConnectionManager* vc_manager)
    : vc_manager_(vc_manager) {
  OSP_DCHECK(vc_manager);
}

VirtualConnectionRouter::~VirtualConnectionRouter() = default;

bool VirtualConnectionRouter::AddHandlerForLocalId(
    std::string local_id,
    CastMessageHandler* endpoint) {
  return endpoints_.emplace(std::move(local_id), endpoint).second;
}

bool VirtualConnectionRouter::RemoveHandlerForLocalId(
    const std::string& local_id) {
  return endpoints_.erase(local_id) == 1u;
}

void VirtualConnectionRouter::TakeSocket(SocketErrorHandler* error_handler,
                                         std::unique_ptr<CastSocket> socket) {
  uint32_t id = socket->socket_id();
  socket->SetClient(this);
  sockets_.emplace(id, SocketWithHandler{std::move(socket), error_handler});
}

void VirtualConnectionRouter::CloseSocket(uint32_t id) {
  auto it = sockets_.find(id);
  if (it != sockets_.end()) {
    std::unique_ptr<CastSocket> socket = std::move(it->second.socket);
    SocketErrorHandler* error_handler = it->second.error_handler;
    sockets_.erase(it);
    error_handler->OnClose(socket.get());
  }
}

Error VirtualConnectionRouter::SendMessage(VirtualConnection virtual_conn,
                                           CastMessage&& message) {
  // TODO(btolsch): Check for broadcast message.
  if (!IsTransportNamespace(message.namespace_()) &&
      !vc_manager_->GetConnectionData(virtual_conn)) {
    return Error::Code::kUnknownError;
  }
  auto it = sockets_.find(virtual_conn.socket_id);
  if (it == sockets_.end()) {
    return Error::Code::kUnknownError;
  }
  message.set_source_id(std::move(virtual_conn.local_id));
  message.set_destination_id(std::move(virtual_conn.peer_id));
  return it->second.socket->SendMessage(message);
}

void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) {
  uint32_t id = socket->socket_id();
  auto it = sockets_.find(id);
  if (it != sockets_.end()) {
    std::unique_ptr<CastSocket> socket_owned = std::move(it->second.socket);
    SocketErrorHandler* error_handler = it->second.error_handler;
    sockets_.erase(it);
    error_handler->OnError(socket, error);
  }
}

void VirtualConnectionRouter::OnMessage(CastSocket* socket,
                                        CastMessage message) {
  // TODO(btolsch): Check for broadcast message.
  VirtualConnection virtual_conn{message.destination_id(), message.source_id(),
                                 socket->socket_id()};
  if (!IsTransportNamespace(message.namespace_()) &&
      !vc_manager_->GetConnectionData(virtual_conn)) {
    return;
  }
  const std::string& local_id = message.destination_id();
  auto it = endpoints_.find(local_id);
  if (it != endpoints_.end()) {
    it->second->OnMessage(this, socket, std::move(message));
  }
}

}  // namespace channel
}  // namespace cast