aboutsummaryrefslogtreecommitdiff
path: root/platform/impl/stream_socket_posix.cc
blob: b60e82ae0d9f49eb91059d05afb9cff8ee7909f0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "platform/impl/stream_socket_posix.h"

#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/ip.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>

namespace openscreen {
namespace platform {

namespace {
constexpr int kDefaultMaxBacklogSize = 64;

// Call Select with no timeout, so that it doesn't block. Then use the result
// to determine if any connection is pending.
bool IsConnectionPending(int fd) {
  fd_set handle_set;
  FD_ZERO(&handle_set);
  FD_SET(fd, &handle_set);
  struct timeval tv {
    0
  };
  return select(fd + 1, &handle_set, nullptr, nullptr, &tv) > 0;
}
}  // namespace

StreamSocketPosix::StreamSocketPosix(IPAddress::Version version)
    : version_(version) {}

StreamSocketPosix::StreamSocketPosix(const IPEndpoint& local_endpoint)
    : version_(local_endpoint.address.version()),
      local_address_(local_endpoint) {}

StreamSocketPosix::StreamSocketPosix(SocketAddressPosix local_address,
                                     int file_descriptor)
    : handle_(file_descriptor),
      version_(local_address.version()),
      local_address_(local_address) {}

StreamSocketPosix::~StreamSocketPosix() {
  if (state_ == SocketState::kConnected) {
    Close();
  }
}

WeakPtr<StreamSocketPosix> StreamSocketPosix::GetWeakPtr() const {
  return weak_factory_.GetWeakPtr();
}

ErrorOr<std::unique_ptr<StreamSocket>> StreamSocketPosix::Accept() {
  if (!EnsureInitialized()) {
    return ReportSocketClosedError();
  }

  if (!is_bound_) {
    return CloseOnError(Error::Code::kSocketInvalidState);
  }

  // Check if any connection is pending, and return a special error code if not.
  if (!IsConnectionPending(handle_.fd)) {
    return Error::Code::kAgain;
  }

  // We copy our address to new_remote_address since it should be in the same
  // family. The accept call will overwrite it.
  SocketAddressPosix new_remote_address = local_address_.value();
  socklen_t remote_address_size = new_remote_address.size();
  const int new_file_descriptor =
      accept(handle_.fd, new_remote_address.address(), &remote_address_size);
  if (new_file_descriptor == kUnsetHandleFd) {
    return CloseOnError(Error::Code::kSocketAcceptFailure);
  }

  return ErrorOr<std::unique_ptr<StreamSocket>>(
      std::make_unique<StreamSocketPosix>(new_remote_address,
                                          new_file_descriptor));
}

Error StreamSocketPosix::Bind() {
  if (!local_address_.has_value()) {
    return CloseOnError(Error::Code::kSocketInvalidState);
  }

  if (!EnsureInitialized()) {
    return ReportSocketClosedError();
  }

  if (is_bound_) {
    return CloseOnError(Error::Code::kSocketInvalidState);
  }

  if (bind(handle_.fd, local_address_.value().address(),
           local_address_.value().size()) != 0) {
    return CloseOnError(Error::Code::kSocketBindFailure);
  }

  is_bound_ = true;
  return Error::None();
}

Error StreamSocketPosix::Close() {
  if (!EnsureInitialized()) {
    return ReportSocketClosedError();
  }

  if (state_ == SocketState::kClosed) {
    last_error_code_ = Error::Code::kSocketInvalidState;
    return Error::Code::kSocketInvalidState;
  }

  const int file_descriptor_to_close = handle_.fd;
  if (close(file_descriptor_to_close) != 0) {
    last_error_code_ = Error::Code::kSocketInvalidState;
    return Error::Code::kSocketInvalidState;
  }
  handle_.fd = kUnsetHandleFd;

  return Error::None();
}

Error StreamSocketPosix::Connect(const IPEndpoint& remote_endpoint) {
  if (!EnsureInitialized()) {
    return ReportSocketClosedError();
  }

  if (!is_initialized_ && !is_bound_) {
    return CloseOnError(Error::Code::kSocketInvalidState);
  }

  SocketAddressPosix address(remote_endpoint);
  if (connect(handle_.fd, address.address(), address.size()) != 0) {
    return CloseOnError(Error::Code::kSocketConnectFailure);
  }

  if (!is_bound_) {
    if (local_address_.has_value()) {
      return CloseOnError(Error::Code::kSocketInvalidState);
    }

    struct sockaddr address;
    socklen_t size = sizeof(address);
    if (getsockname(handle_.fd, &address, &size) != 0) {
      return CloseOnError(Error::Code::kSocketConnectFailure);
    }

    local_address_.emplace(address);
    is_bound_ = true;
  }

  remote_address_ = remote_endpoint;
  state_ = SocketState::kConnected;
  return Error::None();
}

Error StreamSocketPosix::Listen() {
  return Listen(kDefaultMaxBacklogSize);
}

Error StreamSocketPosix::Listen(int max_backlog_size) {
  if (!EnsureInitialized()) {
    return ReportSocketClosedError();
  }

  if (listen(handle_.fd, max_backlog_size) != 0) {
    return CloseOnError(Error::Code::kSocketListenFailure);
  }

  return Error::None();
}

absl::optional<IPEndpoint> StreamSocketPosix::remote_address() const {
  if ((state_ != SocketState::kConnected) || !remote_address_) {
    return absl::nullopt;
  }
  return remote_address_.value();
}

absl::optional<IPEndpoint> StreamSocketPosix::local_address() const {
  if (!local_address_.has_value()) {
    return absl::nullopt;
  }
  return local_address_.value().endpoint();
}

SocketState StreamSocketPosix::state() const {
  return state_;
}

IPAddress::Version StreamSocketPosix::version() const {
  return version_;
}

bool StreamSocketPosix::EnsureInitialized() {
  if (!is_initialized_ && (last_error_code_ == Error::Code::kNone)) {
    return Initialize() == Error::None();
  }

  return false;
}

Error StreamSocketPosix::Initialize() {
  if (is_initialized_) {
    return Error::Code::kItemAlreadyExists;
  }

  int domain;
  switch (version_) {
    case IPAddress::Version::kV4:
      domain = AF_INET;
      break;
    case IPAddress::Version::kV6:
      domain = AF_INET6;
      break;
  }

  const int file_descriptor = socket(domain, SOCK_STREAM, 0);
  if (file_descriptor == kUnsetHandleFd) {
    last_error_code_ = Error::Code::kSocketInvalidState;
    return Error::Code::kSocketInvalidState;
  }

  const int current_flags = fcntl(file_descriptor, F_GETFL, 0);
  if (fcntl(file_descriptor, F_SETFL, current_flags | O_NONBLOCK) == -1) {
    close(file_descriptor);
    last_error_code_ = Error::Code::kSocketInvalidState;
    return Error::Code::kSocketInvalidState;
  }

  handle_.fd = file_descriptor;
  is_initialized_ = true;
  // last_error_code_ should still be Error::None().
  return Error::None();
}

Error StreamSocketPosix::CloseOnError(Error::Code error_code) {
  last_error_code_ = error_code;
  Close();
  state_ = SocketState::kClosed;
  return error_code;
}

// If is_open is false, the socket has either not been initialized
// or has been closed, either on purpose or due to error.
Error StreamSocketPosix::ReportSocketClosedError() {
  last_error_code_ = Error::Code::kSocketClosedFailure;
  return Error::Code::kSocketClosedFailure;
}
}  // namespace platform
}  // namespace openscreen